OTB  6.7.0
Orfeo Toolbox
otbKMeansAttributesLabelMapFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbKMeansAttributesLabelMapFilter_hxx
22 #define otbKMeansAttributesLabelMapFilter_hxx
23 
25 #include "itkNumericTraits.h"
27 
28 namespace otb {
29 
30 template <class TInputImage>
33  : m_LabelMapToSampleListFilter(LabelMapToSampleListFilterType::New()),
34  m_NumberOfClasses(1)
35 {
36 }
37 
38 template<class TInputImage>
39 void
42 {
43  m_LabelMapToSampleListFilter->SetInputLabelMap(m_InputLabelMap);
44  m_LabelMapToSampleListFilter->Update();
45 
46  typename ListSampleType::Pointer listSamples = const_cast<ListSampleType*>(m_LabelMapToSampleListFilter->GetOutputSampleList());
47  typename TrainingListSampleType::Pointer trainingSamples = const_cast<TrainingListSampleType*>(m_LabelMapToSampleListFilter->GetOutputTrainingSampleList());
48 
49  // Build the Kd Tree
50  typename TreeGeneratorType::Pointer kdTreeGenerator = TreeGeneratorType::New();
51  kdTreeGenerator->SetSample(listSamples);
52  kdTreeGenerator->SetBucketSize(100);
53  kdTreeGenerator->Update();
54  // Randomly pick the initial means among the classes
55  unsigned int sampleSize = listSamples->GetMeasurementVector(0).Size();
56  const unsigned int OneClassNbCentroids = 10;
57  unsigned int numberOfCentroids = (m_NumberOfClasses == 1 ? OneClassNbCentroids : m_NumberOfClasses);
58  typename EstimatorType::ParametersType initialMeans(sampleSize * m_NumberOfClasses);
59  initialMeans.Fill(0.);
60 
61  if (m_NumberOfClasses > 1)
62  {
63  // For each class, choose a centroid as the first sample of this class encountered
64  for (ClassLabelType classLabel = 0; classLabel < m_NumberOfClasses; ++classLabel)
65  {
66  typename TrainingListSampleType::ConstIterator it = trainingSamples->Begin();
67  // Iterate on the label list and stop when classLabel is found
68  // TODO: add random initialization ?
69  for (it = trainingSamples->Begin(); it != trainingSamples->End(); ++it)
70  {
71  std::cout <<" Training Samples is "<< it.GetMeasurementVector()[0] << std::endl;
72  if (it.GetMeasurementVector()[0] == classLabel)
73  break;
74  }
75  if (it == trainingSamples->End())
76  {
77  itkExceptionMacro(<<"Unable to find a sample with class label "<< classLabel);
78  }
79 
81  const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
82  for (unsigned int i = 0; i < centroid.Size(); ++i)
83  {
84  initialMeans[classLabel * sampleSize + i] = centroid[i];
85  }
86  }
87  }
88  else
89  {
91  RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::GetInstance();
92  unsigned int nbLabelObjects = listSamples->Size();
93 
94  // Choose arbitrarily OneClassNbCentroids centroids among all available LabelObject
95  for (unsigned int centroidId = 0; centroidId < numberOfCentroids; ++centroidId)
96  {
97  typename ListSampleType::InstanceIdentifier identifier = randomGenerator->GetIntegerVariate(nbLabelObjects - 1);
98  const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
99  for (unsigned int i = 0; i < centroid.Size(); ++i)
100  {
101  initialMeans[centroidId * sampleSize + i] = centroid[i];
102  }
103  }
104  }
105 
106  // Run the KMeans algorithm
107  // Do KMeans estimation
108  typename EstimatorType::Pointer estimator = EstimatorType::New();
109  estimator->SetParameters(initialMeans);
110  estimator->SetKdTree(kdTreeGenerator->GetOutput());
111  estimator->SetMaximumIteration(10000);
112  estimator->SetCentroidPositionChangesThreshold(0.00001);
113  estimator->StartOptimization();
114 
115  // Retrieve final centroids
116  m_Centroids.clear();
117 
118  for(unsigned int cId = 0; cId < numberOfCentroids; ++cId)
119  {
120  VectorType newCenter(sampleSize);
121  for(unsigned int i = 0; i < sampleSize; ++i)
122  {
123  newCenter[i] = estimator->GetParameters()[cId * sampleSize + i];
124  }
125  m_Centroids.push_back(newCenter);
126  }
127 }
128 
129 }// end namespace otb
130 #endif
This class converts a LabelObjectMap with some class labeled objets to a SampleList and a TrainingSam...
Superclass::MeasurementVectorType MeasurementVectorType
const MeasurementVectorType & GetMeasurementVector() const
InstanceIdentifier GetInstanceIdentifier() const
void Fill(TValue const &v)
Superclass::InstanceIdentifier InstanceIdentifier