OTB  9.0.0
Orfeo Toolbox
otbKNearestNeighborsMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbKNearestNeighborsMachineLearningModel_hxx
22 #define otbKNearestNeighborsMachineLearningModel_hxx
23 
26 #include "otbOpenCVUtils.h"
27 
28 #include <fstream>
29 #include <set>
30 #include "itkMacro.h"
31 
32 namespace otb
33 {
34 
35 template <class TInputValue, class TTargetValue>
37  :
38  m_KNearestModel(cv::ml::KNearest::create()),
39  m_K(32),
40  m_DecisionRule(KNN_VOTING)
41 {
42  this->m_ConfidenceIndex = true;
43  this->m_IsRegressionSupported = true;
44 }
45 
47 template <class TInputValue, class TTargetValue>
49 {
50  // convert listsample to opencv matrix
51  cv::Mat samples;
52  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
54 
55  cv::Mat labels;
56  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), labels);
57 
58  // update decision rule if needed
59  if (this->m_RegressionMode)
60  {
61  if (this->m_DecisionRule == KNN_VOTING)
62  {
63  this->SetDecisionRule(KNN_MEAN);
64  }
65  }
66  else
67  {
68  if (this->m_DecisionRule != KNN_VOTING)
69  {
70  this->SetDecisionRule(KNN_VOTING);
71  }
72  }
73 
74  m_KNearestModel->setDefaultK(m_K);
75  // would be nice to expose KDTree mode ( maybe in a different classifier)
76  m_KNearestModel->setAlgorithmType(cv::ml::KNearest::BRUTE_FORCE);
77  m_KNearestModel->setIsClassifier(!this->m_RegressionMode);
78  // setEmax() ?
79  m_KNearestModel->train(cv::ml::TrainData::create(samples, cv::ml::ROW_SAMPLE, labels));
80 }
81 
82 template <class TInputValue, class TTargetValue>
85  ProbaSampleType* proba) const
86 {
87  TargetSampleType target;
88 
89  // convert listsample to Mat
90  cv::Mat sample;
91  otb::SampleToMat<InputSampleType>(input, sample);
92 
93  float result;
94  cv::Mat nearest(1, m_K, CV_32FC1);
95  result = m_KNearestModel->findNearest(sample, m_K, cv::noArray(), nearest, cv::noArray());
96 
97  // compute quality if asked (only happens in classification mode)
98  if (quality != nullptr)
99  {
100  assert(!this->m_RegressionMode);
101  unsigned int accuracy = 0;
102  for (int k = 0; k < m_K; ++k)
103  {
104  if (nearest.at<float>(0, k) == result)
105  {
106  accuracy++;
107  }
108  }
109  (*quality) = static_cast<ConfidenceValueType>(accuracy);
110  }
111  if (proba != nullptr && !this->m_ProbaIndex)
112  itkExceptionMacro("Probability per class not available for this classifier !");
113 
114  // Decision rule :
115  // VOTING is OpenCV default behaviour for classification
116  // MEAN is OpenCV default behaviour for regression
117  // MEDIAN : only case that must be handled here
118  if (this->m_DecisionRule == KNN_MEDIAN)
119  {
120  std::multiset<float> values;
121  for (int k = 0; k < m_K; ++k)
122  {
123  values.insert(nearest.at<float>(0, k));
124  }
125  std::multiset<float>::iterator median = values.begin();
126  int pos = (m_K >> 1);
127  for (int k = 0; k < pos; ++k, ++median)
128  {
129  }
130  result = *median;
131  }
132 
133  target[0] = static_cast<TTargetValue>(result);
134  return target;
135 }
136 
137 template <class TInputValue, class TTargetValue>
138 void KNearestNeighborsMachineLearningModel<TInputValue, TTargetValue>::Save(const std::string& filename, const std::string& name)
139 {
140  cv::FileStorage fs(filename, cv::FileStorage::WRITE);
141  fs << (name.empty() ? m_KNearestModel->getDefaultName() : cv::String(name)) << "{";
142  m_KNearestModel->write(fs);
143  fs << "DecisionRule" << m_DecisionRule;
144  fs << "}";
145  fs.release();
146 }
147 
148 template <class TInputValue, class TTargetValue>
149 void KNearestNeighborsMachineLearningModel<TInputValue, TTargetValue>::Load(const std::string& filename, const std::string& itkNotUsed(name))
150 {
151  std::ifstream ifs(filename);
152  if (!ifs)
153  {
154  itkExceptionMacro(<< "Could not read file " << filename);
155  }
156  // try to load with the 3.x syntax
157  bool isKNNv3 = false;
158  while (!ifs.eof())
159  {
160  std::string line;
161  std::getline(ifs, line);
162  if (line.find(m_KNearestModel->getDefaultName()) != std::string::npos)
163  {
164  isKNNv3 = true;
165  break;
166  }
167  }
168  ifs.close();
169  if (isKNNv3)
170  {
171  cv::FileStorage fs(filename, cv::FileStorage::READ);
172  m_KNearestModel->read(fs.getFirstTopLevelNode());
173  m_DecisionRule = (int)(fs.getFirstTopLevelNode()["DecisionRule"]);
174  m_K = m_KNearestModel->getDefaultK();
175  return;
176  }
177  ifs.open(filename);
178  // there is no m_KNearestModel->load(filename.c_str(), name.c_str());
179 
180  // first line is the K parameter of this algorithm.
181  std::string line;
182  std::getline(ifs, line);
183  std::istringstream iss(line);
184  if (line.find("K") == std::string::npos)
185  {
186  itkExceptionMacro(<< "Could not read file " << filename);
187  }
188  std::string::size_type pos = line.find_first_of("=", 0);
189  std::string::size_type nextpos = line.find_first_of(" \n\r", pos + 1);
190  this->SetK(boost::lexical_cast<int>(line.substr(pos + 1, nextpos - pos - 1)));
191 
192  // second line is the IsRegression parameter
193  std::getline(ifs, line);
194  if (line.find("IsRegression") == std::string::npos)
195  {
196  itkExceptionMacro(<< "Could not read file " << filename);
197  }
198  pos = line.find_first_of("=", 0);
199  nextpos = line.find_first_of(" \n\r", pos + 1);
200  this->SetRegressionMode(boost::lexical_cast<bool>(line.substr(pos + 1, nextpos - pos - 1)));
201  // third line is the DecisionRule parameter (only for regression)
202  if (this->m_RegressionMode)
203  {
204  std::getline(ifs, line);
205  pos = line.find_first_of("=", 0);
206  nextpos = line.find_first_of(" \n\r", pos + 1);
207  this->SetDecisionRule(boost::lexical_cast<int>(line.substr(pos + 1, nextpos - pos - 1)));
208  }
209  // Clear previous listSample (if any)
210  typename InputListSampleType::Pointer samples = InputListSampleType::New();
211  typename TargetListSampleType::Pointer labels = TargetListSampleType::New();
212 
213  // Read a txt file. First column is the label, other columns are the sample data.
214  unsigned int nbFeatures = 0;
215  while (!ifs.eof())
216  {
217  std::getline(ifs, line);
218 
219  if (nbFeatures == 0)
220  {
221  nbFeatures = std::count(line.begin(), line.end(), ' ');
222  }
223 
224  if (line.size() > 1)
225  {
226  // Parse label
227  pos = line.find_first_of(" ", 0);
228  TargetSampleType label;
229  label[0] = static_cast<TargetValueType>(boost::lexical_cast<unsigned int>(line.substr(0, pos)));
230  // Parse sample features
231  InputSampleType sample(nbFeatures);
232  sample.Fill(0);
233  unsigned int id = 0;
234  nextpos = line.find_first_of(" ", pos + 1);
235  while (nextpos != std::string::npos)
236  {
237  nextpos = line.find_first_of(" \n\r", pos + 1);
238  std::string subline = line.substr(pos + 1, nextpos - pos - 1);
239  // sample[id] = static_cast<InputValueType>(boost::lexical_cast<float>(subline));
240  sample[id] = atof(subline.c_str());
241  pos = nextpos;
242  id++;
243  }
244  samples->SetMeasurementVectorSize(itk::NumericTraits<InputSampleType>::GetLength(sample));
245  samples->PushBack(sample);
246  labels->PushBack(label);
247  }
248  }
249  ifs.close();
250 
251  this->SetInputListSample(samples);
252  this->SetTargetListSample(labels);
253  this->Train();
254 }
255 
256 template <class TInputValue, class TTargetValue>
258 {
259  try
260  {
261  this->Load(file);
262  }
263  catch (...)
264  {
265  return false;
266  }
267  return true;
268 }
269 
270 template <class TInputValue, class TTargetValue>
272 {
273  return false;
274 }
275 
276 
277 template <class TInputValue, class TTargetValue>
279 {
280  // Call superclass implementation
281  Superclass::PrintSelf(os, indent);
282 }
283 
284 } // end namespace otb
285 
286 #endif
otb::KNearestNeighborsMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbKNearestNeighborsMachineLearningModel.h:47
otb::KNearestNeighborsMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbKNearestNeighborsMachineLearningModel.h:51
otb_boost_lexicalcast_header.h
otb::KNearestNeighborsMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:84
otb::KNearestNeighborsMachineLearningModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:138
otb::KNearestNeighborsMachineLearningModel::CanReadFile
bool CanReadFile(const std::string &) override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:257
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::MachineLearningModel< TInputValue, TTargetValue >::m_IsRegressionSupported
bool m_IsRegressionSupported
Definition: otbMachineLearningModel.h:225
otb::KNearestNeighborsMachineLearningModel::CanWriteFile
bool CanWriteFile(const std::string &) override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:271
otb::KNearestNeighborsMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbKNearestNeighborsMachineLearningModel.h:50
otb::KNearestNeighborsMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:278
otb::KNearestNeighborsMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbKNearestNeighborsMachineLearningModel.h:45
otb::median
Definition: otbParserXPlugins.h:324
otb::KNearestNeighborsMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbKNearestNeighborsMachineLearningModel.h:48
otb::KNearestNeighborsMachineLearningModel::KNearestNeighborsMachineLearningModel
KNearestNeighborsMachineLearningModel()
Definition: otbKNearestNeighborsMachineLearningModel.hxx:36
otbOpenCVUtils.h
otb::KNearestNeighborsMachineLearningModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:149
otb::KNearestNeighborsMachineLearningModel::Train
void Train() override
Definition: otbKNearestNeighborsMachineLearningModel.hxx:48
otbKNearestNeighborsMachineLearningModel.h
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228