OTB  5.11.0
Orfeo Toolbox
otbSVMModel.txx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMModel_txx
22 #define otbSVMModel_txx
23 #include "otbSVMModel.h"
24 #include "otbMacro.h"
25 
26 #include <algorithm>
27 
28 namespace otb
29 {
30 // TODO: Check memory allocation in this class
31 template <class TValue, class TLabel>
33 {
34  // Default parameters
35  this->SetSVMType(C_SVC);
36  this->SetKernelType(LINEAR);
37  this->SetPolynomialKernelDegree(3);
38  this->SetKernelGamma(1.); // 1/k
39  this->SetKernelCoef0(1.);
40  this->SetNu(0.5);
41  this->SetCacheSize(40);
42  this->SetC(1);
43  this->SetEpsilon(1e-3);
44  this->SetP(0.1);
45  this->DoShrinking(true);
46  this->DoProbabilityEstimates(false);
47 
48  m_Parameters.nr_weight = 0;
49  m_Parameters.weight_label = ITK_NULLPTR;
50  m_Parameters.weight = ITK_NULLPTR;
51 
52  m_Model = ITK_NULLPTR;
53 
54  this->Initialize();
55 }
56 
57 template <class TValue, class TLabel>
59 {
60  this->DeleteModel();
61  this->DeleteProblem();
62 }
63 template <class TValue, class TLabel>
64 void
66 {
67  // Initialize model
68  /*
69  if (!m_Model)
70  {
71  m_Model = new struct svm_model;
72  m_Model->l = 0;
73  m_Model->nr_class = 0;
74  m_Model->SV = NULL;
75  m_Model->sv_coef = NULL;
76  m_Model->rho = NULL;
77  m_Model->label = NULL;
78  m_Model->probA = NULL;
79  m_Model->probB = NULL;
80  m_Model->nSV = NULL;
81 
82  m_ModelUpToDate = false;
83 
84  } */
85  m_ModelUpToDate = false;
86 
87  // Initialize problem
88  m_Problem.l = 0;
89  m_Problem.y = ITK_NULLPTR;
90  m_Problem.x = ITK_NULLPTR;
91 
92  m_ProblemUpToDate = false;
93 }
94 
95 template <class TValue, class TLabel>
96 void
98 {
99  this->DeleteProblem();
100  this->DeleteModel();
101 
102  // Clear samples
103  m_Samples.clear();
104 
105  // Initialize values
106  this->Initialize();
107 }
108 
109 template <class TValue, class TLabel>
110 void
112 {
113  if(m_Model)
114  {
115  svm_free_and_destroy_model(&m_Model);
116  }
117  m_Model = ITK_NULLPTR;
118 }
119 
120 template <class TValue, class TLabel>
121 void
123 {
124 // Deallocate any existing problem
125  if (m_Problem.y)
126  {
127  delete[] m_Problem.y;
128  m_Problem.y = ITK_NULLPTR;
129  }
130 
131  if (m_Problem.x)
132  {
133  for (int i = 0; i < m_Problem.l; ++i)
134  {
135  if (m_Problem.x[i])
136  {
137  delete[] m_Problem.x[i];
138  }
139  }
140  delete[] m_Problem.x;
141  m_Problem.x = ITK_NULLPTR;
142  }
143  m_Problem.l = 0;
144  m_ProblemUpToDate = false;
145 }
146 
147 template <class TValue, class TLabel>
148 void
150 {
151  SampleType newSample(measure, label);
152  m_Samples.push_back(newSample);
153  m_ProblemUpToDate = false;
154 }
155 
156 template <class TValue, class TLabel>
157 void
159 {
160  m_Samples.clear();
161  m_ProblemUpToDate = false;
162 }
163 
164 template <class TValue, class TLabel>
165 void
167 {
168  m_Samples = samples;
169  m_ProblemUpToDate = false;
170 }
171 
172 template <class TValue, class TLabel>
173 void
175 {
176  // Check if problem is up-to-date
177  if (m_ProblemUpToDate)
178  {
179  return;
180  }
181 
182  // Get number of samples
183  int probl = m_Samples.size();
184 
185  if (probl < 1)
186  {
187  itkExceptionMacro(<< "No samples, can not build SVM problem.");
188  }
189  otbMsgDebugMacro(<< "Rebuilding problem ...");
190 
191  // Get the size of the samples
192  long int elements = m_Samples[0].first.size() + 1;
193 
194  // Deallocate any previous problem
195  this->DeleteProblem();
196 
197  // Allocate the problem
198  m_Problem.l = probl;
199  m_Problem.y = new double[probl];
200  m_Problem.x = new struct svm_node*[probl];
201 
202  for (int i = 0; i < probl; ++i)
203  {
204  // Initialize labels to 0
205  m_Problem.y[i] = 0;
206  m_Problem.x[i] = new struct svm_node[elements];
207 
208  // Initialize elements (value = 0; index = -1)
209  for (unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
210  {
211  m_Problem.x[i][j].index = -1;
212  m_Problem.x[i][j].value = 0;
213  }
214  }
215 
216  // Iterate on the samples
217  typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
218  int sampleIndex = 0;
219  int maxElementIndex = 0;
220 
221  while (sIt != m_Samples.end())
222  {
223 
224  // Get the sample measurement and label
225  MeasurementType measure = sIt->first;
226  LabelType label = sIt->second;
227 
228  // Set the label
229  m_Problem.y[sampleIndex] = label;
230 
231  int elementIndex = 0;
232 
233  // Populate the svm nodes
234  for (typename MeasurementType::const_iterator eIt = measure.begin();
235  eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
236  {
237  m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
238  m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
239  }
240 
241  // Get the max index
242  if (elementIndex > maxElementIndex)
243  {
244  maxElementIndex = elementIndex;
245  }
246 
247  ++sampleIndex;
248  ++sIt;
249  }
250 
251  // Compute the kernel gamma from maxElementIndex if necessary
252  if (this->GetKernelGamma() == 0)
253  {
254  this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
255  }
256 
257  // problem is up-to-date
258  m_ProblemUpToDate = true;
259 }
260 
261 template <class TValue, class TLabel>
262 double
264 {
265  // Build problem
266  this->BuildProblem();
267 
268  // Check consistency
269  this->ConsistencyCheck();
270 
271  // Get the length of the problem
272  int length = m_Problem.l;
273 
274  // Temporary memory to store cross validation results
275  double *target = new double[length];
276 
277  // Do cross validation
278  svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, target);
279 
280  // Evaluate accuracy
281  int i;
282  double total_correct = 0.;
283 
284  for (i = 0; i < length; ++i)
285  {
286  if (target[i] == m_Problem.y[i])
287  {
288  ++total_correct;
289  }
290  }
291  double accuracy = total_correct / length;
292 
293  // Free temporary memory
294  delete[] target;
295 
296  // return accuracy value
297  return accuracy;
298 }
299 
300 template <class TValue, class TLabel>
301 void
303 {
304  if (m_Parameters.svm_type == ONE_CLASS && this->GetDoProbabilityEstimates())
305  {
306  otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
307  this->DoProbabilityEstimates(false);
308  }
309 
310  const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
311 
312  if (error_msg)
313  {
314  throw itk::ExceptionObject(__FILE__, __LINE__, error_msg, ITK_LOCATION);
315  }
316 }
317 
318 template <class TValue, class TLabel>
319 void
321 {
322  // If the model is already up-to-date, return
323  if (m_ModelUpToDate)
324  {
325  return;
326  }
327 
328  // Build problem
329  this->BuildProblem();
330 
331  // Check consistency
332  this->ConsistencyCheck();
333 
334  // train the model
335  m_Model = svm_train(&m_Problem, &m_Parameters);
336 
337  // Set the model as up-to-date
338  m_ModelUpToDate = true;
339 }
340 
341 template <class TValue, class TLabel>
344 {
345  // Check if model is up-to-date
346  if (!m_ModelUpToDate)
347  {
348  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
349  }
350 
351  // Check probability prediction
352  bool predict_probability = svm_check_probability_model(m_Model);
353 
354  if (this->GetSVMType() == ONE_CLASS)
355  {
356  predict_probability = 0;
357  }
358 
359  // Get type and number of classes
360  int svm_type = svm_get_svm_type(m_Model);
361  int nr_class = svm_get_nr_class(m_Model);
362 
363  // Allocate space for labels
364  double *prob_estimates = ITK_NULLPTR;
365 
366  // Eventually allocate space for probabilities
367  if (predict_probability)
368  {
369  if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
370  {
372  <<
373  "Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
374  << svm_get_svr_probability(m_Model));
375  }
376  else
377  {
378  prob_estimates = new double[nr_class];
379  }
380  }
381 
382  // Allocate nodes (/TODO if performances problems are related to too
383  // many allocations, a cache approach can be set)
384  struct svm_node * x = new struct svm_node[measure.size() + 1];
385 
386  int valueIndex = 0;
387 
388  // Fill the node
389  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
390  {
391  x[valueIndex].index = valueIndex + 1;
392  x[valueIndex].value = (*mIt);
393  }
394 
395  // terminate node
396  x[measure.size()].index = -1;
397  x[measure.size()].value = 0;
398 
399  LabelType label = 0;
400 
401  if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC))
402  {
403  label = static_cast<LabelType>(svm_predict_probability(m_Model, x, prob_estimates));
404  }
405  else
406  {
407  label = static_cast<LabelType>(svm_predict(m_Model, x));
408  }
409 
410  // Free allocated memory
411  delete[] x;
412 
413  if (prob_estimates)
414  {
415  delete[] prob_estimates;
416  }
417 
418  return label;
419 }
420 
421 template <class TValue, class TLabel>
424 {
425  // Check if model is up-to-date
426  if (!m_ModelUpToDate)
427  {
428  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
429  }
430 
431  // Allocate nodes (/TODO if performances problems are related to too
432  // many allocations, a cache approach can be set)
433  struct svm_node * x = new struct svm_node[measure.size() + 1];
434 
435  int valueIndex = 0;
436 
437  // Fill the node
438  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
439  {
440  x[valueIndex].index = valueIndex + 1;
441  x[valueIndex].value = (*mIt);
442  }
443 
444  // terminate node
445  x[measure.size()].index = -1;
446  x[measure.size()].value = 0;
447 
448  // Initialize distances vector
449  DistancesVectorType distances(m_Model->nr_class*(m_Model->nr_class - 1) / 2);
450 
451  // predict distances vector
452  svm_predict_values(m_Model, x, (double*) (distances.GetDataPointer()));
453 
454  // Free allocated memory
455  delete[] x;
456 
457  return (distances);
458 }
459 
460 template <class TValue, class TLabel>
463 {
464  // Check if model is up-to-date
465  if (!m_ModelUpToDate)
466  {
467  itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
468  }
469 
470  if (!this->HasProbabilities())
471  {
472  throw itk::ExceptionObject(__FILE__, __LINE__,
473  "Model does not support probability estimates", ITK_LOCATION);
474  }
475 
476  // Get number of classes
477  int nr_class = svm_get_nr_class(m_Model);
478 
479  // Allocate nodes (/TODO if performances problems are related to too
480  // many allocations, a cache approach can be set)
481  struct svm_node * x = new struct svm_node[measure.size() + 1];
482 
483  int valueIndex = 0;
484 
485  // Fill the node
486  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
487  {
488  x[valueIndex].index = valueIndex + 1;
489  x[valueIndex].value = (*mIt);
490  }
491 
492  // Termination node
493  x[measure.size()].index = -1;
494  x[measure.size()].value = 0;
495 
496  double* dec_values = new double[nr_class];
497  svm_predict_probability(m_Model, x, dec_values);
498 
499  // Reorder values in increasing class label
500  int* labels = m_Model->label;
501  std::vector<int> orderedLabels(nr_class);
502  std::copy(labels, labels + nr_class, orderedLabels.begin());
503  std::sort(orderedLabels.begin(), orderedLabels.end());
504 
505  ProbabilitiesVectorType probabilities(nr_class);
506  for (int i = 0; i < nr_class; ++i)
507  {
508  // svm_predict_probability is such that "dec_values[i]" corresponds to label "labels[i]"
509  std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
510  probabilities[it - orderedLabels.begin()] = dec_values[i];
511  }
512 
513  // Free allocated memory
514  delete[] x;
515  delete[] dec_values;
516 
517  return probabilities;
518 }
519 
520 
521 template <class TValue, class TLabel>
522 void
523 SVMModel<TValue, TLabel>::SaveModel(const char* model_file_name) const
524 {
525  if (svm_save_model(model_file_name, m_Model) != 0)
526  {
527  itkExceptionMacro(<< "Problem while saving SVM model "
528  << std::string(model_file_name));
529  }
530 }
531 
532 template <class TValue, class TLabel>
533 void
534 SVMModel<TValue, TLabel>::LoadModel(const char* model_file_name)
535 {
536  this->DeleteModel();
537  m_Model = svm_load_model(model_file_name);
538  if (m_Model == ITK_NULLPTR)
539  {
540  itkExceptionMacro(<< "Problem while loading SVM model "
541  << std::string(model_file_name));
542  }
543  m_Parameters = m_Model->param;
544  m_ModelUpToDate = true;
545 }
546 
547 template <class TValue, class TLabel>
548 void
549 SVMModel<TValue, TLabel>::PrintSelf(std::ostream& os, itk::Indent indent) const
550 {
551  Superclass::PrintSelf(os, indent);
552 }
553 
554 template <class TValue, class TLabel>
555 void
556 SVMModel<TValue, TLabel>::SetSupportVectors(svm_node ** sv, int nbOfSupportVector)
557 {
558  if (!m_Model)
559  {
560  itkExceptionMacro( "Internal SVM model is empty!");
561  }
562 
563  // erase the old SV
564  // delete just the first element, it destoyes the whole pointers (cf SV filling with x_space)
565  delete[] (m_Model->SV[0]);
566 
567  for (int n = 0; n < m_Model->l; ++n)
568  {
569  m_Model->SV[n] = ITK_NULLPTR;
570  }
571  delete[] (m_Model->SV);
572  m_Model->SV = ITK_NULLPTR;
573 
574  m_Model->SV = new struct svm_node*[m_Model->l];
575 
576  // copy new SV values
577  svm_node **SV = m_Model->SV;
578 
579  // Compute the total number of SV elements.
580  unsigned int elements = 0;
581  for (int p = 0; p < nbOfSupportVector; ++p)
582  {
583  //std::cout << p << " ";
584  const svm_node *tempNode = sv[p];
585  //std::cout << p << " ";
586  while (tempNode->index != -1)
587  {
588  tempNode++;
589  ++elements;
590  }
591  ++elements; // for -1 values
592  }
593 
594  if (m_Model->l > 0)
595  {
596  SV[0] = new struct svm_node[elements];
597  memcpy(SV[0], sv[0], sizeof(svm_node*) * elements);
598  }
599  svm_node *x_space = SV[0];
600 
601  int j = 0;
602  for (int i = 0; i < m_Model->l; ++i)
603  {
604  // SV
605  SV[i] = &x_space[j];
606  const svm_node *p = sv[i];
607  svm_node * pCpy = SV[i];
608  while (p->index != -1)
609  {
610  pCpy->index = p->index;
611  pCpy->value = p->value;
612  ++p;
613  ++pCpy;
614  ++j;
615  }
616  pCpy->index = -1;
617  ++j;
618  }
619 
620  if (m_Model->l > 0)
621  {
622  delete[] SV[0];
623  }
624 }
625 
626 template <class TValue, class TLabel>
627 void
628 SVMModel<TValue, TLabel>::SetAlpha(double ** alpha, int itkNotUsed(nbOfSupportVector))
629 {
630  if (!m_Model)
631  {
632  itkExceptionMacro( "Internal SVM model is empty!");
633  }
634 
635  // Erase the old sv_coef
636  for (int i = 0; i < m_Model->nr_class - 1; ++i)
637  {
638  delete[] m_Model->sv_coef[i];
639  }
640  delete[] m_Model->sv_coef;
641 
642  // copy new sv_coef values
643  m_Model->sv_coef = new double*[m_Model->nr_class - 1];
644  for (int i = 0; i < m_Model->nr_class - 1; ++i)
645  m_Model->sv_coef[i] = new double[m_Model->l];
646 
647  for (int i = 0; i < m_Model->l; ++i)
648  {
649  // sv_coef
650  for (int k = 0; k < m_Model->nr_class - 1; ++k)
651  {
652  m_Model->sv_coef[k][i] = alpha[k][i];
653  }
654  }
655 }
656 
657 } // end namespace otb
658 
659 #endif
void SaveModel(const char *model_file_name) const
DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType &measure) const
void SetSamples(const SamplesVectorType &samples)
void ClearSamples()
void AddSample(const MeasurementType &measure, const LabelType &label)
~SVMModel() ITK_OVERRIDE
Definition: otbSVMModel.txx:58
void BuildProblem()
void LoadModel(const char *model_file_name)
void PrintSelf(std::ostream &os, itk::Indent indent) const ITK_OVERRIDE
void ConsistencyCheck()
void Initialize() ITK_OVERRIDE
Definition: otbSVMModel.txx:65
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:58
void SetAlpha(double **alpha, int nbOfSupportVector)
LabelType EvaluateLabel(const MeasurementType &measure) const
TLabel LabelType
Definition: otbSVMModel.h:77
double CrossValidation(unsigned int nbFolders)
const TValue * GetDataPointer() const noexcept
void DeleteProblem()
ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType &measure) const
std::vector< SampleType > SamplesVectorType
Definition: otbSVMModel.h:80
std::vector< ValueType > MeasurementType
Definition: otbSVMModel.h:78
std::pair< MeasurementType, LabelType > SampleType
Definition: otbSVMModel.h:79
void SetSupportVectors(svm_node **sv, int nbOfSupportVector)
#define otbMsgDevMacro(x)
Definition: otbMacro.h:98
void DeleteModel()