OTB  9.0.0
Orfeo Toolbox
otbDecisionTreeMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbDecisionTreeMachineLearningModel_h
22 #define otbDecisionTreeMachineLearningModel_h
23 
24 #include "otbRequiresOpenCVCheck.h"
25 
26 #include "itkLightObject.h"
27 #include "itkFixedArray.h"
29 
30 #include "otbOpenCVUtils.h"
31 
32 namespace otb
33 {
34 template <class TInputValue, class TTargetValue>
35 class ITK_EXPORT DecisionTreeMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
36 {
37 public:
41  typedef itk::SmartPointer<Self> Pointer;
42  typedef itk::SmartPointer<const Self> ConstPointer;
43 
45  typedef typename Superclass::InputSampleType InputSampleType;
50  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
51  typedef typename Superclass::ProbaSampleType ProbaSampleType;
53  itkNewMacro(Self);
56 
64  itkGetMacro(MaxDepth, int);
65  itkSetMacro(MaxDepth, int);
67 
73  itkGetMacro(MinSampleCount, int);
74  itkSetMacro(MinSampleCount, int);
76 
84  itkGetMacro(RegressionAccuracy, double);
85  itkSetMacro(RegressionAccuracy, double);
87 
93  itkGetMacro(UseSurrogates, bool);
94  itkSetMacro(UseSurrogates, bool);
96 
108  itkGetMacro(MaxCategories, int);
109  itkSetMacro(MaxCategories, int);
111 
117  itkGetMacro(Use1seRule, bool);
118  itkSetMacro(Use1seRule, bool);
120 
127  itkGetMacro(TruncatePrunedTree, bool);
128  itkSetMacro(TruncatePrunedTree, bool);
130 
131 
132  /* The array of a priori class probabilities, sorted by the class label
133  * value. The parameter can be used to tune the decision tree preferences toward
134  * a certain class. For example, if you want to detect some rare anomaly
135  * occurrence, the training base will likely contain much more normal cases than
136  * anomalies, so a very good classification performance will be achieved just by
137  * considering every case as normal. To avoid this, the priors can be specified,
138  * where the anomaly probability is artificially increased (up to 0.5 or even
139  * greater), so the weight of the misclassified anomalies becomes much bigger,
140  * and the tree is adjusted properly. You can also think about this parameter as
141  * weights of prediction categories which determine relative weights that you
142  * give to misclassification. That is, if the weight of the first category is 1
143  * and the weight of the second category is 10, then each mistake in predicting
144  * the second category is equivalent to making 10 mistakes in predicting the
145  first category. */
146 
147  std::vector<float> GetPriors() const
148  {
149  return m_Priors;
150  }
151 
153  void Train() override;
154 
156  void Save(const std::string& filename, const std::string& name = "") override;
157 
159  void Load(const std::string& filename, const std::string& name = "") override;
160 
163 
165  bool CanReadFile(const std::string&) override;
166 
168  bool CanWriteFile(const std::string&) override;
170 
171 protected:
174 
176  ~DecisionTreeMachineLearningModel() override = default;
177 
179  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
180 
182  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
183 
184 private:
185  DecisionTreeMachineLearningModel(const Self&) = delete;
186  void operator=(const Self&) = delete;
187 
188  cv::Ptr<cv::ml::DTrees> m_DTreeModel;
189 
197  std::vector<float> m_Priors;
198 };
199 } // end namespace otb
200 
201 #ifndef OTB_MANUAL_INSTANTIATION
203 #endif
204 
205 #endif
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetListSampleType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
Definition: otbMachineLearningModel.h:92
otb::DecisionTreeMachineLearningModel
Definition: otbDecisionTreeMachineLearningModel.h:35
otb::MachineLearningModel< TInputValue, TTargetValue >::InputListSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
Definition: otbMachineLearningModel.h:85
otb::MachineLearningModel< TInputValue, TTargetValue >::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb::DecisionTreeMachineLearningModel::m_Use1seRule
bool m_Use1seRule
Definition: otbDecisionTreeMachineLearningModel.h:195
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbDecisionTreeMachineLearningModel.hxx
otb::DecisionTreeMachineLearningModel::m_DTreeModel
cv::Ptr< cv::ml::DTrees > m_DTreeModel
Definition: otbDecisionTreeMachineLearningModel.h:188
otb::DecisionTreeMachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbDecisionTreeMachineLearningModel.h:41
otb::DecisionTreeMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbDecisionTreeMachineLearningModel.h:46
otb::DecisionTreeMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbDecisionTreeMachineLearningModel.h:47
otbMachineLearningModel.h
otb::DecisionTreeMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbDecisionTreeMachineLearningModel.h:48
otb::DecisionTreeMachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbDecisionTreeMachineLearningModel.h:42
otb::DecisionTreeMachineLearningModel::Self
DecisionTreeMachineLearningModel Self
Definition: otbDecisionTreeMachineLearningModel.h:39
otb::DecisionTreeMachineLearningModel::GetPriors
std::vector< float > GetPriors() const
Definition: otbDecisionTreeMachineLearningModel.h:147
otb::DecisionTreeMachineLearningModel::m_TruncatePrunedTree
bool m_TruncatePrunedTree
Definition: otbDecisionTreeMachineLearningModel.h:196
otb::DecisionTreeMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbDecisionTreeMachineLearningModel.h:50
otb::DecisionTreeMachineLearningModel::m_MaxDepth
int m_MaxDepth
Definition: otbDecisionTreeMachineLearningModel.h:190
otb::DecisionTreeMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbDecisionTreeMachineLearningModel.h:51
otb::DecisionTreeMachineLearningModel::Superclass
MachineLearningModel< TInputValue, TTargetValue > Superclass
Definition: otbDecisionTreeMachineLearningModel.h:40
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otbOpenCVUtils.h
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otbRequiresOpenCVCheck.h
otb::DecisionTreeMachineLearningModel::m_MaxCategories
int m_MaxCategories
Definition: otbDecisionTreeMachineLearningModel.h:194
otb::DecisionTreeMachineLearningModel::m_MinSampleCount
int m_MinSampleCount
Definition: otbDecisionTreeMachineLearningModel.h:191
otb::DecisionTreeMachineLearningModel::InputValueType
Superclass::InputValueType InputValueType
Definition: otbDecisionTreeMachineLearningModel.h:44
otb::DecisionTreeMachineLearningModel::m_Priors
std::vector< float > m_Priors
Definition: otbDecisionTreeMachineLearningModel.h:197
otb::DecisionTreeMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbDecisionTreeMachineLearningModel.h:49
otb::DecisionTreeMachineLearningModel::m_RegressionAccuracy
double m_RegressionAccuracy
Definition: otbDecisionTreeMachineLearningModel.h:192
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91
otb::DecisionTreeMachineLearningModel::m_UseSurrogates
bool m_UseSurrogates
Definition: otbDecisionTreeMachineLearningModel.h:193
otb::DecisionTreeMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbDecisionTreeMachineLearningModel.h:45