OTB  6.7.0
Orfeo Toolbox
otbMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMachineLearningModel_hxx
22 #define otbMachineLearningModel_hxx
23 
24 #ifdef _OPENMP
25  # include <omp.h>
26 #endif
27 
29 
30 #include "itkMultiThreader.h"
31 
32 namespace otb
33 {
34 
35 template <class TInputValue, class TOutputValue, class TConfidenceValue>
38  m_RegressionMode(false),
39  m_IsRegressionSupported(false),
40  m_ConfidenceIndex(false),
41  m_ProbaIndex(false),
42  m_IsDoPredictBatchMultiThreaded(false),
43  m_Dimension(0)
44 {}
45 
46 
47 template <class TInputValue, class TOutputValue, class TConfidenceValue>
50 {}
51 
52 template <class TInputValue, class TOutputValue, class TConfidenceValue>
53 void
56 {
57  if (flag && !m_IsRegressionSupported)
58  {
59  itkGenericExceptionMacro(<< "Regression mode not implemented.");
60  }
61  if (m_RegressionMode != flag)
62  {
63  m_RegressionMode = flag;
64  this->Modified();
65  }
66 }
67 
68 template <class TInputValue, class TOutputValue, class TConfidenceValue>
70 ::TargetSampleType
72 ::Predict(const InputSampleType& input, ConfidenceValueType *quality, ProbaSampleType *proba) const
73 {
74  // Call protected specialization entry point
75  return this->DoPredict(input,quality,proba);
76 }
77 
78 
79 template <class TInputValue, class TOutputValue, class TConfidenceValue>
81 ::TargetListSampleType::Pointer
84 {
85  //std::cout << "Enter batch predict" << std::endl;
86  typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
87  targets->Resize(input->Size());
88 
89  if(quality!=nullptr)
90  {
91  quality->Clear();
92  quality->Resize(input->Size());
93  }
94  if(proba!=ITK_NULLPTR)
95  {
96  proba->Clear();
97  proba->Resize(input->Size());
98  }
99  if(m_IsDoPredictBatchMultiThreaded)
100  {
101  // Simply calls DoPredictBatch
102  this->DoPredictBatch(input,0,input->Size(),targets,quality,proba);
103  return targets;
104  }
105  else
106  {
107 #ifdef _OPENMP
108  // OpenMP threading here
109  unsigned int nb_threads(0), threadId(0), nb_batches(0);
110 
111 #pragma omp parallel shared(nb_threads,nb_batches) private(threadId)
112  {
113  // Get number of threads configured with ITK
114  omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
115  nb_threads = omp_get_num_threads();
116  threadId = omp_get_thread_num();
117  nb_batches = std::min(nb_threads,(unsigned int)input->Size());
118  // Ensure that we do not spawn unnecessary threads
119  if(threadId<nb_batches)
120  {
121  unsigned int batch_size = ((unsigned int)input->Size()/nb_batches);
122  unsigned int batch_start = threadId*batch_size;
123  if(threadId == nb_threads-1)
124  {
125  batch_size+=input->Size()%nb_batches;
126  }
127 
128  this->DoPredictBatch(input,batch_start,batch_size,targets,quality,proba);
129  }
130  }
131 #else
132  this->DoPredictBatch(input,0,input->Size(),targets,quality,proba);
133 #endif
134  return targets;
135  }
136 }
137 
138 
139 
140 template <class TInputValue, class TOutputValue, class TConfidenceValue>
141 void
143  ::DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality, ProbaListSampleType * proba) const
144 {
145  assert(input != nullptr);
146  assert(targets != nullptr);
147 
148  assert(input->Size() == targets->Size()
149  && "Input sample list and target label list do not have the same size.");
150  assert(((quality == nullptr) || (quality->Size() == input->Size()))
151  && "Quality samples list is not null and does not have the same size as input samples list");
152  assert(((proba == nullptr) || (input->Size() == proba->Size()))
153  && "Proba sample list and target label list do not have the same size.");
154 
155  if(startIndex+size>input->Size())
156  {
157  itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[");
158  }
159 
160  if (proba != nullptr)
161  {
162  for(unsigned int id = startIndex;id<startIndex+size;++id)
163  {
164  ProbaSampleType prob;
165  ConfidenceValueType confidence = 0;
166  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence, &prob);
167  quality->SetMeasurementVector(id,confidence);
168  proba->SetMeasurementVector(id,prob);
169  targets->SetMeasurementVector(id,target);
170  }
171  }
172  else if(quality != ITK_NULLPTR)
173  {
174  for(unsigned int id = startIndex;id<startIndex+size;++id)
175  {
176  ConfidenceValueType confidence = 0;
177  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence);
178  quality->SetMeasurementVector(id,confidence);
179  targets->SetMeasurementVector(id,target);
180  }
181  }
182  else
183  {
184  for(unsigned int id = startIndex;id<startIndex+size;++id)
185  {
186  const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id));
187  targets->SetMeasurementVector(id,target);
188  }
189  }
190 }
191 
192 template <class TInputValue, class TOutputValue, class TConfidenceValue>
193 void
195  ::PrintSelf(std::ostream& os, itk::Indent indent) const
196  {
197  // Call superclass implementation
198  Superclass::PrintSelf(os,indent);
199  }
200 }
201 
202 #endif
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests...
void SetMeasurementVector(InstanceIdentifier id, const MeasurementVectorType &mv)
const MeasurementVectorType & GetMeasurementVector(InstanceIdentifier id) const override
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
virtual void DoPredictBatch(const InputListSampleType *input, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *target, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
TargetListSampleType::Pointer PredictBatch(const InputListSampleType *input, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const
InstanceIdentifier Size() const override
TargetSampleType Predict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const
MLMSampleTraits< TInputValue >::SampleType InputSampleType
void PrintSelf(std::ostream &os, itk::Indent indent) const override
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
void Resize(InstanceIdentifier newsize)