OTB  9.0.0
Orfeo Toolbox
otbLearningApplicationBase.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbLearningApplicationBase_hxx
22 #define otbLearningApplicationBase_hxx
23 
25 // only need this filter as a dummy process object
26 #include "otbRGBAPixelConverter.h"
27 
28 namespace otb
29 {
30 namespace Wrapper
31 {
32 
33 template <class TInputValue, class TOutputValue>
35 
36 {
37 }
38 
39 template <class TInputValue, class TOutputValue>
41 {
42  ModelFactoryType::CleanFactories();
43 }
44 
45 template <class TInputValue, class TOutputValue>
47 {
48  AddDocTag(Tags::Learning);
49 
50  // main choice parameter that will contain all machine learning options
51  AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
52  SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
53 
54  InitSupervisedClassifierParams();
55  m_SupervisedClassifier = GetChoiceKeys("classifier");
56 
57  InitUnsupervisedClassifierParams();
58  std::vector<std::string> allClassifier = GetChoiceKeys("classifier");
59  // Check for empty unsupervised classifier
60  if (allClassifier.size() > m_UnsupervisedClassifier.size())
61  m_UnsupervisedClassifier.assign(allClassifier.begin() + m_SupervisedClassifier.size(), allClassifier.end());
62 }
63 
64 template <class TInputValue, class TOutputValue>
66 {
67  if (m_UnsupervisedClassifier.empty())
68  {
69  return Supervised;
70  }
71  else
72  {
73  bool foundUnsupervised =
74  std::find(m_UnsupervisedClassifier.begin(), m_UnsupervisedClassifier.end(), GetParameterString("classifier")) != m_UnsupervisedClassifier.end();
75  return foundUnsupervised ? Unsupervised : Supervised;
76  }
77 }
78 
79 template <class TInputValue, class TOutputValue>
81 {
82 
83 // Group LibSVM
84 #ifdef OTB_USE_LIBSVM
85  InitLibSVMParams();
86 #endif
87 
88 #ifdef OTB_USE_OPENCV
89  // OpenCV SVM implementation is buggy with linear kernel
90  // Users should use the libSVM implementation instead.
91  // InitSVMParams();
92  if (!m_RegressionFlag)
93  {
94  InitBoostParams(); // Regression not supported
95  }
96  InitDecisionTreeParams();
97  InitNeuralNetworkParams();
98  if (!m_RegressionFlag)
99  {
100  InitNormalBayesParams(); // Regression not supported
101  }
102  InitRandomForestsParams();
103  InitKNNParams();
104 #endif
105 
106 #ifdef OTB_USE_SHARK
107  InitSharkRandomForestsParams();
108 #endif
109 }
110 
111 template <class TInputValue, class TOutputValue>
113 {
114 #ifdef OTB_USE_SHARK
115  if (!m_RegressionFlag)
116  {
117  InitSharkKMeansParams(); // Regression not supported
118  }
119 #endif
120 }
121 
122 template <class TInputValue, class TOutputValue>
124 LearningApplicationBase<TInputValue, TOutputValue>::Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath)
125 {
126  // Setup fake reporter
128  dummyFilter->SetProgress(0.0f);
129  this->AddProcess(dummyFilter, "Validation...");
130  dummyFilter->InvokeEvent(itk::StartEvent());
131 
132  // load a machine learning model from file and predict the input sample list
133  ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath, ModelFactoryType::ReadMode);
134 
135  if (model.IsNull())
136  {
137  otbAppLogFATAL(<< "Error when loading model " << modelPath);
138  }
139 
140  model->Load(modelPath);
141  model->SetRegressionMode(this->m_RegressionFlag);
142 
143  typename TargetListSampleType::Pointer predictedList = model->PredictBatch(validationListSample, NULL);
144 
145  // update reporter
146  dummyFilter->UpdateProgress(1.0f);
147  dummyFilter->InvokeEvent(itk::EndEvent());
148 
149  return predictedList;
150 }
151 
152 template <class TInputValue, class TOutputValue>
153 void LearningApplicationBase<TInputValue, TOutputValue>::Train(typename ListSampleType::Pointer trainingListSample,
154  typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
155 {
156  otbAppLogINFO("Computing model file : " << modelPath);
157  // Setup fake reporter
159  dummyFilter->SetProgress(0.0f);
160  this->AddProcess(dummyFilter, "Training model...");
161  dummyFilter->InvokeEvent(itk::StartEvent());
162 
163  // get the name of the chosen machine learning model
164  const std::string modelName = GetParameterString("classifier");
165  // call specific train function
166  if (modelName == "libsvm")
167  {
168 #ifdef OTB_USE_LIBSVM
169  TrainLibSVM(trainingListSample, trainingLabeledListSample, modelPath);
170 #else
171  otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
172 #endif
173  }
174  if (modelName == "sharkrf")
175  {
176 #ifdef OTB_USE_SHARK
177  TrainSharkRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
178 #else
179  otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
180 #endif
181  }
182  else if (modelName == "sharkkm")
183  {
184 #ifdef OTB_USE_SHARK
185  TrainSharkKMeans(trainingListSample, trainingLabeledListSample, modelPath);
186 #else
187  otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
188 #endif
189  }
190  else if (modelName == "svm")
191  {
192 #ifdef OTB_USE_OPENCV
193  TrainSVM(trainingListSample, trainingLabeledListSample, modelPath);
194 #else
195  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
196 #endif
197  }
198  else if (modelName == "boost")
199  {
200 #ifdef OTB_USE_OPENCV
201  TrainBoost(trainingListSample, trainingLabeledListSample, modelPath);
202 #else
203  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
204 #endif
205  }
206  else if (modelName == "dt")
207  {
208 #ifdef OTB_USE_OPENCV
209  TrainDecisionTree(trainingListSample, trainingLabeledListSample, modelPath);
210 #else
211  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
212 #endif
213  }
214  else if (modelName == "ann")
215  {
216 #ifdef OTB_USE_OPENCV
217  TrainNeuralNetwork(trainingListSample, trainingLabeledListSample, modelPath);
218 #else
219  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
220 #endif
221  }
222  else if (modelName == "bayes")
223  {
224 #ifdef OTB_USE_OPENCV
225  TrainNormalBayes(trainingListSample, trainingLabeledListSample, modelPath);
226 #else
227  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
228 #endif
229  }
230  else if (modelName == "rf")
231  {
232 #ifdef OTB_USE_OPENCV
233  TrainRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
234 #else
235  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
236 #endif
237  }
238  else if (modelName == "knn")
239  {
240 #ifdef OTB_USE_OPENCV
241  TrainKNN(trainingListSample, trainingLabeledListSample, modelPath);
242 #else
243  otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
244 #endif
245  }
246 
247  // update reporter
248  dummyFilter->UpdateProgress(1.0f);
249  dummyFilter->InvokeEvent(itk::EndEvent());
250 }
251 }
252 }
253 
254 #endif
otb::Wrapper::LearningApplicationBase::~LearningApplicationBase
~LearningApplicationBase() override
Definition: otbLearningApplicationBase.hxx:40
otb::RGBAPixelConverter::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbRGBAPixelConverter.h:52
otb::Wrapper::LearningApplicationBase::ClassifierCategory
ClassifierCategory
Definition: otbLearningApplicationBase.h:108
otb::Wrapper::LearningApplicationBase
LearningApplicationBase is the base class for application that use machine learning model.
Definition: otbLearningApplicationBase.h:75
otb::Wrapper::LearningApplicationBase::InitUnsupervisedClassifierParams
void InitUnsupervisedClassifierParams()
Definition: otbLearningApplicationBase.hxx:112
otbAppLogFATAL
#define otbAppLogFATAL(x)
Definition: otbWrapperMacros.h:25
otb::find
string_view find(string_view const &haystack, string_view const &needle)
Definition: otbStringUtilities.h:305
otb::Wrapper::LearningApplicationBase::Classify
TargetListSampleType::Pointer Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath)
Definition: otbLearningApplicationBase.hxx:124
otb::Wrapper::ParameterType_Choice
@ ParameterType_Choice
Definition: otbWrapperTypes.h:47
otb::Wrapper::LearningApplicationBase::GetClassifierCategory
ClassifierCategory GetClassifierCategory()
Definition: otbLearningApplicationBase.hxx:65
otb::RGBAPixelConverter::New
static Pointer New()
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::LearningApplicationBase::Train
void Train(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
Definition: otbLearningApplicationBase.hxx:153
otbAppLogINFO
#define otbAppLogINFO(x)
Definition: otbWrapperMacros.h:52
otbRGBAPixelConverter.h
otb::Wrapper::LearningApplicationBase::ModelPointerType
ModelFactoryType::MachineLearningModelTypePointer ModelPointerType
Definition: otbLearningApplicationBase.h:95
otb::Wrapper::Tags::Learning
static const std::string Learning
Definition: otbWrapperTags.h:42
otb::Wrapper::LearningApplicationBase::LearningApplicationBase
LearningApplicationBase()
Definition: otbLearningApplicationBase.hxx:34
otb::Wrapper::LearningApplicationBase::DoInit
void DoInit() override
Definition: otbLearningApplicationBase.hxx:46
otb::string_view::end
const_iterator end() const
Definition: otbStringUtilities.h:196
otb::Wrapper::LearningApplicationBase::InitSupervisedClassifierParams
void InitSupervisedClassifierParams()
Definition: otbLearningApplicationBase.hxx:80