OTB  9.0.0
Orfeo Toolbox
otbMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_hxx
22 #define otbMachineLearningModel_hxx
23 
24 #ifdef _OPENMP
25 #include <omp.h>
26 #endif
27 
29 
30 #include "itkMultiThreader.h"
31 
32 namespace otb
33 {
34 
35 template <class TInputValue, class TOutputValue, class TConfidenceValue>
37  : m_RegressionMode(false),
38  m_IsRegressionSupported(false),
39  m_ConfidenceIndex(false),
40  m_ProbaIndex(false),
41  m_IsDoPredictBatchMultiThreaded(false),
42  m_Dimension(0)
43 {
44 }
45 
46 template <class TInputValue, class TOutputValue, class TConfidenceValue>
48 {
49  if (flag && !m_IsRegressionSupported)
50  {
51  itkGenericExceptionMacro(<< "Regression mode not implemented.");
52  }
53  if (m_RegressionMode != flag)
54  {
55  m_RegressionMode = flag;
56  this->Modified();
57  }
58 }
59 
60 template <class TInputValue, class TOutputValue, class TConfidenceValue>
63  ProbaSampleType* proba) const
64 {
65  // Call protected specialization entry point
66  return this->DoPredict(input, quality, proba);
67 }
68 
69 
70 template <class TInputValue, class TOutputValue, class TConfidenceValue>
73  ProbaListSampleType* proba) const
74 {
75  // std::cout << "Enter batch predict" << std::endl;
76  typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
77  targets->Resize(input->Size());
78 
79  if (quality != nullptr)
80  {
81  quality->Clear();
82  quality->Resize(input->Size());
83  }
84  if (proba != ITK_NULLPTR)
85  {
86  proba->Clear();
87  proba->Resize(input->Size());
88  }
89  if (m_IsDoPredictBatchMultiThreaded)
90  {
91  // Simply calls DoPredictBatch
92  this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
93  return targets;
94  }
95  else
96  {
97 #ifdef _OPENMP
98  // OpenMP threading here
99  unsigned int nb_threads(0), threadId(0), nb_batches(0);
100 
101 #pragma omp parallel shared(nb_threads, nb_batches) private(threadId)
102  {
103  // Get number of threads configured with ITK
104  omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
105  nb_threads = omp_get_num_threads();
106  threadId = omp_get_thread_num();
107  nb_batches = std::min(nb_threads, (unsigned int)input->Size());
108  // Ensure that we do not spawn unnecessary threads
109  if (threadId < nb_batches)
110  {
111  unsigned int batch_size = ((unsigned int)input->Size() / nb_batches);
112  unsigned int batch_start = threadId * batch_size;
113  if (threadId == nb_threads - 1)
114  {
115  batch_size += input->Size() % nb_batches;
116  }
117 
118  this->DoPredictBatch(input, batch_start, batch_size, targets, quality, proba);
119  }
120  }
121 #else
122  this->DoPredictBatch(input, 0, input->Size(), targets, quality, proba);
123 #endif
124  return targets;
125  }
126 }
127 
128 
129 template <class TInputValue, class TOutputValue, class TConfidenceValue>
131  const unsigned int& size, TargetListSampleType* targets,
132  ConfidenceListSampleType* quality, ProbaListSampleType* proba) const
133 {
134  assert(input != nullptr);
135  assert(targets != nullptr);
136 
137  assert(input->Size() == targets->Size() && "Input sample list and target label list do not have the same size.");
138  assert(((quality == nullptr) || (quality->Size() == input->Size())) &&
139  "Quality samples list is not null and does not have the same size as input samples list");
140  assert(((proba == nullptr) || (input->Size() == proba->Size())) && "Proba sample list and target label list do not have the same size.");
141 
142  if (startIndex + size > input->Size())
143  {
144  itkExceptionMacro(<< "requested range [" << startIndex << ", " << startIndex + size << "[ partially outside input sample list range.[0," << input->Size()
145  << "[");
146  }
147 
148  if (proba != nullptr)
149  {
150  for (unsigned int id = startIndex; id < startIndex + size; ++id)
151  {
152  ProbaSampleType prob;
153  ConfidenceValueType confidence = 0;
154  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id), &confidence, &prob);
155  quality->SetMeasurementVector(id, confidence);
156  proba->SetMeasurementVector(id, prob);
157  targets->SetMeasurementVector(id, target);
158  }
159  }
160  else if (quality != ITK_NULLPTR)
161  {
162  for (unsigned int id = startIndex; id < startIndex + size; ++id)
163  {
164  ConfidenceValueType confidence = 0;
165  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id), &confidence);
166  quality->SetMeasurementVector(id, confidence);
167  targets->SetMeasurementVector(id, target);
168  }
169  }
170  else
171  {
172  for (unsigned int id = startIndex; id < startIndex + size; ++id)
173  {
174  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id));
175  targets->SetMeasurementVector(id, target);
176  }
177  }
178 }
179 
180 template <class TInputValue, class TOutputValue, class TConfidenceValue>
182 {
183  // Call superclass implementation
184  Superclass::PrintSelf(os, indent);
185 }
186 }
187 
188 #endif
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetListSampleType
itk::Statistics::ListSample< TargetSampleType > TargetListSampleType
Definition: otbMachineLearningModel.h:92
otb::MachineLearningModel< TInputValue, TTargetValue >::ProbaListSampleType
itk::Statistics::ListSample< ProbaSampleType > ProbaListSampleType
Definition: otbMachineLearningModel.h:102
otb::MachineLearningModel< TInputValue, TTargetValue >::InputListSampleType
itk::Statistics::ListSample< InputSampleType > InputListSampleType
Definition: otbMachineLearningModel.h:85
otb::MachineLearningModel< TInputValue, TTargetValue >::ConfidenceValueType
MLMTargetTraits< double >::ValueType ConfidenceValueType
Definition: otbMachineLearningModel.h:96
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::MachineLearningModel< TInputValue, TTargetValue >::ProbaSampleType
itk::VariableLengthVector< double > ProbaSampleType
Definition: otbMachineLearningModel.h:101
otb::MachineLearningModel< TInputValue, TTargetValue >::ConfidenceListSampleType
itk::Statistics::ListSample< ConfidenceSampleType > ConfidenceListSampleType
Definition: otbMachineLearningModel.h:98
otb::MachineLearningModel::MachineLearningModel
MachineLearningModel()
Definition: otbMachineLearningModel.hxx:36
otbMachineLearningModel.h
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otb::MachineLearningModel< TInputValue, TTargetValue >::InputSampleType
MLMSampleTraits< TInputValue >::SampleType InputSampleType
Definition: otbMachineLearningModel.h:84
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91