OTB  9.0.0
Orfeo Toolbox
otbSOMImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSOMImageClassificationFilter_hxx
22 #define otbSOMImageClassificationFilter_hxx
23 
25 #include "itkImageRegionIterator.h"
26 #include "itkNumericTraits.h"
27 
28 namespace otb
29 {
33 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
35 {
36  this->SetNumberOfRequiredInputs(2);
37  this->SetNumberOfRequiredInputs(1);
38  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
39 }
41 
42 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
44 {
45  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
46 }
47 
48 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
51 {
52  if (this->GetNumberOfInputs() < 2)
53  {
54  return nullptr;
55  }
56  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
57 }
58 
59 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
61 {
62  if (!m_Map)
63  {
64  itkGenericExceptionMacro(<< "No model for classification");
65  }
66 }
67 
68 template <class TInputImage, class TOutputImage, class TSOMMap, class TMaskImage>
70  itk::ThreadIdType itkNotUsed(threadId))
71 {
72  InputImageConstPointerType inputPtr = this->GetInput();
73  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
74  OutputImagePointerType outputPtr = this->GetOutput();
75 
76  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
77  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
78  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
79 
80  ListSamplePointerType listSample = ListSampleType::New();
81  listSample->SetMeasurementVectorSize(inputPtr->GetNumberOfComponentsPerPixel());
82 
83  InputIteratorType inIt(inputPtr, outputRegionForThread);
84 
85  MaskIteratorType maskIt;
86  if (inputMaskPtr)
87  {
88  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
89  maskIt.GoToBegin();
90  }
91  unsigned int maxDimension = m_Map->GetNumberOfComponentsPerPixel();
92  unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
93  bool validPoint = true;
94 
95  for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
96  {
97  if (inputMaskPtr)
98  {
99  validPoint = maskIt.Get() > 0;
100  ++maskIt;
101  }
102  if (validPoint)
103  {
104  SampleType sample;
105  sample.SetSize(sampleSize);
106  sample.Fill(itk::NumericTraits<ValueType>::ZeroValue());
107  for (unsigned int i = 0; i < sampleSize; ++i)
108  {
109  sample[i] = inIt.Get()[i];
110  }
111  listSample->PushBack(sample);
112  }
113  }
114  ClassifierPointerType classifier = ClassifierType::New();
115  classifier->SetMap(m_Map);
116  classifier->SetSample(listSample);
117  classifier->Update();
118 
119  typename ClassifierType::OutputType::Pointer membershipSample = classifier->GetOutput();
120  typename ClassifierType::OutputType::ConstIterator sampleIter = membershipSample->Begin();
121  typename ClassifierType::OutputType::ConstIterator sampleLast = membershipSample->End();
122 
123  OutputIteratorType outIt(outputPtr, outputRegionForThread);
124 
125  outIt.GoToBegin();
126 
127  while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
128  {
129  outIt.Set(m_DefaultLabel);
130  ++outIt;
131  }
132 
133  outIt.GoToBegin();
134 
135  if (inputMaskPtr)
136  {
137  maskIt.GoToBegin();
138  }
139  validPoint = true;
140 
141  while (!outIt.IsAtEnd() && (sampleIter != sampleLast))
142  {
143  if (inputMaskPtr)
144  {
145  validPoint = maskIt.Get() > 0;
146  ++maskIt;
147  }
148  if (validPoint)
149  {
150  outIt.Set(sampleIter.GetClassLabel());
151  ++sampleIter;
152  }
153  ++outIt;
154  }
155 }
otb::SOMImageClassificationFilter::MaskImageType
TMaskImage MaskImageType
Definition: otbSOMImageClassificationFilter.h:64
otb::SOMImageClassificationFilter::SOMImageClassificationFilter
SOMImageClassificationFilter()
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SOMImageClassificationFilter::GetInputMask
const MaskImageType * GetInputMask(void)
otbSOMImageClassificationFilter.h
otb::SOMImageClassificationFilter::SetInputMask
void SetInputMask(const MaskImageType *mask)
otb::sampleAugmentation::SampleType
std::vector< double > SampleType
Definition: otbSampleAugmentation.h:41
otb::SOMImageClassificationFilter::BeforeThreadedGenerateData
void BeforeThreadedGenerateData() override
otb::SOMImageClassificationFilter::ThreadedGenerateData
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override