OTB  9.0.0
Orfeo Toolbox
otbPCAModel.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbPCAModel_hxx
21 #define otbPCAModel_hxx
22 
23 #include "otbPCAModel.h"
24 
25 #include <fstream>
26 #include "itkMacro.h"
27 #if defined(__GNUC__) || defined(__clang__)
28 #pragma GCC diagnostic push
29 
30 #if (defined (__GNUC__) && (__GNUC__ >= 9)) || (defined (__clang__) && (__clang_major__ >= 10))
31 #pragma GCC diagnostic ignored "-Wdeprecated-copy"
32 #endif
33 
34 #pragma GCC diagnostic ignored "-Wshadow"
35 #pragma GCC diagnostic ignored "-Wunused-parameter"
36 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
37 #endif
38 #include "otbSharkUtils.h"
39 // include train function
40 #include <shark/ObjectiveFunctions/ErrorFunction.h>
41 #include <shark/Algorithms/GradientDescent/Rprop.h> // the RProp optimization algorithm
42 #include <shark/ObjectiveFunctions/Loss/SquaredLoss.h> // squared loss used for regression
43 #include <shark/ObjectiveFunctions/Regularizer.h> //L2 regulariziation
44 #include <shark/ObjectiveFunctions/ErrorFunction.h>
45 #if defined(__GNUC__) || defined(__clang__)
46 #pragma GCC diagnostic pop
47 #endif
48 
49 namespace otb
50 {
51 
52 template <class TInputValue>
54 {
55  this->m_IsDoPredictBatchMultiThreaded = true;
56  this->m_Dimension = 0;
57 }
58 
59 template <class TInputValue>
61 {
62 }
63 
64 template <class TInputValue>
66 {
67  std::vector<shark::RealVector> features;
68 
69  Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
70 
71  shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
72  m_PCA.setData(inputSamples);
73  m_PCA.encoder(m_Encoder, this->m_Dimension);
74  m_PCA.decoder(m_Decoder, this->m_Dimension);
75 }
76 
77 template <class TInputValue>
78 bool PCAModel<TInputValue>::CanReadFile(const std::string& filename)
79 {
80  try
81  {
82  this->Load(filename);
83  m_Encoder.name();
84  }
85  catch (...)
86  {
87  return false;
88  }
89  return true;
90 }
91 
92 template <class TInputValue>
93 bool PCAModel<TInputValue>::CanWriteFile(const std::string& /*filename*/)
94 {
95  return true;
96 }
97 
98 template <class TInputValue>
99 void PCAModel<TInputValue>::Save(const std::string& filename, const std::string& /*name*/)
100 {
101  std::ofstream ofs(filename);
102  ofs << "pca" << std::endl; // first line
103  shark::TextOutArchive oa(ofs);
104  m_Encoder.write(oa);
105  ofs.close();
106 
107  if (this->m_WriteEigenvectors == true) // output the map vectors in a txt file
108  {
109  std::ofstream otxt(filename + ".txt");
110 
111  otxt << "Eigenvectors : " << m_PCA.eigenvectors() << std::endl;
112  otxt << "Eigenvalues : " << m_PCA.eigenvalues() << std::endl;
113 
114  std::vector<shark::RealVector> features;
115 
116  shark::SquaredLoss<shark::RealVector> loss;
117  Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
118  shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
119  otxt << "Reconstruction error : " << loss.eval(inputSamples, m_Decoder(m_Encoder(inputSamples))) << std::endl;
120  otxt.close();
121  }
122 }
123 
124 template <class TInputValue>
125 void PCAModel<TInputValue>::Load(const std::string& filename, const std::string& /*name*/)
126 {
127  std::ifstream ifs(filename);
128  char encoder[256];
129  ifs.getline(encoder, 256);
130  std::string encoderstr(encoder);
131 
132  if (encoderstr != "pca")
133  {
134  itkExceptionMacro(<< "Error opening " << filename.c_str());
135  }
136  shark::TextInArchive ia(ifs);
137  m_Encoder.read(ia);
138  ifs.close();
139  if (this->m_Dimension == 0)
140  {
141  this->m_Dimension = m_Encoder.outputShape()[0];
142  }
143 
144  auto eigenvectors = m_Encoder.matrix();
145  eigenvectors.resize(this->m_Dimension, m_Encoder.inputShape()[0]);
146 
147  m_Encoder.setStructure(eigenvectors, m_Encoder.offset());
148 }
149 
150 template <class TInputValue>
152  ProbaSampleType* /*proba*/) const
153 {
154  shark::RealVector samples(value.Size());
155  for (size_t i = 0; i < value.Size(); i++)
156  {
157  samples[i] = value[i];
158  }
159 
160  std::vector<shark::RealVector> features;
161  features.push_back(samples);
162 
163  shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
164 
165  data = m_Encoder(data);
166  TargetSampleType target;
167  target.SetSize(this->m_Dimension);
168 
169  for (unsigned int a = 0; a < this->m_Dimension; ++a)
170  {
171  target[a] = data.element(0)[a];
172  }
173  return target;
174 }
175 
176 template <class TInputValue>
177 void PCAModel<TInputValue>::DoPredictBatch(const InputListSampleType* input, const unsigned int& startIndex, const unsigned int& size,
178  TargetListSampleType* targets, ConfidenceListSampleType* /*quality*/, ProbaListSampleType* /*proba*/) const
179 {
180  std::vector<shark::RealVector> features;
181  Shark::ListSampleRangeToSharkVector(input, features, startIndex, size);
182  shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
183  TargetSampleType target;
184  data = m_Encoder(data);
185  unsigned int id = startIndex;
186  target.SetSize(this->m_Dimension);
187  for (const auto& p : data.elements())
188  {
189  for (unsigned int a = 0; a < this->m_Dimension; ++a)
190  {
191  target[a] = p[a];
192  }
193  targets->SetMeasurementVector(id, target);
194  ++id;
195  }
196 }
197 
198 } // namespace otb
199 #endif
otb::PCAModel::TargetSampleType
Superclass::TargetSampleType TargetSampleType
Definition: otbPCAModel.h:86
otb::PCAModel::ProbaListSampleType
Superclass::ProbaListSampleType ProbaListSampleType
Definition: otbPCAModel.h:95
otbPCAModel.h
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::PCAModel::TargetListSampleType
Superclass::TargetListSampleType TargetListSampleType
Definition: otbPCAModel.h:87
otb::PCAModel::CanWriteFile
bool CanWriteFile(const std::string &filename) override
Definition: otbPCAModel.hxx:93
otb::PCAModel::ConfidenceListSampleType
Superclass::ConfidenceListSampleType ConfidenceListSampleType
Definition: otbPCAModel.h:92
otb::PCAModel::DoPredictBatch
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *, ConfidenceListSampleType *quality=nullptr, ProbaListSampleType *proba=nullptr) const override
Definition: otbPCAModel.hxx:177
otb::PCAModel::Save
void Save(const std::string &filename, const std::string &name="") override
Definition: otbPCAModel.hxx:99
otb::PCAModel::ProbaSampleType
Superclass::ProbaSampleType ProbaSampleType
Definition: otbPCAModel.h:94
otb::PCAModel::ConfidenceValueType
Superclass::ConfidenceValueType ConfidenceValueType
Definition: otbPCAModel.h:90
otb::PCAModel::DoPredict
virtual TargetSampleType DoPredict(const InputSampleType &input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override
Definition: otbPCAModel.hxx:151
otb::PCAModel::PCAModel
PCAModel()
Definition: otbPCAModel.hxx:53
otb::PCAModel::InputSampleType
Superclass::InputSampleType InputSampleType
Definition: otbPCAModel.h:82
otb::PCAModel::CanReadFile
bool CanReadFile(const std::string &filename) override
Definition: otbPCAModel.hxx:78
otb::PCAModel::Load
void Load(const std::string &filename, const std::string &name="") override
Definition: otbPCAModel.hxx:125
otb::PCAModel::InputListSampleType
Superclass::InputListSampleType InputListSampleType
Definition: otbPCAModel.h:83
otb::PCAModel::Train
void Train() override
Definition: otbPCAModel.hxx:65
otb::PCAModel::~PCAModel
~PCAModel() override
Definition: otbPCAModel.hxx:60