OTB  9.0.0
Orfeo Toolbox
otbSVMMarginSampler.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMMarginSampler_hxx
22 #define otbSVMMarginSampler_hxx
23 
24 #include "otbSVMMarginSampler.h"
25 #include "otbMacro.h"
26 
27 namespace otb
28 {
29 
30 template <class TSample, class TModel>
32 {
33  m_NumberOfCandidates = 10;
34 }
35 
36 template <class TSample, class TModel>
37 void SVMMarginSampler<TSample, TModel>::PrintSelf(std::ostream& os, itk::Indent indent) const
38 {
39  Superclass::PrintSelf(os, indent);
40 }
41 
42 template <class TSample, class TModel>
44 {
45  if (!m_Model)
46  {
47  itkExceptionMacro("No model, can not do classification.");
48  }
49 
50  if (m_Model->GetNumberOfSupportVectors() == 0)
51  {
52  itkExceptionMacro(<< "SVM model does not contain any support vector, can not perform margin sampling.");
53  }
54 
55  OutputType* output = const_cast<OutputType*>(this->GetOutput());
56  output->SetSample(this->GetInput());
57 
58  this->DoMarginSampling();
59 }
60 
61 template <class TSample, class TModel>
63 {
64  IndexAndDistanceVectorType idDistVector;
65  OutputType* output = const_cast<OutputType*>(this->GetOutput());
66 
67  typename TSample::ConstIterator iter = this->GetInput()->Begin();
68  typename TSample::ConstIterator end = this->GetInput()->End();
69 
70  typename OutputType::ConstIterator iterO = output->Begin();
71  typename OutputType::ConstIterator endO = output->End();
72  typename TSample::MeasurementVectorType measurements;
73 
74  m_Model->SetConfidenceMode(TModel::CM_HYPER);
75 
76  int numberOfComponentsPerSample = iter.GetMeasurementVector().Size();
77 
78  int nbClass = static_cast<int>(m_Model->GetNumberOfClasses());
79  std::vector<double> hdistances(nbClass * (nbClass - 1) / 2);
80 
81  otbMsgDevMacro(<< "Starting iterations ");
82  while (iter != end && iterO != endO)
83  {
84  int i = 0;
85  typename SVMModelType::InputSampleType modelMeasurement(numberOfComponentsPerSample);
86 
87  measurements = iter.GetMeasurementVector();
88  // otbMsgDevMacro( << "Loop on components " << svm_type );
89  for (i = 0; i < numberOfComponentsPerSample; ++i)
90  {
91  modelMeasurement[i] = measurements[i];
92  }
93 
94  // Get distances to the hyperplanes
95  m_Model->Predict(modelMeasurement, &(hdistances[0]));
96  double minDistance = std::abs(hdistances[0]);
97 
98  // Compute th min distances
99  for (unsigned int j = 1; j < hdistances.size(); ++j)
100  {
101  if (std::abs(hdistances[j]) < minDistance)
102  {
103  minDistance = std::abs(hdistances[j]);
104  }
105  }
106  // Keep index and min distance
107  IndexAndDistanceType value(iter.GetInstanceIdentifier(), minDistance);
108  idDistVector.push_back(value);
109 
110  ++iter;
111  ++iterO;
112  }
113 
114  // Sort index by increasing distances
115  sort(idDistVector.begin(), idDistVector.end(), &Compare);
116 
117  // Display the first 10 values
118  otbMsgDevMacro(<< " Margin Sampling: ");
119 
120  // Clear previous margin samples
121  m_MarginSamples.clear();
122 
123  for (unsigned int i = 0; i < m_NumberOfCandidates && i < idDistVector.size(); ++i)
124  {
125  otbMsgDevMacro("Sample " << idDistVector[i].first << " (distance= " << idDistVector[i].second << ")") m_MarginSamples.push_back(idDistVector[i].first);
126  }
127 }
128 
129 } // end of namespace otb
130 
131 #endif
otbSVMMarginSampler.h
otb::SVMMarginSampler::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbSVMMarginSampler.hxx:37
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbMacro.h
otb::SVMMarginSampler::IndexAndDistanceVectorType
std::vector< IndexAndDistanceType > IndexAndDistanceVectorType
Definition: otbSVMMarginSampler.h:59
otb::SVMMarginSampler::SVMMarginSampler
SVMMarginSampler()
Definition: otbSVMMarginSampler.hxx:31
otb::SVMMarginSampler::DoMarginSampling
virtual void DoMarginSampling()
Definition: otbSVMMarginSampler.hxx:62
otbMsgDevMacro
#define otbMsgDevMacro(x)
Definition: otbMacro.h:64
otb::SVMMarginSampler::IndexAndDistanceType
std::pair< unsigned int, double > IndexAndDistanceType
Definition: otbSVMMarginSampler.h:58
otb::SVMMarginSampler::OutputType
Superclass::MembershipSampleType OutputType
Definition: otbSVMMarginSampler.h:49
otb::SVMMarginSampler::GenerateData
void GenerateData() override
Definition: otbSVMMarginSampler.hxx:43