OTB  9.0.0
Orfeo Toolbox
otbMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_h
22 #define otbMachineLearningModel_h
23 
24 #include "itkObject.h"
25 #include "itkListSample.h"
27 
28 namespace otb
29 {
30 
69 template <class TInputValue, class TTargetValue, class TConfidenceValue = double>
70 class ITK_EXPORT MachineLearningModel : public itk::Object
71 {
72 public:
76  typedef itk::Object Superclass;
77  typedef itk::SmartPointer<Self> Pointer;
78  typedef itk::SmartPointer<const Self> ConstPointer;
80 
85  typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
87 
92  typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
94 
98  typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
99 
100 
101  typedef itk::VariableLengthVector<double> ProbaSampleType;
102  typedef itk::Statistics::ListSample<ProbaSampleType> ProbaListSampleType;
105 
107  itkTypeMacro(MachineLearningModel, itk::Object);
109 
111  virtual void Train() = 0;
112 
119  TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const;
120 
123  itkSetMacro(Dimension, unsigned int);
124  itkGetMacro(Dimension, unsigned int);
126 
127 
136  typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType* input, ConfidenceListSampleType* quality = nullptr,
137  ProbaListSampleType* proba = nullptr) const;
138 
141 
143  virtual void Save(const std::string& filename, const std::string& name = "") = 0;
144 
146  virtual void Load(const std::string& filename, const std::string& name = "") = 0;
148 
151 
153  virtual bool CanReadFile(const std::string&) = 0;
154 
156  virtual bool CanWriteFile(const std::string&) = 0;
158 
160  bool HasConfidenceIndex() const
161  {
162  return m_ConfidenceIndex;
163  }
164 
166  bool HasProbaIndex() const
167  {
168  return m_ProbaIndex;
169  }
170 
173  itkSetObjectMacro(InputListSample, InputListSampleType);
174  itkGetObjectMacro(InputListSample, InputListSampleType);
175  itkGetConstObjectMacro(InputListSample, InputListSampleType);
177 
178 
181 
183  itkSetObjectMacro(TargetListSample, TargetListSampleType);
184 
186  itkGetObjectMacro(TargetListSample, TargetListSampleType);
188 
189  itkGetObjectMacro(ConfidenceListSample, ConfidenceListSampleType);
190 
193  itkGetMacro(RegressionMode, bool);
194  void SetRegressionMode(bool flag);
196 
197 
198 protected:
201 
203  ~MachineLearningModel() override = default;
204 
206  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
207 
209  typename InputListSampleType::Pointer m_InputListSample;
210 
212  typename InputListSampleType::Pointer m_ValidationListSample;
213 
215  typename TargetListSampleType::Pointer m_TargetListSample;
216 
217  typename ConfidenceListSampleType::Pointer m_ConfidenceListSample;
218 
221 
226 
229 
232 
235 
237  unsigned int m_Dimension;
238 
239 private:
255  virtual void DoPredictBatch(const InputListSampleType* input, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType* target,
256  ConfidenceListSampleType* quality = nullptr, ProbaListSampleType* proba = nullptr) const;
257 
264  virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const = 0;
265 
266  MachineLearningModel(const Self&) = delete;
267  void operator=(const Self&) = delete;
268 };
269 } // end namespace otb
270 
271 #ifndef OTB_MANUAL_INSTANTIATION
273 #endif
274 
275 #endif
otb::MachineLearningModel::TargetListSampleType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
Definition: otbMachineLearningModel.h:92
otb::MachineLearningModel::m_IsDoPredictBatchMultiThreaded
bool m_IsDoPredictBatchMultiThreaded
Definition: otbMachineLearningModel.h:234
otb::MachineLearningModel::m_ProbaIndex
bool m_ProbaIndex
Definition: otbMachineLearningModel.h:231
otb::MachineLearningModel::Self
MachineLearningModel Self
Definition: otbMachineLearningModel.h:75
otb::MachineLearningModel::m_ConfidenceListSample
ConfidenceListSampleType::Pointer m_ConfidenceListSample
Definition: otbMachineLearningModel.h:217
otb::MachineLearningModel::HasProbaIndex
bool HasProbaIndex() const
Definition: otbMachineLearningModel.h:166
otb::MachineLearningModel::ProbaListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
Definition: otbMachineLearningModel.h:102
otb::MachineLearningModel::InputListSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
Definition: otbMachineLearningModel.h:85
otb::MachineLearningModel::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb::MachineLearningModel::ConfidenceValueType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
Definition: otbMachineLearningModel.h:96
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbMachineLearningModelTraits.h
otb::MachineLearningModel::ProbaSampleType
itk::VariableLengthVector< double > ProbaSampleType
Definition: otbMachineLearningModel.h:101
otb::MachineLearningModel::m_InputListSample
InputListSampleType::Pointer m_InputListSample
Definition: otbMachineLearningModel.h:209
otb::MachineLearningModel::m_IsRegressionSupported
bool m_IsRegressionSupported
Definition: otbMachineLearningModel.h:225
otb::MachineLearningModel::ConfidenceListSampleType
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
Definition: otbMachineLearningModel.h:98
otb::MachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbMachineLearningModel.h:78
otb::MachineLearningModel::Superclass
itk::Object Superclass
Definition: otbMachineLearningModel.h:76
otb::MachineLearningModel::HasConfidenceIndex
bool HasConfidenceIndex() const
Definition: otbMachineLearningModel.h:160
otb::MachineLearningModel::m_RegressionMode
bool m_RegressionMode
Definition: otbMachineLearningModel.h:220
otb::MachineLearningModel::ConfidenceSampleType
MLMTargetTraits< TConfidenceValue >::SampleType ConfidenceSampleType
Definition: otbMachineLearningModel.h:97
otb::MLMSampleTraitsImpl
Definition: otbMachineLearningModelTraits.h:46
otb::MachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbMachineLearningModel.h:77
otbMachineLearningModel.hxx
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otb::MachineLearningModel::m_ValidationListSample
InputListSampleType::Pointer m_ValidationListSample
Definition: otbMachineLearningModel.h:212
otb::MachineLearningModel::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otb::MachineLearningModel::m_Dimension
unsigned int m_Dimension
Definition: otbMachineLearningModel.h:237
otb::MLMTargetTraitsImpl
Definition: otbMachineLearningModelTraits.h:86
otb::MachineLearningModel::m_TargetListSample
TargetListSampleType::Pointer m_TargetListSample
Definition: otbMachineLearningModel.h:215
otb::MachineLearningModel::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228
otb::MachineLearningModel::InputSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
Definition: otbMachineLearningModel.h:84
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91