OTB  9.0.0
Orfeo Toolbox
otbNormalBayesMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbNormalBayesMachineLearningModel_hxx
22 #define otbNormalBayesMachineLearningModel_hxx
23 
24 #include <fstream>
25 #include "itkMacro.h"
27 #include "otbOpenCVUtils.h"
28 
29 namespace otb
30 {
31 
32 template <class TInputValue, class TOutputValue>
34  : m_NormalBayesModel(cv::ml::NormalBayesClassifier::create())
35 {
36 }
37 
39 template <class TInputValue, class TOutputValue>
41 {
42  // convert listsample to opencv matrix
43  cv::Mat samples;
44  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
46 
47  cv::Mat labels;
48  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
49 
50  cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U);
51  var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL)); // all inputs are numerical
52  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
53 
54  m_NormalBayesModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels, cv::noArray(), cv::noArray(), cv::noArray(), var_type));
55 }
56 
57 template <class TInputValue, class TOutputValue>
60 {
61  TargetSampleType target;
62 
63  // convert listsample to Mat
64  cv::Mat sample;
65 
66  otb::SampleToMat<InputSampleType>(input, sample);
67 
68  cv::Mat missing = cv::Mat(1, input.Size(), CV_8U);
69  missing.setTo(0);
70  double result = m_NormalBayesModel->predict(sample);
71 
72  target[0] = static_cast<TOutputValue>(result);
73 
74  if (quality != nullptr)
75  {
76  if (!this->HasConfidenceIndex())
77  {
78  itkExceptionMacro("Confidence index not available for this classifier !");
79  }
80  }
81  if (proba != nullptr && !this->m_ProbaIndex)
82  itkExceptionMacro("Probability per class not available for this classifier !");
83 
84  return target;
85 }
86 
87 template <class TInputValue, class TOutputValue>
88 void NormalBayesMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& name)
89 {
90  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
91  fs << (name.empty() ? m_NormalBayesModel->getDefaultName() : cv::String(name)) << "{";
92  m_NormalBayesModel->write(fs);
93  fs << "}";
94  fs.release();
95 }
96 
97 template <class TInputValue, class TOutputValue>
98 void NormalBayesMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& name)
99 {
100  cv::FileStorage fs(filename, cv::FileStorage::READ);
101  m_NormalBayesModel->read(name.empty() ? fs.getFirstTopLevelNode() : fs[name]);
102 }
103 
104 template <class TInputValue, class TOutputValue>
106 {
107  std::ifstream ifs;
108  ifs.open(file);
109 
110  if (!ifs)
111  {
112  std::cerr << "Could not read file " << file << std::endl;
113  return false;
114  }
115 
116  while (!ifs.eof())
117  {
118  std::string line;
119  std::getline(ifs, line);
120 
121  if (line.find(CV_TYPE_NAME_ML_NBAYES) != std::string::npos || line.find(m_NormalBayesModel->getDefaultName()) != std::string::npos)
122  {
123  return true;
124  }
125  }
126  ifs.close();
127  return false;
128 }
129 
130 template <class TInputValue, class TOutputValue>
132 {
133  return false;
134 }
135 
136 template <class TInputValue, class TOutputValue>
137 void NormalBayesMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
138 {
139  // Call superclass implementation
140  Superclass::PrintSelf(os, indent);
141 }
142 
143 } // end namespace otb
144 
145 #endif
otb::NormalBayesMachineLearningModel::Train
void Train() override
Definition: otbNormalBayesMachineLearningModel.hxx:40
otb::NormalBayesMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbNormalBayesMachineLearningModel.hxx:105
CV_VAR_NUMERICAL
#define CV_VAR_NUMERICAL
Definition: otbOpenCVUtils.h:62
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::NormalBayesMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbNormalBayesMachineLearningModel.hxx:59
CV_TYPE_NAME_ML_NBAYES
#define CV_TYPE_NAME_ML_NBAYES
Definition: otbOpenCVUtils.h:54
otb::NormalBayesMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbNormalBayesMachineLearningModel.hxx:98
otb::NormalBayesMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbNormalBayesMachineLearningModel.hxx:137
otb::NormalBayesMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbNormalBayesMachineLearningModel.hxx:131
otb::NormalBayesMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbNormalBayesMachineLearningModel.hxx:88
otb::NormalBayesMachineLearningModel::NormalBayesMachineLearningModel
NormalBayesMachineLearningModel()
Definition: otbNormalBayesMachineLearningModel.hxx:33
otb::NormalBayesMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbNormalBayesMachineLearningModel.h:50
CV_VAR_CATEGORICAL
#define CV_VAR_CATEGORICAL
Definition: otbOpenCVUtils.h:63
otbNormalBayesMachineLearningModel.h
otb::NormalBayesMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbNormalBayesMachineLearningModel.h:51
otbOpenCVUtils.h
otb::NormalBayesMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbNormalBayesMachineLearningModel.h:45
otb::NormalBayesMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbNormalBayesMachineLearningModel.h:48