OTB  5.0.0
Orfeo Toolbox
otbSVMImageClassificationFilter.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbSVMImageClassificationFilter_txx
19 #define __otbSVMImageClassificationFilter_txx
20 
22 #include "itkImageRegionIterator.h"
23 #include "itkProgressReporter.h"
24 
25 namespace otb
26 {
30 template <class TInputImage, class TOutputImage, class TMaskImage>
33 {
34  this->SetNumberOfRequiredInputs(2);
35  this->SetNumberOfRequiredInputs(1);
37 }
39 
40 template <class TInputImage, class TOutputImage, class TMaskImage>
41 void
43 ::SetInputMask(const MaskImageType * mask)
44 {
45  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType *>(mask));
46 }
47 
48 template <class TInputImage, class TOutputImage, class TMaskImage>
49 const typename SVMImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
50 ::MaskImageType *
53 {
54  if (this->GetNumberOfInputs() < 2)
55  {
56  return 0;
57  }
58  return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1));
59 }
60 
61 template <class TInputImage, class TOutputImage, class TMaskImage>
62 void
65 {
66  if (!m_Model)
67  {
68  itkGenericExceptionMacro(<< "No model for classification");
69  }
70 }
71 
72 template <class TInputImage, class TOutputImage, class TMaskImage>
73 void
75 ::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
76 {
77  // Get the input pointers
78  InputImageConstPointerType inputPtr = this->GetInput();
79  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
80  OutputImagePointerType outputPtr = this->GetOutput();
81 
82  // Progress reporting
83  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
84 
85  // Define iterators
86  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
87  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
88  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
89 
90  InputIteratorType inIt(inputPtr, outputRegionForThread);
91  OutputIteratorType outIt(outputPtr, outputRegionForThread);
92 
93  // Eventually iterate on masks
94  MaskIteratorType maskIt;
95  if (inputMaskPtr)
96  {
97  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
98  maskIt.GoToBegin();
99  }
100 
101  bool validPoint = true;
102 
103  // Walk the part of the image
104  for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
105  {
106  // Check pixel validity
107  if (inputMaskPtr)
108  {
109  validPoint = maskIt.Get() > 0;
110  ++maskIt;
111  }
112  // If point is valid
113  if (validPoint)
114  {
115  // Classifify
116  typename ModelType::MeasurementType measure;
117  for (unsigned int i = 0; i < inIt.Get().Size(); ++i)
118  {
119  measure.push_back(inIt.Get()[i]);
120  }
121  outIt.Set(m_Model->EvaluateLabel(measure));
122  }
123  else
124  {
125  // else, set default value
126  outIt.Set(m_DefaultLabel);
127  }
128  progress.CompletedPixel();
129  }
130 
131 }
const MaskImageType * GetInputMask(void)
static T ZeroValue()
virtual void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId)
virtual void BeforeThreadedGenerateData()
DataObject * GetInput(const DataObjectIdentifierType &key)
void SetInputMask(const MaskImageType *mask)
virtual void SetNthInput(DataObjectPointerArraySizeType num, DataObject *input)
unsigned int ThreadIdType