OTB  6.7.0
Orfeo Toolbox
otbTrainSharkRandomForests.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainSharkRandomForests_hxx
22 #define otbTrainSharkRandomForests_hxx
23 
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 
32 template <class TInputValue, class TOutputValue>
33 void
34 LearningApplicationBase<TInputValue,TOutputValue>
35 ::InitSharkRandomForestsParams()
36 {
37 
38 
39  AddChoice("classifier.sharkrf", "Shark Random forests classifier");
40  SetParameterDescription("classifier.sharkrf",
41  "http://image.diku.dk/shark/doxygen_pages/html/classshark_1_1_r_f_trainer.html.\n It is noteworthy that training is parallel.");
42  //MaxNumberOfTrees
43  AddParameter(ParameterType_Int, "classifier.sharkrf.nbtrees",
44  "Maximum number of trees in the forest");
45  SetParameterInt("classifier.sharkrf.nbtrees",100);
46  SetParameterDescription(
47  "classifier.sharkrf.nbtrees",
48  "The maximum number of trees in the forest. Typically, the more trees you have, the better the accuracy. "
49  "However, the improvement in accuracy generally diminishes and reaches an asymptote for a certain number of trees. "
50  "Also to keep in mind, increasing the number of trees increases the prediction time linearly.");
51 
52 
53  //NodeSize
54  AddParameter(ParameterType_Int, "classifier.sharkrf.nodesize", "Min size of the node for a split");
55  SetParameterInt("classifier.sharkrf.nodesize",25);
56  SetParameterDescription(
57  "classifier.sharkrf.nodesize",
58  "If the number of samples in a node is smaller than this parameter, "
59  "then the node will not be split. A reasonable value is a small percentage of the total data e.g. 1 percent.");
60 
61  //MTry
62  AddParameter(ParameterType_Int, "classifier.sharkrf.mtry", "Number of features tested at each node");
63  SetParameterInt("classifier.sharkrf.mtry",0);
64  SetParameterDescription(
65  "classifier.sharkrf.mtry",
66  "The number of features (variables) which will be tested at each node in "
67  "order to compute the split. If set to zero, the square root of the number of "
68  "features is used.");
69 
70 
71  //OOB Ratio
72  AddParameter(ParameterType_Float, "classifier.sharkrf.oobr", "Out of bound ratio");
73  SetParameterFloat("classifier.sharkrf.oobr",0.66);
74  SetParameterDescription("classifier.sharkrf.oobr",
75  "Set the fraction of the original training dataset to use as the out of bag sample."
76  "A good default value is 0.66. ");
77 }
78 
79 template <class TInputValue, class TOutputValue>
80 void
81 LearningApplicationBase<TInputValue,TOutputValue>
82 ::TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
83  typename TargetListSampleType::Pointer trainingLabeledListSample,
84  std::string modelPath)
85 {
87  typename SharkRandomForestType::Pointer classifier = SharkRandomForestType::New();
88  classifier->SetRegressionMode(this->m_RegressionFlag);
89  classifier->SetInputListSample(trainingListSample);
90  classifier->SetTargetListSample(trainingLabeledListSample);
91  classifier->SetNodeSize(GetParameterInt("classifier.sharkrf.nodesize"));
92  classifier->SetOobRatio(GetParameterFloat("classifier.sharkrf.oobr"));
93  classifier->SetNumberOfTrees(GetParameterInt("classifier.sharkrf.nbtrees"));
94  classifier->SetMTry(GetParameterInt("classifier.sharkrf.mtry"));
95 
96  classifier->Train();
97  classifier->Save(modelPath);
98 }
99 
100 } //end namespace wrapper
101 } //end namespace otb
102 
103 #endif