OTB  6.7.0
Orfeo Toolbox
otbTrainVectorBase.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbTrainVectorBase_h
21 #define otbTrainVectorBase_h
22 
24 #include "otbWrapperApplication.h"
26 
28 #include "otbOGRFeatureWrapper.h"
30 
33 
34 #include "itkListSample.h"
36 
37 #include <algorithm>
38 #include <locale>
39 #include <string>
40 
41 namespace otb
42 {
43 namespace Wrapper
44 {
45 
47 bool IsNotAlphaNum(char c)
48 {
49  return !std::isalnum( c );
50 }
51 
52 template <class TInputValue, class TOutputValue>
53 class TrainVectorBase : public LearningApplicationBase<TInputValue, TOutputValue>
54 {
55 public:
61 
63  itkTypeMacro(Self, Superclass);
64 
65  typedef typename Superclass::SampleType SampleType;
68 
69  typedef double ValueType;
71 
73 
75 
76 protected:
77 
80  {
81  public:
84  };
85 
88  {
89  public:
93  {
94  listSample = ListSampleType::New();
96  }
97  };
99 
105  {
106  public:
107 
109  std::vector<int> m_SelectedIdx;
110 
112  std::vector<int> m_SelectedCFieldIdx;
113 
115  std::string m_SelectedCFieldName;
116 
118  std::vector <std::string> m_SelectedNames;
119  unsigned int m_NbFeatures;
120 
121  void SetFieldNames(std::vector <std::string> fieldNames, std::vector<int> selectedIdx)
122  {
123  m_SelectedIdx = selectedIdx;
124  m_NbFeatures = static_cast<unsigned int>(selectedIdx.size());
125  m_SelectedNames = std::vector<std::string>( m_NbFeatures );
126  for( unsigned int i = 0; i < m_NbFeatures; ++i )
127  {
128  m_SelectedNames[i] = fieldNames[selectedIdx[i]];
129  }
130  }
131  void SetClassFieldNames(std::vector<std::string> cFieldNames, std::vector<int> selectedCFieldIdx)
132  {
133  m_SelectedCFieldIdx = selectedCFieldIdx;
134  // Handle only one class field name, if several are provided only the first one is used.
135  if (selectedCFieldIdx.empty())
136  m_SelectedCFieldName.clear();
137  else
138  m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()];
139  }
140  };
141 
142 
143 protected:
144 
150  virtual void ExtractAllSamples(const ShiftScaleParameters &measurement);
151 
158  virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement);
159 
166  virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement);
167 
168 
177  SamplesWithLabel
178  ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement);
179 
180 
186  ShiftScaleParameters GetStatistics(unsigned int nbFeatures);
187 
192 
193  void DoInit() override;
194  void DoUpdateParameters() override;
195  void DoExecute() override;
196 };
197 
198 }
199 }
200 
201 #ifndef OTB_MANUAL_INSTANTIATION
202 #include "otbTrainVectorBase.hxx"
203 #endif
204 
205 #endif
Superclass::SampleType SampleType
ShiftScaleParameters GetStatistics(unsigned int nbFeatures)
Read a xml file where are stored several statistics.
SamplesWithLabel m_ClassificationSamplesWithLabel
itk::VariableLengthVector< ValueType > MeasurementType
void SetFieldNames(std::vector< std::string > fieldNames, std::vector< int > selectedIdx)
LearningApplicationBase< TInputValue, TOutputValue > Superclass
SamplesWithLabel m_TrainingSamplesWithLabel
LearningApplicationBase is the base class for application that use machine learning model...
SamplesWithLabel ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement)
This class generate a shifted and scaled version of the input sample list.
itk::SmartPointer< const Self > ConstPointer
otb::Statistics::ShiftScaleSampleListFilter< ListSampleType, ListSampleType > ShiftScaleFilterType
otb::StatisticsXMLFileReader< SampleType > StatisticsReader
virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
virtual void ExtractAllSamples(const ShiftScaleParameters &measurement)
Superclass::TargetListSampleType TargetListSampleType
Superclass::ListSampleType ListSampleType
itk::SmartPointer< Self > Pointer
ModelType::InputListSampleType ListSampleType
bool IsNotAlphaNum(char c)
TargetListSampleType::Pointer m_PredictedList
void SetClassFieldNames(std::vector< std::string > cFieldNames, std::vector< int > selectedCFieldIdx)
virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)