OTB  9.0.0
Orfeo Toolbox
otbImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbImageClassificationFilter_hxx
22 #define otbImageClassificationFilter_hxx
23 
25 #include "itkImageRegionIterator.h"
26 #include "itkProgressReporter.h"
27 
28 namespace otb
29 {
33 template <class TInputImage, class TOutputImage, class TMaskImage>
35 {
36  this->SetNumberOfIndexedInputs(2);
37  this->SetNumberOfRequiredInputs(1);
38  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
40 
41  this->SetNumberOfRequiredOutputs(3);
42  this->SetNthOutput(0, TOutputImage::New());
43  this->SetNthOutput(1, ConfidenceImageType::New());
44  this->SetNthOutput(2, ProbaImageType::New());
45  m_UseConfidenceMap = false;
46  m_UseProbaMap = false;
47  m_BatchMode = true;
48  m_NumberOfClasses = 1;
49 }
50 
51 template <class TInputImage, class TOutputImage, class TMaskImage>
53 {
54  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
55 }
56 
57 template <class TInputImage, class TOutputImage, class TMaskImage>
60 {
61  if (this->GetNumberOfInputs() < 2)
62  {
63  return nullptr;
64  }
65  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
66 }
67 
68 template <class TInputImage, class TOutputImage, class TMaskImage>
71 {
72  if (this->GetNumberOfOutputs() < 2)
73  {
74  return nullptr;
75  }
76  return static_cast<ConfidenceImageType*>(this->itk::ProcessObject::GetOutput(1));
77 }
78 
79 template <class TInputImage, class TOutputImage, class TMaskImage>
82 {
83  if (this->GetNumberOfOutputs() < 2)
84  {
85  return nullptr;
86  }
87  return static_cast<ProbaImageType*>(this->itk::ProcessObject::GetOutput(2));
88 }
89 
90 template <class TInputImage, class TOutputImage, class TMaskImage>
92 {
93  if (!m_Model)
94  {
95  itkGenericExceptionMacro(<< "No model for classification");
96  }
97  if (m_BatchMode)
98  {
99 #ifdef _OPENMP
100  // OpenMP will take care of threading
101  this->SetNumberOfThreads(1);
102 #endif
103  }
104 }
105 
106 template <class TInputImage, class TOutputImage, class TMaskImage>
107 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
108  itk::ThreadIdType threadId)
109 {
110  // Get the input pointers
111  InputImageConstPointerType inputPtr = this->GetInput();
112  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
113  OutputImagePointerType outputPtr = this->GetOutput();
114  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
115  ProbaImagePointerType probaPtr = this->GetOutputProba();
116  // Progress reporting
117  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
118 
119  // Define iterators
120  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
121  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
122  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
123  typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
124  typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
125 
126  InputIteratorType inIt(inputPtr, outputRegionForThread);
127  OutputIteratorType outIt(outputPtr, outputRegionForThread);
128 
129  // Eventually iterate on masks
130  MaskIteratorType maskIt;
131  if (inputMaskPtr)
132  {
133  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
134  maskIt.GoToBegin();
135  }
136 
137  // setup iterator for confidence map
138  bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
139  ConfidenceMapIteratorType confidenceIt;
140  if (computeConfidenceMap)
141  {
142  confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
143  confidenceIt.GoToBegin();
144  }
145 
146  // setup iterator for proba map
147  bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
148 
149  ProbaMapIteratorType probaIt;
150 
151  if (computeProbaMap)
152  {
153  probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
154  probaIt.GoToBegin();
155  }
156 
157  bool validPoint = true;
158  double confidenceIndex = 0.0;
159  ProbaSampleType probaVector{m_NumberOfClasses};
160  probaVector.Fill(0);
161  // Walk the part of the image
162  for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
163  {
164  // Check pixel validity
165  if (inputMaskPtr)
166  {
167  validPoint = maskIt.Get() > 0;
168  ++maskIt;
169  }
170  // If point is valid
171  if (validPoint)
172  {
173  // Classifify
174  if (computeProbaMap)
175  {
176  outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex, &probaVector)[0]);
177  }
178  else if (computeConfidenceMap)
179  {
180  outIt.Set(m_Model->Predict(inIt.Get(), &confidenceIndex)[0]);
181  }
182  else
183  {
184  outIt.Set(m_Model->Predict(inIt.Get())[0]);
185  }
186  }
187  else
188  {
189  // else, set default value
190  outIt.Set(m_DefaultLabel);
191  confidenceIndex = 0.0;
192  }
193  if (computeConfidenceMap)
194  {
195  confidenceIt.Set(confidenceIndex);
196  ++confidenceIt;
197  }
198  if (computeProbaMap)
199  {
200  probaIt.Set(probaVector);
201  ++probaIt;
202  }
203  progress.CompletedPixel();
204  }
205 }
206 
207 template <class TInputImage, class TOutputImage, class TMaskImage>
208 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
209  itk::ThreadIdType threadId)
210 {
211  bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
212 
213  bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
214  // Get the input pointers
215  InputImageConstPointerType inputPtr = this->GetInput();
216  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
217  OutputImagePointerType outputPtr = this->GetOutput();
218  ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
219  ProbaImagePointerType probaPtr = this->GetOutputProba();
220 
221  // Progress reporting
222  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
223 
224  // Define iterators
225  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
226  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
227  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
228  typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
229  typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
230 
231  InputIteratorType inIt(inputPtr, outputRegionForThread);
232  OutputIteratorType outIt(outputPtr, outputRegionForThread);
233 
234  MaskIteratorType maskIt;
235  if (inputMaskPtr)
236  {
237  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
238  maskIt.GoToBegin();
239  }
240 
241  typedef typename ModelType::InputSampleType InputSampleType;
242  typedef typename ModelType::InputListSampleType InputListSampleType;
243  typedef typename ModelType::TargetValueType TargetValueType;
244  typedef typename ModelType::TargetListSampleType TargetListSampleType;
245  typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
246  typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
247  typename InputListSampleType::Pointer samples = InputListSampleType::New();
248  unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
249  samples->SetMeasurementVectorSize(num_features);
250  InputSampleType sample(num_features);
251  // Fill the samples
252  bool validPoint = true;
253  for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
254  {
255  // Check pixel validity
256  if (inputMaskPtr)
257  {
258  validPoint = maskIt.Get() > 0;
259  ++maskIt;
260  }
261  if (validPoint)
262  {
263  typename InputImageType::PixelType pix = inIt.Get();
264  for (size_t feat = 0; feat < num_features; ++feat)
265  {
266  sample[feat] = pix[feat];
267  }
268  samples->PushBack(sample);
269  }
270  }
271  // Make the batch prediction
272  typename TargetListSampleType::Pointer labels;
273  typename ConfidenceListSampleType::Pointer confidences;
274  typename ProbaListSampleType::Pointer probas;
275  if (computeConfidenceMap)
276  confidences = ConfidenceListSampleType::New();
277 
278  if (computeProbaMap)
279  probas = ProbaListSampleType::New();
280  // This call is threadsafe
281  labels = m_Model->PredictBatch(samples, confidences, probas);
282 
283  // Set the output values
284  ConfidenceMapIteratorType confidenceIt;
285  if (computeConfidenceMap)
286  {
287  confidenceIt = ConfidenceMapIteratorType(confidencePtr, outputRegionForThread);
288  confidenceIt.GoToBegin();
289  }
290 
291  ProbaMapIteratorType probaIt;
292  if (computeProbaMap)
293  {
294  probaIt = ProbaMapIteratorType(probaPtr, outputRegionForThread);
295  probaIt.GoToBegin();
296  }
297  typename TargetListSampleType::ConstIterator labIt = labels->Begin();
298  maskIt.GoToBegin();
299  for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
300  {
301  double confidenceIndex = 0.0;
302  TargetValueType labelValue(m_DefaultLabel);
303  ProbaSampleType probaValues{m_NumberOfClasses};
304  if (inputMaskPtr)
305  {
306  validPoint = maskIt.Get() > 0;
307  ++maskIt;
308  }
309  if (validPoint && labIt != labels->End())
310  {
311  labelValue = labIt.GetMeasurementVector()[0];
312 
313  if (computeConfidenceMap)
314  {
315  confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
316  }
317  if (computeProbaMap)
318  {
319  // The probas may have different size than the m_NumberOfClasses set by the user
320  auto tempProbaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier());
321  for (unsigned int i = 0; i < m_NumberOfClasses; ++i)
322  {
323  if (i < tempProbaValues.Size())
324  probaValues[i] = tempProbaValues[i];
325  else
326  probaValues[i] = 0;
327  }
328  }
329  ++labIt;
330  }
331  else
332  {
333  labelValue = m_DefaultLabel;
334  }
335 
336  outIt.Set(labelValue);
337 
338  if (computeConfidenceMap)
339  {
340  confidenceIt.Set(confidenceIndex);
341  ++confidenceIt;
342  }
343  if (computeProbaMap)
344  {
345  probaIt.Set(probaValues);
346  ++probaIt;
347  }
348  progress.CompletedPixel();
349  }
350 }
351 template <class TInputImage, class TOutputImage, class TMaskImage>
352 void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
353  itk::ThreadIdType threadId)
354 {
355  if (m_BatchMode)
356  {
357  this->BatchThreadedGenerateData(outputRegionForThread, threadId);
358  }
359  else
360  {
361  this->ClassicThreadedGenerateData(outputRegionForThread, threadId);
362  }
363 }
otb::ImageClassificationFilter::ConfidenceImageType
otb::Image< double > ConfidenceImageType
Definition: otbImageClassificationFilter.h:75
otb::ImageClassificationFilter::GetInputMask
const MaskImageType * GetInputMask(void)
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::ImageClassificationFilter::GetOutputConfidence
ConfidenceImageType * GetOutputConfidence(void)
otbImageClassificationFilter.h
otb::ImageClassificationFilter::ProbaImageType
otb::VectorImage< double > ProbaImageType
Definition: otbImageClassificationFilter.h:79
otb::ImageClassificationFilter::BatchThreadedGenerateData
void BatchThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
otb::ImageClassificationFilter::ThreadedGenerateData
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
otb::ImageClassificationFilter::MaskImageType
TMaskImage MaskImageType
Definition: otbImageClassificationFilter.h:63
otb::ImageClassificationFilter::ImageClassificationFilter
ImageClassificationFilter()
otb::ImageClassificationFilter::SetInputMask
void SetInputMask(const MaskImageType *mask)
otb::ImageClassificationFilter::GetOutputProba
ProbaImageType * GetOutputProba(void)
otb::ImageClassificationFilter::ClassicThreadedGenerateData
void ClassicThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
otb::ImageClassificationFilter::BeforeThreadedGenerateData
void BeforeThreadedGenerateData() override