OTB  9.0.0
Orfeo Toolbox
otbKMeansImageClassificationFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbKMeansImageClassificationFilter_hxx
22 #define otbKMeansImageClassificationFilter_hxx
23 
25 #include "itkImageRegionIterator.h"
26 #include "itkNumericTraits.h"
27 
28 namespace otb
29 {
33 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
35 {
36  this->SetNumberOfRequiredInputs(2);
37  this->SetNumberOfRequiredInputs(1);
38  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
39 }
41 
42 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
44 {
45  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType*>(mask));
46 }
47 
48 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
51 {
52  if (this->GetNumberOfInputs() < 2)
53  {
54  return nullptr;
55  }
56  return static_cast<const MaskImageType*>(this->itk::ProcessObject::GetInput(1));
57 }
58 
59 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
61 {
62  unsigned int sample_size = MaxSampleDimension;
63  unsigned int nb_classes = m_Centroids.Size() / sample_size;
64 
65  for (LabelType label = 1; label <= static_cast<LabelType>(nb_classes); ++label)
66  {
67  SampleType new_centroid;
68  new_centroid.Fill(0);
69  m_CentroidsMap[label] = new_centroid;
70 
71  for (unsigned int i = 0; i < MaxSampleDimension; ++i)
72  {
73  m_CentroidsMap[label][i] = static_cast<ValueType>(m_Centroids[MaxSampleDimension * (static_cast<unsigned int>(label) - 1) + i]);
74  }
75  }
76 }
77 
78 template <class TInputImage, class TOutputImage, unsigned int VMaxSampleDimension, class TMaskImage>
80  const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType itkNotUsed(threadId))
81 {
82  InputImageConstPointerType inputPtr = this->GetInput();
83  MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
84  OutputImagePointerType outputPtr = this->GetOutput();
85 
86  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
87  typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
88  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
89 
90  InputIteratorType inIt(inputPtr, outputRegionForThread);
91  OutputIteratorType outIt(outputPtr, outputRegionForThread);
92 
93  MaskIteratorType maskIt;
94  if (inputMaskPtr)
95  {
96  maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
97  maskIt.GoToBegin();
98  }
99  unsigned int maxDimension = SampleType::Dimension;
100  unsigned int sampleSize = std::min(inputPtr->GetNumberOfComponentsPerPixel(), maxDimension);
101 
102  bool validPoint = true;
103 
104  while (!outIt.IsAtEnd())
105  {
106  outIt.Set(m_DefaultLabel);
107  ++outIt;
108  }
109 
110  outIt.GoToBegin();
111 
112  validPoint = true;
113 
114  typename DistanceType::Pointer distance = DistanceType::New();
115 
116  while (!outIt.IsAtEnd() && (!inIt.IsAtEnd()))
117  {
118  if (inputMaskPtr)
119  {
120  validPoint = maskIt.Get() > 0;
121  ++maskIt;
122  }
123  if (validPoint)
124  {
125  LabelType label = 1;
126  LabelType current_label = 1;
127  SampleType pixel;
128  pixel.Fill(0);
129  for (unsigned int i = 0; i < sampleSize; ++i)
130  {
131  pixel[i] = inIt.Get()[i];
132  }
133 
134  double current_distance = distance->Evaluate(pixel, m_CentroidsMap[label]);
135 
136  for (label = 2; label <= static_cast<LabelType>(m_CentroidsMap.size()); ++label)
137  {
138  double tmp_dist = distance->Evaluate(pixel, m_CentroidsMap[label]);
139  if (tmp_dist < current_distance)
140  {
141  current_label = label;
142  current_distance = tmp_dist;
143  }
144  }
145  outIt.Set(current_label);
146  }
147  ++outIt;
148  ++inIt;
149  }
150 }
otb::KMeansImageClassificationFilter::SetInputMask
void SetInputMask(const MaskImageType *mask)
otb::KMeansImageClassificationFilter::ThreadedGenerateData
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::KMeansImageClassificationFilter::BeforeThreadedGenerateData
void BeforeThreadedGenerateData() override
otb::KMeansImageClassificationFilter::KMeansImageClassificationFilter
KMeansImageClassificationFilter()
otb::sampleAugmentation::SampleType
std::vector< double > SampleType
Definition: otbSampleAugmentation.h:41
otbKMeansImageClassificationFilter.h
otb::KMeansImageClassificationFilter::MaskImageType
TMaskImage MaskImageType
Definition: otbKMeansImageClassificationFilter.h:70
otb::KMeansImageClassificationFilter::GetInputMask
const MaskImageType * GetInputMask(void)