OTB  9.0.0
Orfeo Toolbox
otbSVMCrossValidationCostFunction.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMCrossValidationCostFunction_hxx
22 #define otbSVMCrossValidationCostFunction_hxx
23 
25 #include "otbMacro.h"
26 
27 namespace otb
28 {
29 template <class TModel>
31 {
32 }
33 template <class TModel>
35 {
36 }
37 template <class TModel>
39 {
40  // Check the input model
41  if (!m_Model)
42  {
43  itkExceptionMacro(<< "Model is null, can not evaluate accuracy.");
44  }
45 
46  // Check for a positive and non-null C
47  if (parameters[0] <= 0)
48  {
49  return 0;
50  }
51 
52  // Updates vm_parameters according to current parameters
53  this->UpdateParameters(parameters);
54 
55  return m_Model->CrossValidation();
56 }
57 
58 template <class TModel>
60 {
61  // Set derivative size
62  derivative.SetSize(parameters.Size());
63  derivative.Fill(itk::NumericTraits<ParametersValueType>::Zero);
64 
65  for (unsigned int i = 0; i < parameters.Size(); ++i)
66  {
67  MeasureType y1, y2;
68  ParametersType x1, x2;
69 
70  x1 = parameters;
71  x1[i] -= m_DerivativeStep;
72  y1 = this->GetValue(x1);
73 
74  x2 = parameters;
75  x2[i] += m_DerivativeStep;
76  y2 = this->GetValue(x2);
77 
78  derivative[i] = (y2 - y1) / (2 * m_DerivativeStep);
79  otbMsgDevMacro(<< "x1= " << x1 << " x2= " << x2 << ", y1= " << y1 << ", y2= " << y2);
80  }
81  otbMsgDevMacro("Position: " << parameters << ", Value: " << this->GetValue(parameters) << ", Derivatives: " << derivative);
82 }
83 
84 template <class TModel>
86 {
87  if (!m_Model)
88  {
89  itkExceptionMacro(<< "Model is null, can not evaluate number of parameters.");
90  }
91  return m_Model->GetNumberOfKernelParameters();
92 }
93 
94 template <class TModel>
96 {
97  unsigned int nbParams = m_Model->GetNumberOfKernelParameters();
98  m_Model->SetC(parameters[0]);
99  if (nbParams > 1)
100  m_Model->SetKernelGamma(parameters[1]);
101  if (nbParams > 2)
102  m_Model->SetKernelCoef0(parameters[2]);
103 }
104 
105 } // namespace otb
106 
107 #endif
otb::SVMCrossValidationCostFunction::SVMCrossValidationCostFunction
SVMCrossValidationCostFunction()
Constructor.
Definition: otbSVMCrossValidationCostFunction.hxx:30
otb::SVMCrossValidationCostFunction::ParametersType
Superclass::ParametersType ParametersType
Definition: otbSVMCrossValidationCostFunction.h:71
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SVMCrossValidationCostFunction
This function returns the cross validation accuracy of a SVM model.
Definition: otbSVMCrossValidationCostFunction.h:53
otb::SVMCrossValidationCostFunction::DerivativeType
Superclass::DerivativeType DerivativeType
Definition: otbSVMCrossValidationCostFunction.h:73
otbMacro.h
otb::SVMCrossValidationCostFunction::UpdateParameters
void UpdateParameters(const ParametersType &parameters) const
Definition: otbSVMCrossValidationCostFunction.hxx:95
otb::SVMCrossValidationCostFunction::GetDerivative
void GetDerivative(const ParametersType &parameters, DerivativeType &derivative) const override
Definition: otbSVMCrossValidationCostFunction.hxx:59
otb::SVMCrossValidationCostFunction::GetNumberOfParameters
unsigned int GetNumberOfParameters(void) const override
Definition: otbSVMCrossValidationCostFunction.hxx:85
otb::SVMCrossValidationCostFunction::GetValue
MeasureType GetValue(const ParametersType &parameters) const override
Definition: otbSVMCrossValidationCostFunction.hxx:38
otbSVMCrossValidationCostFunction.h
otbMsgDevMacro
#define otbMsgDevMacro(x)
Definition: otbMacro.h:64
otb::SVMCrossValidationCostFunction::~SVMCrossValidationCostFunction
~SVMCrossValidationCostFunction() override
Destructor.
Definition: otbSVMCrossValidationCostFunction.hxx:34