OTB  9.0.0
Orfeo Toolbox
otbKMeansAttributesLabelMapFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbKMeansAttributesLabelMapFilter_hxx
22 #define otbKMeansAttributesLabelMapFilter_hxx
23 
25 #include "itkNumericTraits.h"
26 #include "itkMersenneTwisterRandomVariateGenerator.h"
27 
28 namespace otb
29 {
30 
31 template <class TInputImage>
33  : m_LabelMapToSampleListFilter(LabelMapToSampleListFilterType::New()), m_NumberOfClasses(1)
34 {
35 }
36 
37 template <class TInputImage>
39 {
40  m_LabelMapToSampleListFilter->SetInputLabelMap(m_InputLabelMap);
41  m_LabelMapToSampleListFilter->Update();
42 
43  typename ListSampleType::Pointer listSamples = const_cast<ListSampleType*>(m_LabelMapToSampleListFilter->GetOutputSampleList());
44  typename TrainingListSampleType::Pointer trainingSamples = const_cast<TrainingListSampleType*>(m_LabelMapToSampleListFilter->GetOutputTrainingSampleList());
45 
46  // Build the Kd Tree
47  typename TreeGeneratorType::Pointer kdTreeGenerator = TreeGeneratorType::New();
48  kdTreeGenerator->SetSample(listSamples);
49  kdTreeGenerator->SetBucketSize(100);
50  kdTreeGenerator->Update();
51  // Randomly pick the initial means among the classes
52  unsigned int sampleSize = listSamples->GetMeasurementVector(0).Size();
53  const unsigned int OneClassNbCentroids = 10;
54  unsigned int numberOfCentroids = (m_NumberOfClasses == 1 ? OneClassNbCentroids : m_NumberOfClasses);
55  typename EstimatorType::ParametersType initialMeans(sampleSize * m_NumberOfClasses);
56  initialMeans.Fill(0.);
57 
58  if (m_NumberOfClasses > 1)
59  {
60  // For each class, choose a centroid as the first sample of this class encountered
61  for (ClassLabelType classLabel = 0; classLabel < m_NumberOfClasses; ++classLabel)
62  {
63  typename TrainingListSampleType::ConstIterator it = trainingSamples->Begin();
64  // Iterate on the label list and stop when classLabel is found
65  // TODO: add random initialization ?
66  for (it = trainingSamples->Begin(); it != trainingSamples->End(); ++it)
67  {
68  std::cout << " Training Samples is " << it.GetMeasurementVector()[0] << std::endl;
69  if (it.GetMeasurementVector()[0] == classLabel)
70  break;
71  }
72  if (it == trainingSamples->End())
73  {
74  itkExceptionMacro(<< "Unable to find a sample with class label " << classLabel);
75  }
76 
77  typename ListSampleType::InstanceIdentifier identifier = it.GetInstanceIdentifier();
78  const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
79  for (unsigned int i = 0; i < centroid.Size(); ++i)
80  {
81  initialMeans[classLabel * sampleSize + i] = centroid[i];
82  }
83  }
84  }
85  else
86  {
87  typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;
88  RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::GetInstance();
89  unsigned int nbLabelObjects = listSamples->Size();
90 
91  // Choose arbitrarily OneClassNbCentroids centroids among all available LabelObject
92  for (unsigned int centroidId = 0; centroidId < numberOfCentroids; ++centroidId)
93  {
94  typename ListSampleType::InstanceIdentifier identifier = randomGenerator->GetIntegerVariate(nbLabelObjects - 1);
95  const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
96  for (unsigned int i = 0; i < centroid.Size(); ++i)
97  {
98  initialMeans[centroidId * sampleSize + i] = centroid[i];
99  }
100  }
101  }
102 
103  // Run the KMeans algorithm
104  // Do KMeans estimation
105  typename EstimatorType::Pointer estimator = EstimatorType::New();
106  estimator->SetParameters(initialMeans);
107  estimator->SetKdTree(kdTreeGenerator->GetOutput());
108  estimator->SetMaximumIteration(10000);
109  estimator->SetCentroidPositionChangesThreshold(0.00001);
110  estimator->StartOptimization();
111 
112  // Retrieve final centroids
113  m_Centroids.clear();
114 
115  for (unsigned int cId = 0; cId < numberOfCentroids; ++cId)
116  {
117  VectorType newCenter(sampleSize);
118  for (unsigned int i = 0; i < sampleSize; ++i)
119  {
120  newCenter[i] = estimator->GetParameters()[cId * sampleSize + i];
121  }
122  m_Centroids.push_back(newCenter);
123  }
124 }
125 
126 } // end namespace otb
127 #endif
otb::KMeansAttributesLabelMapFilter::Compute
void Compute()
Definition: otbKMeansAttributesLabelMapFilter.hxx:38
otb::KMeansAttributesLabelMapFilter::ClassLabelType
LabelObjectType::ClassLabelType ClassLabelType
Definition: otbKMeansAttributesLabelMapFilter.h:60
otb::KMeansAttributesLabelMapFilter::TrainingListSampleType
itk::Statistics::ListSample< ClassLabelVectorType > TrainingListSampleType
Definition: otbKMeansAttributesLabelMapFilter.h:67
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::LabelMapWithClassLabelToLabeledSampleListFilter
This class converts a LabelObjectMap with some class labeled objects to a SampleList and a TrainingSa...
Definition: otbLabelMapWithClassLabelToLabeledSampleListFilter.h:45
otb::KMeansAttributesLabelMapFilter::ListSampleType
itk::Statistics::ListSample< VectorType > ListSampleType
Definition: otbKMeansAttributesLabelMapFilter.h:66
otb::KMeansAttributesLabelMapFilter::KMeansAttributesLabelMapFilter
KMeansAttributesLabelMapFilter()
Definition: otbKMeansAttributesLabelMapFilter.hxx:32
otb::KMeansAttributesLabelMapFilter::VectorType
itk::VariableLengthVector< AttributesValueType > VectorType
Definition: otbKMeansAttributesLabelMapFilter.h:63
otbKMeansAttributesLabelMapFilter.h