OTB  6.7.0
Orfeo Toolbox
otbLearningApplicationBase.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbLearningApplicationBase_hxx
22 #define otbLearningApplicationBase_hxx
23 
25 // only need this filter as a dummy process object
26 #include "otbRGBAPixelConverter.h"
27 
28 namespace otb
29 {
30 namespace Wrapper
31 {
32 
33 template <class TInputValue, class TOutputValue>
35 ::LearningApplicationBase() : m_RegressionFlag(false)
36 
37 {
38 }
39 
40 template <class TInputValue, class TOutputValue>
43 {
44  ModelFactoryType::CleanFactories();
45 }
46 
47 template <class TInputValue, class TOutputValue>
48 void
51 {
52  AddDocTag(Tags::Learning);
53 
54  // main choice parameter that will contain all machine learning options
55  AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
56  SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
57 
58  InitSupervisedClassifierParams();
59  m_SupervisedClassifier = GetChoiceKeys("classifier");
60 
61  InitUnsupervisedClassifierParams();
62  std::vector<std::string> allClassifier = GetChoiceKeys("classifier");
63  // Check for empty unsupervised classifier
64  if( allClassifier.size() > m_UnsupervisedClassifier.size() )
65  m_UnsupervisedClassifier.assign( allClassifier.begin() + m_SupervisedClassifier.size(), allClassifier.end() );
66 }
67 
68 template <class TInputValue, class TOutputValue>
72 {
73  if( m_UnsupervisedClassifier.empty() )
74  {
75  return Supervised;
76  }
77  else
78  {
79  bool foundUnsupervised = std::find( m_UnsupervisedClassifier.begin(), m_UnsupervisedClassifier.end(),
80  GetParameterString( "classifier" ) ) != m_UnsupervisedClassifier.end();
81  return foundUnsupervised ? Unsupervised : Supervised;
82  }
83 }
84 
85 template <class TInputValue, class TOutputValue>
86 void
89 {
90 
91  //Group LibSVM
92 #ifdef OTB_USE_LIBSVM
93  InitLibSVMParams();
94 #endif
95 
96 #ifdef OTB_USE_OPENCV
97  // OpenCV SVM implementation is buggy with linear kernel
98  // Users should use the libSVM implementation instead.
99  // InitSVMParams();
100  if (!m_RegressionFlag)
101  {
102  InitBoostParams(); // Regression not supported
103  }
104  InitDecisionTreeParams();
105  InitNeuralNetworkParams();
106  if (!m_RegressionFlag)
107  {
108  InitNormalBayesParams(); // Regression not supported
109  }
110  InitRandomForestsParams();
111  InitKNNParams();
112 #endif
113 
114 #ifdef OTB_USE_SHARK
115  InitSharkRandomForestsParams();
116 #endif
117 }
118 
119 template <class TInputValue, class TOutputValue>
120 void
123 {
124 #ifdef OTB_USE_SHARK
125  if (!m_RegressionFlag)
126  {
127  InitSharkKMeansParams(); // Regression not supported
128  }
129 #endif
130 }
131 
132 template <class TInputValue, class TOutputValue>
134 ::TargetListSampleType::Pointer
136 ::Classify(typename ListSampleType::Pointer validationListSample,
137  std::string modelPath)
138 {
139  // Setup fake reporter
142  dummyFilter->SetProgress(0.0f);
143  this->AddProcess(dummyFilter,"Validation...");
144  dummyFilter->InvokeEvent(itk::StartEvent());
145 
146  // load a machine learning model from file and predict the input sample list
147  ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath,
148  ModelFactoryType::ReadMode);
149 
150  if (model.IsNull())
151  {
152  otbAppLogFATAL(<< "Error when loading model " << modelPath);
153  }
154 
155  model->Load(modelPath);
156  model->SetRegressionMode(this->m_RegressionFlag);
157 
158  typename TargetListSampleType::Pointer predictedList = model->PredictBatch(validationListSample, NULL);
159 
160  // update reporter
161  dummyFilter->UpdateProgress(1.0f);
162  dummyFilter->InvokeEvent(itk::EndEvent());
163 
164  return predictedList;
165 }
166 
167 template <class TInputValue, class TOutputValue>
168 void
170 ::Train(typename ListSampleType::Pointer trainingListSample,
171  typename TargetListSampleType::Pointer trainingLabeledListSample,
172  std::string modelPath)
173 {
174  otbAppLogINFO("Computing model file : "<<modelPath);
175  // Setup fake reporter
178  dummyFilter->SetProgress(0.0f);
179  this->AddProcess(dummyFilter,"Training model...");
180  dummyFilter->InvokeEvent(itk::StartEvent());
181 
182  // get the name of the chosen machine learning model
183  const std::string modelName = GetParameterString("classifier");
184  // call specific train function
185  if (modelName == "libsvm")
186  {
187  #ifdef OTB_USE_LIBSVM
188  TrainLibSVM(trainingListSample, trainingLabeledListSample, modelPath);
189  #else
190  otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
191  #endif
192  }
193  if(modelName == "sharkrf")
194  {
195  #ifdef OTB_USE_SHARK
196  TrainSharkRandomForests(trainingListSample,trainingLabeledListSample,modelPath);
197  #else
198  otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
199  #endif
200  }
201  else if(modelName == "sharkkm")
202  {
203  #ifdef OTB_USE_SHARK
204  TrainSharkKMeans( trainingListSample, trainingLabeledListSample, modelPath );
205  #else
206  otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
207  #endif
208  }
209  else if (modelName == "svm")
210  {
211  #ifdef OTB_USE_OPENCV
212  TrainSVM(trainingListSample, trainingLabeledListSample, modelPath);
213  #else
214  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
215  #endif
216  }
217  else if (modelName == "boost")
218  {
219  #ifdef OTB_USE_OPENCV
220  TrainBoost(trainingListSample, trainingLabeledListSample, modelPath);
221  #else
222  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
223  #endif
224  }
225  else if (modelName == "dt")
226  {
227  #ifdef OTB_USE_OPENCV
228  TrainDecisionTree(trainingListSample, trainingLabeledListSample, modelPath);
229  #else
230  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
231  #endif
232  }
233  else if (modelName == "ann")
234  {
235  #ifdef OTB_USE_OPENCV
236  TrainNeuralNetwork(trainingListSample, trainingLabeledListSample, modelPath);
237  #else
238  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
239  #endif
240  }
241  else if (modelName == "bayes")
242  {
243  #ifdef OTB_USE_OPENCV
244  TrainNormalBayes(trainingListSample, trainingLabeledListSample, modelPath);
245  #else
246  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
247  #endif
248  }
249  else if (modelName == "rf")
250  {
251  #ifdef OTB_USE_OPENCV
252  TrainRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
253  #else
254  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
255  #endif
256  }
257  else if (modelName == "knn")
258  {
259  #ifdef OTB_USE_OPENCV
260  TrainKNN(trainingListSample, trainingLabeledListSample, modelPath);
261  #else
262  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
263  #endif
264  }
265 
266  // update reporter
267  dummyFilter->UpdateProgress(1.0f);
268  dummyFilter->InvokeEvent(itk::EndEvent());
269 }
270 
271 }
272 }
273 
274 #endif
TargetListSampleType::Pointer Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath)
#define otbAppLogFATAL(x)
LearningApplicationBase is the base class for application that use machine learning model...
#define otbAppLogINFO(x)
static Pointer New()
void Train(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
bool IsNull() const
static const std::string Learning