OTB  9.0.0
Orfeo Toolbox
otbVectorPrediction.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbVectorPrediction_hxx
22 #define otbVectorPrediction_hxx
23 
24 #include "otbVectorPrediction.h"
25 
26 namespace otb
27 {
28 namespace Wrapper
29 {
30 
31 template <bool RegressionMode>
33 {
34  DoInitSpecialization();
35 
36  // Assert that the all needed parameters have ben defined in DoInitSpecialization
37  assert(GetParameterByKey("in") != nullptr);
38  assert(GetParameterByKey("instat") != nullptr);
39  assert(GetParameterByKey("model") != nullptr);
40  assert(GetParameterByKey("cfield") != nullptr);
41  assert(GetParameterByKey("feat") != nullptr);
42  assert(GetParameterByKey("out") != nullptr);
43 }
44 
45 template <bool RegressionMode>
47 {
48  if (HasValue("in"))
49  {
50  auto shapefileName = GetParameterString("in");
51 
53  auto layer = ogrDS->GetLayer(0);
54  OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
55 
56  ClearChoices("feat");
57 
58  FieldParameter::TypeFilterType typeFilter = GetTypeFilter("feat");
59  for (int iField = 0; iField < layerDefn.GetFieldCount(); iField++)
60  {
61  auto fieldDefn = layerDefn.GetFieldDefn(iField);
62  std::string item = fieldDefn->GetNameRef();
63  std::string key(item);
64  key.erase(std::remove_if(key.begin(), key.end(), [](char c) { return !std::isalnum(c); }), key.end());
65  std::transform(key.begin(), key.end(), key.begin(), tolower);
66  auto fieldType = fieldDefn->GetType();
67 
68  if (typeFilter.empty() || std::find(typeFilter.begin(), typeFilter.end(), fieldType) != std::end(typeFilter))
69  {
70  std::string tmpKey = "feat." + key;
71  AddChoice(tmpKey, item);
72  }
73  }
74  }
75 }
76 
77 template <bool RegressionMode>
81 {
82  typename ListSampleType::Pointer input = ListSampleType::New();
83 
84  const auto nbFeatures = GetSelectedItems("feat").size();
85  input->SetMeasurementVectorSize(nbFeatures);
86  std::vector<int> featureFieldIndex(nbFeatures, -1);
87 
88  ogr::Layer::const_iterator it_feat = layer.cbegin();
89  for (unsigned int i = 0; i < nbFeatures; i++)
90  {
91  try
92  {
93  featureFieldIndex[i] = (*it_feat).GetFieldIndex(GetChoiceNames("feat")[GetSelectedItems("feat")[i]]);
94  }
95  catch(...)
96  {
97  otbAppLogFATAL("The field name for feature " << GetChoiceNames("feat")[GetSelectedItems("feat")[i]] << " has not been found" << std::endl);
98  }
99  }
100 
101  for (auto const& feature : layer)
102  {
103  MeasurementType mv(nbFeatures);
104  for (unsigned int idx = 0; idx < nbFeatures; ++idx)
105  {
106  auto field = feature[featureFieldIndex[idx]];
107  switch (field.GetType())
108  {
109  case OFTInteger:
110  case OFTInteger64:
111  mv[idx] = static_cast<ValueType>(field.template GetValue<int>());
112  break;
113  case OFTReal:
114  mv[idx] = static_cast<ValueType>(field.template GetValue<double>());
115  break;
116  default:
117  itkExceptionMacro(<< "incorrect field type: " << field.GetType() << ".");
118  }
119  }
120  input->PushBack(mv);
121  }
122  return input;
123 }
124 
125 
126 template <bool RegressionMode>
128 {
129  const int nbFeatures = GetSelectedItems("feat").size();
130 
131  // Statistics for shift/scale
132  MeasurementType meanMeasurementVector;
133  MeasurementType stddevMeasurementVector;
134  if (HasValue("instat") && IsParameterEnabled("instat"))
135  {
136  typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
137  std::string XMLfile = GetParameterString("instat");
138  statisticsReader->SetFileName(XMLfile);
139  meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
140  stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
141  }
142  else
143  {
144  meanMeasurementVector.SetSize(nbFeatures);
145  meanMeasurementVector.Fill(0.);
146  stddevMeasurementVector.SetSize(nbFeatures);
147  stddevMeasurementVector.Fill(1.);
148  }
149 
150  typename ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
151  trainingShiftScaleFilter->SetInput(input);
152  trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
153  trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
154  trainingShiftScaleFilter->Update();
155  otbAppLogINFO("mean used: " << meanMeasurementVector);
156  otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
157 
158  otbAppLogINFO("Loading model");
159 
160  return trainingShiftScaleFilter->GetOutput();
161 }
162 
163 
164 template <bool RegressionMode>
167 {
169  // Update mode
170  otbAppLogINFO("Update input vector data.");
171  // fill temporary buffer for the transfer
172  otb::ogr::Layer inputLayer = layer;
173  layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
174  // close input data source
175  source->Clear();
176  // Re-open input data source in update mode
178  return output;
179 }
180 
181 template <bool RegressionMode>
183 {
185  // Create new OGRDataSource
186  output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
187  otb::ogr::Layer newLayer = output->CreateLayer(GetParameterString("out"), const_cast<OGRSpatialReference*>(layer.GetSpatialRef()), layer.GetGeomType());
188  // Copy existing fields
189  OGRFeatureDefn& inLayerDefn = layer.GetLayerDefn();
190  for (int k = 0; k < inLayerDefn.GetFieldCount(); k++)
191  {
192  OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
193  newLayer.CreateField(fieldDefn);
194  }
195  return output;
196 }
197 
198 template <bool RegressionMode>
199 void VectorPrediction<RegressionMode>::AddPredictionField(otb::ogr::Layer& outLayer, otb::ogr::Layer const& layer, bool computeConfidenceMap)
200 {
201  OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
202 
203  const OGRFieldType labelType = RegressionMode ? OFTReal : OFTInteger;
204 
205  int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
206  if (idx >= 0)
207  {
208  if (layerDefn.GetFieldDefn(idx)->GetType() != labelType)
209  itkExceptionMacro("Field name " << GetParameterString("cfield") << " already exists with a different type!");
210  }
211  else
212  {
213  OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), labelType);
214  ogr::FieldDefn predictedFieldDef(predictedField);
215  outLayer.CreateField(predictedFieldDef);
216  }
217 
218  // Add confidence field in the output layer
219  if (computeConfidenceMap)
220  {
221  idx = layerDefn.GetFieldIndex(confFieldName.c_str());
222  if (idx >= 0)
223  {
224  if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
225  itkExceptionMacro("Field name " << confFieldName << " already exists with a different type!");
226  }
227  else
228  {
229  OGRFieldDefn confidenceField(confFieldName.c_str(), OFTReal);
230  confidenceField.SetWidth(confidenceField.GetWidth());
231  confidenceField.SetPrecision(confidenceField.GetPrecision());
232  ogr::FieldDefn confFieldDefn(confidenceField);
233  outLayer.CreateField(confFieldDefn);
234  }
235  }
236 }
237 
238 template <bool RegressionMode>
239 void VectorPrediction<RegressionMode>::FillOutputLayer(otb::ogr::Layer& outLayer, otb::ogr::Layer const& layer, typename LabelListSampleType::Pointer target,
240  typename ConfidenceListSampleType::Pointer quality, bool updateMode, bool computeConfidenceMap)
241 {
242  unsigned int count = 0;
243  std::string classfieldname = GetParameterString("cfield");
244  for (auto const& feature : layer)
245  {
246  ogr::Feature dstFeature(outLayer.GetLayerDefn());
247  dstFeature.SetFrom(feature, TRUE);
248  dstFeature.SetFID(feature.GetFID());
249  auto field = dstFeature[classfieldname];
250  switch (field.GetType())
251  {
252  case OFTInteger64:
253  case OFTInteger:
254  field.template SetValue<int>(target->GetMeasurementVector(count)[0]);
255  break;
256  case OFTReal:
257  field.template SetValue<double>(target->GetMeasurementVector(count)[0]);
258  break;
259  case OFTString:
260  field.template SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
261  break;
262  default:
263  itkExceptionMacro(<< "incorrect field type: " << field.GetType() << ".");
264  }
265  if (computeConfidenceMap)
266  dstFeature[confFieldName].template SetValue<double>(quality->GetMeasurementVector(count)[0]);
267  if (updateMode)
268  {
269  outLayer.SetFeature(dstFeature);
270  }
271  else
272  {
273  outLayer.CreateFeature(dstFeature);
274  }
275  count++;
276  }
277 }
278 
279 template <bool RegressionMode>
281 {
282  m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode);
283 
284  if (m_Model.IsNull())
285  {
286  otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
287  }
288 
289  m_Model->SetRegressionMode(RegressionMode);
290 
291  m_Model->Load(GetParameterString("model"));
292  otbAppLogINFO("Model loaded");
293 
294  auto shapefileName = GetParameterString("in");
295 
297  auto layer = source->GetLayer(0);
298  auto input = ReadInputListSample(layer);
299 
300  ListSampleType::Pointer listSample = NormalizeListSample(input);
301  typename LabelListSampleType::Pointer target;
302 
303  // The quality listSample containing confidence values is defined here, but is only used when
304  // computeConfidenceMap evaluates to true. This listSample is also used in FillOutputLayer(...)
305  const bool computeConfidenceMap = shouldComputeConfidenceMap();
306  typename ConfidenceListSampleType::Pointer quality;
307 
308  if (computeConfidenceMap)
309  {
310  quality = ConfidenceListSampleType::New();
311  target = m_Model->PredictBatch(listSample, quality);
312  }
313  else
314  {
315  target = m_Model->PredictBatch(listSample);
316  }
317 
318  const bool updateMode = !(IsParameterEnabled("out") && HasValue("out"));
319 
322 
323  if (updateMode)
324  {
325  // in update mode, output is added to input data source.
326  // buffer needs to be allocated here, as its life-cycle is bound to "layer"
327  buffer = ogr::DataSource::New();
328  output = ReopenDataSourceInUpdateMode(source, layer, buffer);
329  }
330  else
331  {
332  output = CreateOutputDataSource(layer);
333  }
334 
335  otb::ogr::Layer outLayer = output->GetLayer(0);
336 
337  OGRErr errStart = outLayer.ogr().StartTransaction();
338  if (errStart != OGRERR_NONE)
339  {
340  itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
341  }
342 
343  AddPredictionField(outLayer, layer, computeConfidenceMap);
344  FillOutputLayer(outLayer, layer, target, quality, updateMode, computeConfidenceMap);
345 
346  if (outLayer.ogr().TestCapability("Transactions"))
347  {
348  const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
349  if (errCommitX != OGRERR_NONE)
350  {
351  itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
352  }
353  }
354 
355  output->SyncToDisk();
356 }
357 
358 } // end namespace wrapper
359 } // end namespace otb
360 
361 #endif
otb::ogr::Layer::SetFeature
void SetFeature(Feature feature)
otb::ogr::FieldDefn
Encapsulation of OGRFieldDefn: field definition.
Definition: otbOGRFieldWrapper.h:60
otb::ogr::Layer::GetSpatialRef
OGRSpatialReference const * GetSpatialRef() const
otb::ogr::Layer::GetGeomType
OGRwkbGeometryType GetGeomType() const
otb::Wrapper::VectorPrediction::ReadInputListSample
ListSampleType::Pointer ReadInputListSample(otb::ogr::Layer const &layer)
Definition: otbVectorPrediction.hxx:80
otb::ogr::Layer::cbegin
const_iterator cbegin() const
otbAppLogFATAL
#define otbAppLogFATAL(x)
Definition: otbWrapperMacros.h:25
otb::find
string_view find(string_view const &haystack, string_view const &needle)
Definition: otbStringUtilities.h:305
otb::Wrapper::VectorPrediction::DoInit
void DoInit() override
Definition: otbVectorPrediction.hxx:32
otb::Wrapper::VectorPrediction
Definition: otbVectorPrediction.h:48
otb::Wrapper::VectorPrediction::FillOutputLayer
void FillOutputLayer(otb::ogr::Layer &outLayer, otb::ogr::Layer const &layer, typename LabelListSampleType::Pointer target, typename ConfidenceListSampleType::Pointer quality, bool updateMode, bool computeConfidenceMap)
Definition: otbVectorPrediction.hxx:239
otb::ogr::Feature::SetFrom
void SetFrom(Feature const &rhs, int *map, bool mustForgive=true)
Definition: otbOGRFeatureWrapper.hxx:51
otb::Wrapper::FieldParameter::TypeFilterType
std::vector< OGRFieldType > TypeFilterType
Definition: otbWrapperFieldParameter.h:51
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otbVectorPrediction.h
otb::ogr::Layer::CreateFeature
void CreateFeature(Feature feature)
otb::Wrapper::VectorPrediction::DoUpdateParameters
void DoUpdateParameters() override
Definition: otbVectorPrediction.hxx:46
otb::Wrapper::VectorPrediction::DoExecute
void DoExecute() override
Definition: otbVectorPrediction.hxx:280
otb::Wrapper::VectorPrediction::AddPredictionField
void AddPredictionField(otb::ogr::Layer &outLayer, otb::ogr::Layer const &layer, bool computeConfidenceMap)
Definition: otbVectorPrediction.hxx:199
otbAppLogINFO
#define otbAppLogINFO(x)
Definition: otbWrapperMacros.h:52
otb::ogr::DataSource::Modes::Overwrite
@ Overwrite
Definition: otbOGRDataSourceWrapper.h:127
otb::Wrapper::VectorPrediction::ReopenDataSourceInUpdateMode
otb::ogr::DataSource::Pointer ReopenDataSourceInUpdateMode(ogr::DataSource::Pointer source, ogr::Layer &layer, ogr::DataSource::Pointer buffer)
Definition: otbVectorPrediction.hxx:165
otb::ogr::Layer::ogr
OGRLayer & ogr()
otb::ogr::Layer::feature_iter
Implementation class for Feature iterator. This iterator is a single pass iterator....
Definition: otbOGRLayerWrapper.h:348
otb::Statistics::ShiftScaleSampleListFilter::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbShiftScaleSampleListFilter.h:57
otb::ogr::Layer::CreateField
void CreateField(FieldDefn const &field, bool bApproxOK=true)
otb::ogr::DataSource::Modes::Update_LayerUpdate
@ Update_LayerUpdate
Definition: otbOGRDataSourceWrapper.h:138
otb::ogr::DataSource::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbOGRDataSourceWrapper.h:90
otb::Wrapper::VectorPrediction::NormalizeListSample
ListSampleType::Pointer NormalizeListSample(ListSampleType::Pointer input)
Definition: otbVectorPrediction.hxx:127
otb::ogr::DataSource::Modes::Read
@ Read
Open data source in read-only mode.
Definition: otbOGRDataSourceWrapper.h:126
otb::ogr::Feature::SetFID
void SetFID(long fid)
Definition: otbOGRFeatureWrapper.hxx:109
otb::Wrapper::VectorPrediction::MeasurementType
itk::VariableLengthVector< ValueType > MeasurementType
Definition: otbVectorPrediction.h:76
otb::ogr::Feature
Geometric object with descriptive fields.
Definition: otbOGRFeatureWrapper.h:63
otb::Wrapper::VectorPrediction::ValueType
float ValueType
Definition: otbVectorPrediction.h:63
otb::StatisticsXMLFileReader::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbStatisticsXMLFileReader.h:48
otb::ogr::Layer
Layer of geometric objects.
Definition: otbOGRLayerWrapper.h:80
otb::Wrapper::VectorPrediction::CreateOutputDataSource
otb::ogr::DataSource::Pointer CreateOutputDataSource(ogr::Layer &layer)
Definition: otbVectorPrediction.hxx:182
otb::ogr::DataSource::New
static Pointer New()
otb::ogr::Layer::GetLayerDefn
OGRFeatureDefn & GetLayerDefn() const