OTB  9.0.0
Orfeo Toolbox
otbRandomForestsMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbRandomForestsMachineLearningModel_h
22 #define otbRandomForestsMachineLearningModel_h
23 
24 #include "otbRequiresOpenCVCheck.h"
25 
26 #include "itkLightObject.h"
27 #include "itkFixedArray.h"
29 #include "itkVariableSizeMatrix.h"
30 #include "otbCvRTreesWrapper.h"
31 
32 namespace otb
33 {
34 
35 template <class TInputValue, class TTargetValue>
36 class ITK_EXPORT RandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
37 {
38 public:
42  typedef itk::SmartPointer<Self> Pointer;
43  typedef itk::SmartPointer<const Self> ConstPointer;
44 
46  typedef typename Superclass::InputSampleType InputSampleType;
51  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
52  typedef typename Superclass::ProbaSampleType ProbaSampleType;
53  // Other
54  typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType;
55 
56 
57  // opencv typedef
59 
61  itkNewMacro(Self);
64 
66  void Train() override;
67 
69  void Save(const std::string& filename, const std::string& name = "") override;
70 
72  void Load(const std::string& filename, const std::string& name = "") override;
73 
76 
78  bool CanReadFile(const std::string&) override;
79 
81  bool CanWriteFile(const std::string&) override;
83 
84  // Setters of RT parameters (documentation get from opencv doxygen 2.4)
85  itkGetMacro(MaxDepth, int);
86  itkSetMacro(MaxDepth, int);
87 
88  itkGetMacro(MinSampleCount, int);
89  itkSetMacro(MinSampleCount, int);
90 
91  itkGetMacro(RegressionAccuracy, double);
92  itkSetMacro(RegressionAccuracy, double);
93 
94  itkGetMacro(ComputeSurrogateSplit, bool);
95  itkSetMacro(ComputeSurrogateSplit, bool);
96 
97  itkGetMacro(MaxNumberOfCategories, int);
98  itkSetMacro(MaxNumberOfCategories, int);
99 
100  std::vector<float> GetPriors() const
101  {
102  return m_Priors;
103  }
104 
105  void SetPriors(const std::vector<float>& priors)
106  {
107  m_Priors = priors;
108  }
109 
110  itkGetMacro(CalculateVariableImportance, bool);
111  itkSetMacro(CalculateVariableImportance, bool);
112 
113  itkGetMacro(MaxNumberOfVariables, int);
114  itkSetMacro(MaxNumberOfVariables, int);
115 
116  itkGetMacro(MaxNumberOfTrees, int);
117  itkSetMacro(MaxNumberOfTrees, int);
118 
119  itkGetMacro(ForestAccuracy, float);
120  itkSetMacro(ForestAccuracy, float);
121 
122  itkGetMacro(TerminationCriteria, int);
123  itkSetMacro(TerminationCriteria, int);
124 
125  itkGetMacro(ComputeMargin, bool);
126  itkSetMacro(ComputeMargin, bool);
127 
129  VariableImportanceMatrixType GetVariableImportance();
130 
131  float GetTrainError();
132 
133 protected:
136 
138  ~RandomForestsMachineLearningModel() override = default;
139 
141  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
142 
144  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
145 
146  /* /\** Input list sample *\/ */
147  /* typename InputListSampleType::Pointer m_InputListSample; */
148 
149  /* /\** Target list sample *\/ */
150  /* typename TargetListSampleType::Pointer m_TargetListSample; */
151 
152 private:
153  RandomForestsMachineLearningModel(const Self&) = delete;
154  void operator=(const Self&) = delete;
155 
156  cv::Ptr<CvRTreesWrapper> m_RFModel;
157 
162 
166 
172 
188 
203  std::vector<float> m_Priors;
204 
208 
213 
220 
223 
226 
231 };
232 } // end namespace otb
234 
235 #ifndef OTB_MANUAL_INSTANTIATION
237 #endif
238 
239 #endif
otbCvRTreesWrapper.h
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetListSampleType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
Definition: otbMachineLearningModel.h:92
otb::RandomForestsMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbRandomForestsMachineLearningModel.h:52
otb::RandomForestsMachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbRandomForestsMachineLearningModel.h:43
otb::RandomForestsMachineLearningModel::m_MaxNumberOfVariables
int m_MaxNumberOfVariables
Definition: otbRandomForestsMachineLearningModel.h:212
otb::RandomForestsMachineLearningModel::VariableImportanceMatrixType
itk::VariableSizeMatrix< float > VariableImportanceMatrixType
Definition: otbRandomForestsMachineLearningModel.h:54
otb::MachineLearningModel< TInputValue, TTargetValue >::InputListSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
Definition: otbMachineLearningModel.h:85
otb::MachineLearningModel< TInputValue, TTargetValue >::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb::RandomForestsMachineLearningModel::RFType
CvRTreesWrapper RFType
Definition: otbRandomForestsMachineLearningModel.h:58
otb::RandomForestsMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbRandomForestsMachineLearningModel.h:47
otb::RandomForestsMachineLearningModel::SetPriors
void SetPriors(const std::vector< float > &priors)
Definition: otbRandomForestsMachineLearningModel.h:105
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::RandomForestsMachineLearningModel::m_MaxDepth
int m_MaxDepth
Definition: otbRandomForestsMachineLearningModel.h:161
otb::RandomForestsMachineLearningModel::m_Priors
std::vector< float > m_Priors
Definition: otbRandomForestsMachineLearningModel.h:203
otb::RandomForestsMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbRandomForestsMachineLearningModel.h:48
otb::RandomForestsMachineLearningModel::m_CalculateVariableImportance
bool m_CalculateVariableImportance
Definition: otbRandomForestsMachineLearningModel.h:207
otb::RandomForestsMachineLearningModel::m_TerminationCriteria
int m_TerminationCriteria
Definition: otbRandomForestsMachineLearningModel.h:225
otb::RandomForestsMachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbRandomForestsMachineLearningModel.h:42
otb::RandomForestsMachineLearningModel::m_MaxNumberOfTrees
int m_MaxNumberOfTrees
Definition: otbRandomForestsMachineLearningModel.h:219
otb::RandomForestsMachineLearningModel::m_MaxNumberOfCategories
int m_MaxNumberOfCategories
Definition: otbRandomForestsMachineLearningModel.h:187
otb::RandomForestsMachineLearningModel::Self
RandomForestsMachineLearningModel Self
Definition: otbRandomForestsMachineLearningModel.h:40
otbMachineLearningModel.h
otb::RandomForestsMachineLearningModel::m_ForestAccuracy
float m_ForestAccuracy
Definition: otbRandomForestsMachineLearningModel.h:222
otb::CvRTreesWrapper
Wrapper for OpenCV Random Trees.
Definition: otbCvRTreesWrapper.h:35
otb::RandomForestsMachineLearningModel
Definition: otbRandomForestsMachineLearningModel.h:36
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otb::RandomForestsMachineLearningModel::GetPriors
std::vector< float > GetPriors() const
Definition: otbRandomForestsMachineLearningModel.h:100
otb::RandomForestsMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbRandomForestsMachineLearningModel.h:49
otb::RandomForestsMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbRandomForestsMachineLearningModel.h:46
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otb::RandomForestsMachineLearningModel::m_RFModel
cv::Ptr< CvRTreesWrapper > m_RFModel
Definition: otbRandomForestsMachineLearningModel.h:156
otb::RandomForestsMachineLearningModel::m_RegressionAccuracy
float m_RegressionAccuracy
Definition: otbRandomForestsMachineLearningModel.h:170
otb::RandomForestsMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbRandomForestsMachineLearningModel.h:50
otb::RandomForestsMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbRandomForestsMachineLearningModel.h:51
otbRequiresOpenCVCheck.h
otb::RandomForestsMachineLearningModel::m_MinSampleCount
int m_MinSampleCount
Definition: otbRandomForestsMachineLearningModel.h:165
otbRandomForestsMachineLearningModel.hxx
otb::RandomForestsMachineLearningModel::m_ComputeSurrogateSplit
bool m_ComputeSurrogateSplit
Definition: otbRandomForestsMachineLearningModel.h:171
otb::RandomForestsMachineLearningModel::m_ComputeMargin
bool m_ComputeMargin
Definition: otbRandomForestsMachineLearningModel.h:230
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91
otb::RandomForestsMachineLearningModel::Superclass
MachineLearningModel< TInputValue, TTargetValue > Superclass
Definition: otbRandomForestsMachineLearningModel.h:41
otb::RandomForestsMachineLearningModel::InputValueType
Superclass::InputValueType InputValueType
Definition: otbRandomForestsMachineLearningModel.h:45