OTB  9.0.0
Orfeo Toolbox
otbRandomForestsMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbRandomForestsMachineLearningModel_hxx
22 #define otbRandomForestsMachineLearningModel_hxx
23 
24 #include <fstream>
25 #include "itkMacro.h"
27 #include "otbOpenCVUtils.h"
28 
29 namespace otb
30 {
31 
32 template <class TInputValue, class TOutputValue>
34  :
35  m_RFModel(CvRTreesWrapper::create()),
36  m_MaxDepth(5),
37  m_MinSampleCount(10),
38  m_RegressionAccuracy(0.01),
39  m_ComputeSurrogateSplit(false),
40  m_MaxNumberOfCategories(10),
41  m_CalculateVariableImportance(false),
42  m_MaxNumberOfVariables(0),
43  m_MaxNumberOfTrees(100),
44  m_ForestAccuracy(0.01),
45  m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), // identic for v3 ?
46  m_ComputeMargin(false)
47 {
48  this->m_ConfidenceIndex = true;
49  this->m_ProbaIndex = false;
50  this->m_IsRegressionSupported = true;
51 }
52 
53 template <class TInputValue, class TOutputValue>
55 {
56  // TODO
57  cv::Mat samples;
58  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
59 
60  cv::Mat labels;
61  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
62 
63  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
64  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
65 
66  if (this->m_RegressionMode)
67  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
68  else
69  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
70 
71  return m_RFModel->calcError(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type), false,
72  cv::noArray());
73 }
74 
76 template <class TInputValue, class TOutputValue>
78 {
79  // convert listsample to opencv matrix
80  cv::Mat samples;
81  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
83 
84  cv::Mat labels;
85  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
86 
87  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
88  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
89 
90  if (this->m_RegressionMode)
91  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
92  else
93  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
94 
95 // Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U );
96 // std::cout << "priors " << m_Priors[0] << std::endl;
97 // Define random forests paramneters
98 // FIXME do this in the constructor?
99  m_RFModel->setMaxDepth(m_MaxDepth);
100  m_RFModel->setMinSampleCount(m_MinSampleCount);
101  m_RFModel->setRegressionAccuracy(m_RegressionAccuracy);
102  m_RFModel->setUseSurrogates(m_ComputeSurrogateSplit);
103  m_RFModel->setMaxCategories(m_MaxNumberOfCategories);
104  m_RFModel->setPriors(cv::Mat(m_Priors)); // TODO
105  m_RFModel->setCalculateVarImportance(m_CalculateVariableImportance);
106  m_RFModel->setActiveVarCount(m_MaxNumberOfVariables);
107  m_RFModel->setTermCriteria(cv::TermCriteria(m_TerminationCriteria, m_MaxNumberOfTrees, m_ForestAccuracy));
108  m_RFModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
109 }
110 
111 template <class TInputValue, class TOutputValue>
114  ProbaSampleType* proba) const
115 {
116  // std::cout << "Enter predict" << std::endl;
117  TargetSampleType target;
118  // convert listsample to Mat
119  cv::Mat sample;
120 
121  otb::SampleToMat<InputSampleType>(value, sample);
122 
123  double result = m_RFModel->predict(sample);
124 
125  target[0] = static_cast<TOutputValue>(result);
126 
127  if (quality != nullptr)
128  {
129  if (m_ComputeMargin)
130  (*quality) = m_RFModel->predict_margin(sample);
131  else
132  (*quality) = m_RFModel->predict_confidence(sample);
133  }
134 
135  if (proba != nullptr && !this->m_ProbaIndex)
136  itkExceptionMacro("Probability per class not available for this classifier !");
137 
138  return target[0];
139 }
140 
141 template <class TInputValue, class TOutputValue>
142 void RandomForestsMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& name)
143 {
144  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
145  fs << (name.empty() ? m_RFModel->getDefaultName() : cv::String(name)) << "{";
146  m_RFModel->write(fs);
147  fs << "}";
148  fs.release();
149 }
150 
151 template <class TInputValue, class TOutputValue>
152 void RandomForestsMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& name)
153 {
154  cv::FileStorage fs(filename, cv::FileStorage::READ);
155  m_RFModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
156 }
157 
158 template <class TInputValue, class TOutputValue>
160 {
161  std::ifstream ifs;
162  ifs.open(file);
163 
164  if (!ifs)
165  {
166  std::cerr << "Could not read file " << file << std::endl;
167  return false;
168  }
169 
170 
171  while (!ifs.eof())
172  {
173  std::string line;
174  std::getline(ifs, line);
175 
176  // if (line.find(m_RFModel->getName()) != std::string::npos)
177  if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos || line.find(m_RFModel->getDefaultName()) != std::string::npos)
178  {
179  return true;
180  }
181  }
182  ifs.close();
183  return false;
184 }
185 
186 template <class TInputValue, class TOutputValue>
188 {
189  return false;
190 }
191 
192 template <class TInputValue, class TOutputValue>
195 {
196  cv::Mat cvMat = m_RFModel->getVarImportance();
197  VariableImportanceMatrixType itkMat(cvMat.rows, cvMat.cols);
198  for (int i = 0; i < cvMat.rows; i++)
199  {
200  for (int j = 0; j < cvMat.cols; j++)
201  {
202  itkMat(i, j) = cvMat.at<float>(i, j);
203  }
204  }
205  return itkMat;
206 }
207 
208 
209 template <class TInputValue, class TOutputValue>
210 void RandomForestsMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
211 {
212  // Call superclass implementation
213  Superclass::PrintSelf(os, indent);
214 }
215 
216 } // end namespace otb
217 
218 #endif
CV_TYPE_NAME_ML_RTREES
#define CV_TYPE_NAME_ML_RTREES
Definition: otbOpenCVUtils.h:51
otb::RandomForestsMachineLearningModel::Train
void Train() override
Definition: otbRandomForestsMachineLearningModel.hxx:77
otb::RandomForestsMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbRandomForestsMachineLearningModel.h:52
otb::RandomForestsMachineLearningModel::GetTrainError
float GetTrainError()
Definition: otbRandomForestsMachineLearningModel.hxx:54
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ProbaIndex
bool m_ProbaIndex
Definition: otbMachineLearningModel.h:231
CV_VAR_NUMERICAL
#define CV_VAR_NUMERICAL
Definition: otbOpenCVUtils.h:62
otb::RandomForestsMachineLearningModel::VariableImportanceMatrixType
itk::VariableSizeMatrix< float > VariableImportanceMatrixType
Definition: otbRandomForestsMachineLearningModel.h:54
otb::RandomForestsMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbRandomForestsMachineLearningModel.hxx:159
otbRandomForestsMachineLearningModel.h
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::RandomForestsMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbRandomForestsMachineLearningModel.hxx:113
otb::MachineLearningModel< TInputValue, TTargetValue >::m_IsRegressionSupported
bool m_IsRegressionSupported
Definition: otbMachineLearningModel.h:225
otb::RandomForestsMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbRandomForestsMachineLearningModel.hxx:187
otb::RandomForestsMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbRandomForestsMachineLearningModel.hxx:142
otb::CvRTreesWrapper
Wrapper for OpenCV Random Trees.
Definition: otbCvRTreesWrapper.h:35
CV_VAR_CATEGORICAL
#define CV_VAR_CATEGORICAL
Definition: otbOpenCVUtils.h:63
otb::RandomForestsMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbRandomForestsMachineLearningModel.hxx:210
otb::RandomForestsMachineLearningModel::RandomForestsMachineLearningModel
RandomForestsMachineLearningModel()
Definition: otbRandomForestsMachineLearningModel.hxx:33
otbOpenCVUtils.h
otb::RandomForestsMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbRandomForestsMachineLearningModel.h:49
otb::RandomForestsMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbRandomForestsMachineLearningModel.h:46
otb::RandomForestsMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbRandomForestsMachineLearningModel.h:51
otb::RandomForestsMachineLearningModel::GetVariableImportance
VariableImportanceMatrixType GetVariableImportance()
Definition: otbRandomForestsMachineLearningModel.hxx:194
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228
otb::RandomForestsMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbRandomForestsMachineLearningModel.hxx:152