OTB  9.0.0
Orfeo Toolbox
otbSVMMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSVMMachineLearningModel_hxx
22 #define otbSVMMachineLearningModel_hxx
23 
24 #include <fstream>
25 #include "itkMacro.h"
27 #include "otbOpenCVUtils.h"
28 
29 namespace otb
30 {
31 
32 template <class TInputValue, class TOutputValue>
34  :
35  m_SVMModel(cv::ml::SVM::create()),
36  m_SVMType(CvSVM::C_SVC),
37  m_KernelType(CvSVM::RBF),
38  m_Degree(0),
39  m_Gamma(1),
40  m_Coef0(0),
41  m_C(1),
42  m_Nu(0),
43  m_P(0),
44  m_TermCriteriaType(CV_TERMCRIT_ITER),
45  m_MaxIter(1000),
46  m_Epsilon(FLT_EPSILON),
47  m_ParameterOptimization(false),
48  m_OutputDegree(0),
49  m_OutputGamma(1),
50  m_OutputCoef0(0),
51  m_OutputC(1),
52  m_OutputNu(0),
53  m_OutputP(0)
54 {
55  this->m_ConfidenceIndex = true;
56  this->m_IsRegressionSupported = true;
57 }
58 
60 template <class TInputValue, class TOutputValue>
62 {
63  // Check that the SVM type is compatible with the chosen mode (classif/regression)
64  if (bool(m_SVMType == CvSVM::NU_SVR || m_SVMType == CvSVM::EPS_SVR) != this->m_RegressionMode)
65  {
66  itkGenericExceptionMacro(
67  "SVM type incompatible with chosen mode (classification or regression."
68  "SVM types for classification are C_SVC, NU_SVC, ONE_CLASS. "
69  "SVM types for regression are NU_SVR, EPS_SVR");
70  }
71 
72  // convert listsample to opencv matrix
73  cv::Mat samples;
74  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
75 
76  cv::Mat labels;
77  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
78 
79  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
80  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
81 
82  if (!this->m_RegressionMode) // Classification
83  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
84 
85  m_SVMModel->setType(m_SVMType);
86  m_SVMModel->setKernel(m_KernelType);
87  m_SVMModel->setDegree(m_Degree);
88  m_SVMModel->setGamma(m_Gamma);
89  m_SVMModel->setCoef0(m_Coef0);
90  m_SVMModel->setC(m_C);
91  m_SVMModel->setNu(m_Nu);
92  m_SVMModel->setP(m_P);
93  m_SVMModel->setTermCriteria(cv::TermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon));
94 
95  if (!m_ParameterOptimization)
96  {
97  m_SVMModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
98  }
99  else
100  {
101  m_SVMModel->trainAuto(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
102  }
103 
104  m_OutputDegree = m_SVMModel->getDegree();
105  m_OutputGamma = m_SVMModel->getGamma();
106  m_OutputCoef0 = m_SVMModel->getCoef0();
107  m_OutputC = m_SVMModel->getC();
108  m_OutputNu = m_SVMModel->getNu();
109  m_OutputP = m_SVMModel->getP();
110 }
111 
112 template <class TInputValue, class TOutputValue>
115 {
116  TargetSampleType target;
117  // convert listsample to Mat
118  cv::Mat sample;
119 
120  otb::SampleToMat<InputSampleType>(input, sample);
121 
122  double result = m_SVMModel->predict(sample);
123 
124  target[0] = static_cast<TOutputValue>(result);
125 
126  if (quality != nullptr)
127  {
128  (*quality) = m_SVMModel->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT);
129  }
130  if (proba != nullptr && !this->m_ProbaIndex)
131  itkExceptionMacro("Probability per class not available for this classifier !");
132 
133  return target;
134 }
135 
136 template <class TInputValue, class TOutputValue>
137 void SVMMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& name)
138 {
139  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
140  fs << (name.empty() ? m_SVMModel->getDefaultName() : cv::String(name)) << "{";
141  m_SVMModel->write(fs);
142  fs << "}";
143  fs.release();
144 }
145 
146 template <class TInputValue, class TOutputValue>
147 void SVMMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& name)
148 {
149  cv::FileStorage fs(filename, cv::FileStorage::READ);
150  m_SVMModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
151 }
152 
153 template <class TInputValue, class TOutputValue>
155 {
156  std::ifstream ifs;
157  ifs.open(file);
158 
159  if (!ifs)
160  {
161  std::cerr << "Could not read file " << file << std::endl;
162  return false;
163  }
164 
165  while (!ifs.eof())
166  {
167  std::string line;
168  std::getline(ifs, line);
169 
170  // if (line.find(m_SVMModel->getName()) != std::string::npos)
171  if (line.find(CV_TYPE_NAME_ML_SVM) != std::string::npos || line.find(m_SVMModel->getDefaultName()) != std::string::npos)
172  {
173  return true;
174  }
175  }
176  ifs.close();
177  return false;
178 }
179 
180 template <class TInputValue, class TOutputValue>
182 {
183  return false;
184 }
185 
186 template <class TInputValue, class TOutputValue>
187 void SVMMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
188 {
189  // Call superclass implementation
190  Superclass::PrintSelf(os, indent);
191 }
192 
193 } // end namespace otb
194 
195 #endif
otb::SVMMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbSVMMachineLearningModel.h:59
CV_VAR_NUMERICAL
#define CV_VAR_NUMERICAL
Definition: otbOpenCVUtils.h:62
CvSVM
#define CvSVM
Definition: otbOpenCVUtils.h:57
otbSVMMachineLearningModel.h
otb::SVMMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbSVMMachineLearningModel.h:56
otb::SVMMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbSVMMachineLearningModel.hxx:147
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SVMMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbSVMMachineLearningModel.hxx:154
otb::MachineLearningModel< TInputValue, TTargetValue >::m_IsRegressionSupported
bool m_IsRegressionSupported
Definition: otbMachineLearningModel.h:225
otb::SVMMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbSVMMachineLearningModel.hxx:187
otb::SVMMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbSVMMachineLearningModel.hxx:137
otb::SVMMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbSVMMachineLearningModel.h:53
otb::SVMMachineLearningModel::Train
void Train() override
Definition: otbSVMMachineLearningModel.hxx:61
CV_TYPE_NAME_ML_SVM
#define CV_TYPE_NAME_ML_SVM
Definition: otbOpenCVUtils.h:50
otb::SVMMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbSVMMachineLearningModel.h:58
CV_VAR_CATEGORICAL
#define CV_VAR_CATEGORICAL
Definition: otbOpenCVUtils.h:63
otbOpenCVUtils.h
otb::SVMMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbSVMMachineLearningModel.hxx:181
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228
otb::SVMMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbSVMMachineLearningModel.hxx:114
otb::SVMMachineLearningModel::SVMMachineLearningModel
SVMMachineLearningModel()
Definition: otbSVMMachineLearningModel.hxx:33