OTB  9.0.0
Orfeo Toolbox
otbSharkKMeansMachineLearningModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbSharkKMeansMachineLearningModel_hxx
21 #define otbSharkKMeansMachineLearningModel_hxx
22 
23 #include <fstream>
24 #include <utility>
25 
26 #include "itkMacro.h"
28 
29 #if defined(__GNUC__) || defined(__clang__)
30 #pragma GCC diagnostic push
31 
32 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
33 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
34 #endif
35 
36 #pragma GCC diagnostic ignored "-Wshadow"
37 #pragma GCC diagnostic ignored "-Wunused-parameter"
38 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
39 #pragma GCC diagnostic ignored "-Wignored-qualifiers"
40 #endif
41 
42 #include "otb_shark.h"
43 #include "otbSharkUtils.h"
44 #include "shark/Algorithms/KMeans.h" //k-means algorithm
45 #include "shark/Models/Clustering/HardClusteringModel.h"
46 #include "shark/Models/Clustering/SoftClusteringModel.h"
47 #include <shark/Data/Csv.h> //load the csv file
48 
49 #if defined(__GNUC__) || defined(__clang__)
50 #pragma GCC diagnostic pop
51 #endif
52 
53 
54 namespace otb
55 {
56 template <class TInputValue, class TOutputValue>
58 {
59  // Default set HardClusteringModel
60  this->m_ConfidenceIndex = true;
61  m_ClusteringModel = std::make_shared<ClusteringModelType>(&m_Centroids);
62 }
63 
64 
65 template <class TInputValue, class TOutputValue>
67 {
68 }
69 
71 template <class TInputValue, class TOutputValue>
73 {
74  // Parse input data and convert to Shark Data
75  std::vector<shark::RealVector> vector_data;
76  otb::Shark::ListSampleToSharkVector(this->GetInputListSample(), vector_data);
77  shark::Data<shark::RealVector> data = shark::createDataFromRange(vector_data);
79 
80  // Use a Hard Clustering Model for classification
81  shark::kMeans(data, m_K, m_Centroids, m_MaximumNumberOfIterations);
82  m_ClusteringModel = std::make_shared<ClusteringModelType>(&m_Centroids);
83 }
84 
85 template <class TInputValue, class TOutputValue>
88 {
89  shark::RealVector data(value.Size());
90  for (size_t i = 0; i < value.Size(); i++)
91  {
92  data.push_back(value[i]);
93  }
94 
95  // Change quality measurement only if SoftClustering or other clustering method is used.
96  if (quality != nullptr)
97  {
98  // unsigned int probas = (*m_ClusteringModel)( data );
99  (*quality) = ConfidenceValueType(1.);
100  }
101 
102  if (proba != nullptr)
103  {
104  if (!this->m_ProbaIndex)
105  {
106  itkExceptionMacro("Probability per class not available for this classifier !");
107  }
108  }
109  TargetSampleType target;
110  ClusteringOutputType predictedValue = (*m_ClusteringModel)(data);
111  target[0] = static_cast<TOutputValue>(predictedValue);
112  return target;
113 }
114 
115 template <class TInputValue, class TOutputValue>
117  const unsigned int& size, TargetListSampleType* targets,
118  ConfidenceListSampleType* quality, ProbaListSampleType* proba) const
119 {
120 
121  // Perform check on input values
122  assert(input != nullptr);
123  assert(targets != nullptr);
124 
125  // input list sample and target list sample should be initialized and without
126  assert(input->Size() == targets->Size() && "Input sample list and target label list do not have the same size.");
127  assert(((quality == nullptr) || (quality->Size() == input->Size())) &&
128  "Quality samples list is not null and does not have the same size as input samples list");
129  if (startIndex + size > input->Size())
130  {
131  itkExceptionMacro(<< "requested range [" << startIndex << ", " << startIndex + size << "[ partially outside input sample list range.[0," << input->Size()
132  << "[");
133  }
134 
135  // Convert input list of features to shark data format
136  std::vector<shark::RealVector> features;
137  otb::Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
138  shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
139 
140  shark::Data<ClusteringOutputType> clusters;
141  try
142  {
143  clusters = (*m_ClusteringModel)(inputSamples);
144  }
145  catch (...)
146  {
147  itkExceptionMacro(
148  "Failed to run clustering classification. "
149  "The number of features of input samples and the model could differ.");
150  }
151 
152  unsigned int id = startIndex;
153  for (const auto& p : clusters.elements())
154  {
155  TargetSampleType target;
156  target[0] = static_cast<TOutputValue>(p);
157  targets->SetMeasurementVector(id, target);
158  ++id;
159  }
160 
161  // Change quality measurement only if SoftClustering or other clustering method is used.
162  if (quality != nullptr)
163  {
164  for (unsigned int qid = startIndex; qid < startIndex + size; ++qid)
165  {
166  quality->SetMeasurementVector(qid, static_cast<ConfidenceValueType>(1.));
167  }
168  }
169  if (proba != nullptr && !this->m_ProbaIndex)
170  {
171  itkExceptionMacro("Probability per class not available for this classifier !");
172  }
173 }
174 
175 
176 template <class TInputValue, class TOutputValue>
177 void SharkKMeansMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string& filename, const std::string& itkNotUsed(name))
178 {
179  std::ofstream ofs(filename);
180  if (!ofs)
181  {
182  itkExceptionMacro(<< "Error opening " << filename.c_str());
183  }
184  ofs << "#" << m_ClusteringModel->name() << std::endl;
185  shark::TextOutArchive oa(ofs);
186  m_ClusteringModel->save(oa, 1);
187 }
188 
189 template <class TInputValue, class TOutputValue>
190 void SharkKMeansMachineLearningModel<TInputValue, TOutputValue>::Load(const std::string& filename, const std::string& itkNotUsed(name))
191 {
192  m_CanRead = false;
193  std::ifstream ifs(filename);
194  if (ifs.good())
195  {
196  // Check if first line contains model name
197  std::string line;
198  std::getline(ifs, line);
199  m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos;
200  }
201 
202  if (!m_CanRead)
203  return;
204 
205  shark::TextInArchive ia(ifs);
206  m_ClusteringModel->load(ia, 0);
207  ifs.close();
208 }
209 
210 template <class TInputValue, class TOutputValue>
212 {
213  try
214  {
215  m_CanRead = true;
216  this->Load(file);
217  }
218  catch (...)
219  {
220  return false;
221  }
222  return m_CanRead;
223 }
224 
225 template <class TInputValue, class TOutputValue>
227 {
228  return true;
229 }
230 
231 template <class TInputValue, class TOutputValue>
233 {
234  shark::exportCSV(m_Centroids.centroids(), filename, ' ');
235 }
236 
237 template <class TInputValue, class TOutputValue>
238 void SharkKMeansMachineLearningModel<TInputValue, TOutputValue>::PrintSelf(std::ostream& os, itk::Indent indent) const
239 {
240  // Call superclass implementation
241  Superclass::PrintSelf(os, indent);
242 }
243 } // end namespace otb
244 
245 #endif
otb::SharkKMeansMachineLearningModel::ClusteringOutputType
ClusteringModelType::OutputType ClusteringOutputType
Definition: otbSharkKMeansMachineLearningModel.h:103
otbSharkKMeansMachineLearningModel.h
otb::SharkKMeansMachineLearningModel::Load
virtual void Load(const std::string &filename, const std::string &name="") override
Definition: otbSharkKMeansMachineLearningModel.hxx:190
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::SharkKMeansMachineLearningModel::ProbaListSampleType
Superclass::ProbaListSampleType ProbaListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:101
otb::SharkKMeansMachineLearningModel::m_ClusteringModel
std::shared_ptr< ClusteringModelType > m_ClusteringModel
Definition: otbSharkKMeansMachineLearningModel.h:180
otb::SharkKMeansMachineLearningModel::CanWriteFile
virtual bool CanWriteFile(const std::string &) override
Definition: otbSharkKMeansMachineLearningModel.hxx:226
otb::SharkKMeansMachineLearningModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbSharkKMeansMachineLearningModel.h:92
otb::SharkKMeansMachineLearningModel::Save
virtual void Save(const std::string &filename, const std::string &name="") override
Definition: otbSharkKMeansMachineLearningModel.hxx:177
otb::SharkKMeansMachineLearningModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbSharkKMeansMachineLearningModel.h:100
otb::SharkKMeansMachineLearningModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:96
otb::SharkKMeansMachineLearningModel::Train
virtual void Train() override
Definition: otbSharkKMeansMachineLearningModel.hxx:72
otb::SharkKMeansMachineLearningModel::ExportCentroids
void ExportCentroids(const std::string &filename)
Definition: otbSharkKMeansMachineLearningModel.hxx:232
otb::SharkKMeansMachineLearningModel::CanReadFile
virtual bool CanReadFile(const std::string &) override
Definition: otbSharkKMeansMachineLearningModel.hxx:211
otb::SharkKMeansMachineLearningModel::~SharkKMeansMachineLearningModel
virtual ~SharkKMeansMachineLearningModel()
Definition: otbSharkKMeansMachineLearningModel.hxx:66
otb::SharkKMeansMachineLearningModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbSharkKMeansMachineLearningModel.h:97
otb::SharkKMeansMachineLearningModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbSharkKMeansMachineLearningModel.h:95
otb::SharkKMeansMachineLearningModel::PrintSelf
void PrintSelf(std::ostream &os, itk::Indent indent) const override
Definition: otbSharkKMeansMachineLearningModel.hxx:238
otb::SharkKMeansMachineLearningModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:93
otb::SharkKMeansMachineLearningModel::m_Centroids
shark::Centroids m_Centroids
Definition: otbSharkKMeansMachineLearningModel.h:177
otb::SharkKMeansMachineLearningModel::SharkKMeansMachineLearningModel
SharkKMeansMachineLearningModel()
Definition: otbSharkKMeansMachineLearningModel.hxx:57
otb::SharkKMeansMachineLearningModel::ConfidenceListSampleType
Superclass::ConfidenceListSampleType ConfidenceListSampleType
Definition: otbSharkKMeansMachineLearningModel.h:99
otb::SharkKMeansMachineLearningModel::DoPredictBatch
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *, ConfidenceListSampleType *=nullptr, ProbaListSampleType *=nullptr) const override
Definition: otbSharkKMeansMachineLearningModel.hxx:116
otb::MachineLearningModel< TInputValue, TTargetValue >::m_ConfidenceIndex
bool m_ConfidenceIndex
Definition: otbMachineLearningModel.h:228
otb::SharkKMeansMachineLearningModel::DoPredict
virtual TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbSharkKMeansMachineLearningModel.hxx:87