OTB  9.0.0
Orfeo Toolbox
otbListSampleGenerator.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbListSampleGenerator_hxx
22 #define otbListSampleGenerator_hxx
23 
24 #include "otbListSampleGenerator.h"
25 
26 #include "itkImageRegionConstIteratorWithIndex.h"
28 
29 #include "otbMacro.h"
30 
31 namespace otb
32 {
33 
34 /*template <class TVectorData>
35 void printVectorData(TVectorData * vectorData, string msg = "")
36 {
37  typedef TVectorData VectorDataType;
38  typedef itk::PreOrderTreeIterator<typename VectorDataType::DataTreeType> TreeIteratorType;
39 
40  TreeIteratorType itVector(vectorData->GetDataTree());
41  itVector.GoToBegin();
42 
43  if (!msg.empty())
44  {
45  std::cout<< msg << std::endl;
46  }
47 
48  while (!itVector.IsAtEnd())
49  {
50  if (itVector.Get()->IsPolygonFeature())
51  {
52  std::cout << itVector.Get()->GetNodeTypeAsString() << std::endl;
53  for (unsigned int itPoints = 0; itPoints < itVector.Get()->GetPolygonExteriorRing()->GetVertexList()->Size(); itPoints++)
54  {
55  std::cout << "vertex[" << itPoints << "]: " << itVector.Get()->GetPolygonExteriorRing()->GetVertexList()->GetElement(itPoints) <<std::endl;
56  }
57  std::cout << "Polygon bounding region:\n" << itVector.Get()->GetPolygonExteriorRing()->GetBoundingRegion() << std::endl;
58  }
59  ++itVector;
60  }
61 }*/
62 
63 template <class TImage, class TVectorData>
65  : m_MaxTrainingSize(-1),
66  m_MaxValidationSize(-1),
67  m_ValidationTrainingProportion(0.0),
68  m_BoundByMin(true),
69  m_PolygonEdgeInclusion(false),
70  m_NumberOfClasses(0),
71  m_ClassKey("Class"),
72  m_ClassMinSize(-1)
73 {
74  this->SetNumberOfRequiredInputs(2);
75  this->SetNumberOfRequiredOutputs(4);
76 
77  // Register the outputs
78  this->itk::ProcessObject::SetNthOutput(0, this->MakeOutput(0).GetPointer());
79  this->itk::ProcessObject::SetNthOutput(1, this->MakeOutput(1).GetPointer());
80  this->itk::ProcessObject::SetNthOutput(2, this->MakeOutput(2).GetPointer());
81  this->itk::ProcessObject::SetNthOutput(3, this->MakeOutput(3).GetPointer());
82 
83  m_RandomGenerator = RandomGeneratorType::GetInstance();
84 }
85 
86 template <class TImage, class TVectorData>
88 {
89  this->ProcessObject::SetNthInput(0, const_cast<ImageType*>(image));
90 }
91 
92 template <class TImage, class TVectorData>
94 {
95  if (this->GetNumberOfInputs() < 1)
96  {
97  return nullptr;
98  }
99 
100  return static_cast<const ImageType*>(this->ProcessObject::GetInput(0));
101 }
102 
103 template <class TImage, class TVectorData>
105 {
106  this->ProcessObject::SetNthInput(1, const_cast<VectorDataType*>(vectorData));
107 
108  // printVectorData(vectorData);
109 }
110 
111 template <class TImage, class TVectorData>
113 {
114  if (this->GetNumberOfInputs() < 2)
115  {
116  return nullptr;
117  }
118 
119  return static_cast<const VectorDataType*>(this->ProcessObject::GetInput(1));
120 }
121 
122 template <class TImage, class TVectorData>
124 {
125  DataObjectPointer output;
126  switch (idx)
127  {
128  case 0:
129  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
130  break;
131  case 1:
132  output = static_cast<itk::DataObject*>(ListLabelType::New().GetPointer());
133  break;
134  case 2:
135  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
136  break;
137  case 3:
138  output = static_cast<itk::DataObject*>(ListLabelType::New().GetPointer());
139  break;
140  default:
141  output = static_cast<itk::DataObject*>(ListSampleType::New().GetPointer());
142  break;
143  }
144  return output;
145 }
146 // Get the Training ListSample
147 template <class TImage, class TVectorData>
149 {
150  return dynamic_cast<ListSampleType*>(this->itk::ProcessObject::GetOutput(0));
151 }
152 // Get the Training label ListSample
153 template <class TImage, class TVectorData>
155 {
156  return dynamic_cast<ListLabelType*>(this->itk::ProcessObject::GetOutput(1));
157 }
158 
159 // Get the validation ListSample
160 template <class TImage, class TVectorData>
162 {
163  return dynamic_cast<ListSampleType*>(this->itk::ProcessObject::GetOutput(2));
164 }
165 
166 
167 // Get the validation label ListSample
168 template <class TImage, class TVectorData>
170 {
171  return dynamic_cast<ListLabelType*>(this->itk::ProcessObject::GetOutput(3));
172 }
173 
174 template <class TImage, class TVectorData>
176 {
177  ImagePointerType img = static_cast<ImageType*>(this->ProcessObject::GetInput(0));
178 
179  if (img.IsNotNull())
180  {
181 
182  // Requested regions will be generated during GenerateData
183  // call. For now request an empty region so as to avoid requesting
184  // the largest possible region (fixes bug #943 )
185  typename ImageType::RegionType dummyRegion;
186  typename ImageType::SizeType dummySize;
187  dummySize.Fill(0);
188  dummyRegion.SetSize(dummySize);
189  img->SetRequestedRegion(dummyRegion);
190  }
191 }
192 
193 
194 template <class TImage, class TVectorData>
196 {
197  // Get the inputs
198  ImagePointerType image = const_cast<ImageType*>(this->GetInput());
199  VectorDataPointerType vectorData = const_cast<VectorDataType*>(this->GetInputVectorData());
200 
201  // Get the outputs
202  ListSamplePointerType trainingListSample = this->GetTrainingListSample();
203  ListLabelPointerType trainingListLabel = this->GetTrainingListLabel();
204  ListSamplePointerType validationListSample = this->GetValidationListSample();
205  ListLabelPointerType validationListLabel = this->GetValidationListLabel();
206 
207  // Gather some information about the relative size of the classes
208  // We would like to have the same number of samples per class
209  this->GenerateClassStatistics();
210 
211  this->ComputeClassSelectionProbability();
212 
213  // Clear the sample lists
214  trainingListSample->Clear();
215  trainingListLabel->Clear();
216  validationListSample->Clear();
217  validationListLabel->Clear();
218 
219  // Set MeasurementVectorSize for each sample list
220  trainingListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
221  // stores label as integers,so put the size to 1
222  trainingListLabel->SetMeasurementVectorSize(1);
223  validationListSample->SetMeasurementVectorSize(image->GetNumberOfComponentsPerPixel());
224  // stores label as integers,so put the size to 1
225  validationListLabel->SetMeasurementVectorSize(1);
226 
227  m_ClassesSamplesNumberTraining.clear();
228  m_ClassesSamplesNumberValidation.clear();
229 
230  typename ImageType::RegionType imageLargestRegion = image->GetLargestPossibleRegion();
231 
232  TreeIteratorType itVector(vectorData->GetDataTree());
233  for (itVector.GoToBegin(); !itVector.IsAtEnd(); ++itVector)
234  {
235  if (itVector.Get()->IsPolygonFeature())
236  {
237  PolygonPointerType exteriorRing = itVector.Get()->GetPolygonExteriorRing();
238 
239  typename ImageType::RegionType polygonRegion = otb::TransformPhysicalRegionToIndexRegion(exteriorRing->GetBoundingRegion(), image.GetPointer());
240 
241  const bool hasIntersection = polygonRegion.Crop(imageLargestRegion);
242  if (!hasIntersection)
243  {
244  continue;
245  }
246 
247  image->SetRequestedRegion(polygonRegion);
248  image->PropagateRequestedRegion();
249  image->UpdateOutputData();
250 
251  typedef itk::ImageRegionConstIteratorWithIndex<ImageType> IteratorType;
252  IteratorType it(image, polygonRegion);
253 
254  for (it.GoToBegin(); !it.IsAtEnd(); ++it)
255  {
256  itk::ContinuousIndex<double, 2> point;
257  image->TransformIndexToPhysicalPoint(it.GetIndex(), point);
258 
259  if (exteriorRing->IsInside(point) || (this->GetPolygonEdgeInclusion() && exteriorRing->IsOnEdge(point)))
260  {
261  PolygonListPointerType interiorRings = itVector.Get()->GetPolygonInteriorRings();
262 
263  bool isInsideInteriorRing = false;
264  for (typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
265  {
266  if (interiorRing.Get()->IsInside(point) || (this->GetPolygonEdgeInclusion() && interiorRing.Get()->IsOnEdge(point)))
267  {
268  isInsideInteriorRing = true;
269  break;
270  }
271  }
272  if (isInsideInteriorRing)
273  {
274  continue; // skip this pixel and continue
275  }
276 
277  double randomValue = m_RandomGenerator->GetUniformVariate(0.0, 1.0);
278  if (randomValue < m_ClassesProbTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)])
279  {
280  // Add the sample to the training list
281  trainingListSample->PushBack(it.Get());
282  trainingListLabel->PushBack(itVector.Get()->GetFieldAsInt(m_ClassKey));
283  m_ClassesSamplesNumberTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)] += 1;
284  }
285  else if (randomValue <
286  m_ClassesProbTraining[itVector.Get()->GetFieldAsInt(m_ClassKey)] + m_ClassesProbValidation[itVector.Get()->GetFieldAsInt(m_ClassKey)])
287  {
288  // Add the sample to the validation list
289  validationListSample->PushBack(it.Get());
290  validationListLabel->PushBack(itVector.Get()->GetFieldAsInt(m_ClassKey));
291  m_ClassesSamplesNumberValidation[itVector.Get()->GetFieldAsInt(m_ClassKey)] += 1;
292  }
293  // Note: some samples may not be used at all
294  }
295  }
296  }
297  }
298 
299  assert(trainingListSample->Size() == trainingListLabel->Size());
300  assert(validationListSample->Size() == validationListLabel->Size());
301  this->UpdateProgress(1.0f);
302 }
303 
304 template <class TImage, class TVectorData>
306 {
307  m_ClassesSize.clear();
308 
309  ImageType* image = const_cast<ImageType*>(this->GetInput());
310  typename VectorDataType::ConstPointer vectorData = this->GetInputVectorData();
311 
312  // Compute cumulative area of all polygons of each class
313  TreeIteratorType itVector(vectorData->GetDataTree());
314  for (itVector.GoToBegin(); !itVector.IsAtEnd(); ++itVector)
315  {
316  DataNodeType* datanode = itVector.Get();
317  if (datanode->IsPolygonFeature())
318  {
319  double area = GetPolygonAreaInPixelsUnits(datanode, image);
320  m_ClassesSize[datanode->GetFieldAsInt(m_ClassKey)] += area;
321  }
322  }
323  m_NumberOfClasses = m_ClassesSize.size();
324 }
325 
326 template <class TImage, class TVectorData>
328 {
329  m_ClassesProbTraining.clear();
330  m_ClassesProbValidation.clear();
331 
332  // Sanity check
333  if (m_ClassesSize.empty())
334  {
335  itkGenericExceptionMacro(<< "No training sample found inside image");
336  }
337 
338  // Go through the classes size to find the smallest one
339  double minSize = itk::NumericTraits<double>::max();
340  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
341  {
342  if (minSize > itmap->second)
343  {
344  minSize = itmap->second;
345  }
346  }
347 
348  // Apply the proportion between training and validation samples (all training by default)
349  double minSizeTraining = minSize * (1.0 - m_ValidationTrainingProportion);
350  double minSizeValidation = minSize * m_ValidationTrainingProportion;
351 
352  // Apply the limit if specified by the user
353  if (m_BoundByMin)
354  {
355  if ((m_MaxTrainingSize != -1) && (m_MaxTrainingSize < minSizeTraining))
356  {
357  minSizeTraining = m_MaxTrainingSize;
358  }
359  if ((m_MaxValidationSize != -1) && (m_MaxValidationSize < minSizeValidation))
360  {
361  minSizeValidation = m_MaxValidationSize;
362  }
363  }
364  // Compute the probability selection for each class
365  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
366  {
367  if (m_BoundByMin)
368  {
369  m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second;
370  m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second;
371  }
372  else
373  {
374  long int maxSizeT = (itmap->second) * (1.0 - m_ValidationTrainingProportion);
375  long int maxSizeV = (itmap->second) * m_ValidationTrainingProportion;
376 
377  // Check if max sizes respect the maximum bounds
378  double correctionRatioTrain = 1.0;
379  if ((m_MaxTrainingSize > -1) && (m_MaxTrainingSize < maxSizeT))
380  {
381  correctionRatioTrain = (double)(m_MaxTrainingSize) / (double)(maxSizeT);
382  }
383  double correctionRatioValid = 1.0;
384  if ((m_MaxValidationSize > -1) && (m_MaxValidationSize < maxSizeV))
385  {
386  correctionRatioValid = (double)(m_MaxValidationSize) / (double)(maxSizeV);
387  }
388  double correctionRatio = std::min(correctionRatioTrain, correctionRatioValid);
389  m_ClassesProbTraining[itmap->first] = correctionRatio * (1.0 - m_ValidationTrainingProportion);
390  m_ClassesProbValidation[itmap->first] = correctionRatio * m_ValidationTrainingProportion;
391  }
392  }
393 }
394 template <class TImage, class TVectorData>
396 {
397  const double pixelArea = std::abs(image->GetSignedSpacing()[0] * image->GetSignedSpacing()[1]);
398 
399  // Compute area of exterior ring in pixels
400  PolygonPointerType exteriorRing = polygonDataNode->GetPolygonExteriorRing();
401  double area = exteriorRing->GetArea() / pixelArea;
402 
403  // Remove contribution of all interior rings
404  PolygonListPointerType interiorRings = polygonDataNode->GetPolygonInteriorRings();
405  for (typename PolygonListType::Iterator interiorRing = interiorRings->Begin(); interiorRing != interiorRings->End(); ++interiorRing)
406  {
407  area -= interiorRing.Get()->GetArea() / pixelArea;
408  }
409 
410  return area;
411 }
412 
413 template <class TImage, class TVectorData>
414 void ListSampleGenerator<TImage, TVectorData>::PrintSelf(std::ostream& os, itk::Indent indent) const
415 {
416  os << indent << "* MaxTrainingSize: " << m_MaxTrainingSize << "\n";
417  os << indent << "* MaxValidationSize: " << m_MaxValidationSize << "\n";
418  os << indent << "* Proportion: " << m_ValidationTrainingProportion << "\n";
419  os << indent << "* Input data:\n";
420  if (m_ClassesSize.empty())
421  {
422  os << indent << "Empty\n";
423  }
424  else
425  {
426  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap)
427  {
428  os << indent << itmap->first << ": " << itmap->second << "\n";
429  }
430  }
431 
432  os << "\n" << indent << "* Training set:\n";
433  if (m_ClassesProbTraining.empty())
434  {
435  os << indent << "Not computed\n";
436  }
437  else
438  {
439  os << indent << "** Selection probability:\n";
440  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbTraining.begin(); itmap != m_ClassesProbTraining.end(); ++itmap)
441  {
442  os << indent << itmap->first << ": " << itmap->second << "\n";
443  }
444  os << indent << "** Number of selected samples:\n";
445  for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberTraining.begin(); itmap != m_ClassesSamplesNumberTraining.end(); ++itmap)
446  {
447  os << indent << itmap->first << ": " << itmap->second << "\n";
448  }
449  }
450 
451  os << "\n" << indent << "* Validation set:\n";
452  if (m_ClassesProbValidation.empty())
453  {
454  os << indent << "Not computed\n";
455  }
456  else
457  {
458  os << indent << "** Selection probability:\n";
459  for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesProbValidation.begin(); itmap != m_ClassesProbValidation.end(); ++itmap)
460  {
461  os << indent << itmap->first << ": " << itmap->second << "\n";
462  }
463  os << indent << "** Number of selected samples:\n";
464  for (std::map<ClassLabelType, int>::const_iterator itmap = m_ClassesSamplesNumberValidation.begin(); itmap != m_ClassesSamplesNumberValidation.end();
465  ++itmap)
466  {
467  os << indent << itmap->first << ": " << itmap->second << "\n";
468  }
469  }
470 }
471 }
472 
473 #endif
otb::ListSampleGenerator::SetInput
void SetInput(const ImageType *)
Definition: otbListSampleGenerator.hxx:87
otb::ListSampleGenerator::VectorDataPointerType
VectorDataType::Pointer VectorDataPointerType
Definition: otbListSampleGenerator.h:74
otb::ListSampleGenerator::ListSampleGenerator
ListSampleGenerator()
Definition: otbListSampleGenerator.hxx:64
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::TransformPhysicalRegionToIndexRegion
ImageType::RegionType TransformPhysicalRegionToIndexRegion(const RemoteSensingRegionType &region, const ImageType *image)
Definition: otbRemoteSensingRegion.h:345
otbMacro.h
otb::ListSampleGenerator::GetTrainingListSample
ListSampleType * GetTrainingListSample()
Definition: otbListSampleGenerator.hxx:148
otb::ListSampleGenerator::PolygonListPointerType
DataNodeType::PolygonListPointerType PolygonListPointerType
Definition: otbListSampleGenerator.h:179
otb::ListSampleGenerator::GetInputVectorData
const VectorDataType * GetInputVectorData() const
Definition: otbListSampleGenerator.hxx:112
otb::ListSampleGenerator::GetPolygonAreaInPixelsUnits
double GetPolygonAreaInPixelsUnits(DataNodeType *polygonDataNode, ImageType *image)
Definition: otbListSampleGenerator.hxx:395
otb::ListSampleGenerator::DataObjectPointerArraySizeType
itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType
Definition: otbListSampleGenerator.h:75
otb::ListSampleGenerator::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbListSampleGenerator.hxx:414
otb::ListSampleGenerator::ComputeClassSelectionProbability
void ComputeClassSelectionProbability()
Definition: otbListSampleGenerator.hxx:327
otb::ListSampleGenerator::GetTrainingListLabel
ListLabelType * GetTrainingListLabel()
Definition: otbListSampleGenerator.hxx:154
otb::ListSampleGenerator::ListSampleType
itk::Statistics::ListSample< SampleType > ListSampleType
Definition: otbListSampleGenerator.h:79
otb::ListSampleGenerator::DataObjectPointer
itk::DataObject::Pointer DataObjectPointer
Definition: otbListSampleGenerator.h:102
otb::ListSampleGenerator::ListSamplePointerType
ListSampleType::Pointer ListSamplePointerType
Definition: otbListSampleGenerator.h:80
otb::ListSampleGenerator::GenerateInputRequestedRegion
void GenerateInputRequestedRegion(void) override
Definition: otbListSampleGenerator.hxx:175
otb::ListSampleGenerator::MakeOutput
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) override
Definition: otbListSampleGenerator.hxx:123
otb::ListSampleGenerator::GenerateData
void GenerateData(void) override
Definition: otbListSampleGenerator.hxx:195
otbListSampleGenerator.h
otb::ListSampleGenerator::GenerateClassStatistics
void GenerateClassStatistics()
Definition: otbListSampleGenerator.hxx:305
otb::ListSampleGenerator::GetInput
const ImageType * GetInput() const
Definition: otbListSampleGenerator.hxx:93
otb::ListSampleGenerator::m_RandomGenerator
RandomGeneratorType::Pointer m_RandomGenerator
Definition: otbListSampleGenerator.h:210
otb::ListSampleGenerator::GetValidationListLabel
ListLabelType * GetValidationListLabel()
Definition: otbListSampleGenerator.hxx:169
otbVectorDataProjectionFilter.h
otb::ListSampleGenerator::VectorDataType
TVectorData VectorDataType
Definition: otbListSampleGenerator.h:73
otb::ListSampleGenerator::SetInputVectorData
void SetInputVectorData(const VectorDataType *)
Definition: otbListSampleGenerator.hxx:104
otb::ListSampleGenerator::GetValidationListSample
ListSampleType * GetValidationListSample()
Definition: otbListSampleGenerator.hxx:161
otb::ListSampleGenerator::TreeIteratorType
itk::PreOrderTreeIterator< typename VectorDataType::DataTreeType > TreeIteratorType
Definition: otbListSampleGenerator.h:180
otb::ListSampleGenerator::ImagePointerType
ImageType::Pointer ImagePointerType
Definition: otbListSampleGenerator.h:70
otb::ListSampleGenerator::PolygonPointerType
DataNodeType::PolygonPointerType PolygonPointerType
Definition: otbListSampleGenerator.h:177
otb::ListSampleGenerator::ListLabelPointerType
ListLabelType::Pointer ListLabelPointerType
Definition: otbListSampleGenerator.h:86
otb::ListSampleGenerator::ListLabelType
itk::Statistics::ListSample< LabelType > ListLabelType
Definition: otbListSampleGenerator.h:85
otb::ListSampleGenerator::DataNodeType
VectorDataType::DataNodeType DataNodeType
Definition: otbListSampleGenerator.h:175
otb::ListSampleGenerator::ImageType
TImage ImageType
Definition: otbListSampleGenerator.h:67