OTB  9.0.0
Orfeo Toolbox
otbDecisionTreeMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbDecisionTreeMachineLearningModel_hxx
22 #define otbDecisionTreeMachineLearningModel_hxx
23 
25 #include "otbOpenCVUtils.h"
26 
27 #include <fstream>
28 #include "itkMacro.h"
29 
30 namespace otb
31 {
32 
33 template <class TInputValue, class TOutputValue>
35  :
36  m_DTreeModel(cv::ml::DTrees::create()),
37  m_MaxDepth(10),
38  m_MinSampleCount(10),
39  m_RegressionAccuracy(0.01),
40  m_UseSurrogates(false),
41  m_MaxCategories(10),
42  m_Use1seRule(true),
43  m_TruncatePrunedTree(true)
44 {
45  this->m_IsRegressionSupported = true;
46 }
47 
49 template <class TInputValue, class TOutputValue>
51 {
52  // convert listsample to opencv matrix
53  cv::Mat samples;
54  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
56 
57  cv::Mat labels;
58  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
59 
60  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
61  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
62 
63  if (!this->m_RegressionMode) // Classification
64  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
65 
66  m_DTreeModel->setMaxDepth(m_MaxDepth);
67  m_DTreeModel->setMinSampleCount(m_MinSampleCount);
68  m_DTreeModel->setRegressionAccuracy(m_RegressionAccuracy);
69  m_DTreeModel->setUseSurrogates(m_UseSurrogates);
70  // CvFold is not exposed because it crashes in OpenCV 3 and 4
71  m_DTreeModel->setCVFolds(0);
72  m_DTreeModel->setMaxCategories(m_MaxCategories);
73  m_DTreeModel->setUse1SERule(m_Use1seRule);
74  m_DTreeModel->setTruncatePrunedTree(m_TruncatePrunedTree);
75  m_DTreeModel->setPriors(cv::Mat(m_Priors));
76  m_DTreeModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
77 }
78 
79 template <class TInputValue, class TOutputValue>
82 {
83  TargetSampleType target;
84 
85  // convert listsample to Mat
86  cv::Mat sample;
87 
88  otb::SampleToMat<InputSampleType>(input, sample);
89  double result = m_DTreeModel->predict(sample);
90 
91  target[0] = static_cast<TOutputValue>(result);
92 
93  if (quality != nullptr)
94  {
95  if (!this->m_ConfidenceIndex)
96  {
97  itkExceptionMacro("Confidence index not available for this classifier !");
98  }
99  }
100  if (proba != nullptr && !this->m_ProbaIndex)
101  itkExceptionMacro("Probability per class not available for this classifier !");
102 
103  return target;
104 }
105 
106 template <class TInputValue, class TOutputValue>
107 void DecisionTreeMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& name)
108 {
109  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
110  fs << (name.empty() ? m_DTreeModel->getDefaultName() : cv::String(name)) << "{";
111  m_DTreeModel->write(fs);
112  fs << "}";
113  fs.release();
114 }
115 
116 template <class TInputValue, class TOutputValue>
117 void DecisionTreeMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& name)
118 {
119  cv::FileStorage fs(filename, cv::FileStorage::READ);
120  m_DTreeModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
121 }
122 
123 template <class TInputValue, class TOutputValue>
125 {
126  std::ifstream ifs;
127  ifs.open(file);
128 
129  if (!ifs)
130  {
131  std::cerr << "Could not read file " << file << std::endl;
132  return false;
133  }
134 
135  while (!ifs.eof())
136  {
137  std::string line;
138  std::getline(ifs, line);
139 
140  // if (line.find(m_SVMModel->getName()) != std::string::npos)
141  if (line.find(CV_TYPE_NAME_ML_TREE) != std::string::npos || line.find(m_DTreeModel->getDefaultName()) != std::string::npos)
142  {
143  return true;
144  }
145  }
146  ifs.close();
147  return false;
148 }
149 
150 template <class TInputValue, class TOutputValue>
152 {
153  return false;
154 }
155 
156 template <class TInputValue, class TOutputValue>
157 void DecisionTreeMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
158 {
159  // Call superclass implementation
160  Superclass::PrintSelf(os, indent);
161 }
162 
163 } // end namespace otb
164 
165 #endif
CV_VAR_NUMERICAL
#define CV_VAR_NUMERICAL
Definition: otbOpenCVUtils.h:62
CV_TYPE_NAME_ML_TREE
#define CV_TYPE_NAME_ML_TREE
Definition: otbOpenCVUtils.h:55
otb::DecisionTreeMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbDecisionTreeMachineLearningModel.hxx:151
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::MachineLearningModel< TInputValue, TTargetValue >::m_IsRegressionSupported
bool m_IsRegressionSupported
Definition: otbMachineLearningModel.h:225
otb::DecisionTreeMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbDecisionTreeMachineLearningModel.hxx:81
otb::DecisionTreeMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbDecisionTreeMachineLearningModel.hxx:124
otb::DecisionTreeMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbDecisionTreeMachineLearningModel.hxx:107
otb::DecisionTreeMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbDecisionTreeMachineLearningModel.h:48
otb::DecisionTreeMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbDecisionTreeMachineLearningModel.hxx:117
otbDecisionTreeMachineLearningModel.h
otb::DecisionTreeMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbDecisionTreeMachineLearningModel.h:50
CV_VAR_CATEGORICAL
#define CV_VAR_CATEGORICAL
Definition: otbOpenCVUtils.h:63
otb::DecisionTreeMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbDecisionTreeMachineLearningModel.h:51
otb::DecisionTreeMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbDecisionTreeMachineLearningModel.hxx:157
otbOpenCVUtils.h
otb::DecisionTreeMachineLearningModel::Train
void Train() override
Definition: otbDecisionTreeMachineLearningModel.hxx:50
otb::DecisionTreeMachineLearningModel::DecisionTreeMachineLearningModel
DecisionTreeMachineLearningModel()
Definition: otbDecisionTreeMachineLearningModel.hxx:34
otb::DecisionTreeMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbDecisionTreeMachineLearningModel.h:45