OTB  5.0.0
Orfeo Toolbox
otbSVMSampleListModelEstimator.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 
19 #ifndef __otbSVMSampleListModelEstimator_txx
20 #define __otbSVMSampleListModelEstimator_txx
21 
23 #include "itkCommand.h"
24 #include "otbMacro.h"
25 
26 namespace otb
27 {
28 template<class TInputSampleList,
29  class TTrainingSampleList, class TMeasurementFunctor>
31 ::SVMSampleListModelEstimator(void) : SVMModelEstimator<typename TInputSampleList::MeasurementType,
32  typename TTrainingSampleList::MeasurementType>()
33 {
34  this->SetNumberOfRequiredInputs(2);
35 }
36 
37 template<class TInputSampleList,
38  class TTrainingSampleList, class TMeasurementFunctor>
41 {}
42 
43 //Set the input sample list
44 template<class TInputSampleList,
45  class TTrainingSampleList,
46  class TMeasurementFunctor>
47 void
49 ::SetInputSampleList( const InputSampleListType* inputSampleList )
50 {
51  // Process object is not const-correct so the const_cast is required here
53  const_cast< InputSampleListType* >(inputSampleList) );
54 }
55 
56 
57 // Set the label sample list
58 template<class TInputSampleList,
59  class TTrainingSampleList,
60  class TMeasurementFunctor>
61 void
64 {
65  // Process object is not const-correct so the const_cast is required here
67  const_cast<TrainingSampleListType*>(trainingSampleList) );
68 }
69 
70 // Get the input sample list
71 template<class TInputSampleList,
72  class TTrainingSampleList,
73  class TMeasurementFunctor>
74 const typename SVMSampleListModelEstimator<TInputSampleList,
75  TTrainingSampleList,
76  TMeasurementFunctor>::InputSampleListType *
79 {
80  if (this->GetNumberOfInputs() < 2)
81  {
82  return 0;
83  }
84 
85  return static_cast<const InputSampleListType* >
86  (this->itk::ProcessObject::GetInput(0) );
87 }
88 
89 // Get the input label sample list
90 template<class TInputSampleList,
91  class TTrainingSampleList,
92  class TMeasurementFunctor>
93 const typename SVMSampleListModelEstimator<TInputSampleList,
94  TTrainingSampleList,
95  TMeasurementFunctor>::TrainingSampleListType *
98 {
99  if (this->GetNumberOfInputs() < 2)
100  {
101  return 0;
102  }
103 
104  return static_cast<const TrainingSampleListType* >
105  (this->itk::ProcessObject::GetInput(1));
106 }
107 
108 
109 /*
110  * PrintSelf
111  */
112 template<class TInputSampleList,
113  class TTrainingSampleList, class TMeasurementFunctor>
114 void
116 ::PrintSelf(std::ostream& os, itk::Indent indent) const
117 {
118  Superclass::PrintSelf(os, indent);
119 } // end PrintSelf
120 
124 template<class TInputSampleList,
125  class TTrainingSampleList, class TMeasurementFunctor>
126 void
129 {
130  //Do some error checking
131  InputSampleListPointer inputSampleList = const_cast<InputSampleListType*>(this->GetInputSampleList());
132  TrainingSampleListPointer trainingSampleList = const_cast<TrainingSampleListType*>(this->GetTrainingSampleList());
133  typename Superclass::ModelType * model = this->GetModel();
135 
136  int inputSampleListSize = inputSampleList->Size();
137  int trainingSampleListSize = trainingSampleList->Size();
138 
139  // Check if size of the two inputs are same
140  if (inputSampleListSize != trainingSampleListSize)
141  {
142  /*throw itk::ExceptionObject(
143  __FILE__,
144  __LINE__,
145  "Input pointset size is not the same as the training pointset size.",
146  ITK_LOCATION); */
147  itkExceptionMacro(<< "Input pointset size is not the same as the training pointset size ("
148  << inputSampleListSize << " vs "<< trainingSampleListSize << ").");
149  }
150 
151  // Declaration of the iterators on the input and training images
152  InputSampleListIteratorType inIt = inputSampleList->Begin();
153  TrainingSampleListIteratorType trIt = trainingSampleList->Begin();
154 
155  InputSampleListIteratorType inEnd = inputSampleList->End();
156  TrainingSampleListIteratorType trEnd = trainingSampleList->End();
157 
158  // Clear previous samples
159  model->ClearSamples();
160 
161  otbMsgDebugMacro(<< " Input nb points " << inputSampleListSize);
162  otbMsgDebugMacro(<< " Training nb points " << trainingSampleListSize);
163 
164  MeasurementFunctorType mfunctor;
165  while (inIt != inEnd && trIt != trEnd)
166  {
167  typename TTrainingSampleList::MeasurementType label =
168  trIt.GetMeasurementVector()[0];
169  typename TInputSampleList::MeasurementVectorType value =
170  inIt.GetMeasurementVector();
171  model->AddSample(mfunctor(value), label);
172  ++inIt;
173  ++trIt;
174  }
175 }
176 } //End namespace OTB
177 #endif
TTrainingSampleList::Pointer TrainingSampleListPointer
virtual void PrintSelf(std::ostream &os, itk::Indent indent) const
void SetInputSampleList(const InputSampleListType *inputSampleList)
TInputSampleList::ConstIterator InputSampleListIteratorType
TTrainingSampleList::ConstIterator TrainingSampleListIteratorType
const TrainingSampleListType * GetTrainingSampleList()
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:55
Class for SVM model estimation from SampleLists used for classification.
DataObject * GetInput(const DataObjectIdentifierType &key)
Class for SVM model estimation from images used for classification.
const InputSampleListType * GetInputSampleList()
virtual void SetNthInput(DataObjectPointerArraySizeType num, DataObject *input)
void SetTrainingSampleList(const TrainingSampleListType *trainingSampleList)