OTB  6.7.0
Orfeo Toolbox
otbSVMMarginSampler.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMMarginSampler_hxx
22 #define otbSVMMarginSampler_hxx
23 
24 #include "otbSVMMarginSampler.h"
25 #include "otbMacro.h"
26 
27 namespace otb
28 {
29 
30 template< class TSample, class TModel >
33 {
34  m_NumberOfCandidates = 10;
35 }
36 
37 template< class TSample, class TModel >
38 void
40 ::PrintSelf(std::ostream& os, itk::Indent indent) const
41 {
42  Superclass::PrintSelf(os, indent);
43 }
44 
45 template< class TSample, class TModel >
46 void
49 {
50  if(!m_Model)
51  {
52  itkExceptionMacro("No model, can not do classification.");
53  }
54 
55  if(m_Model->GetNumberOfSupportVectors() == 0)
56  {
57  itkExceptionMacro(<<"SVM model does not contain any support vector, can not perform margin sampling.");
58  }
59 
60  OutputType* output = const_cast<OutputType*>(this->GetOutput());
61  output->SetSample(this->GetInput());
62 
63  this->DoMarginSampling();
64 }
65 
66 template< class TSample, class TModel >
67 void
70 {
71  IndexAndDistanceVectorType idDistVector;
72  OutputType* output = const_cast<OutputType*>(this->GetOutput());
73 
74  typename TSample::ConstIterator iter = this->GetInput()->Begin();
75  typename TSample::ConstIterator end = this->GetInput()->End();
76 
77  typename OutputType::ConstIterator iterO = output->Begin();
78  typename OutputType::ConstIterator endO = output->End();
79  typename TSample::MeasurementVectorType measurements;
80 
81  m_Model->SetConfidenceMode(TModel::CM_HYPER);
82 
83  int numberOfComponentsPerSample = iter.GetMeasurementVector().Size();
84 
85  int nbClass = static_cast<int>(m_Model->GetNumberOfClasses());
86  std::vector<double> hdistances(nbClass * (nbClass - 1) / 2);
87 
88  otbMsgDevMacro( << "Starting iterations " );
89  while (iter != end && iterO != endO)
90  {
91  int i = 0;
92  typename SVMModelType::InputSampleType modelMeasurement(numberOfComponentsPerSample);
93 
94  measurements = iter.GetMeasurementVector();
95  // otbMsgDevMacro( << "Loop on components " << svm_type );
96  for (i=0; i<numberOfComponentsPerSample; ++i)
97  {
98  modelMeasurement[i] = measurements[i];
99  }
100 
101  // Get distances to the hyperplanes
102  m_Model->Predict(modelMeasurement, &(hdistances[0]));
103  double minDistance = std::abs(hdistances[0]);
104 
105  // Compute th min distances
106  for(unsigned int j = 1; j<hdistances.size(); ++j)
107  {
108  if(std::abs(hdistances[j])<minDistance)
109  {
110  minDistance = std::abs(hdistances[j]);
111  }
112  }
113  // Keep index and min distance
114  IndexAndDistanceType value(iter.GetInstanceIdentifier(), minDistance);
115  idDistVector.push_back(value);
116 
117  ++iter;
118  ++iterO;
119  }
120 
121  // Sort index by increasing distances
122  sort(idDistVector.begin(), idDistVector.end(), &Compare);
123 
124  // Display the first 10 values
125  otbMsgDevMacro( <<" Margin Sampling: " );
126 
127  // Clear previous margin samples
128  m_MarginSamples.clear();
129 
130  for(unsigned int i = 0; i<m_NumberOfCandidates && i<idDistVector.size(); ++i)
131  {
132  otbMsgDevMacro( "Sample "<<idDistVector[i].first<<" (distance= "<<idDistVector[i].second<<")" )
133  m_MarginSamples.push_back(idDistVector[i].first);
134  }
135 
136 }
137 
138 } // end of namespace otb
139 
140 #endif
std::vector< IndexAndDistanceType > IndexAndDistanceVectorType
void PrintSelf(std::ostream &os, itk::Indent indent) const override
std::pair< unsigned int, double > IndexAndDistanceType
Superclass::MembershipSampleType OutputType
#define otbMsgDevMacro(x)
Definition: otbMacro.h:66