OTB  6.1.0
Orfeo Toolbox
otbSVMCrossValidationCostFunction.txx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMCrossValidationCostFunction_txx
22 #define otbSVMCrossValidationCostFunction_txx
23 
25 #include "otbMacro.h"
26 
27 namespace otb
28 {
29 template<class TModel>
31 ::SVMCrossValidationCostFunction() : m_Model(), m_DerivativeStep(0.001)
32 {}
33 template<class TModel>
36 {}
37 template<class TModel>
39 ::MeasureType
41 ::GetValue(const ParametersType& parameters) const
42 {
43  // Check the input model
44  if (!m_Model)
45  {
46  itkExceptionMacro(<< "Model is null, can not evaluate accuracy.");
47  }
48 
49  // Check for a positive and non-null C
50  if (parameters[0] <= 0)
51  {
52  return 0;
53  }
54 
55  // Updates vm_parameters according to current parameters
56  this->UpdateParameters(parameters);
57 
58  return m_Model->CrossValidation();
59 }
60 
61 template<class TModel>
62 void
64 ::GetDerivative(const ParametersType& parameters, DerivativeType& derivative) const
65 {
66  // Set derivative size
67  derivative.SetSize(parameters.Size());
69 
70  for (unsigned int i = 0; i < parameters.Size(); ++i)
71  {
72  MeasureType y1, y2;
73  ParametersType x1, x2;
74 
75  x1 = parameters;
76  x1[i] -= m_DerivativeStep;
77  y1 = this->GetValue(x1);
78 
79  x2 = parameters;
80  x2[i] += m_DerivativeStep;
81  y2 = this->GetValue(x2);
82 
83  derivative[i] = (y2 - y1) / (2 * m_DerivativeStep);
84  otbMsgDevMacro( << "x1= " << x1 << " x2= " << x2 << ", y1= " << y1 << ", y2= " << y2 );
85  }
86  otbMsgDevMacro( "Position: " << parameters << ", Value: " << this->GetValue(parameters)
87  << ", Derivatives: " << derivative );
88 }
89 
90 template<class TModel>
91 unsigned int
94 {
95  if (!m_Model)
96  {
97  itkExceptionMacro(<< "Model is null, can not evaluate number of parameters.");
98  }
99  return m_Model->GetNumberOfKernelParameters();
100 }
101 
102 template<class TModel>
103 void
105 ::UpdateParameters(const ParametersType& parameters) const
106 {
107  unsigned int nbParams = m_Model->GetNumberOfKernelParameters();
108  m_Model->SetC(parameters[0]);
109  if (nbParams > 1) m_Model->SetKernelGamma(parameters[1]);
110  if (nbParams > 2) m_Model->SetKernelCoef0(parameters[2]);
111 }
112 
113 } // namespace otb
114 
115 #endif
void UpdateParameters(const ParametersType &parameters) const
void Fill(TValue const &v)
void SetSize(SizeValueType sz)
This function returns the cross validation accuracy of a SVM model.
void GetDerivative(const ParametersType &parameters, DerivativeType &derivative) const ITK_OVERRIDE
SizeValueType Size(void) const
unsigned int GetNumberOfParameters(void) const ITK_OVERRIDE
MeasureType GetValue(const ParametersType &parameters) const ITK_OVERRIDE
#define otbMsgDevMacro(x)
Definition: otbMacro.h:98