OTB  9.0.0
Orfeo Toolbox
otbSharkRandomForestsMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSharkRandomForestsMachineLearningModel_hxx
22 #define otbSharkRandomForestsMachineLearningModel_hxx
23 
24 #include <fstream>
25 #include "itkMacro.h"
27 
28 #if defined(__GNUC__) || defined(__clang__)
29 #pragma GCC diagnostic push
30 #pragma GCC diagnostic ignored "-Wshadow"
31 #pragma GCC diagnostic ignored "-Wunused-parameter"
32 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
33 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
34 #endif
35 #if defined(__GNUC__) || defined(__clang__)
36 #pragma GCC diagnostic pop
37 #endif
38 
39 
40 #include "otbSharkUtils.h"
41 #include <algorithm>
42 
43 namespace otb
44 {
45 
46 template <class TInputValue, class TOutputValue>
48 {
49  this->m_ConfidenceIndex = true;
50  this->m_ProbaIndex = true;
51  this->m_IsRegressionSupported = false;
52  this->m_IsDoPredictBatchMultiThreaded = true;
53  this->m_NormalizeClassLabels = true;
54  this->m_ComputeMargin = false;
55 }
56 
58 template <class TInputValue, class TOutputValue>
60 {
61 #ifdef _OPENMP
62  omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
63 #endif
64 
65  std::vector<shark::RealVector> features;
66  std::vector<unsigned int> class_labels;
67 
68  Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
69  Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
70  if (m_NormalizeClassLabels)
71  {
72  Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
73  }
74  shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features, class_labels);
75 
76  // Set parameters
77  m_RFTrainer.setMTry(m_MTry);
78  m_RFTrainer.setNTrees(m_NumberOfTrees);
79  m_RFTrainer.setNodeSize(m_NodeSize);
80  // m_RFTrainer.setOOBratio(m_OobRatio);
81  m_RFTrainer.train(m_RFModel, TrainSamples);
82 }
83 
84 template <class TInputValue, class TOutputValue>
87 {
88  assert(!probas.empty() && "probas vector is empty");
89  assert((!computeMargin || probas.size() > 1) && "probas size should be at least 2 if computeMargin is true");
90 
91  ConfidenceValueType conf{0};
92  if (computeMargin)
93  {
94  std::nth_element(probas.begin(), probas.begin() + 1, probas.end(), std::greater<double>());
95  conf = static_cast<ConfidenceValueType>(probas[0] - probas[1]);
96  }
97  else
98  {
99  auto max_proba = *(std::max_element(probas.begin(), probas.end()));
100  conf = static_cast<ConfidenceValueType>(max_proba);
101  }
102  return conf;
103 }
104 
105 template <class TInputValue, class TOutputValue>
108  ProbaSampleType* proba) const
109 {
110  shark::RealVector samples(value.Size());
111  for (size_t i = 0; i < value.Size(); i++)
112  {
113  samples.push_back(value[i]);
114  }
115  if (quality != nullptr || proba != nullptr)
116  {
117  shark::RealVector probas = m_RFModel.decisionFunction()(samples);
118  if (quality != nullptr)
119  {
120  (*quality) = ComputeConfidence(probas, m_ComputeMargin);
121  }
122  if (proba != nullptr)
123  {
124  for (size_t i = 0; i < probas.size(); i++)
125  {
126  // probas contain the N class probability indexed between 0 and N-1
127  (*proba)[i] = static_cast<unsigned int>(probas[i] * 1000);
128  }
129  }
130  }
131  unsigned int res{0};
132  m_RFModel.eval(samples, res);
133 
134  TargetSampleType target;
135  if (m_NormalizeClassLabels)
136  {
137  target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
138  }
139  else
140  {
141  target[0] = static_cast<TOutputValue>(res);
142  }
143  return target;
144 }
145 
146 template <class TInputValue, class TOutputValue>
148  const unsigned int& size, TargetListSampleType* targets,
149  ConfidenceListSampleType* quality, ProbaListSampleType* proba) const
150 {
151  assert(input != nullptr);
152  assert(targets != nullptr);
153 
154  assert(input->Size() == targets->Size() && "Input sample list and target label list do not have the same size.");
155  assert(((quality == nullptr) || (quality->Size() == input->Size())) &&
156  "Quality samples list is not null and does not have the same size as input samples list");
157  assert(((proba == nullptr) || (input->Size() == proba->Size())) && "Proba sample list and target label list do not have the same size.");
158 
159  if (startIndex + size > input->Size())
160  {
161  itkExceptionMacro(<< "requested range [" << startIndex << ", " << startIndex + size << "[ partially outside input sample list range.[0," << input->Size()
162  << "[");
163  }
164 
165  std::vector<shark::RealVector> features;
166  Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
167  shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
168 
169 #ifdef _OPENMP
170  omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
171 
172 #endif
173  if (proba != nullptr || quality != nullptr)
174  {
175  shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
176  if (proba != nullptr)
177  {
178  unsigned int id = startIndex;
179  for (shark::RealVector&& p : probas.elements())
180  {
181  ProbaSampleType prob{(unsigned int)p.size()};
182  for (size_t i = 0; i < p.size(); i++)
183  {
184  prob[i] = p[i] * 1000;
185  }
186  proba->SetMeasurementVector(id, prob);
187  ++id;
188  }
189  }
190  if (quality != nullptr)
191  {
192  unsigned int id = startIndex;
193  for (shark::RealVector&& p : probas.elements())
194  {
195  ConfidenceSampleType confidence;
196  auto conf = ComputeConfidence(p, m_ComputeMargin);
197  confidence[0] = static_cast<ConfidenceValueType>(conf);
198  quality->SetMeasurementVector(id, confidence);
199  ++id;
200  }
201  }
202  }
203 
204  auto prediction = m_RFModel(inputSamples);
205  unsigned int id = startIndex;
206  for (const auto& p : prediction.elements())
207  {
208  TargetSampleType target;
209  if (m_NormalizeClassLabels)
210  {
211  target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
212  }
213  else
214  {
215  target[0] = static_cast<TOutputValue>(p);
216  }
217  targets->SetMeasurementVector(id, target);
218  ++id;
219  }
220 }
221 
222 template <class TInputValue, class TOutputValue>
223 void SharkRandomForestsMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& itkNotUsed(name))
224 {
225  std::ofstream ofs(filename);
226  if (!ofs)
227  {
228  itkExceptionMacro(<< "Error opening " << filename.c_str());
229  }
230  // Add comment with model file name
231  ofs << "#" << m_RFModel.name();
232  if (m_NormalizeClassLabels)
233  ofs << " with_dictionary";
234  ofs << std::endl;
235  if (m_NormalizeClassLabels)
236  {
237  ofs << m_ClassDictionary.size() << " ";
238  for (const auto& l : m_ClassDictionary)
239  {
240  ofs << l << " ";
241  }
242  ofs << std::endl;
243  }
244  shark::TextOutArchive oa(ofs);
245  m_RFModel.save(oa, 0);
246 }
247 
248 template <class TInputValue, class TOutputValue>
249 void SharkRandomForestsMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& itkNotUsed(name))
250 {
251  std::ifstream ifs(filename);
252  if (ifs.good())
253  {
254  // Check if the first line is a comment and verify the name of the model in this case.
255  std::string line;
256  getline(ifs, line);
257  if (line.at(0) == '#')
258  {
259  if (line.find(m_RFModel.name()) == std::string::npos)
260  itkExceptionMacro("The model file : " + filename + " cannot be read.");
261  if (line.find("with_dictionary") == std::string::npos)
262  {
263  m_NormalizeClassLabels = false;
264  }
265  }
266  else
267  {
268  // rewind if first line is not a comment
269  ifs.clear();
270  ifs.seekg(0, std::ios::beg);
271  }
272  if (m_NormalizeClassLabels)
273  {
274  size_t nbLabels{0};
275  ifs >> nbLabels;
276  m_ClassDictionary.resize(nbLabels);
277  for (size_t i = 0; i < nbLabels; ++i)
278  {
279  unsigned int label;
280  ifs >> label;
281  m_ClassDictionary[i] = label;
282  }
283  }
284  shark::TextInArchive ia(ifs);
285  m_RFModel.load(ia, 0);
286  }
287 }
288 
289 template <class TInputValue, class TOutputValue>
291 {
292  try
293  {
294  this->Load(file);
295  m_RFModel.name();
296  }
297  catch (...)
298  {
299  return false;
300  }
301  return true;
302 }
303 
304 template <class TInputValue, class TOutputValue>
306 {
307  return true;
308 }
309 
310 template <class TInputValue, class TOutputValue>
312 {
313  // Call superclass implementation
314  Superclass::PrintSelf(os, indent);
315 }
316 
317 } // end namespace otb
318 
319 #endif
otb::SharkRandomForestsMachineLearningModel::Train
virtual void Train() override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:59
otb::SharkRandomForestsMachineLearningModel::CanWriteFile
virtual bool CanWriteFile(const std::string &) override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:305
otb::SharkRandomForestsMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:88
otb::SharkRandomForestsMachineLearningModel::ProbaListSampleType
Superclass::ProbaListSampleType ProbaListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:96
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SharkRandomForestsMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:311
otb::SharkRandomForestsMachineLearningModel::ComputeConfidence
ConfidenceValueType ComputeConfidence(shark::RealVector &probas, bool computeMargin) const
Definition: otbSharkRandomForestsMachineLearningModel.hxx:86
otb::SharkRandomForestsMachineLearningModel::CanReadFile
virtual bool CanReadFile(const std::string &) override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:290
otb::SharkRandomForestsMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbSharkRandomForestsMachineLearningModel.h:92
otb::SharkRandomForestsMachineLearningModel::DoPredictBatch
void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *, ConfidenceListSampleType *=nullptr, ProbaListSampleType *=nullptr) const override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:147
otb::SharkRandomForestsMachineLearningModel::DoPredict
TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:107
otb::SharkRandomForestsMachineLearningModel::Load
virtual void Load(const std::string &filename, const std::string &name="") override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:249
otbSharkRandomForestsMachineLearningModel.h
otb::SharkRandomForestsMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:90
otb::SharkRandomForestsMachineLearningModel::SharkRandomForestsMachineLearningModel
SharkRandomForestsMachineLearningModel()
Definition: otbSharkRandomForestsMachineLearningModel.hxx:47
otb::SharkRandomForestsMachineLearningModel::ConfidenceListSampleType
Superclass::ConfidenceListSampleType ConfidenceListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:94
otb::SharkRandomForestsMachineLearningModel::ConfidenceSampleType
Superclass::ConfidenceSampleType ConfidenceSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:93
otb::SharkRandomForestsMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:91
otb::SharkRandomForestsMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:95
otb::SharkRandomForestsMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:87
otb::SharkRandomForestsMachineLearningModel::Save
virtual void Save(const std::string &filename, const std::string &name="") override
Definition: otbSharkRandomForestsMachineLearningModel.hxx:223