OTB  6.7.0
Orfeo Toolbox
otbTrainNeuralNetwork.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainNeuralNetwork_hxx
22 #define otbTrainNeuralNetwork_hxx
23 #include <boost/lexical_cast.hpp>
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 
32 template <class TInputValue, class TOutputValue>
33 void
34 LearningApplicationBase<TInputValue,TOutputValue>
35 ::InitNeuralNetworkParams()
36 {
37  AddChoice("classifier.ann", "Artificial Neural Network classifier");
38  SetParameterDescription("classifier.ann", "http://docs.opencv.org/modules/ml/doc/neural_networks.html");
39 
40  //TrainMethod
41  AddParameter(ParameterType_Choice, "classifier.ann.t", "Train Method Type");
42  AddChoice("classifier.ann.t.back", "Back-propagation algorithm");
43  SetParameterDescription("classifier.ann.t.back",
44  "Method to compute the gradient of the loss function and adjust weights "
45  "in the network to optimize the result.");
46  AddChoice("classifier.ann.t.reg", "Resilient Back-propagation algorithm");
47  SetParameterDescription("classifier.ann.t.reg",
48  "Almost the same as the Back-prop algorithm except that it does not "
49  "take into account the magnitude of the partial derivative (coordinate "
50  "of the gradient) but only its sign.");
51 
52  SetParameterString("classifier.ann.t", "reg");
53  SetParameterDescription("classifier.ann.t",
54  "Type of training method for the multilayer perceptron (MLP) neural network.");
55 
56  //LayerSizes
57  //There is no ParameterType_IntList, so i use a ParameterType_StringList and convert it.
58  /*std::vector<std::string> layerSizes;
59  layerSizes.push_back("100");
60  layerSizes.push_back("100"); */
61  AddParameter(ParameterType_StringList, "classifier.ann.sizes",
62  "Number of neurons in each intermediate layer");
63  //SetParameterStringList("classifier.ann.sizes", layerSizes);
64  SetParameterDescription("classifier.ann.sizes",
65  "The number of neurons in each intermediate layer (excluding input and output layers).");
66 
67  //ActivateFunction
68  AddParameter(ParameterType_Choice, "classifier.ann.f",
69  "Neuron activation function type");
70  AddChoice("classifier.ann.f.ident", "Identity function");
71  AddChoice("classifier.ann.f.sig", "Symmetrical Sigmoid function");
72  AddChoice("classifier.ann.f.gau", "Gaussian function (Not completely supported)");
73  SetParameterString("classifier.ann.f", "sig");
74  SetParameterDescription("classifier.ann.f",
75  "This function determine whether the output of the node is positive or not "
76  "depending on the output of the transfert function.");
77 
78  //Alpha
79  AddParameter(ParameterType_Float, "classifier.ann.a",
80  "Alpha parameter of the activation function");
81  SetParameterFloat("classifier.ann.a",1.);
82  SetParameterDescription("classifier.ann.a",
83  "Alpha parameter of the activation function (used only with sigmoid and gaussian functions).");
84 
85  //Beta
86  AddParameter(ParameterType_Float, "classifier.ann.b",
87  "Beta parameter of the activation function");
88  SetParameterFloat("classifier.ann.b",1.);
89  SetParameterDescription("classifier.ann.b",
90  "Beta parameter of the activation function (used only with sigmoid and gaussian functions).");
91 
92  //BackPropDWScale
93  AddParameter(ParameterType_Float, "classifier.ann.bpdw",
94  "Strength of the weight gradient term in the BACKPROP method");
95  SetParameterFloat("classifier.ann.bpdw",0.1);
96  SetParameterDescription("classifier.ann.bpdw",
97  "Strength of the weight gradient term in the BACKPROP method. The "
98  "recommended value is about 0.1.");
99 
100  //BackPropMomentScale
101  AddParameter(ParameterType_Float, "classifier.ann.bpms",
102  "Strength of the momentum term (the difference between weights on the 2 previous iterations)");
103  SetParameterFloat("classifier.ann.bpms",0.1);
104  SetParameterDescription("classifier.ann.bpms",
105  "Strength of the momentum term (the difference between weights on the 2 previous "
106  "iterations). This parameter provides some inertia to smooth the random "
107  "fluctuations of the weights. It can vary from 0 (the feature is disabled) "
108  "to 1 and beyond. The value 0.1 or so is good enough.");
109 
110  //RegPropDW0
111  AddParameter(ParameterType_Float, "classifier.ann.rdw",
112  "Initial value Delta_0 of update-values Delta_{ij} in RPROP method");
113  SetParameterFloat("classifier.ann.rdw",0.1);
114  SetParameterDescription("classifier.ann.rdw",
115  "Initial value Delta_0 of update-values Delta_{ij} in RPROP method (default = 0.1).");
116 
117  //RegPropDWMin
118  AddParameter(ParameterType_Float, "classifier.ann.rdwm",
119  "Update-values lower limit Delta_{min} in RPROP method");
120  SetParameterFloat("classifier.ann.rdwm",1e-7);
121  SetParameterDescription("classifier.ann.rdwm",
122  "Update-values lower limit Delta_{min} in RPROP method. It must be positive "
123  "(default = 1e-7).");
124 
125  //TermCriteriaType
126  AddParameter(ParameterType_Choice, "classifier.ann.term", "Termination criteria");
127  AddChoice("classifier.ann.term.iter", "Maximum number of iterations");
128  SetParameterDescription("classifier.ann.term.iter",
129  "Set the number of iterations allowed to the network for its "
130  "training. Training will stop regardless of the result when this "
131  "number is reached");
132  AddChoice("classifier.ann.term.eps", "Epsilon");
133  SetParameterDescription("classifier.ann.term.eps",
134  "Training will focus on result and will stop once the precision is"
135  "at most epsilon");
136  AddChoice("classifier.ann.term.all", "Max. iterations + Epsilon");
137  SetParameterDescription("classifier.ann.term.all",
138  "Both termination criteria are used. Training stop at the first reached");
139  SetParameterString("classifier.ann.term", "all");
140  SetParameterDescription("classifier.ann.term", "Termination criteria.");
141 
142  //Epsilon
143  AddParameter(ParameterType_Float, "classifier.ann.eps",
144  "Epsilon value used in the Termination criteria");
145  SetParameterFloat("classifier.ann.eps",0.01);
146  SetParameterDescription("classifier.ann.eps",
147  "Epsilon value used in the Termination criteria.");
148 
149  //MaxIter
150  AddParameter(ParameterType_Int, "classifier.ann.iter",
151  "Maximum number of iterations used in the Termination criteria");
152  SetParameterInt("classifier.ann.iter",1000);
153  SetParameterDescription("classifier.ann.iter",
154  "Maximum number of iterations used in the Termination criteria.");
155 
156 }
157 
158 template <class TInputValue, class TOutputValue>
159 void
160 LearningApplicationBase<TInputValue,TOutputValue>
161 ::TrainNeuralNetwork(typename ListSampleType::Pointer trainingListSample,
162  typename TargetListSampleType::Pointer trainingLabeledListSample,
163  std::string modelPath)
164 {
166  typename NeuralNetworkType::Pointer classifier = NeuralNetworkType::New();
167  classifier->SetRegressionMode(this->m_RegressionFlag);
168  classifier->SetInputListSample(trainingListSample);
169  classifier->SetTargetListSample(trainingLabeledListSample);
170 
171  switch (GetParameterInt("classifier.ann.t"))
172  {
173  case 0: // BACKPROP
174  classifier->SetTrainMethod(CvANN_MLP_TrainParams::BACKPROP);
175  break;
176  case 1: // RPROP
177  classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
178  break;
179  default: // DEFAULT = RPROP
180  classifier->SetTrainMethod(CvANN_MLP_TrainParams::RPROP);
181  break;
182  }
183 
184  std::vector<unsigned int> layerSizes;
185  std::vector<std::string> sizes = GetParameterStringList("classifier.ann.sizes");
186 
187 
188  unsigned int nbImageBands = trainingListSample->GetMeasurementVectorSize();
189  layerSizes.push_back(nbImageBands);
190  for (unsigned int i = 0; i < sizes.size(); i++)
191  {
192  unsigned int nbNeurons = boost::lexical_cast<unsigned int>(sizes[i]);
193  layerSizes.push_back(nbNeurons);
194  }
195 
196 
197  unsigned int nbClasses = 0;
198  if (this->m_RegressionFlag)
199  {
200  layerSizes.push_back(1);
201  }
202  else
203  {
204  std::set<TargetValueType> labelSet;
205  TargetSampleType currentLabel;
206  for (unsigned int itLab = 0; itLab < trainingLabeledListSample->Size(); ++itLab)
207  {
208  currentLabel = trainingLabeledListSample->GetMeasurementVector(itLab);
209  labelSet.insert(currentLabel[0]);
210  }
211  nbClasses = labelSet.size();
212  layerSizes.push_back(nbClasses);
213  }
214 
215  classifier->SetLayerSizes(layerSizes);
216 
217  switch (GetParameterInt("classifier.ann.f"))
218  {
219  case 0: // ident
220  classifier->SetActivateFunction(CvANN_MLP::IDENTITY);
221  break;
222  case 1: // sig
223  classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
224  break;
225  case 2: // gaussian
226  classifier->SetActivateFunction(CvANN_MLP::GAUSSIAN);
227  break;
228  default: // DEFAULT = RPROP
229  classifier->SetActivateFunction(CvANN_MLP::SIGMOID_SYM);
230  break;
231  }
232 
233  classifier->SetAlpha(GetParameterFloat("classifier.ann.a"));
234  classifier->SetBeta(GetParameterFloat("classifier.ann.b"));
235  classifier->SetBackPropDWScale(GetParameterFloat("classifier.ann.bpdw"));
236  classifier->SetBackPropMomentScale(GetParameterFloat("classifier.ann.bpms"));
237  classifier->SetRegPropDW0(GetParameterFloat("classifier.ann.rdw"));
238  classifier->SetRegPropDWMin(GetParameterFloat("classifier.ann.rdwm"));
239 
240  switch (GetParameterInt("classifier.ann.term"))
241  {
242  case 0: // CV_TERMCRIT_ITER
243  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
244  break;
245  case 1: // CV_TERMCRIT_EPS
246  classifier->SetTermCriteriaType(CV_TERMCRIT_EPS);
247  break;
248  case 2: // CV_TERMCRIT_ITER + CV_TERMCRIT_EPS
249  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
250  break;
251  default: // DEFAULT = CV_TERMCRIT_ITER + CV_TERMCRIT_EPS
252  classifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
253  break;
254  }
255  classifier->SetEpsilon(GetParameterFloat("classifier.ann.eps"));
256  classifier->SetMaxIter(GetParameterInt("classifier.ann.iter"));
257  classifier->Train();
258  classifier->Save(modelPath);
259 }
260 
261 } //end namespace wrapper
262 } //end namespace otb
263 
264 #endif