OTB  9.0.0
Orfeo Toolbox
otbBoostMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbBoostMachineLearningModel_hxx
22 #define otbBoostMachineLearningModel_hxx
23 
25 #include "otbOpenCVUtils.h"
26 
27 #include <fstream>
28 #include "itkMacro.h"
29 
30 namespace otb
31 {
32 
33 template <class TInputValue, class TOutputValue>
35  : m_BoostModel(cv::ml::Boost::create()),
36  m_BoostType(CvBoost::REAL),
37  m_WeakCount(100),
38  m_WeightTrimRate(0.95),
39  m_MaxDepth(1)
40 {
41  this->m_ConfidenceIndex = true;
42 }
43 
45 template <class TInputValue, class TOutputValue>
47 {
48  // convert listsample to opencv matrix
49  cv::Mat samples;
50  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
52 
53  cv::Mat labels;
54  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
55 
56  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
57  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
58  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
59 
60  m_BoostModel->setBoostType(m_BoostType);
61  m_BoostModel->setWeakCount(m_WeakCount);
62  m_BoostModel->setWeightTrimRate(m_WeightTrimRate);
63  m_BoostModel->setMaxDepth(m_MaxDepth);
64  m_BoostModel->setUseSurrogates(false);
65  m_BoostModel->setPriors(cv::Mat());
66  m_BoostModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
67 }
68 
69 template <class TInputValue, class TOutputValue>
72 {
73  TargetSampleType target;
74 
75  // convert listsample to Mat
76  cv::Mat sample;
77 
78  otb::SampleToMat<InputSampleType>(input, sample);
79  double result = 0.;
80 
81  result = m_BoostModel->predict(sample);
82 
83  if (quality != nullptr)
84  {
85  (*quality) = static_cast<ConfidenceValueType>(m_BoostModel->predict(sample, cv::noArray(), cv::ml::StatModel::RAW_OUTPUT));
86  }
87  if (proba != nullptr && !this->m_ProbaIndex)
88  itkExceptionMacro("Probability per class not available for this classifier !");
89 
90  target[0] = static_cast<TOutputValue>(result);
91  return target;
92 }
93 
94 template <class TInputValue, class TOutputValue>
95 void BoostMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& name)
96 {
97  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
98  fs << (name.empty() ? m_BoostModel->getDefaultName() : cv::String(name)) << "{";
99  m_BoostModel->write(fs);
100  fs << "}";
101  fs.release();
102 }
103 
104 template <class TInputValue, class TOutputValue>
105 void BoostMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& name)
106 {
107  cv::FileStorage fs(filename, cv::FileStorage::READ);
108  m_BoostModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
109 }
110 
111 template <class TInputValue, class TOutputValue>
113 {
114  std::ifstream ifs;
115  ifs.open(file);
116 
117  if (!ifs)
118  {
119  std::cerr << "Could not read file " << file << std::endl;
120  return false;
121  }
122 
123  while (!ifs.eof())
124  {
125  std::string line;
126  std::getline(ifs, line);
127 
128  // if (line.find(m_SVMModel->getName()) != std::string::npos)
129  if (line.find(CV_TYPE_NAME_ML_BOOSTING) != std::string::npos || line.find(m_BoostModel->getDefaultName()) != std::string::npos)
130  {
131  // std::cout<<"Reading a "<<CV_TYPE_NAME_ML_BOOSTING<<" model"<<std::endl;
132  return true;
133  }
134  }
135  ifs.close();
136  return false;
137 }
138 
139 template <class TInputValue, class TOutputValue>
141 {
142  return false;
143 }
144 
145 template <class TInputValue, class TOutputValue>
146 void BoostMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
147 {
148  // Call superclass implementation
149  Superclass::PrintSelf(os, indent);
150 }
151 
152 } // end namespace otb
153 
154 #endif
otb::BoostMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbBoostMachineLearningModel.hxx:146
otb::BoostMachineLearningModel::BoostMachineLearningModel
BoostMachineLearningModel()
Definition: otbBoostMachineLearningModel.hxx:34
otb::BoostMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbBoostMachineLearningModel.h:50
CV_VAR_NUMERICAL
#define CV_VAR_NUMERICAL
Definition: otbOpenCVUtils.h:62
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::BoostMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbBoostMachineLearningModel.hxx:71
otb::BoostMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbBoostMachineLearningModel.hxx:140
CvBoost
#define CvBoost
Definition: otbOpenCVUtils.h:60
otb::BoostMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbBoostMachineLearningModel.h:45
CV_VAR_CATEGORICAL
#define CV_VAR_CATEGORICAL
Definition: otbOpenCVUtils.h:63
otbOpenCVUtils.h
CV_TYPE_NAME_ML_BOOSTING
#define CV_TYPE_NAME_ML_BOOSTING
Definition: otbOpenCVUtils.h:52
otb::BoostMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbBoostMachineLearningModel.hxx:105
otb::BoostMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbBoostMachineLearningModel.h:51
otbBoostMachineLearningModel.h
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228
otb::BoostMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbBoostMachineLearningModel.hxx:112
otb::BoostMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbBoostMachineLearningModel.h:48
otb::BoostMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbBoostMachineLearningModel.hxx:95
otb::BoostMachineLearningModel::Train
void Train() override
Definition: otbBoostMachineLearningModel.hxx:46