OTB  6.7.0
Orfeo Toolbox
otbTrainVectorBase.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainVectorBase_hxx
21 #define otbTrainVectorBase_hxx
22 
23 #include "otbTrainVectorBase.h"
24 
25 namespace otb
26 {
27 namespace Wrapper
28 {
29 
30 template <class TInputValue, class TOutputValue>
31 void
34 {
35  // Common Parameters for all Learning Application
36  this->AddParameter( ParameterType_Group, "io", "Input and output data" );
37  this->SetParameterDescription( "io",
38  "This group of parameters allows setting input and output data." );
39 
40  this->AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" );
41  this->SetParameterDescription( "io.vd",
42  "Input geometries used for training (note: all geometries from the layer will be used)" );
43 
44  this->AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" );
45  this->MandatoryOff( "io.stats" );
46  this->SetParameterDescription( "io.stats",
47  "XML file containing mean and variance of each feature." );
48 
49  this->AddParameter( ParameterType_OutputFilename, "io.out", "Output model" );
50  this->SetParameterDescription( "io.out",
51  "Output file containing the model estimated (.txt format)." );
52 
53  this->AddParameter( ParameterType_Int, "layer", "Layer Index" );
54  this->SetParameterDescription( "layer",
55  "Index of the layer to use in the input vector file." );
56  this->MandatoryOff( "layer" );
57  this->SetDefaultParameterInt( "layer", 0 );
58 
59  this->AddParameter(ParameterType_ListView, "feat", "Field names for training features");
60  this->SetParameterDescription("feat",
61  "List of field names in the input vector data to be used as features for training.");
62 
63  // Add validation data used to compute confusion matrix or contingency table
64  this->AddParameter( ParameterType_Group, "valid", "Validation data" );
65  this->SetParameterDescription( "valid",
66  "This group of parameters defines validation data." );
67 
68  this->AddParameter( ParameterType_InputVectorDataList, "valid.vd",
69  "Validation Vector Data" );
70  this->SetParameterDescription( "valid.vd", "Geometries used for validation "
71  "(must contain the same fields used for training, all geometries from the layer will be used)" );
72  this->MandatoryOff( "valid.vd" );
73 
74  this->AddParameter( ParameterType_Int, "valid.layer", "Layer Index" );
75  this->SetParameterDescription( "valid.layer",
76  "Index of the layer to use in the validation vector file." );
77  this->MandatoryOff( "valid.layer" );
78  this->SetDefaultParameterInt( "valid.layer", 0 );
79 
80  // Add class field if we used validation
81  this->AddParameter( ParameterType_ListView, "cfield",
82  "Field containing the class integer label for supervision" );
83  this->SetParameterDescription( "cfield",
84  "Field containing the class id for supervision. "
85  "The values in this field shall be cast into integers. "
86  "Only geometries with this field available will be taken into account." );
87  this->SetListViewSingleSelectionMode( "cfield", true );
88 
89  this->AddParameter(ParameterType_Bool, "v", "Verbose mode");
90  this->SetParameterDescription("v", "Verbose mode, display the contingency table result.");
91  this->SetParameterInt("v", 1);
92 
93  // Doc example parameter settings
94  this->SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
95  this->SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
96  this->SetDocExampleParameterValue( "io.out", "svmModel.svm" );
97  this->SetDocExampleParameterValue( "feat", "perimeter area width" );
98  this->SetDocExampleParameterValue( "cfield", "predicted" );
99 
100 
101  // Add parameters for the classifier choice
102  Superclass::DoInit();
103 
104  this->AddRANDParameter();
105 }
106 
107 template <class TInputValue, class TOutputValue>
108 void
111 {
112  // if vector data is present and updated then reload fields
113  if( this->HasValue( "io.vd" ) )
114  {
115  std::vector<std::string> vectorFileList = this->GetParameterStringList( "io.vd" );
117  ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) );
118  ogr::Feature feature = layer.ogr().GetNextFeature();
119 
120  this->ClearChoices( "feat" );
121  this->ClearChoices( "cfield" );
122 
123  for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ )
124  {
125  std::string key, item = feature.ogr().GetFieldDefnRef( iField )->GetNameRef();
126  key = item;
127  std::string::iterator end = std::remove_if( key.begin(), key.end(), IsNotAlphaNum );
128  std::transform( key.begin(), end, key.begin(), tolower );
129 
130  OGRFieldType fieldType = feature.ogr().GetFieldDefnRef( iField )->GetType();
131 
132  if( fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal )
133  {
134  std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
135  this->AddChoice( tmpKey, item );
136  }
137  if( fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal )
138  {
139  std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
140  this->AddChoice( tmpKey, item );
141  }
142  }
143  }
144 }
145 
146 template <class TInputValue, class TOutputValue>
147 void
150 {
151  m_FeaturesInfo.SetFieldNames( this->GetChoiceNames( "feat" ), this->GetSelectedItems( "feat" ));
152 
153  // Check input parameters
154  if( m_FeaturesInfo.m_SelectedIdx.empty() )
155  {
156  otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
157  }
158 
159  ShiftScaleParameters measurement = GetStatistics( m_FeaturesInfo.m_NbFeatures );
160  ExtractAllSamples( measurement );
161 
162  this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString( "io.out" ) );
163 
164  m_PredictedList =
165  this->Classify( m_ClassificationSamplesWithLabel.listSample, this->GetParameterString( "io.out" ) );
166 }
167 
168 template <class TInputValue, class TOutputValue>
169 void
172 {
173  m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
174  m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
175 }
176 
177 template <class TInputValue, class TOutputValue>
181 {
182  return ExtractSamplesWithLabel( "io.vd", "layer", measurement);
183 }
184 
185 template <class TInputValue, class TOutputValue>
189 {
190  if(this->GetClassifierCategory() == Superclass::Supervised)
191  {
192  SamplesWithLabel tmpSamplesWithLabel;
193  SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement );
194  //Test the input validation set size
195  if( validationSamplesWithLabel.labeledListSample->Size() != 0 )
196  {
197  tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample;
198  tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample;
199  }
200  else
201  {
203  "The validation set is empty. The performance estimation is done using the input training set in this case." );
204  tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample;
205  tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
206  }
207 
208  return tmpSamplesWithLabel;
209  }
210  else
211  {
212  return m_TrainingSamplesWithLabel;
213  }
214 }
215 
216 template <class TInputValue, class TOutputValue>
219 ::GetStatistics(unsigned int nbFeatures)
220 {
222  if( this->HasValue( "io.stats" ) && this->IsParameterEnabled( "io.stats" ) )
223  {
224  typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
225  std::string XMLfile = this->GetParameterString( "io.stats" );
226  statisticsReader->SetFileName( XMLfile );
227  measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" );
228  measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" );
229  }
230  else
231  {
232  measurement.meanMeasurementVector.SetSize( nbFeatures );
233  measurement.meanMeasurementVector.Fill( 0. );
234  measurement.stddevMeasurementVector.SetSize( nbFeatures );
235  measurement.stddevMeasurementVector.Fill( 1. );
236  }
237  return measurement;
238 }
239 
240 template <class TInputValue, class TOutputValue>
243 ::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer,
244  const ShiftScaleParameters &measurement)
245 {
246  SamplesWithLabel samplesWithLabel;
247  if( this->HasValue( parameterName ) && this->IsParameterEnabled( parameterName ) )
248  {
249  typename ListSampleType::Pointer input = ListSampleType::New();
250  typename TargetListSampleType::Pointer target = TargetListSampleType::New();
251  input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures );
252 
253  std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
254  for( unsigned int k = 0; k < fileList.size(); k++ )
255  {
256  otbAppLogINFO( "Reading vector file " << k + 1 << "/" << fileList.size() );
258  ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) );
259  ogr::Feature feature = layer.ogr().GetNextFeature();
260  bool goesOn = feature.addr() != 0;
261  if( !goesOn )
262  {
263  otbAppLogWARNING( "The layer " << this->GetParameterInt( parameterLayer ) << " of " << fileList[k]
264  << " is empty, input is skipped." );
265  continue;
266  }
267 
268  // Check all needed fields are present :
269  // - check class field if we use supervised classification or if class field name is not empty
270  int cFieldIndex = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedCFieldName.c_str() );
271  if( cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
272  {
273  otbAppLogFATAL( "The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName
274  << ") has not been found in the vector file "
275  << fileList[k] );
276  }
277 
278  // - check feature fields
279  std::vector<int> featureFieldIndex( m_FeaturesInfo.m_NbFeatures, -1 );
280  for( unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++ )
281  {
282  featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedNames[i].c_str() );
283  if( featureFieldIndex[i] < 0 )
284  otbAppLogFATAL( "The field name for feature " << m_FeaturesInfo.m_SelectedNames[i]
285  << " has not been found in the vector file "
286  << fileList[k] );
287  }
288 
289 
290  while( goesOn )
291  {
292  // Retrieve all the features for each field in the ogr layer.
293  MeasurementType mv;
294  mv.SetSize( m_FeaturesInfo.m_NbFeatures );
295  for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx )
296  {
297  switch (feature[featureFieldIndex[idx]].GetType())
298  {
299  case OFTInteger:
300  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
301  break;
302  case OFTInteger64:
303  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
304  break;
305  case OFTReal:
306  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<double>());
307  break;
308  default:
309  itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[idx]].GetType() << ".");
310  }
311  }
312 
313  input->PushBack( mv );
314 
315  if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet())
316  {
317  switch (feature[cFieldIndex].GetType())
318  {
319  case OFTInteger:
320  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
321  break;
322  case OFTInteger64:
323  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
324  break;
325  case OFTReal:
326  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<double>()));
327  break;
328  case OFTString:
329  target->PushBack(static_cast<ValueType>(std::stod(feature[cFieldIndex].GetValue<std::string>())));
330  break;
331  default:
332  itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[cFieldIndex]].GetType() << ".");
333  }
334  }
335  else
336  target->PushBack( 0. );
337 
338  feature = layer.ogr().GetNextFeature();
339  goesOn = feature.addr() != 0;
340  }
341  }
342 
343 
344 
345  typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
346  shiftScaleFilter->SetInput( input );
347  shiftScaleFilter->SetShifts( measurement.meanMeasurementVector );
348  shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );
349  shiftScaleFilter->Update();
350 
351  samplesWithLabel.listSample = shiftScaleFilter->GetOutput();
352  samplesWithLabel.labeledListSample = target;
353  samplesWithLabel.listSample->DisconnectPipeline();
354  }
355 
356  return samplesWithLabel;
357 }
358 
359 
360 }
361 }
362 
363 #endif
ShiftScaleParameters GetStatistics(unsigned int nbFeatures)
OGRFeature const * addr() const
Layer of geometric objets.
#define otbAppLogFATAL(x)
OGRFeature & ogr() const
#define otbAppLogWARNING(x)
#define otbAppLogINFO(x)
SamplesWithLabel ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement)
void Fill(TValue const &v) noexcept
static Pointer New()
OTBApplicationEngine_EXPORT MDType GetType(const std::string &val)
OGRLayer & ogr()
virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
virtual void ExtractAllSamples(const ShiftScaleParameters &measurement)
Geometric objet with descriptive fields.
Encapsulation of OGRField Instances of Field are expected to be built from an existing Feature with w...
bool IsNotAlphaNum(char c)
void SetSize(unsigned int sz, TReallocatePolicy reallocatePolicy, TKeepValuesPolicy keepValues)
virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)
Open data source in read-only mode.