OTB  5.0.0
Orfeo Toolbox
otbSVMModel.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbSVMModel_txx
19 #define __otbSVMModel_txx
20 #include "otbSVMModel.h"
21 #include "otbMacro.h"
22 
23 #include <algorithm>
24 
25 namespace otb
26 {
27 // TODO: Check memory allocation in this class
28 template <class TValue, class TLabel>
30 {
31  // Default parameters
32  this->SetSVMType(C_SVC);
33  this->SetKernelType(LINEAR);
34  this->SetPolynomialKernelDegree(3);
35  this->SetKernelGamma(1.); // 1/k
36  this->SetKernelCoef0(1.);
37  this->SetNu(0.5);
38  this->SetCacheSize(40);
39  this->SetC(1);
40  this->SetEpsilon(1e-3);
41  this->SetP(0.1);
42  this->DoShrinking(true);
43  this->DoProbabilityEstimates(false);
44 
45  m_Parameters.nr_weight = 0;
46  m_Parameters.weight_label = NULL;
47  m_Parameters.weight = NULL;
48 
49  m_Model = NULL;
50 
51  this->Initialize();
52 }
53 
54 template <class TValue, class TLabel>
56 {
57  this->DeleteModel();
58  this->DeleteProblem();
59 }
60 template <class TValue, class TLabel>
61 void
63 {
64  // Initialize model
65  /*
66  if (!m_Model)
67  {
68  m_Model = new struct svm_model;
69  m_Model->l = 0;
70  m_Model->nr_class = 0;
71  m_Model->SV = NULL;
72  m_Model->sv_coef = NULL;
73  m_Model->rho = NULL;
74  m_Model->label = NULL;
75  m_Model->probA = NULL;
76  m_Model->probB = NULL;
77  m_Model->nSV = NULL;
78 
79  m_ModelUpToDate = false;
80 
81  } */
82  m_ModelUpToDate = false;
83 
84  // Intialize problem
85  m_Problem.l = 0;
86  m_Problem.y = NULL;
87  m_Problem.x = NULL;
88 
89  m_ProblemUpToDate = false;
90 }
91 
92 template <class TValue, class TLabel>
93 void
95 {
96  this->DeleteProblem();
97  this->DeleteModel();
98 
99  // Clear samples
100  m_Samples.clear();
101 
102  // Initialize values
103  this->Initialize();
104 }
105 
106 template <class TValue, class TLabel>
107 void
109 {
110  if(m_Model)
111  {
112  svm_free_and_destroy_model(&m_Model);
113  }
114  m_Model = NULL;
115 }
116 
117 template <class TValue, class TLabel>
118 void
120 {
121 // Deallocate any existing problem
122  if (m_Problem.y)
123  {
124  delete[] m_Problem.y;
125  m_Problem.y = NULL;
126  }
127 
128  if (m_Problem.x)
129  {
130  for (int i = 0; i < m_Problem.l; ++i)
131  {
132  if (m_Problem.x[i])
133  {
134  delete[] m_Problem.x[i];
135  }
136  }
137  delete[] m_Problem.x;
138  m_Problem.x = NULL;
139  }
140  m_Problem.l = 0;
141  m_ProblemUpToDate = false;
142 }
143 
144 template <class TValue, class TLabel>
145 void
147 {
148  SampleType newSample(measure, label);
149  m_Samples.push_back(newSample);
150  m_ProblemUpToDate = false;
151 }
152 
153 template <class TValue, class TLabel>
154 void
156 {
157  m_Samples.clear();
158  m_ProblemUpToDate = false;
159 }
160 
161 template <class TValue, class TLabel>
162 void
164 {
165  m_Samples = samples;
166  m_ProblemUpToDate = false;
167 }
168 
169 template <class TValue, class TLabel>
170 void
172 {
173  // Check if problem is up-to-date
174  if (m_ProblemUpToDate)
175  {
176  return;
177  }
178 
179  // Get number of samples
180  int probl = m_Samples.size();
181 
182  if (probl < 1)
183  {
184  itkExceptionMacro(<< "No samples, can not build SVM problem.");
185  }
186  otbMsgDebugMacro(<< "Rebuilding problem ...");
187 
188  // Get the size of the samples
189  long int elements = m_Samples[0].first.size() + 1;
190 
191  // Deallocate any previous problem
192  this->DeleteProblem();
193 
194  // Allocate the problem
195  m_Problem.l = probl;
196  m_Problem.y = new double[probl];
197  m_Problem.x = new struct svm_node*[probl];
198 
199  for (int i = 0; i < probl; ++i)
200  {
201  // Initialize labels to 0
202  m_Problem.y[i] = 0;
203  m_Problem.x[i] = new struct svm_node[elements];
204 
205  // Intialize elements (value = 0; index = -1)
206  for (unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
207  {
208  m_Problem.x[i][j].index = -1;
209  m_Problem.x[i][j].value = 0;
210  }
211  }
212 
213  // Iterate on the samples
214  typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
215  int sampleIndex = 0;
216  int maxElementIndex = 0;
217 
218  while (sIt != m_Samples.end())
219  {
220 
221  // Get the sample measurement and label
222  MeasurementType measure = sIt->first;
223  LabelType label = sIt->second;
224 
225  // Set the label
226  m_Problem.y[sampleIndex] = label;
227 
228  int elementIndex = 0;
229 
230  // Populate the svm nodes
231  for (typename MeasurementType::const_iterator eIt = measure.begin();
232  eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
233  {
234  m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
235  m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
236  }
237 
238  // Get the max index
239  if (elementIndex > maxElementIndex)
240  {
241  maxElementIndex = elementIndex;
242  }
243 
244  ++sampleIndex;
245  ++sIt;
246  }
247 
248  // Compute the kernel gamma from maxElementIndex if necessary
249  if (this->GetKernelGamma() == 0)
250  {
251  this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
252  }
253 
254  // problem is up-to-date
255  m_ProblemUpToDate = true;
256 }
257 
258 template <class TValue, class TLabel>
259 double
261 {
262  // Build problem
263  this->BuildProblem();
264 
265  // Check consistency
266  this->ConsistencyCheck();
267 
268  // Get the length of the problem
269  int length = m_Problem.l;
270 
271  // Temporary memory to store cross validation results
272  double *target = new double[length];
273 
274  // Do cross validation
275  svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, target);
276 
277  // Evaluate accuracy
278  int i;
279  double total_correct = 0.;
280 
281  for (i = 0; i < length; ++i)
282  {
283  if (target[i] == m_Problem.y[i])
284  {
285  ++total_correct;
286  }
287  }
288  double accuracy = total_correct / length;
289 
290  // Free temporary memory
291  delete[] target;
292 
293  // return accuracy value
294  return accuracy;
295 }
296 
297 template <class TValue, class TLabel>
298 void
300 {
301  if (m_Parameters.svm_type == ONE_CLASS && this->GetDoProbabilityEstimates())
302  {
303  otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
304  this->DoProbabilityEstimates(false);
305  }
306 
307  const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
308 
309  if (error_msg)
310  {
311  throw itk::ExceptionObject(__FILE__, __LINE__, error_msg, ITK_LOCATION);
312  }
313 }
314 
315 template <class TValue, class TLabel>
316 void
318 {
319  // If the model is already up-to-date, return
320  if (m_ModelUpToDate)
321  {
322  return;
323  }
324 
325  // Build problem
326  this->BuildProblem();
327 
328  // Check consistency
329  this->ConsistencyCheck();
330 
331  // train the model
332  m_Model = svm_train(&m_Problem, &m_Parameters);
333 
334  // Set the model as up-to-date
335  m_ModelUpToDate = true;
336 }
337 
338 template <class TValue, class TLabel>
341 {
342  // Check if model is up-to-date
343  if (!m_ModelUpToDate)
344  {
345  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
346  }
347 
348  // Check probability prediction
349  bool predict_probability = svm_check_probability_model(m_Model);
350 
351  if (this->GetSVMType() == ONE_CLASS)
352  {
353  predict_probability = 0;
354  }
355 
356  // Get type and number of classes
357  int svm_type = svm_get_svm_type(m_Model);
358  int nr_class = svm_get_nr_class(m_Model);
359 
360  // Allocate space for labels
361  double *prob_estimates = NULL;
362 
363  // Eventually allocate space for probabilities
364  if (predict_probability)
365  {
366  if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
367  {
369  <<
370  "Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
371  << svm_get_svr_probability(m_Model));
372  }
373  else
374  {
375  prob_estimates = new double[nr_class];
376  }
377  }
378 
379  // Allocate nodes (/TODO if performances problems are related to too
380  // many allocations, a cache approach can be set)
381  struct svm_node * x = new struct svm_node[measure.size() + 1];
382 
383  int valueIndex = 0;
384 
385  // Fill the node
386  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
387  {
388  x[valueIndex].index = valueIndex + 1;
389  x[valueIndex].value = (*mIt);
390  }
391 
392  // terminate node
393  x[measure.size()].index = -1;
394  x[measure.size()].value = 0;
395 
396  LabelType label = 0;
397 
398  if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC))
399  {
400  label = static_cast<LabelType>(svm_predict_probability(m_Model, x, prob_estimates));
401  }
402  else
403  {
404  label = static_cast<LabelType>(svm_predict(m_Model, x));
405  }
406 
407  // Free allocated memory
408  delete[] x;
409 
410  if (prob_estimates)
411  {
412  delete[] prob_estimates;
413  }
414 
415  return label;
416 }
417 
418 template <class TValue, class TLabel>
421 {
422  // Check if model is up-to-date
423  if (!m_ModelUpToDate)
424  {
425  itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
426  }
427 
428  // Allocate nodes (/TODO if performances problems are related to too
429  // many allocations, a cache approach can be set)
430  struct svm_node * x = new struct svm_node[measure.size() + 1];
431 
432  int valueIndex = 0;
433 
434  // Fill the node
435  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
436  {
437  x[valueIndex].index = valueIndex + 1;
438  x[valueIndex].value = (*mIt);
439  }
440 
441  // terminate node
442  x[measure.size()].index = -1;
443  x[measure.size()].value = 0;
444 
445  // Intialize distances vector
446  DistancesVectorType distances(m_Model->nr_class*(m_Model->nr_class - 1) / 2);
447 
448  // predict distances vector
449  svm_predict_values(m_Model, x, (double*) (distances.GetDataPointer()));
450 
451  // Free allocated memory
452  delete[] x;
453 
454  return (distances);
455 }
456 
457 template <class TValue, class TLabel>
460 {
461  // Check if model is up-to-date
462  if (!m_ModelUpToDate)
463  {
464  itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
465  }
466 
467  if (!this->HasProbabilities())
468  {
469  throw itk::ExceptionObject(__FILE__, __LINE__,
470  "Model does not support probability estimates", ITK_LOCATION);
471  }
472 
473  // Get number of classes
474  int nr_class = svm_get_nr_class(m_Model);
475 
476  // Allocate nodes (/TODO if performances problems are related to too
477  // many allocations, a cache approach can be set)
478  struct svm_node * x = new struct svm_node[measure.size() + 1];
479 
480  int valueIndex = 0;
481 
482  // Fill the node
483  for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
484  {
485  x[valueIndex].index = valueIndex + 1;
486  x[valueIndex].value = (*mIt);
487  }
488 
489  // Termination node
490  x[measure.size()].index = -1;
491  x[measure.size()].value = 0;
492 
493  double* dec_values = new double[nr_class];
494  svm_predict_probability(m_Model, x, dec_values);
495 
496  // Reorder values in increasing class label
497  int* labels = m_Model->label;
498  std::vector<int> orderedLabels(nr_class);
499  std::copy(labels, labels + nr_class, orderedLabels.begin());
500  std::sort(orderedLabels.begin(), orderedLabels.end());
501 
502  ProbabilitiesVectorType probabilities(nr_class);
503  for (int i = 0; i < nr_class; ++i)
504  {
505  // svm_predict_probability is such that "dec_values[i]" corresponds to label "labels[i]"
506  std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
507  probabilities[it - orderedLabels.begin()] = dec_values[i];
508  }
509 
510  // Free allocated memory
511  delete[] x;
512  delete[] dec_values;
513 
514  return probabilities;
515 }
516 
517 
518 template <class TValue, class TLabel>
519 void
520 SVMModel<TValue, TLabel>::SaveModel(const char* model_file_name) const
521 {
522  if (svm_save_model(model_file_name, m_Model) != 0)
523  {
524  itkExceptionMacro(<< "Problem while saving SVM model "
525  << std::string(model_file_name));
526  }
527 }
528 
529 template <class TValue, class TLabel>
530 void
531 SVMModel<TValue, TLabel>::LoadModel(const char* model_file_name)
532 {
533  this->DeleteModel();
534  m_Model = svm_load_model(model_file_name);
535  if (m_Model == 0)
536  {
537  itkExceptionMacro(<< "Problem while loading SVM model "
538  << std::string(model_file_name));
539  }
540  m_Parameters = m_Model->param;
541  m_ModelUpToDate = true;
542 }
543 
544 template <class TValue, class TLabel>
545 void
546 SVMModel<TValue, TLabel>::PrintSelf(std::ostream& os, itk::Indent indent) const
547 {
548  Superclass::PrintSelf(os, indent);
549 }
550 
551 template <class TValue, class TLabel>
552 void
553 SVMModel<TValue, TLabel>::SetSupportVectors(svm_node ** sv, int nbOfSupportVector)
554 {
555  if (!m_Model)
556  {
557  itkExceptionMacro( "Internal SVM model is empty!");
558  }
559 
560  // erase the old SV
561  // delete just the first element, it destoyes the whole pointers (cf SV filling with x_space)
562  delete[] (m_Model->SV[0]);
563 
564  for (int n = 0; n < m_Model->l; ++n)
565  {
566  m_Model->SV[n] = NULL;
567  }
568  delete[] (m_Model->SV);
569  m_Model->SV = NULL;
570 
571  m_Model->SV = new struct svm_node*[m_Model->l];
572 
573  // copy new SV values
574  svm_node **SV = m_Model->SV;
575 
576  // Compute the total number of SV elements.
577  unsigned int elements = 0;
578  for (int p = 0; p < nbOfSupportVector; ++p)
579  {
580  //std::cout << p << " ";
581  const svm_node *tempNode = sv[p];
582  //std::cout << p << " ";
583  while (tempNode->index != -1)
584  {
585  tempNode++;
586  ++elements;
587  }
588  ++elements; // for -1 values
589  }
590 
591  if (m_Model->l > 0)
592  {
593  SV[0] = new struct svm_node[elements];
594  memcpy(SV[0], sv[0], sizeof(svm_node*) * elements);
595  }
596  svm_node *x_space = SV[0];
597 
598  int j = 0;
599  for (int i = 0; i < m_Model->l; ++i)
600  {
601  // SV
602  SV[i] = &x_space[j];
603  const svm_node *p = sv[i];
604  svm_node * pCpy = SV[i];
605  while (p->index != -1)
606  {
607  pCpy->index = p->index;
608  pCpy->value = p->value;
609  ++p;
610  ++pCpy;
611  ++j;
612  }
613  pCpy->index = -1;
614  ++j;
615  }
616 
617  if (m_Model->l > 0)
618  {
619  delete[] SV[0];
620  }
621 }
622 
623 template <class TValue, class TLabel>
624 void
625 SVMModel<TValue, TLabel>::SetAlpha(double ** alpha, int nbOfSupportVector)
626 {
627  if (!m_Model)
628  {
629  itkExceptionMacro( "Internal SVM model is empty!");
630  }
631 
632  // Erase the old sv_coef
633  for (int i = 0; i < m_Model->nr_class - 1; ++i)
634  {
635  delete[] m_Model->sv_coef[i];
636  }
637  delete[] m_Model->sv_coef;
638 
639  // copy new sv_coef values
640  m_Model->sv_coef = new double*[m_Model->nr_class - 1];
641  for (int i = 0; i < m_Model->nr_class - 1; ++i)
642  m_Model->sv_coef[i] = new double[m_Model->l];
643 
644  for (int i = 0; i < m_Model->l; ++i)
645  {
646  // sv_coef
647  for (int k = 0; k < m_Model->nr_class - 1; ++k)
648  {
649  m_Model->sv_coef[k][i] = alpha[k][i];
650  }
651  }
652 }
653 
654 } // end namespace otb
655 
656 #endif
void SaveModel(const char *model_file_name) const
DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType &measure) const
void SetSamples(const SamplesVectorType &samples)
void ClearSamples()
void AddSample(const MeasurementType &measure, const LabelType &label)
void BuildProblem()
void Initialize()
Definition: otbSVMModel.txx:62
void LoadModel(const char *model_file_name)
void ConsistencyCheck()
const TValue * GetDataPointer() const
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:55
void SetAlpha(double **alpha, int nbOfSupportVector)
LabelType EvaluateLabel(const MeasurementType &measure) const
TLabel LabelType
Definition: otbSVMModel.h:74
double CrossValidation(unsigned int nbFolders)
void DeleteProblem()
ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType &measure) const
void PrintSelf(std::ostream &os, itk::Indent indent) const
#define NULL
std::vector< SampleType > SamplesVectorType
Definition: otbSVMModel.h:77
std::vector< ValueType > MeasurementType
Definition: otbSVMModel.h:75
std::pair< MeasurementType, LabelType > SampleType
Definition: otbSVMModel.h:76
void SetSupportVectors(svm_node **sv, int nbOfSupportVector)
#define otbMsgDevMacro(x)
Definition: otbMacro.h:95
void DeleteModel()
virtual ~SVMModel()
Definition: otbSVMModel.txx:55