OTB  6.7.0
Orfeo Toolbox
otbTrainSharkKMeans.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainSharkKMeans_hxx
21 #define otbTrainSharkKMeans_hxx
22 
26 
27 namespace otb
28 {
29 namespace Wrapper
30 {
31 template<class TInputValue, class TOutputValue>
32 void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
33 {
34  AddChoice( "classifier.sharkkm", "Shark kmeans classifier" );
35  SetParameterDescription("classifier.sharkkm", "http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html ");
36 
37  // MaxNumberOfIterations
38  AddParameter(ParameterType_Int, "classifier.sharkkm.maxiter", "Maximum number of iterations for the kmeans algorithm");
39  SetParameterInt("classifier.sharkkm.maxiter", 10);
40  SetMinimumParameterIntValue("classifier.sharkkm.maxiter", 0);
41  SetParameterDescription("classifier.sharkkm.maxiter", "The maximum number of iterations for the kmeans algorithm. 0=unlimited");
42 
43  // Number of classes
44  AddParameter(ParameterType_Int, "classifier.sharkkm.k", "Number of classes for the kmeans algorithm");
45  SetParameterInt("classifier.sharkkm.k", 2);
46  SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
47  SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
48 
49  // Centroid IO
50  AddParameter( ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters" );
51  SetParameterDescription( "classifier.sharkkm.centroids", "Group of parameters for centroids IO." );
52 
53  // Input centroids
54  AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids");
55  SetParameterDescription("classifier.sharkkm.centroids.in", "Input text file containing centroid posistions used to initialize the algorithm. "
56  "Each centroid must be described by p parameters, p being the number of features in "
57  "the input vector data, and the number of centroids must be equal to the number of classes "
58  "(one centroid per line with values separated by spaces).");
59  MandatoryOff("classifier.sharkkm.centroids");
60 
61  // Centroid statistics
62  AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file");
63  SetParameterDescription("classifier.sharkkm.centroids.stats", "A XML file containing mean and standard deviation to center"
64  "and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
65  MandatoryOff("classifier.sharkkm.centroids.stats");
66 
67  // Output centroids
68  AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file");
69  SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm.");
70  MandatoryOff("classifier.sharkkm.centroids.out");
71 
72 }
73 
74 template<class TInputValue, class TOutputValue>
75 void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
76  typename ListSampleType::Pointer trainingListSample,
77  typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
78 {
79  unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.maxiter" ) ));
80  unsigned int k = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.k" ) ));
81 
83  typename SharkKMeansType::Pointer classifier = SharkKMeansType::New();
84  classifier->SetRegressionMode( this->m_RegressionFlag );
85  classifier->SetInputListSample( trainingListSample );
86  classifier->SetTargetListSample( trainingLabeledListSample );
87  classifier->SetK( k );
88 
89  // Initialize centroids from file
90  if(IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in"))
91  {
92  shark::Data<shark::RealVector> centroidData;
93  shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids.in"), ' ');
94  if( HasValue( "classifier.sharkkm.centroids.stats" ) )
95  {
97  statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroids.stats" ));
98  auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
99  auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
100 
101  // Convert itk Variable Length Vector to shark Real Vector
102  shark::RealVector offsetRV(meanMeasurementVector.Size());
103  shark::RealVector scaleRV(stddevMeasurementVector.Size());
104 
105  assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size());
106  for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
107  {
108  scaleRV[i] = 1/stddevMeasurementVector[i];
109  // Substract the normalized mean
110  offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i];
111  }
112 
113  shark::Normalizer<> normalizer(scaleRV, offsetRV);
114  centroidData = normalizer(centroidData);
115  }
116 
117  if (centroidData.numberOfElements() != k)
118  otbAppLogWARNING( "The input centroid file will not be used because it contains " << centroidData.numberOfElements() <<
119  " points, which is different than from the requested number of class: " << k <<".");
120 
121  classifier->SetCentroidsFromData( centroidData);
122  }
123 
124  classifier->SetMaximumNumberOfIterations( nbMaxIter );
125  classifier->Train();
126  classifier->Save( modelPath );
127 
128  if( HasValue( "classifier.sharkkm.centroids.out"))
129  classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.centroids.out" ));
130 }
131 
132 } //end namespace wrapper
133 } //end namespace otb
134 
135 #endif
Read a xml file where are stored several statistics.
#define otbAppLogWARNING(x)
bool abs(const bool x)