OTB  6.7.0
Orfeo Toolbox
otbMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_h
22 #define otbMachineLearningModel_h
23 
24 #include "itkObject.h"
25 #include "itkListSample.h"
27 
28 namespace otb
29 {
30 
69 template <class TInputValue, class TTargetValue, class TConfidenceValue = double >
70 class ITK_EXPORT MachineLearningModel
71  : public itk::Object
72 {
73 public:
81 
88 
95 
100 
101 
106 
108  itkTypeMacro(MachineLearningModel, itk::Object);
110 
112  virtual void Train() =0;
113 
120  TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba = nullptr) const;
121 
124  itkSetMacro(Dimension,unsigned int);
125  itkGetMacro(Dimension,unsigned int);
127 
128 
137  typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
138 
141 
143  virtual void Save(const std::string & filename, const std::string & name="") = 0;
144 
146  virtual void Load(const std::string & filename, const std::string & name="") = 0;
148 
151 
153  virtual bool CanReadFile(const std::string &) = 0;
154 
156  virtual bool CanWriteFile(const std::string &) = 0;
158 
160  bool HasConfidenceIndex() const {return m_ConfidenceIndex;}
161 
163  bool HasProbaIndex() const {return m_ProbaIndex;}
164 
167  itkSetObjectMacro(InputListSample,InputListSampleType);
168  itkGetObjectMacro(InputListSample,InputListSampleType);
169  itkGetConstObjectMacro(InputListSample,InputListSampleType);
171 
172 
175 
177  itkSetObjectMacro(TargetListSample,TargetListSampleType);
178 
180  itkGetObjectMacro(TargetListSample,TargetListSampleType);
182 
183  itkGetObjectMacro(ConfidenceListSample,ConfidenceListSampleType);
184 
187  itkGetMacro(RegressionMode,bool);
188  void SetRegressionMode(bool flag);
190 
191 
192 protected:
195 
197  ~MachineLearningModel() override;
198 
200  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
201 
204 
207 
210 
212 
215 
220 
223 
226 
229 
231  unsigned int m_Dimension;
232 
233 private:
249  virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
250 
257  virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr, ProbaSampleType *proba=nullptr) const = 0;
258 
259  MachineLearningModel(const Self &) = delete;
260  void operator =(const Self&) = delete;
261 };
262 } // end namespace otb
263 
264 #ifndef OTB_MANUAL_INSTANTIATION
266 #endif
267 
268 #endif
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
MLMTargetTraits< TConfidenceValue >::SampleType ConfidenceSampleType
itk::VariableLengthVector< double > ProbaSampleType
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests...
ConfidenceListSampleType::Pointer m_ConfidenceListSample
InputListSampleType::Pointer m_InputListSample
itk::SmartPointer< const Self > ConstPointer
itk::SmartPointer< Self > Pointer
itk::Statistics::ListSample< InputSampleType > InputListSampleType
InputListSampleType::Pointer m_ValidationListSample
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
MLMSampleTraits< TInputValue >::ValueType InputValueType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
TargetListSampleType::Pointer m_TargetListSample
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType