OTB  9.0.0
Orfeo Toolbox
otbTrainSharkKMeans.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainSharkKMeans_hxx
21 #define otbTrainSharkKMeans_hxx
22 
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 template <class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
33 {
34  AddChoice("classifier.sharkkm", "Shark kmeans classifier");
35  SetParameterDescription("classifier.sharkkm", "http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html ");
36 
37  // MaxNumberOfIterations
38  AddParameter(ParameterType_Int, "classifier.sharkkm.maxiter", "Maximum number of iterations for the kmeans algorithm");
39  SetParameterInt("classifier.sharkkm.maxiter", 10);
40  SetMinimumParameterIntValue("classifier.sharkkm.maxiter", 0);
41  SetParameterDescription("classifier.sharkkm.maxiter", "The maximum number of iterations for the kmeans algorithm. 0=unlimited");
42 
43  // Number of classes
44  AddParameter(ParameterType_Int, "classifier.sharkkm.k", "Number of classes for the kmeans algorithm");
45  SetParameterInt("classifier.sharkkm.k", 2);
46  SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
47  SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
48 
49  // Input centroids
50  AddParameter(ParameterType_InputFilename, "classifier.sharkkm.incentroids", "User defined input centroids");
51  SetParameterDescription("classifier.sharkkm.incentroids",
52  "Input text file containing centroid posistions used to initialize the algorithm. "
53  "Each centroid must be described by p parameters, p being the number of features in "
54  "the input vector data, and the number of centroids must be equal to the number of classes "
55  "(one centroid per line with values separated by spaces).");
56  MandatoryOff("classifier.sharkkm.incentroids");
57 
58  // Centroid statistics
59  AddParameter(ParameterType_InputFilename, "classifier.sharkkm.cstats", "Statistics file");
60  SetParameterDescription("classifier.sharkkm.cstats",
61  "A XML file containing mean and standard deviation to center"
62  "and reduce the input centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
63  MandatoryOff("classifier.sharkkm.cstats");
64 
65  // Output centroids
66  AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
67  SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
68  MandatoryOff("classifier.sharkkm.outcentroids");
69 }
70 
71 template <class TInputValue, class TOutputValue>
72 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample,
73  typename TargetListSampleType::Pointer trainingLabeledListSample,
74  std::string modelPath)
75 {
76  unsigned int nbMaxIter = static_cast<unsigned int>(abs(GetParameterInt("classifier.sharkkm.maxiter")));
77  unsigned int k = static_cast<unsigned int>(abs(GetParameterInt("classifier.sharkkm.k")));
78 
80  typename SharkKMeansType::Pointer classifier = SharkKMeansType::New();
81  classifier->SetRegressionMode(this->m_RegressionFlag);
82  classifier->SetInputListSample(trainingListSample);
83  classifier->SetTargetListSample(trainingLabeledListSample);
84  classifier->SetK(k);
85 
86  // Initialize centroids from file
87  if (IsParameterEnabled("classifier.sharkkm.incentroids") && HasValue("classifier.sharkkm.incentroids"))
88  {
89  shark::Data<shark::RealVector> centroidData;
90  shark::importCSV(centroidData, GetParameterString("classifier.sharkkm.incentroids"), ' ');
91  if (HasValue("classifier.sharkkm.cstats"))
92  {
94  statisticsReader->SetFileName(GetParameterString("classifier.sharkkm.cstats"));
95  auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
96  auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
97 
98  // Convert itk Variable Length Vector to shark Real Vector
99  shark::RealVector offsetRV(meanMeasurementVector.Size());
100  shark::RealVector scaleRV(stddevMeasurementVector.Size());
101 
102  assert(meanMeasurementVector.Size() == stddevMeasurementVector.Size());
103  for (unsigned int i = 0; i < meanMeasurementVector.Size(); ++i)
104  {
105  scaleRV[i] = 1 / stddevMeasurementVector[i];
106  // Subtract the normalized mean
107  offsetRV[i] = -meanMeasurementVector[i] / stddevMeasurementVector[i];
108  }
109 
110  shark::Normalizer<> normalizer(scaleRV, offsetRV);
111  centroidData = normalizer(centroidData);
112  }
113 
114  if (centroidData.numberOfElements() != k)
115  otbAppLogWARNING("The input centroid file will not be used because it contains "
116  << centroidData.numberOfElements() << " points, which is different than from the requested number of class: " << k << ".");
117 
118  classifier->SetCentroidsFromData(centroidData);
119  }
120 
121  classifier->SetMaximumNumberOfIterations(nbMaxIter);
122  classifier->Train();
123  classifier->Save(modelPath);
124 
125  if (HasValue("classifier.sharkkm.outcentroids"))
126  classifier->ExportCentroids(GetParameterString("classifier.sharkkm.outcentroids"));
127 }
128 
129 } // end namespace wrapper
130 } // end namespace otb
131 
132 #endif
otbSharkKMeansMachineLearningModel.h
otbStatisticsXMLFileReader.h
otbAppLogWARNING
#define otbAppLogWARNING(x)
Definition: otbWrapperMacros.h:45
otb::Wrapper::ParameterType_OutputFilename
@ ParameterType_OutputFilename
Definition: otbWrapperTypes.h:45
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbLearningApplicationBase.h
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::SharkKMeansMachineLearningModel
Definition: otbSharkKMeansMachineLearningModel.h:82
otb::StatisticsXMLFileReader
Read a xml file where are stored several statistics.
Definition: otbStatisticsXMLFileReader.h:42
otb::Wrapper::ParameterType_InputFilename
@ ParameterType_InputFilename
Definition: otbWrapperTypes.h:43