OTB  9.0.0
Orfeo Toolbox
otbTrainDecisionTree.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainDecisionTree_hxx
22 #define otbTrainDecisionTree_hxx
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31 template <class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitDecisionTreeParams()
33 {
34  AddChoice("classifier.dt", "Decision Tree classifier");
35  SetParameterDescription("classifier.dt", "http://docs.opencv.org/modules/ml/doc/decision_trees.html");
36  // MaxDepth
37  AddParameter(ParameterType_Int, "classifier.dt.max", "Maximum depth of the tree");
38  SetParameterInt("classifier.dt.max", 10);
39  SetParameterDescription("classifier.dt.max",
40  "The training algorithm attempts to split each node while its depth is smaller "
41  "than the maximum possible depth of the tree. The actual depth may be smaller "
42  "if the other termination criteria are met, and/or if the tree is pruned.");
43 
44  // MinSampleCount
45  AddParameter(ParameterType_Int, "classifier.dt.min", "Minimum number of samples in each node");
46  SetParameterInt("classifier.dt.min", 10);
47  SetParameterDescription("classifier.dt.min",
48  "If the number of samples in a node is smaller "
49  "than this parameter, then this node will not be split.");
50 
51  // RegressionAccuracy
52  AddParameter(ParameterType_Float, "classifier.dt.ra", "Termination criteria for regression tree");
53  SetParameterFloat("classifier.dt.ra", 0.01);
54  SetParameterDescription("classifier.dt.ra",
55  "If all absolute differences between an estimated value in a node "
56  "and the values of the train samples in this node are smaller than this "
57  "regression accuracy parameter, then the node will not be split further.");
58 
59  // UseSurrogates : don't need to be exposed !
60  // SetParameterDescription("classifier.dt.sur","These splits allow working with missing data and compute variable importance correctly.");
61 
62  // MaxCategories
63  AddParameter(ParameterType_Int, "classifier.dt.cat",
64  "Cluster possible values of a categorical variable into K <= cat clusters to find a "
65  "suboptimal split");
66  SetParameterInt("classifier.dt.cat", 10);
67  SetParameterDescription("classifier.dt.cat",
68  "Cluster possible values of a categorical variable into K <= cat clusters to find a "
69  "suboptimal split.");
70 
71  // Use1seRule
72  AddParameter(ParameterType_Bool, "classifier.dt.r", "Set Use1seRule flag to false");
73  SetParameterDescription("classifier.dt.r",
74  "If true, then a pruning will be harsher. This will make a tree more compact and more "
75  "resistant to the training data noise but a bit less accurate.");
76 
77  // TruncatePrunedTree
78  AddParameter(ParameterType_Bool, "classifier.dt.t", "Set TruncatePrunedTree flag to false");
79  SetParameterDescription("classifier.dt.t", "If true, then pruned branches are physically removed from the tree.");
80 
81  // Priors are not exposed.
82 }
83 
84 template <class TInputValue, class TOutputValue>
85 void LearningApplicationBase<TInputValue, TOutputValue>::TrainDecisionTree(typename ListSampleType::Pointer trainingListSample,
86  typename TargetListSampleType::Pointer trainingLabeledListSample,
87  std::string modelPath)
88 {
90  typename DecisionTreeType::Pointer classifier = DecisionTreeType::New();
91  classifier->SetRegressionMode(this->m_RegressionFlag);
92  classifier->SetInputListSample(trainingListSample);
93  classifier->SetTargetListSample(trainingLabeledListSample);
94  classifier->SetMaxDepth(GetParameterInt("classifier.dt.max"));
95  classifier->SetMinSampleCount(GetParameterInt("classifier.dt.min"));
96  classifier->SetRegressionAccuracy(GetParameterFloat("classifier.dt.ra"));
97  classifier->SetMaxCategories(GetParameterInt("classifier.dt.cat"));
98 
99  if (GetParameterInt("classifier.dt.r"))
100  {
101  classifier->SetUse1seRule(false);
102  }
103  if (GetParameterInt("classifier.dt.t"))
104  {
105  classifier->SetTruncatePrunedTree(false);
106  }
107  classifier->Train();
108  classifier->Save(modelPath);
109 }
110 
111 } // end namespace wrapper
112 } // end namespace otb
113 
114 #endif
otb::Wrapper::ParameterType_Bool
@ ParameterType_Bool
Definition: otbWrapperTypes.h:60
otb::DecisionTreeMachineLearningModel
Definition: otbDecisionTreeMachineLearningModel.h:35
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::Wrapper::ParameterType_Float
@ ParameterType_Float
Definition: otbWrapperTypes.h:39
otbDecisionTreeMachineLearningModel.h