OTB  9.0.0
Orfeo Toolbox
otbNeuralNetworkMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbNeuralNetworkMachineLearningModel_h
22 #define otbNeuralNetworkMachineLearningModel_h
23 
24 #include "otbRequiresOpenCVCheck.h"
25 #include "otbOpenCVUtils.h"
26 
27 #include "itkLightObject.h"
28 #include "itkFixedArray.h"
30 
31 namespace otb
32 {
33 template <class TInputValue, class TTargetValue>
34 class ITK_EXPORT NeuralNetworkMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
35 {
36 public:
40  typedef itk::SmartPointer<Self> Pointer;
41  typedef itk::SmartPointer<const Self> ConstPointer;
42 
44  typedef typename Superclass::InputSampleType InputSampleType;
48  typedef typename Superclass::TargetListSampleType TargetListSampleType;
49  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
50  typedef typename Superclass::ProbaSampleType ProbaSampleType;
51  typedef std::map<TargetValueType, unsigned int> MapOfLabelsType;
52 
54  itkNewMacro(Self);
57 
65  itkGetMacro(TrainMethod, int);
66  itkSetMacro(TrainMethod, int);
68 
74  void SetLayerSizes(const std::vector<unsigned int> layers);
75 
76 
85  itkGetMacro(ActivateFunction, int);
86  itkSetMacro(ActivateFunction, int);
88 
93  itkGetMacro(Alpha, double);
94  itkSetMacro(Alpha, double);
96 
101  itkGetMacro(Beta, double);
102  itkSetMacro(Beta, double);
104 
110  itkGetMacro(BackPropDWScale, double);
111  itkSetMacro(BackPropDWScale, double);
113 
121  itkGetMacro(BackPropMomentScale, double);
122  itkSetMacro(BackPropMomentScale, double);
124 
129  itkGetMacro(RegPropDW0, double);
130  itkSetMacro(RegPropDW0, double);
132 
137  itkGetMacro(RegPropDWMin, double);
138  itkSetMacro(RegPropDWMin, double);
140 
146  itkGetMacro(TermCriteriaType, int);
147  itkSetMacro(TermCriteriaType, int);
149 
154  itkGetMacro(MaxIter, int);
155  itkSetMacro(MaxIter, int);
157 
162  itkGetMacro(Epsilon, double);
163  itkSetMacro(Epsilon, double);
165 
167  void Train() override;
168 
170  void Save(const std::string& filename, const std::string& name = "") override;
171 
173  void Load(const std::string& filename, const std::string& name = "") override;
174 
177 
179  bool CanReadFile(const std::string&) override;
180 
182  bool CanWriteFile(const std::string&) override;
184 
185 protected:
188 
190  ~NeuralNetworkMachineLearningModel() override = default;
191 
193  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
194 
195  void LabelsToMat(const TargetListSampleType* listSample, cv::Mat& output);
196 
198  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
199 
200 private:
201  NeuralNetworkMachineLearningModel(const Self&) = delete;
202  void operator=(const Self&) = delete;
203 
204  void CreateNetwork();
205  void SetupNetworkAndTrain(cv::Mat& labels);
206  cv::Ptr<cv::ml::ANN_MLP> m_ANNModel;
209  std::vector<unsigned int> m_LayerSizes;
210 
211  double m_Alpha;
212  double m_Beta;
215  double m_RegPropDW0;
219  double m_Epsilon;
220 
223 
224 };
225 } // end namespace otb
226 
227 #ifndef OTB_MANUAL_INSTANTIATION
229 #endif
230 
231 #endif
otb::NeuralNetworkMachineLearningModel::m_Alpha
double m_Alpha
Definition: otbNeuralNetworkMachineLearningModel.h:211
otb::NeuralNetworkMachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbNeuralNetworkMachineLearningModel.h:40
otb::NeuralNetworkMachineLearningModel::m_ActivateFunction
int m_ActivateFunction
Definition: otbNeuralNetworkMachineLearningModel.h:208
otb::NeuralNetworkMachineLearningModel::Superclass
MachineLearningModel< TInputValue, TTargetValue > Superclass
Definition: otbNeuralNetworkMachineLearningModel.h:39
otb::NeuralNetworkMachineLearningModel::m_BackPropMomentScale
double m_BackPropMomentScale
Definition: otbNeuralNetworkMachineLearningModel.h:214
otb::NeuralNetworkMachineLearningModel::m_TermCriteriaType
int m_TermCriteriaType
Definition: otbNeuralNetworkMachineLearningModel.h:217
otb::NeuralNetworkMachineLearningModel::InputValueType
Superclass::InputValueType InputValueType
Definition: otbNeuralNetworkMachineLearningModel.h:43
otb::MachineLearningModel< TInputValue, TTargetValue >::InputListSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
Definition: otbMachineLearningModel.h:85
otb::NeuralNetworkMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbNeuralNetworkMachineLearningModel.h:50
otb::NeuralNetworkMachineLearningModel::m_BackPropDWScale
double m_BackPropDWScale
Definition: otbNeuralNetworkMachineLearningModel.h:213
otb::MachineLearningModel< TInputValue, TTargetValue >::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb::NeuralNetworkMachineLearningModel::m_MaxIter
int m_MaxIter
Definition: otbNeuralNetworkMachineLearningModel.h:218
otbNeuralNetworkMachineLearningModel.hxx
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::NeuralNetworkMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbNeuralNetworkMachineLearningModel.h:49
otb::NeuralNetworkMachineLearningModel::Self
NeuralNetworkMachineLearningModel Self
Definition: otbNeuralNetworkMachineLearningModel.h:38
otb::NeuralNetworkMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbNeuralNetworkMachineLearningModel.h:46
otb::NeuralNetworkMachineLearningModel::m_MatrixOfLabels
cv::Mat m_MatrixOfLabels
Definition: otbNeuralNetworkMachineLearningModel.h:221
otb::NeuralNetworkMachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbNeuralNetworkMachineLearningModel.h:41
otb::NeuralNetworkMachineLearningModel
Definition: otbNeuralNetworkMachineLearningModel.h:34
otb::NeuralNetworkMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbNeuralNetworkMachineLearningModel.h:44
otb::NeuralNetworkMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbNeuralNetworkMachineLearningModel.h:45
otb::NeuralNetworkMachineLearningModel::m_ANNModel
cv::Ptr< cv::ml::ANN_MLP > m_ANNModel
Definition: otbNeuralNetworkMachineLearningModel.h:206
otbMachineLearningModel.h
otb::NeuralNetworkMachineLearningModel::m_MapOfLabels
MapOfLabelsType m_MapOfLabels
Definition: otbNeuralNetworkMachineLearningModel.h:222
otb::NeuralNetworkMachineLearningModel::m_Epsilon
double m_Epsilon
Definition: otbNeuralNetworkMachineLearningModel.h:219
otb::NeuralNetworkMachineLearningModel::m_LayerSizes
std::vector< unsigned int > m_LayerSizes
Definition: otbNeuralNetworkMachineLearningModel.h:209
otb::NeuralNetworkMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbNeuralNetworkMachineLearningModel.h:48
otb::NeuralNetworkMachineLearningModel::MapOfLabelsType
std::map< TargetValueType, unsigned int > MapOfLabelsType
Definition: otbNeuralNetworkMachineLearningModel.h:51
otb::NeuralNetworkMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbNeuralNetworkMachineLearningModel.h:47
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otbOpenCVUtils.h
otb::MachineLearningModel::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otb::NeuralNetworkMachineLearningModel::m_Beta
double m_Beta
Definition: otbNeuralNetworkMachineLearningModel.h:212
otbRequiresOpenCVCheck.h
otb::NeuralNetworkMachineLearningModel::m_RegPropDW0
double m_RegPropDW0
Definition: otbNeuralNetworkMachineLearningModel.h:215
otb::NeuralNetworkMachineLearningModel::m_RegPropDWMin
double m_RegPropDWMin
Definition: otbNeuralNetworkMachineLearningModel.h:216
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91
otb::NeuralNetworkMachineLearningModel::m_TrainMethod
int m_TrainMethod
Definition: otbNeuralNetworkMachineLearningModel.h:207