OTB  9.0.0
Orfeo Toolbox
otbFastICAImageFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbFastICAImageFilter_hxx
22 #define otbFastICAImageFilter_hxx
23 
24 #include "otbFastICAImageFilter.h"
25 
26 
27 #include "itkNumericTraits.h"
28 #include "itkProgressReporter.h"
29 
30 #include <vnl/vnl_matrix.h>
31 #include <vnl/algo/vnl_matrix_inverse.h>
32 #include <vnl/algo/vnl_generalized_eigensystem.h>
33 
34 namespace otb
35 {
36 
37 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
39 {
40  this->SetNumberOfRequiredInputs(1);
41 
42  m_NumberOfPrincipalComponentsRequired = 0;
43 
44  m_GivenTransformationMatrix = false;
45  m_IsTransformationForward = true;
46 
47  m_NumberOfIterations = 50;
48  m_ConvergenceThreshold = 1E-4;
49 
50  m_NonLinearity = [](double x) { return std::tanh(x); };
51  m_NonLinearityDerivative = [](double x) { return 1 - std::pow(std::tanh(x), 2.); };
52 
53  m_Mu = 1.;
54 
55  m_PCAFilter = PCAFilterType::New();
56  m_PCAFilter->SetUseNormalization(true);
57  m_PCAFilter->SetUseVarianceForNormalization(false);
58 
59  m_TransformFilter = TransformFilterType::New();
60 }
61 
62 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
64 // throw itk::ExceptionObject
65 {
66  Superclass::GenerateOutputInformation();
67 
68  switch (static_cast<int>(DirectionOfTransformation))
69  {
70  case static_cast<int>(Transform::FORWARD):
71  {
72  if (m_NumberOfPrincipalComponentsRequired == 0 || m_NumberOfPrincipalComponentsRequired > this->GetInput()->GetNumberOfComponentsPerPixel())
73  {
74  m_NumberOfPrincipalComponentsRequired = this->GetInput()->GetNumberOfComponentsPerPixel();
75  }
76  m_PCAFilter->SetNumberOfPrincipalComponentsRequired(m_NumberOfPrincipalComponentsRequired);
77  this->GetOutput()->SetNumberOfComponentsPerPixel(m_NumberOfPrincipalComponentsRequired);
78  break;
79  }
80  case static_cast<int>(Transform::INVERSE):
81  {
82  unsigned int theOutputDimension = 0;
83  if (m_GivenTransformationMatrix)
84  {
85  const auto & pcaMatrix = m_PCAFilter->GetTransformationMatrix();
86  theOutputDimension = pcaMatrix.Rows() >= pcaMatrix.Cols() ? pcaMatrix.Rows() : pcaMatrix.Cols();
87  }
88  else
89  {
90  throw itk::ExceptionObject(__FILE__, __LINE__, "Mixture matrix is required to know the output size", ITK_LOCATION);
91  }
92 
93  this->GetOutput()->SetNumberOfComponentsPerPixel(theOutputDimension);
94 
95  break;
96  }
97  default:
98  throw itk::ExceptionObject(__FILE__, __LINE__, "Class should be templeted with FORWARD or INVERSE only...", ITK_LOCATION);
99  }
100 
101  switch (static_cast<int>(DirectionOfTransformation))
102  {
103  case static_cast<int>(Transform::FORWARD):
104  {
105  ForwardGenerateOutputInformation();
106  break;
107  }
108  case static_cast<int>(Transform::INVERSE):
109  {
110  ReverseGenerateOutputInformation();
111  break;
112  }
113  }
114 }
115 
116 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
118 {
119  typename InputImageType::Pointer inputImgPtr = const_cast<InputImageType*>(this->GetInput());
120 
121  m_PCAFilter->SetInput(inputImgPtr);
122  m_PCAFilter->GetOutput()->UpdateOutputInformation();
123 
124  if (!m_GivenTransformationMatrix)
125  {
126  GenerateTransformationMatrix();
127  }
128  else if (!m_IsTransformationForward)
129  {
130  // prevent from multiple inversion in the pipelines
131  m_IsTransformationForward = true;
132  vnl_svd<MatrixElementType> invertor(m_TransformationMatrix.GetVnlMatrix());
133  m_TransformationMatrix = invertor.pinverse();
134  }
135 
136  if (m_TransformationMatrix.GetVnlMatrix().empty())
137  {
138  throw itk::ExceptionObject(__FILE__, __LINE__, "Empty transformation matrix", ITK_LOCATION);
139  }
140 
141  m_TransformFilter->SetInput(m_PCAFilter->GetOutput());
142  m_TransformFilter->SetMatrix(m_TransformationMatrix.GetVnlMatrix());
143 }
144 
145 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
147 {
148  if (!m_GivenTransformationMatrix)
149  {
150  throw itk::ExceptionObject(__FILE__, __LINE__, "No Transformation matrix given", ITK_LOCATION);
151  }
152 
153  if (m_TransformationMatrix.GetVnlMatrix().empty())
154  {
155  throw itk::ExceptionObject(__FILE__, __LINE__, "Empty transformation matrix", ITK_LOCATION);
156  }
157 
158  if (m_IsTransformationForward)
159  {
160  // prevent from multiple inversion in the pipelines
161  m_IsTransformationForward = false;
162  vnl_svd<MatrixElementType> invertor(m_TransformationMatrix.GetVnlMatrix());
163  m_TransformationMatrix = invertor.pinverse();
164  }
165 
166  m_TransformFilter->SetInput(this->GetInput());
167  m_TransformFilter->SetMatrix(m_TransformationMatrix.GetVnlMatrix());
168 
169  /*
170  * PCA filter may throw exception if
171  * the mean, stdDev and transformation matrix
172  * have not been given at this point
173  */
174  m_PCAFilter->SetInput(m_TransformFilter->GetOutput());
175 }
176 
177 
178 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
180 {
181  switch (static_cast<int>(DirectionOfTransformation))
182  {
183  case static_cast<int>(Transform::FORWARD):
184  return ForwardGenerateData();
185  case static_cast<int>(Transform::INVERSE):
186  return ReverseGenerateData();
187  default:
188  throw itk::ExceptionObject(__FILE__, __LINE__, "Class should be templated with FORWARD or INVERSE only...", ITK_LOCATION);
189  }
190 }
191 
192 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
194 {
195  m_TransformFilter->GraftOutput(this->GetOutput());
196  m_TransformFilter->Update();
197 
198  this->GraftOutput(m_TransformFilter->GetOutput());
199 }
200 
201 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
203 {
204  m_PCAFilter->GraftOutput(this->GetOutput());
205  m_PCAFilter->Update();
206  this->GraftOutput(m_PCAFilter->GetOutput());
207 }
208 
209 template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation>
211 {
212  itk::ProgressReporter reporter(this, 0, GetNumberOfIterations(), GetNumberOfIterations());
213 
214  double convergence = itk::NumericTraits<double>::max();
215  unsigned int iteration = 0;
216 
217  const unsigned int size = this->GetNumberOfPrincipalComponentsRequired();
218 
219  // transformation matrix
220  InternalMatrixType W(size, size, vnl_matrix_identity);
221 
222  while (iteration++ < GetNumberOfIterations() && convergence > GetConvergenceThreshold())
223  {
224  InternalMatrixType W_old(W);
225 
226  typename InputImageType::Pointer img = const_cast<InputImageType*>(m_PCAFilter->GetOutput());
227  TransformFilterPointerType transformer = TransformFilterType::New();
228  if (!W.is_identity())
229  {
230  transformer->SetInput(GetPCAFilter()->GetOutput());
231  transformer->SetMatrix(W);
232  transformer->Update();
233  img = const_cast<InputImageType*>(transformer->GetOutput());
234  }
235 
236  for (unsigned int band = 0; band < size; band++)
237  {
238  otbMsgDebugMacro(<< "Iteration " << iteration << ", bande " << band << ", convergence " << convergence);
239 
240  InternalOptimizerPointerType optimizer = InternalOptimizerType::New();
241  optimizer->SetInput(0, m_PCAFilter->GetOutput());
242  optimizer->SetInput(1, img);
243  optimizer->SetW(W);
244  optimizer->SetNonLinearity(this->GetNonLinearity(), this->GetNonLinearityDerivative());
245  optimizer->SetCurrentBandForLoop(band);
246 
247  MeanEstimatorFilterPointerType estimator = MeanEstimatorFilterType::New();
248  estimator->SetInput(optimizer->GetOutput());
249 
250  // Here we have a pipeline of two persistent filters, we have to manually
251  // call Reset() and Synthetize () on the first one (optimizer).
252  optimizer->Reset();
253  estimator->Update();
254  optimizer->Synthetize();
255 
256  double norm = 0.;
257  for (unsigned int bd = 0; bd < size; bd++)
258  {
259  W(bd, band) -= m_Mu * (estimator->GetMean()[bd] - optimizer->GetBeta() * W(bd, band)) / optimizer->GetDen();
260  norm += std::pow(W(bd, band), 2.);
261  }
262  for (unsigned int bd = 0; bd < size; bd++)
263  W(bd, band) /= std::sqrt(norm);
264  }
265 
266  // Decorrelation of the W vectors
267  InternalMatrixType W_tmp = W * W.transpose();
268  vnl_svd<MatrixElementType> solver(W_tmp);
269  InternalMatrixType valP = solver.W();
270  for (unsigned int i = 0; i < valP.rows(); ++i)
271  valP(i, i) = 1. / std::sqrt(static_cast<double>(valP(i, i))); // Watch for 0 or neg
272  InternalMatrixType transf = solver.U();
273  W_tmp = transf * valP * transf.transpose();
274  W = W_tmp * W;
275 
276  // Convergence evaluation
277  convergence = 0.;
278  for (unsigned int i = 0; i < W.rows(); ++i)
279  for (unsigned int j = 0; j < W.cols(); ++j)
280  convergence += std::abs(W(i, j) - W_old(i, j));
281 
282  reporter.CompletedPixel();
283  } // end of while loop
284 
285  this->m_TransformationMatrix = W;
286 
287  otbMsgDebugMacro(<< "Final convergence " << convergence << " after " << iteration << " iterations");
288 }
289 
290 } // end of namespace otb
291 
292 #endif
otb::FastICAImageFilter::ForwardGenerateData
virtual void ForwardGenerateData()
Definition: otbFastICAImageFilter.hxx:193
otb::FastICAImageFilter::InputImageType
TInputImage InputImageType
Definition: otbFastICAImageFilter.h:71
otb::FastICAImageFilter::InternalOptimizerPointerType
InternalOptimizerType::Pointer InternalOptimizerPointerType
Definition: otbFastICAImageFilter.h:87
otb::FastICAImageFilter::GenerateTransformationMatrix
virtual void GenerateTransformationMatrix()
Definition: otbFastICAImageFilter.hxx:210
otb::FastICAImageFilter::GenerateOutputInformation
void GenerateOutputInformation() override
Definition: otbFastICAImageFilter.hxx:63
otb::FastICAImageFilter::ReverseGenerateOutputInformation
void ReverseGenerateOutputInformation()
Definition: otbFastICAImageFilter.hxx:146
otb::FastICAImageFilter::ReverseGenerateData
virtual void ReverseGenerateData()
Definition: otbFastICAImageFilter.hxx:202
otb::FastICAImageFilter::InternalMatrixType
MatrixType::InternalMatrixType InternalMatrixType
Definition: otbFastICAImageFilter.h:80
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::FastICAImageFilter::TransformFilterPointerType
TransformFilterType::Pointer TransformFilterPointerType
Definition: otbFastICAImageFilter.h:84
otb::Transform::FORWARD
@ FORWARD
Definition: otbPCAImageFilter.h:36
otb::FastICAImageFilter::MeanEstimatorFilterPointerType
MeanEstimatorFilterType::Pointer MeanEstimatorFilterPointerType
Definition: otbFastICAImageFilter.h:90
otb::FastICAImageFilter::FastICAImageFilter
FastICAImageFilter()
Definition: otbFastICAImageFilter.hxx:38
otbFastICAImageFilter.h
otb::FastICAImageFilter::GenerateData
void GenerateData() override
Definition: otbFastICAImageFilter.hxx:179
otb::Transform::INVERSE
@ INVERSE
Definition: otbPCAImageFilter.h:37
otbMsgDebugMacro
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:62
otb::FastICAImageFilter::ForwardGenerateOutputInformation
void ForwardGenerateOutputInformation()
Definition: otbFastICAImageFilter.hxx:117