OTB  6.7.0
Orfeo Toolbox
otbBoostMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbBoostMachineLearningModel_hxx
22 #define otbBoostMachineLearningModel_hxx
23 
25 #include "otbOpenCVUtils.h"
26 
27 #include <fstream>
28 #include "itkMacro.h"
29 
30 namespace otb
31 {
32 
33 template <class TInputValue, class TOutputValue>
36 #ifdef OTB_OPENCV_3
37  m_BoostModel(cv::ml::Boost::create()),
38 #else
39  m_BoostModel (new CvBoost),
40 #endif
41  m_BoostType(CvBoost::REAL),
42  m_WeakCount(100),
43  m_WeightTrimRate(0.95),
44 #ifdef OTB_OPENCV_3
45  m_SplitCrit(0), // not used in OpenCV 3.x
46 #else
47  m_SplitCrit(CvBoost::DEFAULT),
48 #endif
49  m_MaxDepth(1)
50 {
51  this->m_ConfidenceIndex = true;
52 }
53 
54 
55 template <class TInputValue, class TOutputValue>
58 {
59 #ifndef OTB_OPENCV_3
60  delete m_BoostModel;
61 #endif
62 }
63 
65 template <class TInputValue, class TOutputValue>
66 void
69 {
70  //convert listsample to opencv matrix
71  cv::Mat samples;
72  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
74 
75  cv::Mat labels;
76  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
77 
78  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
79  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
80  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
81 
82 #ifdef OTB_OPENCV_3
83  m_BoostModel->setBoostType(m_BoostType);
84  m_BoostModel->setWeakCount(m_WeakCount);
85  m_BoostModel->setWeightTrimRate(m_WeightTrimRate);
86  m_BoostModel->setMaxDepth(m_MaxDepth);
87  m_BoostModel->setUseSurrogates(false);
88  m_BoostModel->setPriors(cv::Mat());
89  m_BoostModel->train(cv::ml::TrainData::create(
90  samples,
91  cv::ml::ROW_SAMPLE,
92  labels,
93  cv::noArray(),
94  cv::noArray(),
95  cv::noArray(),
96  var_type));
97 #else
98  CvBoostParams params = CvBoostParams(m_BoostType, m_WeakCount, m_WeightTrimRate, m_MaxDepth, false, nullptr);
99  params.split_criteria = m_SplitCrit;
100  m_BoostModel->train(samples,CV_ROW_SAMPLE,labels,cv::Mat(),cv::Mat(),var_type,cv::Mat(),params);
101 #endif
102 }
103 
104 template <class TInputValue, class TOutputValue>
106 ::TargetSampleType
108 ::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const
109 {
110  TargetSampleType target;
111 
112  //convert listsample to Mat
113  cv::Mat sample;
114 
115  otb::SampleToMat<InputSampleType>(input,sample);
116  double result = 0.;
117 
118 #ifdef OTB_OPENCV_3
119  result = m_BoostModel->predict(sample);
120 #else
121  cv::Mat missing = cv::Mat(1,input.Size(), CV_8U );
122  missing.setTo(0);
123  result = m_BoostModel->predict(sample,missing);
124 #endif
125 
126  if (quality != nullptr)
127  {
128  (*quality) = static_cast<ConfidenceValueType>(
129 #ifdef OTB_OPENCV_3
130  m_BoostModel->predict(sample,cv::noArray(), cv::ml::StatModel::RAW_OUTPUT)
131 #else
132  m_BoostModel->predict(sample,missing,cv::Range::all(),false,true)
133 #endif
134  );
135  }
136  if (proba != nullptr && !this->m_ProbaIndex)
137  itkExceptionMacro("Probability per class not available for this classifier !");
138 
139  target[0] = static_cast<TOutputValue>(result);
140  return target;
141 }
142 
143 template <class TInputValue, class TOutputValue>
144 void
146 ::Save(const std::string & filename, const std::string & name)
147 {
148 #ifdef OTB_OPENCV_3
149  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
150  fs << (name.empty() ? m_BoostModel->getDefaultName() : cv::String(name)) << "{";
151  m_BoostModel->write(fs);
152  fs << "}";
153  fs.release();
154 #else
155  if (name == "")
156  m_BoostModel->save(filename.c_str(), nullptr);
157  else
158  m_BoostModel->save(filename.c_str(), name.c_str());
159 #endif
160 }
161 
162 template <class TInputValue, class TOutputValue>
163 void
165 ::Load(const std::string & filename, const std::string & name)
166 {
167 #ifdef OTB_OPENCV_3
168  cv::FileStorage fs(filename, cv::FileStorage::READ);
169  m_BoostModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
170 #else
171  if (name == "")
172  m_BoostModel->load(filename.c_str(), nullptr);
173  else
174  m_BoostModel->load(filename.c_str(), name.c_str());
175 #endif
176 }
177 
178 template <class TInputValue, class TOutputValue>
179 bool
181 ::CanReadFile(const std::string & file)
182 {
183  std::ifstream ifs;
184  ifs.open(file);
185 
186  if(!ifs)
187  {
188  std::cerr<<"Could not read file "<<file<<std::endl;
189  return false;
190  }
191 
192  while (!ifs.eof())
193  {
194  std::string line;
195  std::getline(ifs, line);
196 
197  //if (line.find(m_SVMModel->getName()) != std::string::npos)
198  if (line.find(CV_TYPE_NAME_ML_BOOSTING) != std::string::npos
199 #ifdef OTB_OPENCV_3
200  || line.find(m_BoostModel->getDefaultName()) != std::string::npos
201 #endif
202  )
203  {
204  //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_BOOSTING<<" model"<<std::endl;
205  return true;
206  }
207  }
208  ifs.close();
209  return false;
210 }
211 
212 template <class TInputValue, class TOutputValue>
213 bool
215 ::CanWriteFile(const std::string & itkNotUsed(file))
216 {
217  return false;
218 }
219 
220 template <class TInputValue, class TOutputValue>
221 void
223 ::PrintSelf(std::ostream& os, itk::Indent indent) const
224 {
225  // Call superclass implementation
226  Superclass::PrintSelf(os,indent);
227 }
228 
229 } //end namespace otb
230 
231 #endif
Superclass::ConfidenceValueType ConfidenceValueType
Superclass::InputSampleType InputSampleType
bool CanWriteFile(const std::string &) override
bool CanReadFile(const std::string &) override
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Superclass::TargetSampleType TargetSampleType
void Load(const std::string &filename, const std::string &name="") override
void Save(const std::string &filename, const std::string &name="") override
Superclass::ProbaSampleType ProbaSampleType
void PrintSelf(std::ostream &os, itk::Indent indent) const override