OTB  6.7.0
Orfeo Toolbox
otbTrainImagesBase.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbTrainImagesBase_hxx
22 #define otbTrainImagesBase_hxx
23 
24 #include "otbTrainImagesBase.h"
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
31 {
32  //Group IO
33  AddParameter( ParameterType_Group, "io", "Input and output data" );
34  SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
35 
36  AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" );
37  SetParameterDescription( "io.il", "A list of input images." );
38  AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
39  SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
40  MandatoryOn( "io.vd" );
41 
42  AddParameter( ParameterType_Bool, "cleanup", "Temporary files cleaning" );
43  SetParameterDescription( "cleanup",
44  "If activated, the application will try to clean all temporary files it created" );
45  SetParameterInt( "cleanup", 1);
46 }
47 
49 {
50  AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" );
51  AddApplication( "MultiImageSamplingRate", "rates", "Sampling rates" );
52  AddApplication( "SampleSelection", "select", "Sample selection" );
53  AddApplication( "SampleExtraction", "extraction", "Sample extraction" );
54 
55  // Sampling settings
56  AddParameter( ParameterType_Group, "sample", "Training and validation samples parameters" );
57  SetParameterDescription( "sample",
58  "This group of parameters allows you to set training and validation sample lists parameters." );
59  AddParameter( ParameterType_Int, "sample.mt", "Maximum training sample size per class" );
60  SetDefaultParameterInt( "sample.mt", 1000 );
61  SetParameterDescription( "sample.mt", "Maximum size per class (in pixels) of "
62  "the training sample list (default = 1000) (no limit = -1). If equal to -1,"
63  " then the maximal size of the available training sample list per class "
64  "will be equal to the surface area of the smallest class multiplied by the"
65  " training sample ratio." );
66  AddParameter( ParameterType_Int, "sample.mv", "Maximum validation sample size per class" );
67  SetDefaultParameterInt( "sample.mv", 1000 );
68  SetParameterDescription( "sample.mv", "Maximum size per class (in pixels) of "
69  "the validation sample list (default = 1000) (no limit = -1). If equal to -1,"
70  " then the maximal size of the available validation sample list per class "
71  "will be equal to the surface area of the smallest class multiplied by the "
72  "validation sample ratio." );
73  AddParameter( ParameterType_Int, "sample.bm", "Bound sample number by minimum" );
74  SetDefaultParameterInt( "sample.bm", 1 );
75  SetParameterDescription( "sample.bm", "Bound the number of samples for each "
76  "class by the number of available samples by the smaller class. Proportions "
77  "between training and validation are respected. Default is true (=1)." );
78  AddParameter( ParameterType_Float, "sample.vtr", "Training and validation sample ratio" );
79  SetParameterDescription( "sample.vtr", "Ratio between training and validation samples (0.0 = all training, 1.0 = "
80  "all validation) (default = 0.5)." );
81  SetParameterFloat( "sample.vtr", 0.5);
82  SetMaximumParameterFloatValue( "sample.vtr", 1.0 );
83  SetMinimumParameterFloatValue( "sample.vtr", 0.0 );
84 
85 // AddParameter( ParameterType_Float, "sample.percent", "Percentage of sample extract from images" );
86 // SetParameterDescription( "sample.percent", "Percentage of sample extract from images for "
87 // "training and validation when only images are provided." );
88 // SetDefaultParameterFloat( "sample.percent", 1.0 );
89 // SetMinimumParameterFloatValue( "sample.percent", 0.0 );
90 // SetMaximumParameterFloatValue( "sample.percent", 1.0 );
91 
94 }
95 
97 {
98  // hide sampling parameters
99  //ShareParameter("sample.strategy","rates.strategy");
100  //ShareParameter("sample.mim","rates.mim");
101  ShareParameter( "ram", "polystat.ram" );
102  ShareParameter( "elev", "polystat.elev" );
103  ShareParameter( "sample.vfn", "polystat.field",
104  "Field containing the class integer label for supervision" ,
105  "Field containing the class id for supervision. "
106  "The values in this field shall be cast into integers.");
107 }
108 
110 {
111  Connect( "extraction.field", "polystat.field" );
112  Connect( "extraction.layer", "polystat.layer" );
113 
114  Connect( "select.ram", "polystat.ram" );
115  Connect( "extraction.ram", "polystat.ram" );
116 
117  Connect( "select.field", "polystat.field" );
118  Connect( "select.layer", "polystat.layer" );
119  Connect( "select.elev", "polystat.elev" );
120 
121  Connect( "extraction.in", "select.in" );
122  Connect( "extraction.vec", "select.out" );
123 }
124 
126 {
127  AddApplication( "TrainVectorClassifier", "training", "Model training" );
128 
129  AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" );
130  SetParameterDescription( "io.valid", "A list of vector data to select the validation samples." );
131  MandatoryOff( "io.valid" );
132 
135 };
136 
138 {
139  ShareParameter( "io.imstat", "training.io.stats" );
140  ShareParameter( "io.out", "training.io.out" );
141 
142  ShareParameter( "classifier", "training.classifier" );
143  ShareParameter( "rand", "training.rand" );
144 
145  ShareParameter( "io.confmatout", "training.io.confmatout" );
146 }
147 
149 {
150  Connect( "training.cfield", "polystat.field" );
151  Connect( "select.rand", "training.rand" );
152 }
153 
155  const std::vector<std::string> &vectorFileNames,
156  const std::vector<std::string> &statisticsFileNames)
157 {
158  unsigned int nbImages = static_cast<unsigned int>(imageList->Size());
159  for( unsigned int i = 0; i < nbImages; i++ )
160  {
161  GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) );
162  GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileNames[i]);
163  GetInternalApplication( "polystat" )->SetParameterString( "out", statisticsFileNames[i]);
164  ExecuteInternal( "polystat" );
165  }
166 }
167 
168 
170 {
171  SamplingRates rates;
172  GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional");
173  double vtr = GetParameterFloat( "sample.vtr" );
174  long mt = GetParameterInt( "sample.mt" );
175  long mv = GetParameterInt( "sample.mv" );
176  // compute final maximum training and final maximum validation
177  // By default take all samples (-1 means all samples)
178  rates.fmt = -1;
179  rates.fmv = -1;
180  if( GetParameterInt( "sample.bm" ) == 0 )
181  {
182  if( dedicatedValidation )
183  {
184  // fmt and fmv will be used separately
185  rates.fmt = mt;
186  rates.fmv = mv;
187  if( mt > -1 && mv <= -1 && vtr < 0.99999 )
188  {
189  rates.fmv = static_cast<long>(( double ) mt * vtr / ( 1.0 - vtr ));
190  }
191  if( mt <= -1 && mv > -1 && vtr > 0.00001 )
192  {
193  rates.fmt = static_cast<long>(( double ) mv * ( 1.0 - vtr ) / vtr);
194  }
195  }
196  else
197  {
198  // only fmt will be used for both training and validation samples
199  // So we try to compute the total number of samples given input
200  // parameters mt, mv and vtr.
201  if( mt > -1 && vtr < 0.99999 )
202  {
203  rates.fmt = static_cast<long>(( double ) mt / ( 1.0 - vtr ));
204  }
205  if( mv > -1 && vtr > 0.00001 )
206  {
207  if( rates.fmt > -1 )
208  {
209  rates.fmt = std::min( rates.fmt, static_cast<long>(( double ) mv / vtr) );
210  }
211  else
212  {
213  rates.fmt = static_cast<long>(( double ) mv / vtr);
214  }
215  }
216  }
217  }
218  return rates;
219 }
220 
221 
222 void TrainImagesBase::ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames,
223  const std::string &ratesFileName, long maximum)
224 {
225  // Sampling rates
226  GetInternalApplication( "rates" )->SetParameterStringList( "il", statisticsFileNames);
227  GetInternalApplication( "rates" )->SetParameterString( "out", ratesFileName);
228  if( GetParameterInt( "sample.bm" ) != 0 )
229  {
230  GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest");
231  }
232  else
233  {
234  if( maximum > -1 )
235  {
236  std::ostringstream oss;
237  oss << maximum;
238  GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant");
239  GetInternalApplication( "rates" )->SetParameterString( "strategy.constant.nb", oss.str());
240  }
241  else
242  {
243  GetInternalApplication( "rates" )->SetParameterString( "strategy", "all");
244  }
245  }
246  ExecuteInternal( "rates" );
247 }
248 
249 void
250 TrainImagesBase::TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames,
251  const std::vector<std::string> &sampleValidationFileNames)
252 {
253  GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileNames);
254  if( !sampleValidationFileNames.empty() )
255  GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", sampleValidationFileNames);
256 
257  UpdateInternalParameters( "training" );
258  // set field names
259  FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 );
260  unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
261  std::vector<std::string> selectedNames;
262  for( unsigned int i = 0; i < nbBands; i++ )
263  {
264  std::ostringstream oss;
265  oss << i;
266  selectedNames.push_back( "value_" + oss.str() );
267  }
268  GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames);
269  ExecuteInternal( "training" );
270 }
271 
272 void TrainImagesBase::SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName,
273  std::string sampleFileName, std::string statisticsFileName,
274  std::string ratesFileName, SamplingStrategy strategy,
275  std::string selectedField)
276 {
277  GetInternalApplication( "select" )->SetParameterInputImage( "in", image );
278  GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName);
279 
280  // Change the selection strategy based on selected sampling strategy
281  switch( strategy )
282  {
283 // case GEOMETRIC:
284 // GetInternalApplication( "select" )->SetParameterString( "sampler", "random");
285 // GetInternalApplication( "select" )->SetParameterString( "strategy", "percent");
286 // GetInternalApplication( "select" )->SetParameterFloat( "strategy.percent.p",
287 // GetParameterFloat( "sample.percent" ));
288 // break;
289  case CLASS:
290  default:
291  GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName);
292  GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName);
293  GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic");
294  GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
295  GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass");
296  GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", ratesFileName);
297  break;
298  }
299 
300  // select sample positions
301  ExecuteInternal( "select" );
302 
303  GetInternalApplication( "extraction" )->SetParameterString( "vec", sampleFileName);
304  UpdateInternalParameters( "extraction" );
305  if( !selectedField.empty() )
306  GetInternalApplication( "extraction" )->SetParameterString( "field", selectedField);
307 
308  GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix");
309  GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_");
310 
311  // extract sample descriptors
312  ExecuteInternal( "extraction" );
313 }
314 
315 
317  FloatVectorImageListType *imageList,
318  std::vector<std::string> vectorFileNames, SamplingStrategy strategy,
319  std::string selectedFieldName)
320 {
321 
322  for( unsigned int i = 0; i < imageList->Size(); ++i )
323  {
324  std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i];
325  SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i],
326  fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy,
327  selectedFieldName );
328  }
329 }
330 
331 
333  FloatVectorImageListType *imageList,
334  const std::vector<std::string> &validationVectorFileList)
335 {
336  for( unsigned int i = 0; i < imageList->Size(); ++i )
337  {
338  SelectAndExtractSamples( imageList->GetNthElement( i ), validationVectorFileList[i],
339  fileNames.sampleValidOutputs[i], fileNames.polyStatValidOutputs[i],
340  fileNames.ratesValidOutputs[i], Self::CLASS );
341  }
342 }
343 
345  FloatVectorImageListType *imageList)
346 {
347  for( unsigned int i = 0; i < imageList->Size(); ++i )
348  {
349  SplitTrainingAndValidationSamples( imageList->GetNthElement( i ), fileNames.sampleOutputs[i],
350  fileNames.sampleTrainOutputs[i], fileNames.sampleValidOutputs[i],
351  fileNames.ratesTrainOutputs[i] );
352  }
353 }
354 
356  std::string sampleTrainFileName,
357  std::string sampleValidFileName,
358  std::string ratesTrainFileName)
359 
360 {
361  // Split between training and validation
365  // read sampling rates from ratesTrainOutputs
367  rateCalculator->Read( ratesTrainFileName );
368  // Compute sampling rates for train and valid
369  const MapRateType &inputRates = rateCalculator->GetRatesByClass();
370  MapRateType trainRates;
371  MapRateType validRates;
373  for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
374  {
375  double vtr = GetParameterFloat( "sample.vtr" );
376  unsigned long total = std::min( it->second.Required, it->second.Tot );
377  unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
378  unsigned long neededTrain = total - neededValid;
379  tpt.Tot = total;
380  tpt.Required = neededTrain;
381  tpt.Rate = ( 1.0 - vtr );
382  trainRates[it->first] = tpt;
383  tpt.Tot = neededValid;
384  tpt.Required = neededValid;
385  tpt.Rate = 1.0;
386  validRates[it->first] = tpt;
387  }
388 
389  // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
391  param.Offset = 0;
392  param.MaxJitter = 0;
394  splitter->SetInput( image );
395  splitter->SetOGRData( source );
396  splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
397  splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
398  splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
399  splitter->SetLayerIndex( 0 );
400  splitter->SetOriginFieldName( std::string( "" ) );
401  splitter->SetSamplerParameters( param );
402  splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
403  AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
404  splitter->Update();
405 }
406 }
407 }
408 
409 #endif
void TrainModel(FloatVectorImageListType *imageList, const std::vector< std::string > &sampleTrainFileNames, const std::vector< std::string > &sampleValidationFileNames)
void SetParameterStringList(std::string parameter, std::vector< std::string > values, bool hasUserValueFlag=true)
Creation of an "otb" vector image which contains metadata.
bool AddApplication(std::string appType, std::string key, std::string desc)
void SetParameterString(std::string parameter, std::string value, bool hasUserValueFlag=true)
void ComputeSamplingRate(const std::vector< std::string > &statisticsFileNames, const std::string &ratesFileName, long maximum)
void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector< std::string > &vectorFileNames, const std::vector< std::string > &statisticsFileNames)
ObjectPointerType GetNthElement(unsigned int index) const
void SetMinimumParameterFloatValue(std::string parameter, float value)
bool Connect(std::string fromKey, std::string toKey)
InternalContainerSizeType Size(void) const override
void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, std::string sampleTrainFileName, std::string sampleValidFileName, std::string ratesTrainFileName)
void UpdateInternalParameters(std::string key)
void MandatoryOff(std::string paramKey)
void SetDefaultParameterInt(std::string parameter, int value)
void AddProcess(itk::ProcessObject *object, std::string description)
otb::SamplingRateCalculator::MapRateType MapRateType
void SetParameterDescription(std::string paramKey, std::string dec)
SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation)
void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, std::vector< std::string > vectorFileNames, SamplingStrategy strategy, std::string selectedFieldName="")
void SetParameterInt(std::string parameter, int value, bool hasUserValueFlag=true)
This class is a generic all-purpose wrapping around an std::vector<itk::SmartPointer<ObjectType> >...
Definition: otbObjectList.h:40
void SplitTrainingToValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList)
static Pointer New()
void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, const std::vector< std::string > &validationVectorFileList=std::vector< std::string >())
void SetParameterFloat(std::string parameter, float value, bool hasUserValueFlag=true)
std::vector< std::string > GetParameterStringList(const std::string &parameter)
void MandatoryOn(std::string paramKey)
struct OTBSampling_EXPORT otb::SamplingRateCalculator::Triplet TripletType
void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName, std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy, std::string selectedField="")
Application * GetInternalApplication(std::string id)
void SetParameterInputImage(std::string parameter, ImageBaseType *inputImage)
int GetParameterInt(std::string parameter)
float GetParameterFloat(std::string parameter)
bool ShareParameter(std::string localKey, std::string internalKey, std::string name=std::string(), std::string desc=std::string())
void AddParameter(ParameterType type, std::string paramKey, std::string paramName)
void SetMaximumParameterFloatValue(std::string parameter, float value)
void ExecuteInternal(std::string key)
Open data source in read-only mode.