OTB  9.0.0
Orfeo Toolbox
otbSharkRandomForestsMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbSharkRandomForestsMachineLearningModel_h
22 #define otbSharkRandomForestsMachineLearningModel_h
23 
24 #include "itkLightObject.h"
26 
27 // Quiet a deprecation warning
28 #define BOOST_BIND_GLOBAL_PLACEHOLDERS
29 
30 #if defined(__GNUC__) || defined(__clang__)
31 #pragma GCC diagnostic push
32 
33 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
34 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
35 #endif
36 #pragma GCC diagnostic ignored "-Wshadow"
37 #pragma GCC diagnostic ignored "-Wunused-parameter"
38 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
39 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
40 #pragma GCC diagnostic ignored "-Wsign-compare"
41 #pragma GCC diagnostic ignored "-Wcast-align"
42 #pragma GCC diagnostic ignored "-Wunknown-pragmas"
43 #pragma GCC diagnostic ignored "-Wmissing-field-initializers"
44 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
45 #if defined(__clang__)
46 #pragma clang diagnostic ignored "-Wheader-guard"
47 #pragma clang diagnostic ignored "-Wexpansion-to-defined"
48 #else
49 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
50 #endif
51 #endif
52 #include <shark/Models/Classifier.h>
53 #include "otb_shark.h"
54 #include "shark/Algorithms/Trainers/RFTrainer.h"
55 #if defined(__GNUC__) || defined(__clang__)
56 #pragma GCC diagnostic pop
57 #endif
58 
59 
74 namespace otb
75 {
76 template <class TInputValue, class TTargetValue>
77 class ITK_EXPORT SharkRandomForestsMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
78 {
79 public:
83  typedef itk::SmartPointer<Self> Pointer;
84  typedef itk::SmartPointer<const Self> ConstPointer;
85 
87  typedef typename Superclass::InputSampleType InputSampleType;
88  typedef typename Superclass::InputListSampleType InputListSampleType;
91  typedef typename Superclass::TargetListSampleType TargetListSampleType;
94  typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
95  typedef typename Superclass::ProbaSampleType ProbaSampleType;
96  typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
98  itkNewMacro(Self);
101 
103  virtual void Train() override;
104 
106  virtual void Save(const std::string& filename, const std::string& name = "") override;
107 
109  virtual void Load(const std::string& filename, const std::string& name = "") override;
110 
113 
115  virtual bool CanReadFile(const std::string&) override;
116 
118  virtual bool CanWriteFile(const std::string&) override;
120 
122  itkGetMacro(NumberOfTrees, unsigned int);
123 
125  itkSetMacro(NumberOfTrees, unsigned int);
126 
128  itkGetMacro(MTry, unsigned int);
129 
131  itkSetMacro(MTry, unsigned int);
132 
136  itkGetMacro(NodeSize, unsigned int);
137 
141  itkSetMacro(NodeSize, unsigned int);
142 
146  itkGetMacro(OobRatio, float);
147 
151  itkSetMacro(OobRatio, float);
152 
154  itkGetMacro(ComputeMargin, bool);
155 
157  itkSetMacro(ComputeMargin, bool);
158 
160  itkGetMacro(NormalizeClassLabels, bool);
161  itkSetMacro(NormalizeClassLabels, bool);
163 
164 protected:
167 
169  ~SharkRandomForestsMachineLearningModel() override = default;
170 
172  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
173 
174  void DoPredictBatch(const InputListSampleType*, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType*,
175  ConfidenceListSampleType* = nullptr, ProbaListSampleType* = nullptr) const override;
176 
178  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
179 
180 private:
182  void operator=(const Self&) = delete;
183 
184  shark::RFClassifier<unsigned int> m_RFModel;
185  shark::RFTrainer<unsigned int> m_RFTrainer;
186  std::vector<unsigned int> m_ClassDictionary;
188 
189  unsigned int m_NumberOfTrees;
190  unsigned int m_MTry;
191  unsigned int m_NodeSize;
192  float m_OobRatio;
194 
196  ConfidenceValueType ComputeConfidence(shark::RealVector& probas, bool computeMargin) const;
197 };
198 } // end namespace otb
199 
200 #ifndef OTB_MANUAL_INSTANTIATION
202 #endif
203 
204 #endif
otb::SharkRandomForestsMachineLearningModel::m_RFTrainer
shark::RFTrainer< unsigned int > m_RFTrainer
Definition: otbSharkRandomForestsMachineLearningModel.h:185
otb::SharkRandomForestsMachineLearningModel::m_NormalizeClassLabels
bool m_NormalizeClassLabels
Definition: otbSharkRandomForestsMachineLearningModel.h:187
otb::SharkRandomForestsMachineLearningModel::m_OobRatio
float m_OobRatio
Definition: otbSharkRandomForestsMachineLearningModel.h:192
otb::SharkRandomForestsMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:88
otb::SharkRandomForestsMachineLearningModel::ProbaListSampleType
Superclass::ProbaListSampleType ProbaListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:96
otb::SharkRandomForestsMachineLearningModel::m_ClassDictionary
std::vector< unsigned int > m_ClassDictionary
Definition: otbSharkRandomForestsMachineLearningModel.h:186
otb::MachineLearningModel< TInputValue, TTargetValue >::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb::MachineLearningModel::ConfidenceValueType
MLMTargetTraits< TConfidenceValue >::ValueType ConfidenceValueType
Definition: otbMachineLearningModel.h:96
otb::SharkRandomForestsMachineLearningModel::m_NumberOfTrees
unsigned int m_NumberOfTrees
Definition: otbSharkRandomForestsMachineLearningModel.h:189
otb::SharkRandomForestsMachineLearningModel::Superclass
MachineLearningModel< TInputValue, TTargetValue > Superclass
Definition: otbSharkRandomForestsMachineLearningModel.h:82
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SharkRandomForestsMachineLearningModel::m_NodeSize
unsigned int m_NodeSize
Definition: otbSharkRandomForestsMachineLearningModel.h:191
otb::SharkRandomForestsMachineLearningModel::m_MTry
unsigned int m_MTry
Definition: otbSharkRandomForestsMachineLearningModel.h:190
otb::SharkRandomForestsMachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbSharkRandomForestsMachineLearningModel.h:84
otb::SharkRandomForestsMachineLearningModel::Self
SharkRandomForestsMachineLearningModel Self
Definition: otbSharkRandomForestsMachineLearningModel.h:81
otb::SharkRandomForestsMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbSharkRandomForestsMachineLearningModel.h:89
otb::SharkRandomForestsMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbSharkRandomForestsMachineLearningModel.h:92
otb::SharkRandomForestsMachineLearningModel::m_RFModel
shark::RFClassifier< unsigned int > m_RFModel
Definition: otbSharkRandomForestsMachineLearningModel.h:184
otbMachineLearningModel.h
otb::SharkRandomForestsMachineLearningModel::m_ComputeMargin
bool m_ComputeMargin
Definition: otbSharkRandomForestsMachineLearningModel.h:193
otb::MachineLearningModel< TInputValue, TTargetValue >::ConfidenceSampleType
MLMTargetTraits< double >::SampleType ConfidenceSampleType
Definition: otbMachineLearningModel.h:97
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otb::SharkRandomForestsMachineLearningModel::InputValueType
Superclass::InputValueType InputValueType
Definition: otbSharkRandomForestsMachineLearningModel.h:86
otb::SharkRandomForestsMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:90
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otb::SharkRandomForestsMachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbSharkRandomForestsMachineLearningModel.h:83
otb::SharkRandomForestsMachineLearningModel::ConfidenceListSampleType
Superclass::ConfidenceListSampleType ConfidenceListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:94
otb::SharkRandomForestsMachineLearningModel::ConfidenceSampleType
Superclass::ConfidenceSampleType ConfidenceSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:93
otb::SharkRandomForestsMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:91
otb::SharkRandomForestsMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:95
otb::SharkRandomForestsMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbSharkRandomForestsMachineLearningModel.h:87
otb::SharkRandomForestsMachineLearningModel
Definition: otbSharkRandomForestsMachineLearningModel.h:77
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91
otbSharkRandomForestsMachineLearningModel.hxx