OTB  5.0.0
Orfeo Toolbox
otbSVMModel.h
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbSVMModel_h
19 #define __otbSVMModel_h
20 
21 #include "itkObjectFactory.h"
22 #include "itkDataObject.h"
24 #include "itkTimeProbe.h"
25 #include "svm.h"
26 
27 namespace otb
28 {
29 
60 template <class TValue, class TLabel>
61 class ITK_EXPORT SVMModel : public itk::DataObject
62 {
63 public:
65  typedef SVMModel Self;
69 
71  typedef TValue ValueType;
72 
74  typedef TLabel LabelType;
75  typedef std::vector<ValueType> MeasurementType;
76  typedef std::pair<MeasurementType, LabelType> SampleType;
77  typedef std::vector<SampleType> SamplesVectorType;
78 
80  typedef std::vector<struct svm_node *> CacheVectorType;
81 
85 
86  typedef struct svm_node * NodeCacheType;
87 
89  itkNewMacro(Self);
90  itkTypeMacro(SVMModel, itk::DataObject);
92 
94  unsigned int GetNumberOfClasses(void) const
95  {
96  if (m_Model) return (unsigned int) (m_Model->nr_class);
97  return 0;
98  }
100 
102  unsigned int GetNumberOfHyperplane(void) const
103  {
104  if (m_Model) return (unsigned int) (m_Model->nr_class * (m_Model->nr_class - 1) / 2);
105  return 0;
106  }
108 
110  const struct svm_model* GetModel()
111  {
112  return m_Model;
113  }
114 
116  struct svm_parameter& GetParameters()
117  {
118  return m_Parameters;
119  }
120 
122  const struct svm_parameter& GetParameters() const
123  {
124  return m_Parameters;
125  }
126 
128  void SaveModel(const char* model_file_name) const;
129  void SaveModel(const std::string& model_file_name) const
130  {
131  //implemented in term of const char * version
132  this->SaveModel(model_file_name.c_str());
133  }
135 
137  void LoadModel(const char* model_file_name);
138  void LoadModel(const std::string& model_file_name)
139  {
140  //implemented in term of const char * version
141  this->LoadModel(model_file_name.c_str());
142  }
144 
146  void SetSVMType(int svmtype)
147  {
148  m_Parameters.svm_type = svmtype;
149  m_ModelUpToDate = false;
150  this->Modified();
151  }
153 
155  int GetSVMType(void) const
156  {
157  return m_Parameters.svm_type;
158  }
159 
165  void SetKernelType(int kerneltype)
166  {
167  m_Parameters.kernel_type = kerneltype;
168  m_ModelUpToDate = false;
169  this->Modified();
170  }
172 
174  int GetKernelType(void) const
175  {
176  return m_Parameters.kernel_type;
177  }
178 
180  void SetPolynomialKernelDegree(int degree)
181  {
182  m_Parameters.degree = degree;
183  m_ModelUpToDate = false;
184  this->Modified();
185  }
187 
190  {
191  return m_Parameters.degree;
192  }
193 
195  virtual void SetKernelGamma(double gamma)
196  {
197  m_Parameters.gamma = gamma;
198  m_ModelUpToDate = false;
199  this->Modified();
200  }
201 
203  double GetKernelGamma(void) const
204  {
205  return m_Parameters.gamma;
206  }
207 
209  void SetKernelCoef0(double coef0)
210  {
211  m_Parameters.coef0 = coef0;
212  m_ModelUpToDate = false;
213  this->Modified();
214  }
216 
218  double GetKernelCoef0(void) const
219  {
220  //return m_Parameters.coef0;
221  return m_Parameters.coef0;
222  }
223 
225  void SetNu(double nu)
226  {
227  m_Parameters.nu = nu;
228  m_ModelUpToDate = false;
229  this->Modified();
230  }
232 
234  double GetNu(void) const
235  {
236  //return m_Parameters.nu;
237  return m_Parameters.nu;
238  }
239 
241  void SetCacheSize(int cSize)
242  {
243  m_Parameters.cache_size = static_cast<double>(cSize);
244  m_ModelUpToDate = false;
245  this->Modified();
246  }
248 
250  int GetCacheSize(void) const
251  {
252  return static_cast<int>(m_Parameters.cache_size);
253  }
254 
256  void SetC(double c)
257  {
258  m_Parameters.C = c;
259  m_ModelUpToDate = false;
260  this->Modified();
261  }
263 
265  double GetC(void) const
266  {
267  return m_Parameters.C;
268  }
269 
271  void SetEpsilon(double eps)
272  {
273  m_Parameters.eps = eps;
274  m_ModelUpToDate = false;
275  this->Modified();
276  }
278 
280  double GetEpsilon(void) const
281  {
282  return m_Parameters.eps;
283  }
284 
285  /* Set the value of p for EPSILON_SVR */
286  void SetP(double p)
287  {
288  m_Parameters.p = p;
289  m_ModelUpToDate = false;
290  this->Modified();
291  }
292 
293  /* Get the value of p for EPSILON_SVR */
294  double GetP(void) const
295  {
296  return m_Parameters.p;
297  }
298 
300  void DoShrinking(bool s)
301  {
302  m_Parameters.shrinking = static_cast<int>(s);
303  m_ModelUpToDate = false;
304  this->Modified();
305  }
307 
309  bool GetDoShrinking(void) const
310  {
311  return static_cast<bool>(m_Parameters.shrinking);
312  }
313 
315  void DoProbabilityEstimates(bool prob)
316  {
317  m_Parameters.probability = static_cast<int>(prob);
318  m_ModelUpToDate = false;
319  this->Modified();
320  }
322 
324  bool GetDoProbabilityEstimates(void) const
325  {
326  return static_cast<bool>(m_Parameters.probability);
327  }
328 
331  {
332  if (m_Model) return m_Model->l;
333  return 0;
334  }
336 
338  double * GetRho(void) const
339  {
340  if (m_Model) return m_Model->rho;
341  return NULL;
342  }
343 
345  svm_node ** GetSupportVectors(void)
346  {
347  if (m_Model) return m_Model->SV;
348  return NULL;
349  }
350 
352  void SetSupportVectors(svm_node ** sv, int nbOfSupportVector);
353 
355  double ** GetAlpha(void)
356  {
357  if (m_Model) return m_Model->sv_coef;
358  return NULL;
359  }
360 
362  void SetAlpha(double ** alpha, int nbOfSupportVector);
363 
365  int * GetLabels()
366  {
367  if (m_Model) return m_Model->label;
368  return NULL;
369  }
371 
374  {
375  if (m_Model) return m_Model->nSV;
376  return NULL;
377  }
379 
380  struct svm_problem& GetProblem()
381  {
382  return m_Problem;
383  }
384 
386  void BuildProblem();
387 
389  void ConsistencyCheck();
390 
392  void Train();
393 
395  double CrossValidation(unsigned int nbFolders);
396 
401  LabelType EvaluateLabel(const MeasurementType& measure) const;
402 
407  DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType& measure) const;
408 
415  ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType& measure) const;
416 
418  void AddSample(const MeasurementType& measure, const LabelType& label);
419 
421  void ClearSamples();
422 
424  void SetSamples(const SamplesVectorType& samples);
425 
428  void Reset();
429 
430 protected:
432  SVMModel();
433 
435  virtual ~SVMModel();
436 
438  void PrintSelf(std::ostream& os, itk::Indent indent) const;
439 
441  void DeleteProblem();
442 
444  void DeleteModel();
445 
447  void Initialize();
448 
449 private:
450  SVMModel(const Self &); //purposely not implemented
451  void operator =(const Self&); //purposely not implemented
452 
454  struct svm_model* m_Model;
455 
457  mutable bool m_ModelUpToDate;
458 
460  struct svm_problem m_Problem;
461 
463  struct svm_parameter m_Parameters;
464 
467 
470 }; // class SVMModel
471 
472 } // namespace otb
473 
474 #ifndef OTB_MANUAL_INSTANTIATION
475 #include "otbSVMModel.txx"
476 #endif
477 
478 #endif
void SetPolynomialKernelDegree(int degree)
Definition: otbSVMModel.h:180
bool GetDoShrinking(void) const
Definition: otbSVMModel.h:309
itk::SmartPointer< const Self > ConstPointer
Definition: otbSVMModel.h:68
itk::VariableLengthVector< double > ProbabilitiesVectorType
Definition: otbSVMModel.h:83
int GetPolynomialKernelDegree(void) const
Definition: otbSVMModel.h:189
int * GetNumberOfSVPerClasse()
Definition: otbSVMModel.h:373
double GetP(void) const
Definition: otbSVMModel.h:294
int GetSVMType(void) const
Definition: otbSVMModel.h:155
bool GetDoProbabilityEstimates(void) const
Definition: otbSVMModel.h:324
TValue ValueType
Definition: otbSVMModel.h:71
int GetCacheSize(void) const
Definition: otbSVMModel.h:250
void LoadModel(const std::string &model_file_name)
Definition: otbSVMModel.h:138
int * GetLabels()
Definition: otbSVMModel.h:365
double GetEpsilon(void) const
Definition: otbSVMModel.h:280
const struct svm_model * GetModel()
Definition: otbSVMModel.h:110
itk::DataObject Superclass
Definition: otbSVMModel.h:66
unsigned int GetNumberOfHyperplane(void) const
Definition: otbSVMModel.h:102
const struct svm_parameter & GetParameters() const
Definition: otbSVMModel.h:122
itk::SmartPointer< Self > Pointer
Definition: otbSVMModel.h:67
svm_node ** GetSupportVectors(void)
Definition: otbSVMModel.h:345
struct svm_model * m_Model
Definition: otbSVMModel.h:454
void SetKernelCoef0(double coef0)
Definition: otbSVMModel.h:209
void SetCacheSize(int cSize)
Definition: otbSVMModel.h:241
double GetC(void) const
Definition: otbSVMModel.h:265
double GetKernelGamma(void) const
Definition: otbSVMModel.h:203
unsigned int GetNumberOfClasses(void) const
Definition: otbSVMModel.h:94
void SetC(double c)
Definition: otbSVMModel.h:256
virtual void SetKernelGamma(double gamma)
Definition: otbSVMModel.h:195
Class for SVM models.
Definition: otbSVMModel.h:61
bool m_ProblemUpToDate
Definition: otbSVMModel.h:466
TLabel LabelType
Definition: otbSVMModel.h:74
double GetKernelCoef0(void) const
Definition: otbSVMModel.h:218
double ** GetAlpha(void)
Definition: otbSVMModel.h:355
void SetKernelType(int kerneltype)
Definition: otbSVMModel.h:165
void SetSVMType(int svmtype)
Definition: otbSVMModel.h:146
struct svm_problem & GetProblem()
Definition: otbSVMModel.h:380
SamplesVectorType m_Samples
Definition: otbSVMModel.h:469
struct svm_node * NodeCacheType
Definition: otbSVMModel.h:86
void SetEpsilon(double eps)
Definition: otbSVMModel.h:271
void DoProbabilityEstimates(bool prob)
Definition: otbSVMModel.h:315
int GetKernelType(void) const
Definition: otbSVMModel.h:174
itk::VariableLengthVector< double > DistancesVectorType
Definition: otbSVMModel.h:84
void SaveModel(const std::string &model_file_name) const
Definition: otbSVMModel.h:129
double * GetRho(void) const
Definition: otbSVMModel.h:338
struct svm_parameter & GetParameters()
Definition: otbSVMModel.h:116
int GetNumberOfSupportVectors(void) const
Definition: otbSVMModel.h:330
#define NULL
double GetNu(void) const
Definition: otbSVMModel.h:234
SVMModel Self
Definition: otbSVMModel.h:65
std::vector< SampleType > SamplesVectorType
Definition: otbSVMModel.h:77
bool m_ModelUpToDate
Definition: otbSVMModel.h:457
std::vector< struct svm_node * > CacheVectorType
Definition: otbSVMModel.h:80
void SetNu(double nu)
Definition: otbSVMModel.h:225
std::vector< ValueType > MeasurementType
Definition: otbSVMModel.h:75
std::pair< MeasurementType, LabelType > SampleType
Definition: otbSVMModel.h:76
void DoShrinking(bool s)
Definition: otbSVMModel.h:300
void SetP(double p)
Definition: otbSVMModel.h:286