OTB  9.0.0
Orfeo Toolbox
otbTrainSVM.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainSVM_hxx
22 #define otbTrainSVM_hxx
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31 template <class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitSVMParams()
33 {
34  AddChoice("classifier.svm", "SVM classifier (OpenCV)");
35  SetParameterDescription("classifier.svm", "http://docs.opencv.org/modules/ml/doc/support_vector_machines.html");
36  AddParameter(ParameterType_Choice, "classifier.svm.m", "SVM Model Type");
37  SetParameterDescription("classifier.svm.m", "Type of SVM formulation.");
38  if (this->m_RegressionFlag)
39  {
40  AddChoice("classifier.svm.m.epssvr", "Epsilon Support Vector Regression");
41  AddChoice("classifier.svm.m.nusvr", "Nu Support Vector Regression");
42  SetParameterString("classifier.svm.m", "epssvr");
43  }
44  else
45  {
46  AddChoice("classifier.svm.m.csvc", "C support vector classification");
47  AddChoice("classifier.svm.m.nusvc", "Nu support vector classification");
48  AddChoice("classifier.svm.m.oneclass", "Distribution estimation (One Class SVM)");
49  SetParameterString("classifier.svm.m", "csvc");
50  }
51  AddParameter(ParameterType_Choice, "classifier.svm.k", "SVM Kernel Type");
52  AddChoice("classifier.svm.k.linear", "Linear");
53 
54  AddChoice("classifier.svm.k.rbf", "Gaussian radial basis function");
55  AddChoice("classifier.svm.k.poly", "Polynomial");
56  AddChoice("classifier.svm.k.sigmoid", "Sigmoid");
57  SetParameterString("classifier.svm.k", "linear");
58  SetParameterDescription("classifier.svm.k", "SVM Kernel Type.");
59  AddParameter(ParameterType_Float, "classifier.svm.c", "Cost parameter C");
60  SetParameterFloat("classifier.svm.c", 1.0);
61  SetParameterDescription("classifier.svm.c",
62  "SVM models have a cost parameter C (1 by default) to control the trade-off"
63  " between training errors and forcing rigid margins.");
64  AddParameter(ParameterType_Float, "classifier.svm.nu", "Parameter nu of a SVM optimization problem (NU_SVC / ONE_CLASS)");
65  SetParameterFloat("classifier.svm.nu", 0.0);
66  SetParameterDescription("classifier.svm.nu", "Parameter nu of a SVM optimization problem.");
67  if (this->m_RegressionFlag)
68  {
69  AddParameter(ParameterType_Float, "classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR)");
70  SetParameterFloat("classifier.svm.p", 1.0);
71  SetParameterDescription("classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR).");
72 
73  AddParameter(ParameterType_Choice, "classifier.svm.term", "Termination criteria");
74  SetParameterDescription("classifier.svm.term", "Termination criteria for iterative algorithm");
75  AddChoice("classifier.svm.term.iter", "Stops when maximum iteration is reached.");
76  AddChoice("classifier.svm.term.eps", "Stops when accuracy is lower than epsilon.");
77  AddChoice("classifier.svm.term.all", "Stops when either iteration or epsilon criteria is true");
78 
79  AddParameter(ParameterType_Float, "classifier.svm.iter", "Maximum iteration");
80  SetParameterFloat("classifier.svm.iter", 1000);
81  SetParameterDescription("classifier.svm.iter", "Maximum number of iterations (corresponds to the termination criteria 'iter').");
82 
83  AddParameter(ParameterType_Float, "classifier.svm.eps", "Epsilon accuracy threshold");
84  SetParameterFloat("classifier.svm.eps", FLT_EPSILON);
85  SetParameterDescription("classifier.svm.eps", "Epsilon accuracy (corresponds to the termination criteria 'eps').");
86  }
87  AddParameter(ParameterType_Float, "classifier.svm.coef0", "Parameter coef0 of a kernel function (POLY / SIGMOID)");
88  SetParameterFloat("classifier.svm.coef0", 0.0);
89  SetParameterDescription("classifier.svm.coef0", "Parameter coef0 of a kernel function (POLY / SIGMOID).");
90  AddParameter(ParameterType_Float, "classifier.svm.gamma", "Parameter gamma of a kernel function (POLY / RBF / SIGMOID)");
91  SetParameterFloat("classifier.svm.gamma", 1.0);
92  SetParameterDescription("classifier.svm.gamma", "Parameter gamma of a kernel function (POLY / RBF / SIGMOID).");
93  AddParameter(ParameterType_Float, "classifier.svm.degree", "Parameter degree of a kernel function (POLY)");
94  SetParameterFloat("classifier.svm.degree", 1.0);
95  SetParameterDescription("classifier.svm.degree", "Parameter degree of a kernel function (POLY).");
96  AddParameter(ParameterType_Bool, "classifier.svm.opt", "Parameters optimization");
97  SetParameterDescription("classifier.svm.opt",
98  "SVM parameters optimization flag.\n"
99  "-If set to True, then the optimal SVM parameters will be estimated. "
100  "Parameters are considered optimal by OpenCV when the cross-validation estimate of "
101  "the test set error is minimal. Finally, the SVM training process is computed "
102  "10 times with these optimal parameters over subsets corresponding to 1/10th of "
103  "the training samples using the k-fold cross-validation (with k = 10).\n-If set "
104  "to False, the SVM classification process will be computed once with the "
105  "currently set input SVM parameters over the training samples.\n-Thus, even "
106  "with identical input SVM parameters and a similar random seed, the output "
107  "SVM models will be different according to the method used (optimized or not) "
108  "because the samples are not identically processed within OpenCV.");
109 }
110 
111 template <class TInputValue, class TOutputValue>
112 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSVM(typename ListSampleType::Pointer trainingListSample,
113  typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
114 {
116  typename SVMType::Pointer SVMClassifier = SVMType::New();
117  SVMClassifier->SetRegressionMode(this->m_RegressionFlag);
118  SVMClassifier->SetInputListSample(trainingListSample);
119  SVMClassifier->SetTargetListSample(trainingLabeledListSample);
120  switch (GetParameterInt("classifier.svm.k"))
121  {
122  case 0: // LINEAR
123  SVMClassifier->SetKernelType(CvSVM::LINEAR);
124  std::cout << "CvSVM::LINEAR = " << CvSVM::LINEAR << std::endl;
125  break;
126  case 1: // RBF
127  SVMClassifier->SetKernelType(CvSVM::RBF);
128  std::cout << "CvSVM::RBF = " << CvSVM::RBF << std::endl;
129  break;
130  case 2: // POLY
131  SVMClassifier->SetKernelType(CvSVM::POLY);
132  std::cout << "CvSVM::POLY = " << CvSVM::POLY << std::endl;
133  break;
134  case 3: // SIGMOID
135  SVMClassifier->SetKernelType(CvSVM::SIGMOID);
136  std::cout << "CvSVM::SIGMOID = " << CvSVM::SIGMOID << std::endl;
137  break;
138  default: // DEFAULT = LINEAR
139  SVMClassifier->SetKernelType(CvSVM::LINEAR);
140  std::cout << "CvSVM::LINEAR = " << CvSVM::LINEAR << std::endl;
141  break;
142  }
143  if (this->m_RegressionFlag)
144  {
145  switch (GetParameterInt("classifier.svm.m"))
146  {
147  case 0: // EPS_SVR
148  SVMClassifier->SetSVMType(CvSVM::EPS_SVR);
149  std::cout << "CvSVM::EPS_SVR = " << CvSVM::EPS_SVR << std::endl;
150  break;
151  case 1: // NU_SVR
152  SVMClassifier->SetSVMType(CvSVM::NU_SVR);
153  std::cout << "CvSVM::NU_SVR = " << CvSVM::NU_SVR << std::endl;
154  break;
155  default: // DEFAULT = EPS_SVR
156  SVMClassifier->SetSVMType(CvSVM::EPS_SVR);
157  std::cout << "CvSVM::EPS_SVR = " << CvSVM::EPS_SVR << std::endl;
158  break;
159  }
160  }
161  else
162  {
163  switch (GetParameterInt("classifier.svm.m"))
164  {
165  case 0: // C_SVC
166  SVMClassifier->SetSVMType(CvSVM::C_SVC);
167  std::cout << "CvSVM::C_SVC = " << CvSVM::C_SVC << std::endl;
168  break;
169  case 1: // NU_SVC
170  SVMClassifier->SetSVMType(CvSVM::NU_SVC);
171  std::cout << "CvSVM::NU_SVC = " << CvSVM::NU_SVC << std::endl;
172  break;
173  case 2: // ONE_CLASS
174  SVMClassifier->SetSVMType(CvSVM::ONE_CLASS);
175  std::cout << "CvSVM::ONE_CLASS = " << CvSVM::ONE_CLASS << std::endl;
176  break;
177  default: // DEFAULT = C_SVC
178  SVMClassifier->SetSVMType(CvSVM::C_SVC);
179  std::cout << "CvSVM::C_SVC = " << CvSVM::C_SVC << std::endl;
180  break;
181  }
182  }
183  SVMClassifier->SetC(GetParameterFloat("classifier.svm.c"));
184  SVMClassifier->SetNu(GetParameterFloat("classifier.svm.nu"));
185  if (this->m_RegressionFlag)
186  {
187  SVMClassifier->SetP(GetParameterFloat("classifier.svm.p"));
188  switch (GetParameterInt("classifier.svm.term"))
189  {
190  case 0: // ITER
191  SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
192  break;
193  case 1: // EPS
194  SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_EPS);
195  break;
196  case 2: // ITER+EPS
197  SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
198  break;
199  default:
200  SVMClassifier->SetTermCriteriaType(CV_TERMCRIT_ITER);
201  break;
202  }
203  SVMClassifier->SetMaxIter(GetParameterInt("classifier.svm.iter"));
204  SVMClassifier->SetEpsilon(GetParameterFloat("classifier.svm.eps"));
205  }
206  SVMClassifier->SetCoef0(GetParameterFloat("classifier.svm.coef0"));
207  SVMClassifier->SetGamma(GetParameterFloat("classifier.svm.gamma"));
208  SVMClassifier->SetDegree(GetParameterFloat("classifier.svm.degree"));
209  SVMClassifier->SetParameterOptimization(GetParameterInt("classifier.svm.opt"));
210  SVMClassifier->Train();
211  SVMClassifier->Save(modelPath);
212 
213  // Update the displayed parameters in the GUI after the training process, for further use of them
214  SetParameterFloat("classifier.svm.c", static_cast<float>(SVMClassifier->GetOutputC()));
215  SetParameterFloat("classifier.svm.nu", static_cast<float>(SVMClassifier->GetOutputNu()));
216  if (this->m_RegressionFlag)
217  {
218  SetParameterFloat("classifier.svm.p", static_cast<float>(SVMClassifier->GetOutputP()));
219  }
220  SetParameterFloat("classifier.svm.coef0", static_cast<float>(SVMClassifier->GetOutputCoef0()));
221  SetParameterFloat("classifier.svm.gamma", static_cast<float>(SVMClassifier->GetOutputGamma()));
222  SetParameterFloat("classifier.svm.degree", static_cast<float>(SVMClassifier->GetOutputDegree()));
223 }
224 
225 } // end namespace wrapper
226 } // end namespace otb
227 
228 #endif
otb::Wrapper::ParameterType_Bool
@ ParameterType_Bool
Definition: otbWrapperTypes.h:60
otbSVMMachineLearningModel.h
otb::Wrapper::ParameterType_Choice
@ ParameterType_Choice
Definition: otbWrapperTypes.h:47
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Float
@ ParameterType_Float
Definition: otbWrapperTypes.h:39
otb::SVMMachineLearningModel
OpenCV implementation of SVM algorithm.
Definition: otbSVMMachineLearningModel.h:42