OTB  9.0.0
Orfeo Toolbox
otbSharkKMeansMachineLearningModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbSharkKMeansMachineLearningModel_h
21 #define otbSharkKMeansMachineLearningModel_h
22 
23 #include <memory>
24 
25 #include "itkLightObject.h"
27 
28 // Quiet a deprecation warning
29 #define BOOST_BIND_GLOBAL_PLACEHOLDERS
30 
31 #if defined(__GNUC__) || defined(__clang__)
32 #pragma GCC diagnostic push
33 
34 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
35 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
36 #endif
37 
38 #pragma GCC diagnostic ignored "-Wshadow"
39 #pragma GCC diagnostic ignored "-Wunused-parameter"
40 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
41 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
42 #pragma GCC diagnostic ignored "-Wsign-compare"
43 #pragma GCC diagnostic ignored "-Wcast-align"
44 #pragma GCC diagnostic ignored "-Wunknown-pragmas"
45 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
46 #if defined(__clang__)
47 #pragma clang diagnostic ignored "-Wheader-guard"
48 #pragma clang diagnostic ignored "-Wexpansion-to-defined"
49 #else
50 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
51 #endif
52 #endif
53 
54 #include "otb_shark.h"
55 #include "shark/Models/Clustering/HardClusteringModel.h"
56 #include "shark/Models/Clustering/SoftClusteringModel.h"
57 #include "shark/Models/Clustering/Centroids.h"
58 #include "shark/Models/Clustering/ClusteringModel.h"
59 #include "shark/Algorithms/KMeans.h"
60 #include "shark/Models/Normalizer.h"
61 
62 #if defined(__GNUC__) || defined(__clang__)
63 #pragma GCC diagnostic pop
64 #endif
65 
79 namespace otb
80 {
81 template <class TInputValue, class TTargetValue>
82 class ITK_EXPORT SharkKMeansMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue>
83 {
84 public:
88  typedef itk::SmartPointer<Self> Pointer;
89  typedef itk::SmartPointer<const Self> ConstPointer;
90 
92  typedef typename Superclass::InputSampleType InputSampleType;
93  typedef typename Superclass::InputListSampleType InputListSampleType;
96  typedef typename Superclass::TargetListSampleType TargetListSampleType;
97  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
99  typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
100  typedef typename Superclass::ProbaSampleType ProbaSampleType;
101  typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
102  typedef shark::HardClusteringModel<shark::RealVector> ClusteringModelType;
103  typedef ClusteringModelType::OutputType ClusteringOutputType;
104 
106  itkNewMacro(Self);
109 
111  virtual void Train() override;
112 
114  virtual void Save(const std::string& filename, const std::string& name = "") override;
115 
117  virtual void Load(const std::string& filename, const std::string& name = "") override;
118 
121 
123  virtual bool CanReadFile(const std::string&) override;
124 
126  virtual bool CanWriteFile(const std::string&) override;
128 
130  itkGetMacro(MaximumNumberOfIterations, unsigned);
131 
133  itkSetMacro(MaximumNumberOfIterations, unsigned);
134 
136  itkGetMacro(K, unsigned);
137 
139  itkSetMacro(K, unsigned);
140 
142  void SetCentroidsFromData(const shark::Data<shark::RealVector>& data)
143  {
144  m_Centroids.setCentroids(data);
145  this->Modified();
146  }
148 
149  void ExportCentroids(const std::string& filename);
150 
151 protected:
154 
157 
159  virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType* quality = nullptr, ProbaSampleType* proba = nullptr) const override;
160 
161  virtual void DoPredictBatch(const InputListSampleType*, const unsigned int& startIndex, const unsigned int& size, TargetListSampleType*,
162  ConfidenceListSampleType* = nullptr, ProbaListSampleType* = nullptr) const override;
163 
165  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
166 
167 private:
168  SharkKMeansMachineLearningModel(const Self&) = delete;
169  void operator=(const Self&) = delete;
170 
171  // Parameters set by the user
172  unsigned int m_K;
174  bool m_CanRead;
175 
177  shark::Centroids m_Centroids;
178 
180  std::shared_ptr<ClusteringModelType> m_ClusteringModel;
181 };
182 } // end namespace otb
183 
184 #ifndef OTB_MANUAL_INSTANTIATION
185 
187 
188 #endif
189 
190 #endif
otb::SharkKMeansMachineLearningModel::ClusteringOutputType
ClusteringModelType::OutputType ClusteringOutputType
Definition: otbSharkKMeansMachineLearningModel.h:103
otb::SharkKMeansMachineLearningModel::ConfidenceSampleType
Superclass::ConfidenceSampleType ConfidenceSampleType
Definition: otbSharkKMeansMachineLearningModel.h:98
otb::SharkKMeansMachineLearningModel::ConstPointer
itk::SmartPointer< const Self > ConstPointer
Definition: otbSharkKMeansMachineLearningModel.h:89
otb::SharkKMeansMachineLearningModel::m_K
unsigned int m_K
Definition: otbSharkKMeansMachineLearningModel.h:172
otb::SharkKMeansMachineLearningModel::m_MaximumNumberOfIterations
unsigned int m_MaximumNumberOfIterations
Definition: otbSharkKMeansMachineLearningModel.h:173
otb::MachineLearningModel< TInputValue, TTargetValue >::InputValueType
MLMSampleTraits< TInputValue >::ValueType InputValueType
Definition: otbMachineLearningModel.h:83
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SharkKMeansMachineLearningModel::ProbaListSampleType
Superclass::ProbaListSampleType ProbaListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:101
otb::SharkKMeansMachineLearningModel::m_ClusteringModel
std::shared_ptr< ClusteringModelType > m_ClusteringModel
Definition: otbSharkKMeansMachineLearningModel.h:180
otb::SharkKMeansMachineLearningModel::m_CanRead
bool m_CanRead
Definition: otbSharkKMeansMachineLearningModel.h:174
otb::SharkKMeansMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbSharkKMeansMachineLearningModel.h:92
otb::SharkKMeansMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbSharkKMeansMachineLearningModel.h:100
otb::SharkKMeansMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:96
otb::SharkKMeansMachineLearningModel::Pointer
itk::SmartPointer< Self > Pointer
Definition: otbSharkKMeansMachineLearningModel.h:88
otbSharkKMeansMachineLearningModel.hxx
otb::SharkKMeansMachineLearningModel::Superclass
MachineLearningModel< TInputValue, TTargetValue > Superclass
Definition: otbSharkKMeansMachineLearningModel.h:87
otb::SharkKMeansMachineLearningModel
Definition: otbSharkKMeansMachineLearningModel.h:82
otbMachineLearningModel.h
otb::SharkKMeansMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbSharkKMeansMachineLearningModel.h:97
otb::SharkKMeansMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbSharkKMeansMachineLearningModel.h:95
otb::SharkKMeansMachineLearningModel::Self
SharkKMeansMachineLearningModel Self
Definition: otbSharkKMeansMachineLearningModel.h:86
otb::MachineLearningModel< TInputValue, TTargetValue >::ConfidenceSampleType
MLMTargetTraits< double >::SampleType ConfidenceSampleType
Definition: otbMachineLearningModel.h:97
otb::SharkKMeansMachineLearningModel::ClusteringModelType
shark::HardClusteringModel< shark::RealVector > ClusteringModelType
Definition: otbSharkKMeansMachineLearningModel.h:102
otb::SharkKMeansMachineLearningModel::TargetValueType
Superclass::TargetValueType TargetValueType
Definition: otbSharkKMeansMachineLearningModel.h:94
otb::SharkKMeansMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:93
otb::MachineLearningModel
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests,...
Definition: otbMachineLearningModel.h:70
otb::SharkKMeansMachineLearningModel::m_Centroids
shark::Centroids m_Centroids
Definition: otbSharkKMeansMachineLearningModel.h:177
otb::SharkKMeansMachineLearningModel::SetCentroidsFromData
void SetCentroidsFromData(const shark::Data< shark::RealVector > &data)
Definition: otbSharkKMeansMachineLearningModel.h:142
otb::MachineLearningModel< TInputValue, TTargetValue >::TargetValueType
MLMTargetTraits< TTargetValue >::ValueType TargetValueType
Definition: otbMachineLearningModel.h:90
otb::SharkKMeansMachineLearningModel::ConfidenceListSampleType
Superclass::ConfidenceListSampleType ConfidenceListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:99
otb::SharkKMeansMachineLearningModel::InputValueType
Superclass::InputValueType InputValueType
Definition: otbSharkKMeansMachineLearningModel.h:91
otb::MachineLearningModel::TargetSampleType
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType
Definition: otbMachineLearningModel.h:91