OTB  9.0.0
Orfeo Toolbox
otbTrainSharkRandomForests.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainSharkRandomForests_hxx
22 #define otbTrainSharkRandomForests_hxx
23 
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 
32 template <class TInputValue, class TOutputValue>
33 void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkRandomForestsParams()
34 {
35 
36 
37  AddChoice("classifier.sharkrf", "Shark Random forests classifier");
38  SetParameterDescription("classifier.sharkrf",
39  "http://image.diku.dk/shark/doxygen_pages/html/classshark_1_1_r_f_trainer.html.\n It is noteworthy that training is parallel.");
40  // MaxNumberOfTrees
41  AddParameter(ParameterType_Int, "classifier.sharkrf.nbtrees", "Maximum number of trees in the forest");
42  SetParameterInt("classifier.sharkrf.nbtrees", 100);
43  SetParameterDescription("classifier.sharkrf.nbtrees",
44  "The maximum number of trees in the forest. Typically, the more trees you have, the better the accuracy. "
45  "However, the improvement in accuracy generally diminishes and reaches an asymptote for a certain number of trees. "
46  "Also to keep in mind, increasing the number of trees increases the prediction time linearly.");
47 
48 
49  // NodeSize
50  AddParameter(ParameterType_Int, "classifier.sharkrf.nodesize", "Min size of the node for a split");
51  SetParameterInt("classifier.sharkrf.nodesize", 25);
52  SetParameterDescription("classifier.sharkrf.nodesize",
53  "If the number of samples in a node is smaller than this parameter, "
54  "then the node will not be split. A reasonable value is a small percentage of the total data e.g. 1 percent.");
55 
56  // MTry
57  AddParameter(ParameterType_Int, "classifier.sharkrf.mtry", "Number of features tested at each node");
58  SetParameterInt("classifier.sharkrf.mtry", 0);
59  SetParameterDescription("classifier.sharkrf.mtry",
60  "The number of features (variables) which will be tested at each node in "
61  "order to compute the split. If set to zero, the square root of the number of "
62  "features is used.");
63 
64 
65  // OOB Ratio
66  AddParameter(ParameterType_Float, "classifier.sharkrf.oobr", "Out of bound ratio");
67  SetParameterFloat("classifier.sharkrf.oobr", 0.66);
68  SetParameterDescription("classifier.sharkrf.oobr",
69  "Set the fraction of the original training dataset to use as the out of bag sample."
70  "A good default value is 0.66. ");
71 }
72 
73 template <class TInputValue, class TOutputValue>
74 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
75  typename TargetListSampleType::Pointer trainingLabeledListSample,
76  std::string modelPath)
77 {
79  typename SharkRandomForestType::Pointer classifier = SharkRandomForestType::New();
80  classifier->SetRegressionMode(this->m_RegressionFlag);
81  classifier->SetInputListSample(trainingListSample);
82  classifier->SetTargetListSample(trainingLabeledListSample);
83  classifier->SetNodeSize(GetParameterInt("classifier.sharkrf.nodesize"));
84  classifier->SetOobRatio(GetParameterFloat("classifier.sharkrf.oobr"));
85  classifier->SetNumberOfTrees(GetParameterInt("classifier.sharkrf.nbtrees"));
86  classifier->SetMTry(GetParameterInt("classifier.sharkrf.mtry"));
87 
88  classifier->Train();
89  classifier->Save(modelPath);
90 }
91 
92 } // end namespace wrapper
93 } // end namespace otb
94 
95 #endif
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::Wrapper::ParameterType_Float
@ ParameterType_Float
Definition: otbWrapperTypes.h:39
otbSharkRandomForestsMachineLearningModel.h
otb::SharkRandomForestsMachineLearningModel
Definition: otbSharkRandomForestsMachineLearningModel.h:77