OTB  6.7.0
Orfeo Toolbox
otbLearningApplicationBase.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbLearningApplicationBase_h
22 #define otbLearningApplicationBase_h
23 
24 #include "otbConfigure.h"
25 
26 #include "otbWrapperApplication.h"
27 
28 
29 // ListSample
30 #include "itkListSample.h"
32 
33 //Estimator
35 #include <string>
36 
37 namespace otb
38 {
39 namespace Wrapper
40 {
41 
74 template <class TInputValue, class TOutputValue>
76 {
77 public:
83 
85  itkTypeMacro(LearningApplicationBase, otb::Application);
86 
87  typedef TInputValue InputValueType;
88  typedef TOutputValue OutputValueType;
89 
92 
93  // Machine Learning models
98 
101 
105 
106  itkGetConstReferenceMacro(SupervisedClassifier, std::vector<std::string>);
107  itkGetConstReferenceMacro(UnsupervisedClassifier, std::vector<std::string>);
108 
112  };
113 
120 
121 protected:
123 
124  ~LearningApplicationBase() override;
125 
128  void Train(typename ListSampleType::Pointer trainingListSample,
129  typename TargetListSampleType::Pointer trainingLabeledListSample,
130  std::string modelPath);
131 
134  typename ListSampleType::Pointer validationListSample,
135  std::string modelPath);
136 
138  void DoInit() override;
139 
143 
144 private:
149  std::vector<std::string> m_SupervisedClassifier;
150 
153  std::vector<std::string> m_UnsupervisedClassifier;
154 
156 #ifdef OTB_USE_LIBSVM
157  void InitLibSVMParams();
158 
159  void TrainLibSVM(typename ListSampleType::Pointer trainingListSample,
160  typename TargetListSampleType::Pointer trainingLabeledListSample,
161  std::string modelPath);
162 #endif
163 
164 #ifdef OTB_USE_OPENCV
165  void InitBoostParams();
166  void InitSVMParams();
167  void InitDecisionTreeParams();
168  void InitNeuralNetworkParams();
169  void InitNormalBayesParams();
170  void InitRandomForestsParams();
171  void InitKNNParams();
172 
173  void TrainBoost(typename ListSampleType::Pointer trainingListSample,
174  typename TargetListSampleType::Pointer trainingLabeledListSample,
175  std::string modelPath);
176  void TrainSVM(typename ListSampleType::Pointer trainingListSample,
177  typename TargetListSampleType::Pointer trainingLabeledListSample,
178  std::string modelPath);
179  void TrainDecisionTree(typename ListSampleType::Pointer trainingListSample,
180  typename TargetListSampleType::Pointer trainingLabeledListSample,
181  std::string modelPath);
182  void TrainNeuralNetwork(typename ListSampleType::Pointer trainingListSample,
183  typename TargetListSampleType::Pointer trainingLabeledListSample,
184  std::string modelPath);
185  void TrainNormalBayes(typename ListSampleType::Pointer trainingListSample,
186  typename TargetListSampleType::Pointer trainingLabeledListSample,
187  std::string modelPath);
188  void TrainRandomForests(typename ListSampleType::Pointer trainingListSample,
189  typename TargetListSampleType::Pointer trainingLabeledListSample,
190  std::string modelPath);
191  void TrainKNN(typename ListSampleType::Pointer trainingListSample,
192  typename TargetListSampleType::Pointer trainingLabeledListSample,
193  std::string modelPath);
194 #endif
195 
196 #ifdef OTB_USE_SHARK
197  void InitSharkRandomForestsParams();
198  void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
199  typename TargetListSampleType::Pointer trainingLabeledListSample,
200  std::string modelPath);
201  void InitSharkKMeansParams();
202  void TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample,
203  typename TargetListSampleType::Pointer trainingLabeledListSample,
204  std::string modelPath);
205 #endif
206 
207 };
208 
209 }
210 }
211 
212 #ifndef OTB_MANUAL_INSTANTIATION
214 #ifdef OTB_USE_OPENCV
215 #include "otbTrainBoost.hxx"
216 #include "otbTrainDecisionTree.hxx"
217 #include "otbTrainKNN.hxx"
218 #include "otbTrainNeuralNetwork.hxx"
219 #include "otbTrainNormalBayes.hxx"
220 #include "otbTrainRandomForests.hxx"
221 #include "otbTrainSVM.hxx"
222 #endif
223 #ifdef OTB_USE_LIBSVM
224 #include "otbTrainLibSVM.hxx"
225 #endif
226 #ifdef OTB_USE_SHARK
228 #include "otbTrainSharkKMeans.hxx"
229 #endif
230 #endif
231 
232 #endif
Creation of an "otb" vector image which contains metadata.
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
SmartPointer< Self > Pointer
TargetListSampleType::Pointer Classify(typename ListSampleType::Pointer validationListSample, std::string modelPath)
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests...
otb::VectorImage< InputValueType > SampleImageType
ModelFactoryType::MachineLearningModelTypePointer ModelPointerType
LearningApplicationBase is the base class for application that use machine learning model...
ModelFactoryType::MachineLearningModelType ModelType
void Train(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
ModelType::TargetListSampleType TargetListSampleType
otb::MachineLearningModelFactory< InputValueType, OutputValueType > ModelFactoryType
ModelType::InputListSampleType ListSampleType
This class represent an application TODO.
MLMSampleTraits< TInputValue >::SampleType InputSampleType
itk::SmartPointer< const Self > ConstPointer
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Creation of object instance using object factory.