OTB  6.7.0
Orfeo Toolbox
otbAutoencoderModel.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 #ifndef otbAutoencoderModel_h
21 #define otbAutoencoderModel_h
22 
25 #include <string>
26 
27 #if defined(__GNUC__) || defined(__clang__)
28 #pragma GCC diagnostic push
29 #pragma GCC diagnostic ignored "-Wshadow"
30 #pragma GCC diagnostic ignored "-Wunused-parameter"
31 #pragma GCC diagnostic ignored "-Woverloaded-virtual"
32 #pragma GCC diagnostic ignored "-Wsign-compare"
33 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
34 #if defined(__clang__)
35 #pragma clang diagnostic ignored "-Wheader-guard"
36 #pragma clang diagnostic ignored "-Wdivision-by-zero"
37 #pragma clang diagnostic ignored "-Wexpansion-to-defined"
38 #else
39 #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
40 #endif
41 #endif
42 #include "otb_shark.h"
43 #include <shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h>
44 #include <shark/Models/LinearModel.h>
45 #include <shark/Models/ConcatenatedModel.h>
46 #include <shark/Models/NeuronLayers.h>
47 #if defined(__GNUC__) || defined(__clang__)
48 #pragma GCC diagnostic pop
49 #endif
50 
51 namespace otb
52 {
60 template <class TInputValue, class NeuronType>
61 class ITK_EXPORT AutoencoderModel
62  : public MachineLearningModel<
63  itk::VariableLengthVector< TInputValue>,
64  itk::VariableLengthVector< TInputValue> >
65 {
66 public:
68  typedef MachineLearningModel<
73 
74  typedef typename Superclass::InputValueType InputValueType;
75  typedef typename Superclass::InputSampleType InputSampleType;
76  typedef typename Superclass::InputListSampleType InputListSampleType;
77  typedef typename InputListSampleType::Pointer ListSamplePointerType;
78  typedef typename Superclass::TargetValueType TargetValueType;
80  typedef typename Superclass::TargetListSampleType TargetListSampleType;
81 
83  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
84  typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
85  typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
86 
87  typedef typename Superclass::ProbaSampleType ProbaSampleType;
88  typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
90  typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
91  typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
92  typedef shark::LinearModel<shark::RealVector, shark::LinearNeuron> OutLayerType;
93 
94  itkNewMacro(Self);
95  itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);
96 
97  itkGetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);
98  itkSetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);
99 
100  itkGetMacro(NumberOfIterations,unsigned int);
101  itkSetMacro(NumberOfIterations,unsigned int);
102 
103  itkGetMacro(NumberOfIterationsFineTuning,unsigned int);
104  itkSetMacro(NumberOfIterationsFineTuning,unsigned int);
105 
106  itkGetMacro(Epsilon,double);
107  itkSetMacro(Epsilon,double);
108 
109  itkGetMacro(InitFactor,double);
110  itkSetMacro(InitFactor,double);
111 
112  itkGetMacro(Regularization,itk::Array<double>);
113  itkSetMacro(Regularization,itk::Array<double>);
114 
115  itkGetMacro(Noise,itk::Array<double>);
116  itkSetMacro(Noise,itk::Array<double>);
117 
118  itkGetMacro(Rho,itk::Array<double>);
119  itkSetMacro(Rho,itk::Array<double>);
120 
121  itkGetMacro(Beta,itk::Array<double>);
122  itkSetMacro(Beta,itk::Array<double>);
123 
124  itkGetMacro(WriteLearningCurve,bool);
125  itkSetMacro(WriteLearningCurve,bool);
126 
127  itkSetMacro(WriteWeights, bool);
128  itkGetMacro(WriteWeights, bool);
129 
130  itkGetMacro(LearningCurveFileName,std::string);
131  itkSetMacro(LearningCurveFileName,std::string);
132 
133  bool CanReadFile(const std::string & filename) override;
134  bool CanWriteFile(const std::string & filename) override;
135 
136  void Save(const std::string & filename, const std::string & name="") override;
137  void Load(const std::string & filename, const std::string & name="") override;
138 
139  void Train() override;
140 
141  template <class T>
142  void TrainOneLayer(
143  shark::AbstractStoppingCriterion<T> & criterion,
144  unsigned int,
145  shark::Data<shark::RealVector> &,
146  std::ostream&);
147 
148  template <class T>
149  void TrainOneSparseLayer(
150  shark::AbstractStoppingCriterion<T> & criterion,
151  unsigned int,
152  shark::Data<shark::RealVector> &,
153  std::ostream&);
154 
155  template <class T>
156  void TrainNetwork(
157  shark::AbstractStoppingCriterion<T> & criterion,
158  shark::Data<shark::RealVector> &,
159  std::ostream&);
160 
161 protected:
163  ~AutoencoderModel() override;
164 
165  virtual TargetSampleType DoPredict(
166  const InputSampleType& input,
167  ConfidenceValueType * quality = nullptr,
168  ProbaSampleType * proba = nullptr) const override;
169 
170  virtual void DoPredictBatch(
171  const InputListSampleType *,
172  const unsigned int & startIndex,
173  const unsigned int & size,
175  ConfidenceListSampleType * quality = nullptr,
176  ProbaListSampleType * proba = nullptr) const override;
177 
178 private:
181  std::vector<LayerType> m_InLayers;
184 
186  unsigned int m_NumberOfIterations; // stop the training after a fixed number of iterations
187  unsigned int m_NumberOfIterationsFineTuning; // stop the fine tuning after a fixed number of iterations
188  double m_Epsilon; // Stops the training when the training error seems to converge
189  itk::Array<double> m_Regularization; // L2 Regularization parameter
190  itk::Array<double> m_Noise; // probability for an input to be set to 0 (denosing autoencoder)
191  itk::Array<double> m_Rho; // Sparsity parameter
192  itk::Array<double> m_Beta; // Sparsity regularization parameter
193  double m_InitFactor; // Weight initialization factor (the weights are intialized at m_initfactor/sqrt(inputDimension) )
195 
196  bool m_WriteLearningCurve; // Flag for writing the learning curve into a txt file
197  std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
199 };
200 } // end namespace otb
201 
202 #ifndef OTB_MANUAL_INSTANTIATION
203 #include "otbAutoencoderModel.hxx"
204 #endif
205 
206 #endif
207 
Superclass::ConfidenceValueType ConfidenceValueType
Confidence map related typedefs.
shark::LinearModel< shark::RealVector, shark::LinearNeuron > OutLayerType
itk::Array< double > m_Noise
Superclass::TargetListSampleType TargetListSampleType
MachineLearningModel is the base class for all classifier objects (SVM, KNN, Random Forests...
Superclass::TargetSampleType TargetSampleType
itk::Array< double > m_Regularization
Superclass::ConfidenceListSampleType ConfidenceListSampleType
std::vector< LayerType > m_InLayers
unsigned int m_NumberOfIterationsFineTuning
Superclass::ProbaSampleType ProbaSampleType
itk::Array< double > m_Beta
Superclass::ConfidenceSampleType ConfidenceSampleType
Superclass::ProbaListSampleType ProbaListSampleType
MachineLearningModel< itk::VariableLengthVector< TInputValue >, itk::VariableLengthVector< TInputValue > > Superclass
shark::ConcatenatedModel< shark::RealVector > ModelType
Neural network related typedefs.
Superclass::InputListSampleType InputListSampleType
Superclass::TargetValueType TargetValueType
itk::Array< unsigned int > m_NumberOfHiddenNeurons
itk::SmartPointer< const Self > ConstPointer
itk::SmartPointer< Self > Pointer
InputListSampleType::Pointer ListSamplePointerType
Superclass::InputSampleType InputSampleType
Superclass::InputValueType InputValueType
shark::LinearModel< shark::RealVector, NeuronType > LayerType
itk::Array< double > m_Rho
MLMTargetTraits< TTargetValue >::SampleType TargetSampleType