OTB  9.0.0
Orfeo Toolbox
otbTrainVectorBase.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainVectorBase_hxx
21 #define otbTrainVectorBase_hxx
22 
23 #include "otbTrainVectorBase.h"
24 
25 namespace otb
26 {
27 namespace Wrapper
28 {
29 
30 template <class TInputValue, class TOutputValue>
32 {
33  // Common Parameters for all Learning Application
34  this->AddParameter(ParameterType_Group, "io", "Input and output data");
35  this->SetParameterDescription("io", "This group of parameters allows setting input and output data.");
36 
37  this->AddParameter(ParameterType_InputVectorDataList, "io.vd", "Input Vector Data");
38  this->SetParameterDescription("io.vd", "Input geometries used for training (note: all geometries from the layer will be used)");
39 
40  this->AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file");
41  this->MandatoryOff("io.stats");
42  this->SetParameterDescription("io.stats", "XML file containing mean and variance of each feature.");
43 
44  this->AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
45  this->SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
46 
47  this->AddParameter(ParameterType_Int, "layer", "Layer Index");
48  this->SetParameterDescription("layer", "Index of the layer to use in the input vector file.");
49  this->MandatoryOff("layer");
50  this->SetDefaultParameterInt("layer", 0);
51 
52  this->AddParameter(ParameterType_Field, "feat", "Field names for training features");
53  this->SetParameterDescription("feat", "List of field names in the input vector data to be used as features for training.");
54  this->SetVectorData("feat", "io.vd");
55  this->SetTypeFilter("feat", { OFTInteger, OFTInteger64, OFTReal });
56 
57  // Add validation data used to compute confusion matrix or contingency table
58  this->AddParameter(ParameterType_Group, "valid", "Validation data");
59  this->SetParameterDescription("valid", "This group of parameters defines validation data.");
60 
61  this->AddParameter(ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data");
62  this->SetParameterDescription("valid.vd",
63  "Geometries used for validation "
64  "(must contain the same fields used for training, all geometries from the layer will be used)");
65  this->MandatoryOff("valid.vd");
66 
67  this->AddParameter(ParameterType_Int, "valid.layer", "Layer Index");
68  this->SetParameterDescription("valid.layer", "Index of the layer to use in the validation vector file.");
69  this->MandatoryOff("valid.layer");
70  this->SetDefaultParameterInt("valid.layer", 0);
71 
72  // Add class field if we used validation
73  this->AddParameter(ParameterType_Field, "cfield", "Field containing the class integer label for supervision");
74  this->SetParameterDescription("cfield",
75  "Field containing the class id for supervision. "
76  "The values in this field shall be cast into integers. "
77  "Only geometries with this field available will be taken into account.");
78  this->SetVectorData("cfield", "io.vd");
79  this->SetTypeFilter("cfield", { OFTString, OFTInteger, OFTInteger64, OFTReal });
80  this->SetListViewSingleSelectionMode("cfield", true);
81 
82  this->AddParameter(ParameterType_Bool, "v", "Verbose mode");
83  this->SetParameterDescription("v", "Verbose mode, display the contingency table result.");
84  this->SetParameterInt("v", 1);
85 
86  // Doc example parameter settings
87  this->SetDocExampleParameterValue("io.vd", "vectorData.shp");
88  this->SetDocExampleParameterValue("io.stats", "meanVar.xml");
89  this->SetDocExampleParameterValue("io.out", "svmModel.svm");
90  this->SetDocExampleParameterValue("feat", "perimeter area width");
91  this->SetDocExampleParameterValue("cfield", "predicted");
92 
93 
94  // Add parameters for the classifier choice
95  Superclass::DoInit();
96 
97  this->AddRANDParameter();
98 }
99 
100 template <class TInputValue, class TOutputValue>
102 {
103  // if vector data is present and updated then reload fields
104  if (this->HasValue("io.vd"))
105  {
106  std::vector<std::string> vectorFileList = this->GetParameterStringList("io.vd");
108  ogr::Layer layer = ogrDS->GetLayer(static_cast<size_t>(this->GetParameterInt("layer")));
109  ogr::Feature feature = layer.ogr().GetNextFeature();
110 
111  this->ClearChoices("feat");
112  this->ClearChoices("cfield");
113 
114  FieldParameter::TypeFilterType featTypeFilter = this->GetTypeFilter("feat");
115  FieldParameter::TypeFilterType cfieldTypeFilter = this->GetTypeFilter("cfield");
116  for (int iField = 0; iField < feature.ogr().GetFieldCount(); iField++)
117  {
118  std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
119  key = item;
120  std::string::iterator end = std::remove_if(key.begin(), key.end(), [](char c) { return !std::isalnum(c); });
121  std::transform(key.begin(), end, key.begin(), tolower);
122 
123  OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
124 
125  if (featTypeFilter.empty() || std::find(featTypeFilter.begin(), featTypeFilter.end(), fieldType) != std::end(featTypeFilter))
126  {
127  std::string tmpKey = "feat." + key.substr(0, static_cast<unsigned long>(end - key.begin()));
128  this->AddChoice(tmpKey, item);
129  }
130  if (cfieldTypeFilter.empty() || std::find(cfieldTypeFilter.begin(), cfieldTypeFilter.end(), fieldType) != std::end(cfieldTypeFilter))
131  {
132  std::string tmpKey = "cfield." + key.substr(0, static_cast<unsigned long>(end - key.begin()));
133  this->AddChoice(tmpKey, item);
134  }
135  }
136  }
137 }
138 
139 template <class TInputValue, class TOutputValue>
141 {
142  m_FeaturesInfo.SetFieldNames(this->GetChoiceNames("feat"), this->GetSelectedItems("feat"));
143 
144  // Check input parameters
145  if (m_FeaturesInfo.m_SelectedIdx.empty())
146  {
147  otbAppLogFATAL(<< "No features have been selected to train the classifier on!");
148  }
149 
150  ShiftScaleParameters measurement = GetStatistics(m_FeaturesInfo.m_NbFeatures);
151  ExtractAllSamples(measurement);
152 
153  this->Train(m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, this->GetParameterString("io.out"));
154 
155  m_PredictedList = this->Classify(m_ClassificationSamplesWithLabel.listSample, this->GetParameterString("io.out"));
156 }
157 
158 template <class TInputValue, class TOutputValue>
160 {
161  m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
162  m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
163 }
164 
165 template <class TInputValue, class TOutputValue>
168 {
169  return ExtractSamplesWithLabel("io.vd", "layer", measurement);
170 }
171 
172 template <class TInputValue, class TOutputValue>
175 {
176  if (this->GetClassifierCategory() == Superclass::Supervised)
177  {
178  SamplesWithLabel tmpSamplesWithLabel;
179  SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel("valid.vd", "valid.layer", measurement);
180  // Test the input validation set size
181  if (validationSamplesWithLabel.labeledListSample->Size() != 0)
182  {
183  tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample;
184  tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample;
185  }
186  else
187  {
188  otbAppLogWARNING("The validation set is empty. The performance estimation is done using the input training set in this case.");
189  tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample;
190  tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
191  }
192 
193  return tmpSamplesWithLabel;
194  }
195  else
196  {
197  return m_TrainingSamplesWithLabel;
198  }
199 }
200 
201 template <class TInputValue, class TOutputValue>
203 {
205  if (this->HasValue("io.stats") && this->IsParameterEnabled("io.stats"))
206  {
207  typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
208  std::string XMLfile = this->GetParameterString("io.stats");
209  statisticsReader->SetFileName(XMLfile);
210  measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
211  measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
212  }
213  else
214  {
215  measurement.meanMeasurementVector.SetSize(nbFeatures);
216  measurement.meanMeasurementVector.Fill(0.);
217  measurement.stddevMeasurementVector.SetSize(nbFeatures);
218  measurement.stddevMeasurementVector.Fill(1.);
219  }
220  return measurement;
221 }
222 
223 template <class TInputValue, class TOutputValue>
225 TrainVectorBase<TInputValue, TOutputValue>::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer,
226  const ShiftScaleParameters& measurement)
227 {
228  SamplesWithLabel samplesWithLabel;
229  if (this->HasValue(parameterName) && this->IsParameterEnabled(parameterName))
230  {
231  typename ListSampleType::Pointer input = ListSampleType::New();
232  typename TargetListSampleType::Pointer target = TargetListSampleType::New();
233  input->SetMeasurementVectorSize(m_FeaturesInfo.m_NbFeatures);
234 
235  std::vector<std::string> fileList = this->GetParameterStringList(parameterName);
236  for (unsigned int k = 0; k < fileList.size(); k++)
237  {
238  otbAppLogINFO("Reading vector file " << k + 1 << "/" << fileList.size());
240  ogr::Layer layer = source->GetLayer(static_cast<size_t>(this->GetParameterInt(parameterLayer)));
241  ogr::Feature feature = layer.ogr().GetNextFeature();
242  bool goesOn = feature.addr() != 0;
243  if (!goesOn)
244  {
245  otbAppLogWARNING("The layer " << this->GetParameterInt(parameterLayer) << " of " << fileList[k] << " is empty, input is skipped.");
246  continue;
247  }
248 
249  // Check all needed fields are present :
250  // - check class field if we use supervised classification or if class field name is not empty
251  int cFieldIndex = feature.ogr().GetFieldIndex(m_FeaturesInfo.m_SelectedCFieldName.c_str());
252  if (cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
253  {
254  otbAppLogFATAL("The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName << ") has not been found in the vector file " << fileList[k]);
255  }
256 
257  // - check feature fields
258  std::vector<int> featureFieldIndex(m_FeaturesInfo.m_NbFeatures, -1);
259  for (unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++)
260  {
261  featureFieldIndex[i] = feature.ogr().GetFieldIndex(m_FeaturesInfo.m_SelectedNames[i].c_str());
262  if (featureFieldIndex[i] < 0)
263  otbAppLogFATAL("The field name for feature " << m_FeaturesInfo.m_SelectedNames[i] << " has not been found in the vector file " << fileList[k]);
264  }
265 
266 
267  while (goesOn)
268  {
269  // Retrieve all the features for each field in the ogr layer.
270  MeasurementType mv;
271  mv.SetSize(m_FeaturesInfo.m_NbFeatures);
272  for (unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx)
273  {
274  switch (feature[featureFieldIndex[idx]].GetType())
275  {
276  case OFTInteger:
277  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
278  break;
279  case OFTInteger64:
280  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
281  break;
282  case OFTReal:
283  mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<double>());
284  break;
285  default:
286  itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[idx]].GetType() << ".");
287  }
288  }
289 
290  input->PushBack(mv);
291 
292  if (cFieldIndex >= 0 && ogr::Field(feature, cFieldIndex).HasBeenSet())
293  {
294  switch (feature[cFieldIndex].GetType())
295  {
296  case OFTInteger:
297  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
298  break;
299  case OFTInteger64:
300  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
301  break;
302  case OFTReal:
303  target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<double>()));
304  break;
305  case OFTString:
306  target->PushBack(static_cast<ValueType>(std::stod(feature[cFieldIndex].GetValue<std::string>())));
307  break;
308  default:
309  itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[cFieldIndex]].GetType() << ".");
310  }
311  }
312  else
313  target->PushBack(0.);
314 
315  feature = layer.ogr().GetNextFeature();
316  goesOn = feature.addr() != 0;
317  }
318  }
319 
320 
321  typename ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
322  shiftScaleFilter->SetInput(input);
323  shiftScaleFilter->SetShifts(measurement.meanMeasurementVector);
324  shiftScaleFilter->SetScales(measurement.stddevMeasurementVector);
325  shiftScaleFilter->Update();
326 
327  samplesWithLabel.listSample = shiftScaleFilter->GetOutput();
328  samplesWithLabel.labeledListSample = target;
329  samplesWithLabel.listSample->DisconnectPipeline();
330  }
331 
332  return samplesWithLabel;
333 }
334 }
335 }
336 
337 #endif
otb::ogr::Feature::addr
OGRFeature const * addr() const
Definition: otbOGRFeatureWrapper.h:338
otbAppLogWARNING
#define otbAppLogWARNING(x)
Definition: otbWrapperMacros.h:45
otbAppLogFATAL
#define otbAppLogFATAL(x)
Definition: otbWrapperMacros.h:25
otb::Wrapper::ParameterType_Bool
@ ParameterType_Bool
Definition: otbWrapperTypes.h:60
otb::find
string_view find(string_view const &haystack, string_view const &needle)
Definition: otbStringUtilities.h:305
otb::Wrapper::ParameterType_OutputFilename
@ ParameterType_OutputFilename
Definition: otbWrapperTypes.h:45
otb::Wrapper::TrainVectorBase::ShiftScaleParameters::stddevMeasurementVector
MeasurementType stddevMeasurementVector
Definition: otbTrainVectorBase.h:76
otb::Wrapper::FieldParameter::TypeFilterType
std::vector< OGRFieldType > TypeFilterType
Definition: otbWrapperFieldParameter.h:51
otb::Wrapper::TrainVectorBase::DoUpdateParameters
void DoUpdateParameters() override
Definition: otbTrainVectorBase.hxx:101
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::Wrapper::ParameterType_Field
@ ParameterType_Field
Definition: otbWrapperTypes.h:61
otb::Wrapper::TrainVectorBase::ExtractSamplesWithLabel
SamplesWithLabel ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement)
Definition: otbTrainVectorBase.hxx:225
otb::Wrapper::TrainVectorBase::SamplesWithLabel::listSample
ListSampleType::Pointer listSample
Definition: otbTrainVectorBase.h:83
otb::Wrapper::TrainVectorBase::ShiftScaleParameters
Definition: otbTrainVectorBase.h:72
otb::Wrapper::ParameterType_Group
@ ParameterType_Group
Definition: otbWrapperTypes.h:55
otb::Wrapper::TrainVectorBase::ValueType
double ValueType
Definition: otbTrainVectorBase.h:63
otb::Wrapper::ParameterType_Int
@ ParameterType_Int
Definition: otbWrapperTypes.h:38
otb::Wrapper::TrainVectorBase::ExtractAllSamples
virtual void ExtractAllSamples(const ShiftScaleParameters &measurement)
Definition: otbTrainVectorBase.hxx:159
otb::Wrapper::TrainVectorBase::SamplesWithLabel
Definition: otbTrainVectorBase.h:80
otbTrainVectorBase.h
otbAppLogINFO
#define otbAppLogINFO(x)
Definition: otbWrapperMacros.h:52
otb::ogr::Layer::ogr
OGRLayer & ogr()
otb::Wrapper::TrainVectorBase::GetStatistics
ShiftScaleParameters GetStatistics(unsigned int nbFeatures)
Definition: otbTrainVectorBase.hxx:202
otb::Statistics::ShiftScaleSampleListFilter::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbShiftScaleSampleListFilter.h:57
otb::ogr::Field
Encapsulation of OGRField Instances of Field are expected to be built from an existing Feature with w...
Definition: otbOGRFieldWrapper.h:117
otb::Wrapper::TrainVectorBase::DoExecute
void DoExecute() override
Definition: otbTrainVectorBase.hxx:140
otb::Wrapper::ParameterType_InputVectorDataList
@ ParameterType_InputVectorDataList
Definition: otbWrapperTypes.h:51
otb::Wrapper::TrainVectorBase::ExtractClassificationSamplesWithLabel
virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)
Definition: otbTrainVectorBase.hxx:174
otb::ogr::DataSource::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbOGRDataSourceWrapper.h:90
otb::Wrapper::TrainVectorBase::DoInit
void DoInit() override
Definition: otbTrainVectorBase.hxx:31
otb::ogr::DataSource::Modes::Read
@ Read
Open data source in read-only mode.
Definition: otbOGRDataSourceWrapper.h:126
otb::ogr::Feature
Geometric object with descriptive fields.
Definition: otbOGRFeatureWrapper.h:63
otb::StatisticsXMLFileReader::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbStatisticsXMLFileReader.h:48
otb::Wrapper::TrainVectorBase::ExtractTrainingSamplesWithLabel
virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
Definition: otbTrainVectorBase.hxx:167
otb::ogr::Layer
Layer of geometric objects.
Definition: otbOGRLayerWrapper.h:80
otb::Wrapper::MetaDataHelper::GetType
OTBApplicationEngine_EXPORT MDType GetType(const std::string &val)
otb::Wrapper::TrainVectorBase::ShiftScaleParameters::meanMeasurementVector
MeasurementType meanMeasurementVector
Definition: otbTrainVectorBase.h:75
otb::ogr::Feature::ogr
OGRFeature & ogr() const
Definition: otbOGRFeatureWrapper.hxx:158
otb::ogr::DataSource::New
static Pointer New()
otb::Wrapper::TrainVectorBase::MeasurementType
itk::VariableLengthVector< ValueType > MeasurementType
Definition: otbTrainVectorBase.h:64
otb::Wrapper::ParameterType_InputFilename
@ ParameterType_InputFilename
Definition: otbWrapperTypes.h:43
otb::Wrapper::TrainVectorBase::SamplesWithLabel::labeledListSample
TargetListSampleType::Pointer labeledListSample
Definition: otbTrainVectorBase.h:84