OTB  9.0.0
Orfeo Toolbox
otbTrainBoost.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainBoost_hxx
22 #define otbTrainBoost_hxx
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31 template <class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitBoostParams()
33 {
34  AddChoice("classifier.boost", "Boost classifier");
35  SetParameterDescription("classifier.boost", "http://docs.opencv.org/modules/ml/doc/boosting.html");
36  // BoostType
37  AddParameter(ParameterType_Choice, "classifier.boost.t", "Boost Type");
38  AddChoice("classifier.boost.t.discrete", "Discrete AdaBoost");
39  SetParameterDescription("classifier.boost.t.discrete",
40  "This procedure trains the classifiers on weighted versions of the training "
41  "sample, giving higher weight to cases that are currently misclassified. "
42  "This is done for a sequence of weighter samples, and then the final "
43  "classifier is defined as a linear combination of the classifier from "
44  "each stage.");
45  AddChoice("classifier.boost.t.real",
46  "Real AdaBoost (technique using confidence-rated predictions "
47  "and working well with categorical data)");
48  SetParameterDescription("classifier.boost.t.real", "Adaptation of the Discrete Adaboost algorithm with Real value");
49  AddChoice("classifier.boost.t.logit", "LogitBoost (technique producing good regression fits)");
50  SetParameterDescription("classifier.boost.t.logit",
51  "This procedure is an adaptive Newton algorithm for fitting an additive "
52  "logistic regression model. Beware it can produce numeric instability.");
53  AddChoice("classifier.boost.t.gentle",
54  "Gentle AdaBoost (technique setting less weight on outlier data points "
55  "and, for that reason, being often good with regression data)");
56  SetParameterDescription("classifier.boost.t.gentle",
57  "A modified version of the Real Adaboost algorithm, using Newton stepping "
58  "rather than exact optimization at each step.");
59  SetParameterString("classifier.boost.t", "real");
60  SetParameterDescription("classifier.boost.t", "Type of Boosting algorithm.");
61  // WeakCount
62  AddParameter(ParameterType_Int, "classifier.boost.w", "Weak count");
63  SetParameterInt("classifier.boost.w", 100);
64  SetParameterDescription("classifier.boost.w", "The number of weak classifiers.");
65  // WeightTrimRate
66  AddParameter(ParameterType_Float, "classifier.boost.r", "Weight Trim Rate");
67  SetParameterFloat("classifier.boost.r", 0.95);
68  SetParameterDescription("classifier.boost.r",
69  "A threshold between 0 and 1 used to save computational time. "
70  "Samples with summary weight <= (1 - weight_trim_rate) do not participate in"
71  " the next iteration of training. Set this parameter to 0 to turn off this "
72  "functionality.");
73  // MaxDepth : Not sure that this parameter has to be exposed.
74  AddParameter(ParameterType_Int, "classifier.boost.m", "Maximum depth of the tree");
75  SetParameterInt("classifier.boost.m", 1);
76  SetParameterDescription("classifier.boost.m", "Maximum depth of the tree.");
77 }
78 
79 template <class TInputValue, class TOutputValue>
80 void LearningApplicationBase<TInputValue, TOutputValue>::TrainBoost(typename ListSampleType::Pointer trainingListSample,
81  typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
82 {
84  typename BoostType::Pointer boostClassifier = BoostType::New();
85  boostClassifier->SetRegressionMode(this->m_RegressionFlag);
86  boostClassifier->SetInputListSample(trainingListSample);
87  boostClassifier->SetTargetListSample(trainingLabeledListSample);
88  boostClassifier->SetBoostType(GetParameterInt("classifier.boost.t"));
89  boostClassifier->SetWeakCount(GetParameterInt("classifier.boost.w"));
90  boostClassifier->SetWeightTrimRate(GetParameterFloat("classifier.boost.r"));
91  boostClassifier->SetMaxDepth(GetParameterInt("classifier.boost.m"));
92 
93  boostClassifier->Train();
94  boostClassifier->Save(modelPath);
95 }
96 
97 } // end namespace wrapper
98 } // end namespace otb
99 
100 #endif
otb::Wrapper::ParameterType_Choice
@ ParameterType_Choice
Definition: otbWrapperTypes.h:47
otb::BoostMachineLearningModel
Definition: otbBoostMachineLearningModel.h:35
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::Wrapper::ParameterType_Float
@ ParameterType_Float
Definition: otbWrapperTypes.h:39
otbBoostMachineLearningModel.h