OTB  9.0.0
Orfeo Toolbox
otbLibSVMMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbLibSVMMachineLearningModel_hxx
22 #define otbLibSVMMachineLearningModel_hxx
23 
24 #include <fstream>
28 #include "otbMacro.h"
29 #include "otbUtils.h"
30 
31 namespace otb
32 {
33 
34 template <class TInputValue, class TOutputValue>
36 {
37  this->SetSVMType(C_SVC);
38  this->SetKernelType(LINEAR);
39  this->SetPolynomialKernelDegree(3);
40  this->SetKernelGamma(1.); // 1/k
41  this->SetKernelCoef0(1.);
42  this->SetNu(0.5);
43  this->SetC(1.0);
44  this->SetEpsilon(1e-3);
45  this->SetP(0.1);
46  this->SetDoProbabilityEstimates(false);
47  this->DoShrinking(true);
48  this->SetCacheSize(40); // MB
49  this->m_ParameterOptimization = false;
50  this->m_IsRegressionSupported = true;
51  this->SetCVFolders(5);
52  this->m_InitialCrossValidationAccuracy = 0.;
53  this->m_FinalCrossValidationAccuracy = 0.;
54  this->m_CoarseOptimizationNumberOfSteps = 5;
55  this->m_FineOptimizationNumberOfSteps = 5;
57 
58  this->m_Parameters.nr_weight = 0;
59  this->m_Parameters.weight_label = nullptr;
60  this->m_Parameters.weight = nullptr;
61 
62  this->m_Model = nullptr;
63 
64  this->m_Problem.l = 0;
65  this->m_Problem.y = nullptr;
66  this->m_Problem.x = nullptr;
67 #ifdef NDEBUG
68  svm_set_print_string_function(&otb::Utils::PrintNothing);
69 #endif
70 }
71 
72 template <class TInputValue, class TOutputValue>
74 {
75  this->DeleteModel();
76  this->DeleteProblem();
77 }
78 
80 template <class TInputValue, class TOutputValue>
82 {
83  this->DeleteProblem();
84  this->DeleteModel();
86 
87  // Build problem
88  this->BuildProblem();
89 
90  // Check consistency
91  this->ConsistencyCheck();
92 
93  // Compute accuracy and eventually optimize parameters
94  this->OptimizeParameters();
95 
96  // train the model
97  m_Model = svm_train(&m_Problem, &m_Parameters);
98 
99  this->m_ConfidenceIndex = this->HasProbabilities();
100 }
101 
102 template <class TInputValue, class TOutputValue>
105 {
106  TargetSampleType target;
107  target.Fill(0);
108 
109  // Get type and number of classes
110  int svm_type = svm_get_svm_type(m_Model);
111 
112  // Allocate nodes (/TODO if performances problems are related to too
113  // many allocations, a cache approach can be set)
114  struct svm_node* x = new struct svm_node[input.Size() + 1];
115 
116  // Fill the node
117  for (unsigned int i = 0; i < input.Size(); i++)
118  {
119  x[i].index = i + 1;
120  x[i].value = input[i];
121  }
122 
123  // terminate node
124  x[input.Size()].index = -1;
125  x[input.Size()].value = 0;
126  if (proba != nullptr && !this->m_ProbaIndex)
127  itkExceptionMacro("Probability per class not available for this classifier !");
128 
129  if (quality != nullptr)
130  {
131  if (!this->m_ConfidenceIndex)
132  {
133  itkExceptionMacro("Confidence index not available for this classifier !");
134  }
135  if (this->m_ConfidenceMode == CM_INDEX)
136  {
137  if (svm_type == C_SVC || svm_type == NU_SVC)
138  {
139  // Eventually allocate space for probabilities
140  unsigned int nr_class = svm_get_nr_class(m_Model);
141  double* prob_estimates = new double[nr_class];
142  // predict
143  target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
144  double maxProb = 0.0;
145  double secProb = 0.0;
146  for (unsigned int i = 0; i < nr_class; ++i)
147  {
148  if (maxProb < prob_estimates[i])
149  {
150  secProb = maxProb;
151  maxProb = prob_estimates[i];
152  }
153  else if (secProb < prob_estimates[i])
154  {
155  secProb = prob_estimates[i];
156  }
157  }
158  (*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
159 
160  delete[] prob_estimates;
161  }
162  else
163  {
164  target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
165  // Prob. model for test data: target value = predicted value + z
166  // z: Laplace distribution e^(-|z|/sigma)/(2sigma)
167  // sigma is output as confidence index
168  (*quality) = svm_get_svr_probability(m_Model);
169  }
170  }
171  else if (this->m_ConfidenceMode == CM_PROBA)
172  {
173  target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, quality));
174  }
175  else if (this->m_ConfidenceMode == CM_HYPER)
176  {
177  target[0] = static_cast<TargetValueType>(svm_predict_values(m_Model, x, quality));
178  }
179  }
180  else
181  {
182  // default case : if the model has probabilities, we call svm_predict_probabilities()
183  // which gives different results than svm_predict()
184  if (svm_check_probability_model(m_Model))
185  {
186  unsigned int nr_class = svm_get_nr_class(m_Model);
187  double* prob_estimates = new double[nr_class];
188  target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
189  delete[] prob_estimates;
190  }
191  else
192  {
193  target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
194  }
195  }
196 
197  // Free allocated memory
198  delete[] x;
199 
200  return target;
201 }
202 
203 template <class TInputValue, class TOutputValue>
204 void LibSVMMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& itkNotUsed(name))
205 {
206  if (svm_save_model(filename.c_str(), m_Model) != 0)
207  {
208  itkExceptionMacro(<< "Problem while saving SVM model " << filename);
209  }
210 }
211 
212 template <class TInputValue, class TOutputValue>
213 void LibSVMMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& itkNotUsed(name))
214 {
215  this->DeleteModel();
216  m_Model = svm_load_model(filename.c_str());
217  if (m_Model == nullptr)
218  {
219  itkExceptionMacro(<< "Problem while loading SVM model " << filename);
220  }
221  m_Parameters = m_Model->param;
222 
223  this->m_ConfidenceIndex = this->HasProbabilities();
224 }
225 
226 template <class TInputValue, class TOutputValue>
228 {
229  // TODO: Rework.
230  std::ifstream ifs;
231  ifs.open(file);
232 
233  if (!ifs)
234  {
235  std::cerr << "Could not read file " << file << std::endl;
236  return false;
237  }
238 
239  // Read only the first line.
240  std::string line;
241  std::getline(ifs, line);
242 
243  // if (line.find(m_SVMModel->getName()) != std::string::npos)
244  if (line.find("svm_type") != std::string::npos)
245  {
246  // std::cout<<"Reading a libSVM model"<<std::endl;
247  return true;
248  }
249  ifs.close();
250  return false;
251 }
252 
253 template <class TInputValue, class TOutputValue>
255 {
256  return false;
257 }
258 
259 template <class TInputValue, class TOutputValue>
260 void LibSVMMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
261 {
262  // Call superclass implementation
263  Superclass::PrintSelf(os, indent);
264 }
265 
266 template <class TInputValue, class TOutputValue>
268 {
269  bool modelHasProba = static_cast<bool>(svm_check_probability_model(m_Model));
270  int type = svm_get_svm_type(m_Model);
271  int cmMode = this->m_ConfidenceMode;
272  bool ret = false;
273  if (type == EPSILON_SVR || type == NU_SVR)
274  {
275  ret = (modelHasProba && cmMode == CM_INDEX);
276  }
277  else if (type == C_SVC || type == NU_SVC)
278  {
279  ret = (modelHasProba && (cmMode == CM_INDEX || cmMode == CM_PROBA)) || (cmMode == CM_HYPER);
280  }
281  return ret;
282 }
283 
284 template <class TInputValue, class TOutputValue>
286 {
287  // Get number of samples
288  typename InputListSampleType::Pointer samples = this->GetInputListSample();
289  typename TargetListSampleType::Pointer target = this->GetTargetListSample();
290  int probl = samples->Size();
291 
292  if (probl < 1)
293  {
294  itkExceptionMacro(<< "No samples, can not build SVM problem.");
295  }
296  otbMsgDebugMacro(<< "Building problem ...");
297 
298  // Get the size of the samples
299  long int elements = samples->GetMeasurementVectorSize();
300 
301  // Allocate the problem
302  m_Problem.l = probl;
303  m_Problem.y = new double[probl];
304  m_Problem.x = new struct svm_node*[probl];
305  for (int i = 0; i < probl; ++i)
306  {
307  m_Problem.x[i] = new struct svm_node[elements + 1];
308  }
309 
310  // Iterate on the samples
311  typename InputListSampleType::ConstIterator sIt = samples->Begin();
312  typename TargetListSampleType::ConstIterator tIt = target->Begin();
313  int sampleIndex = 0;
314 
315  while (sIt != samples->End() && tIt != target->End())
316  {
317  // Set the label
318  m_Problem.y[sampleIndex] = tIt.GetMeasurementVector()[0];
319  const InputSampleType& sample = sIt.GetMeasurementVector();
320  for (int k = 0; k < elements; ++k)
321  {
322  m_Problem.x[sampleIndex][k].index = k + 1;
323  m_Problem.x[sampleIndex][k].value = sample[k];
324  }
325  // terminate node
326  m_Problem.x[sampleIndex][elements].index = -1;
327  m_Problem.x[sampleIndex][elements].value = 0;
328 
329  ++sampleIndex;
330  ++sIt;
331  ++tIt;
332  }
333 
334  // Compute the kernel gamma from number of elements if necessary
335  if (this->GetKernelGamma() == 0)
336  {
337  this->SetKernelGamma(1.0 / static_cast<double>(elements));
338  }
339 
340  // allocate buffer for cross validation
341  m_TmpTarget.resize(m_Problem.l);
342 }
343 
344 template <class TInputValue, class TOutputValue>
346 {
347  if (this->GetSVMType() == ONE_CLASS && this->GetDoProbabilityEstimates())
348  {
349  otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
350  this->SetDoProbabilityEstimates(false);
351  }
352 
353  const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
354 
355  if (error_msg)
356  {
357  std::string err(error_msg);
358  itkExceptionMacro("SVM parameter check failed : " << err);
359  }
360 }
361 
362 template <class TInputValue, class TOutputValue>
364 {
365  if (m_Problem.y)
366  {
367  delete[] m_Problem.y;
368  m_Problem.y = nullptr;
369  }
370  if (m_Problem.x)
371  {
372  for (int i = 0; i < m_Problem.l; ++i)
373  {
374  if (m_Problem.x[i])
375  {
376  delete[] m_Problem.x[i];
377  }
378  }
379  delete[] m_Problem.x;
380  m_Problem.x = nullptr;
381  }
382  m_Problem.l = 0;
383 }
384 
385 template <class TInputValue, class TOutputValue>
387 {
388  if (m_Model)
389  {
390  svm_free_and_destroy_model(&m_Model);
391  }
392  m_Model = nullptr;
393 }
394 
395 template <class TInputValue, class TOutputValue>
397 {
398  unsigned int nb = 1;
399  switch (this->GetKernelType())
400  {
401  case LINEAR:
402  // C
403  nb = 1;
404  break;
405  case POLY:
406  // C, gamma and coef0
407  nb = 3;
408  break;
409  case RBF:
410  // C and gamma
411  nb = 2;
412  break;
413  case SIGMOID:
414  // C, gamma and coef0
415  nb = 3;
416  break;
417  default:
418  // C
419  nb = 1;
420  break;
421  }
422  return nb;
423 }
424 
425 template <class TInputValue, class TOutputValue>
427 {
428  double accuracy = 0.0;
429  // Get the length of the problem
430  unsigned int length = m_Problem.l;
431  if (length == 0 || m_TmpTarget.size() < length)
432  return accuracy;
433 
434  // Do cross validation
435  svm_cross_validation(&m_Problem, &m_Parameters, m_CVFolders, &m_TmpTarget[0]);
436 
437  // Evaluate accuracy
438  double total_correct = 0.;
439  for (unsigned int i = 0; i < length; ++i)
440  {
441  if (m_TmpTarget[i] == m_Problem.y[i])
442  {
443  ++total_correct;
444  }
445  }
446  accuracy = total_correct / length;
447 
448  // return accuracy value
449  return accuracy;
450 }
451 
452 template <class TInputValue, class TOutputValue>
454 {
456  typename CrossValidationFunctionType::Pointer crossValidationFunction = CrossValidationFunctionType::New();
457  crossValidationFunction->SetModel(this);
458 
459  typename CrossValidationFunctionType::ParametersType initialParameters, coarseBestParameters, fineBestParameters;
460 
461  unsigned int nbParams = this->GetNumberOfKernelParameters();
462  initialParameters.SetSize(nbParams);
463  initialParameters[0] = this->GetC();
464  if (nbParams > 1)
465  initialParameters[1] = this->GetKernelGamma();
466  if (nbParams > 2)
467  initialParameters[2] = this->GetKernelCoef0();
468 
469  m_InitialCrossValidationAccuracy = crossValidationFunction->GetValue(initialParameters);
470  m_FinalCrossValidationAccuracy = m_InitialCrossValidationAccuracy;
471 
472  otbMsgDebugMacro(<< "Initial accuracy : " << m_InitialCrossValidationAccuracy << ", Parameters Optimization" << m_ParameterOptimization);
473 
474  if (m_ParameterOptimization)
475  {
476  otbMsgDebugMacro(<< "Model parameters optimization");
478  typename ExhaustiveExponentialOptimizer::StepsType coarseNbSteps(initialParameters.Size());
479  coarseNbSteps.Fill(m_CoarseOptimizationNumberOfSteps);
480 
481  coarseOptimizer->SetNumberOfSteps(coarseNbSteps);
482  coarseOptimizer->SetCostFunction(crossValidationFunction);
483  coarseOptimizer->SetInitialPosition(initialParameters);
484  coarseOptimizer->StartOptimization();
485 
486  coarseBestParameters = coarseOptimizer->GetMaximumMetricValuePosition();
487 
488  otbMsgDevMacro(<< "Coarse minimum accuracy: " << coarseOptimizer->GetMinimumMetricValue() << " " << coarseOptimizer->GetMinimumMetricValuePosition());
489  otbMsgDevMacro(<< "Coarse maximum accuracy: " << coarseOptimizer->GetMaximumMetricValue() << " " << coarseOptimizer->GetMaximumMetricValuePosition());
490 
492  typename ExhaustiveExponentialOptimizer::StepsType fineNbSteps(initialParameters.Size());
493  fineNbSteps.Fill(m_FineOptimizationNumberOfSteps);
494 
495  double stepLength = 1. / static_cast<double>(m_FineOptimizationNumberOfSteps);
496 
497  fineOptimizer->SetNumberOfSteps(fineNbSteps);
498  fineOptimizer->SetStepLength(stepLength);
499  fineOptimizer->SetCostFunction(crossValidationFunction);
500  fineOptimizer->SetInitialPosition(coarseBestParameters);
501  fineOptimizer->StartOptimization();
502 
503  otbMsgDevMacro(<< "Fine minimum accuracy: " << fineOptimizer->GetMinimumMetricValue() << " " << fineOptimizer->GetMinimumMetricValuePosition());
504  otbMsgDevMacro(<< "Fine maximum accuracy: " << fineOptimizer->GetMaximumMetricValue() << " " << fineOptimizer->GetMaximumMetricValuePosition());
505 
506  fineBestParameters = fineOptimizer->GetMaximumMetricValuePosition();
507 
508  m_FinalCrossValidationAccuracy = fineOptimizer->GetMaximumMetricValue();
509 
510  this->SetC(fineBestParameters[0]);
511  if (nbParams > 1)
512  this->SetKernelGamma(fineBestParameters[1]);
513  if (nbParams > 2)
514  this->SetKernelCoef0(fineBestParameters[2]);
515  }
516 }
517 
518 } // end namespace otb
519 
520 #endif
otb::LibSVMMachineLearningModel::CrossValidation
double CrossValidation(void)
Definition: otbLibSVMMachineLearningModel.hxx:426
otbLibSVMMachineLearningModel.h
otb::LibSVMMachineLearningModel::~LibSVMMachineLearningModel
~LibSVMMachineLearningModel() override
Definition: otbLibSVMMachineLearningModel.hxx:73
otb::LibSVMMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbLibSVMMachineLearningModel.hxx:204
otbUtils.h
otb::LibSVMMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbLibSVMMachineLearningModel.h:49
otb::Utils::PrintNothing
void OTBCommon_EXPORT PrintNothing(const char *s)
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::LibSVMMachineLearningModel::BuildProblem
void BuildProblem(void)
Definition: otbLibSVMMachineLearningModel.hxx:285
otb::SVMCrossValidationCostFunction
This function returns the cross validation accuracy of a SVM model.
Definition: otbSVMCrossValidationCostFunction.h:53
otb::LibSVMMachineLearningModel
Definition: otbLibSVMMachineLearningModel.h:33
otb::LibSVMMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbLibSVMMachineLearningModel.h:46
otbExhaustiveExponentialOptimizer.h
otbMacro.h
otb::LibSVMMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbLibSVMMachineLearningModel.h:48
otb::LibSVMMachineLearningModel::LibSVMMachineLearningModel
LibSVMMachineLearningModel()
Definition: otbLibSVMMachineLearningModel.hxx:35
otb::ExhaustiveExponentialOptimizer::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbExhaustiveExponentialOptimizer.h:53
otb::LibSVMMachineLearningModel::DeleteProblem
void DeleteProblem(void)
Definition: otbLibSVMMachineLearningModel.hxx:363
otb::ExhaustiveExponentialOptimizer::StepsType
itk::Array< unsigned long > StepsType
Definition: otbExhaustiveExponentialOptimizer.h:56
otb::LibSVMMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbLibSVMMachineLearningModel.h:45
otb::LibSVMMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbLibSVMMachineLearningModel.hxx:213
otb::LibSVMMachineLearningModel::ConsistencyCheck
void ConsistencyCheck(void)
Definition: otbLibSVMMachineLearningModel.hxx:345
otb::LibSVMMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbLibSVMMachineLearningModel.hxx:254
otb::LibSVMMachineLearningModel::GetNumberOfKernelParameters
unsigned int GetNumberOfKernelParameters()
Definition: otbLibSVMMachineLearningModel.hxx:396
otb::LibSVMMachineLearningModel::Train
void Train() override
Definition: otbLibSVMMachineLearningModel.hxx:81
otb::ExhaustiveExponentialOptimizer::New
static Pointer New()
otb::LibSVMMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbLibSVMMachineLearningModel.hxx:260
otbMsgDebugMacro
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:62
otb::LibSVMMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbLibSVMMachineLearningModel.hxx:227
otbSVMCrossValidationCostFunction.h
otb::LibSVMMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbLibSVMMachineLearningModel.h:43
otbMsgDevMacro
#define otbMsgDevMacro(x)
Definition: otbMacro.h:64
otb::LibSVMMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbLibSVMMachineLearningModel.hxx:104
otb::LibSVMMachineLearningModel::HasProbabilities
bool HasProbabilities(void) const
Definition: otbLibSVMMachineLearningModel.hxx:267
otb::LibSVMMachineLearningModel::DeleteModel
void DeleteModel(void)
Definition: otbLibSVMMachineLearningModel.hxx:386
otb::LibSVMMachineLearningModel::OptimizeParameters
void OptimizeParameters(void)
Definition: otbLibSVMMachineLearningModel.hxx:453