OTB  6.7.0
Orfeo Toolbox
otbFastICAImageFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbFastICAImageFilter_hxx
22 #define otbFastICAImageFilter_hxx
23 
24 #include "otbFastICAImageFilter.h"
25 
26 
27 #include "itkNumericTraits.h"
28 #include "itkProgressReporter.h"
29 
30 #include <vnl/vnl_matrix.h>
31 #include <vnl/algo/vnl_matrix_inverse.h>
32 #include <vnl/algo/vnl_generalized_eigensystem.h>
33 
34 namespace otb
35 {
36 
37 template < class TInputImage, class TOutputImage,
38  Transform::TransformDirection TDirectionOfTransformation >
41 {
42  this->SetNumberOfRequiredInputs(1);
43 
44  m_NumberOfPrincipalComponentsRequired = 0;
45 
46  m_GivenTransformationMatrix = false;
47  m_IsTransformationForward = true;
48 
49  m_NumberOfIterations = 50;
50  m_ConvergenceThreshold = 1E-4;
51 
52  m_NonLinearity = [](double x) {return std::tanh(x);};
53  m_NonLinearityDerivative = [](double x) {return 1-std::pow( std::tanh(x), 2. );};
54 
55  m_Mu = 1.;
56 
57  m_PCAFilter = PCAFilterType::New();
58  m_PCAFilter->SetUseNormalization(true);
59  m_PCAFilter->SetUseVarianceForNormalization(false);
60 
61  m_TransformFilter = TransformFilterType::New();
62 }
63 
64 template < class TInputImage, class TOutputImage,
65  Transform::TransformDirection TDirectionOfTransformation >
66 void
69 // throw itk::ExceptionObject
70 {
71  Superclass::GenerateOutputInformation();
72 
73  switch ( static_cast<int>(DirectionOfTransformation) )
74  {
75  case static_cast<int>(Transform::FORWARD):
76  {
77  if ( m_NumberOfPrincipalComponentsRequired == 0
78  || m_NumberOfPrincipalComponentsRequired
79  > this->GetInput()->GetNumberOfComponentsPerPixel() )
80  {
81  m_NumberOfPrincipalComponentsRequired =
82  this->GetInput()->GetNumberOfComponentsPerPixel();
83  }
84  m_PCAFilter->SetNumberOfPrincipalComponentsRequired(
85  m_NumberOfPrincipalComponentsRequired);
86  this->GetOutput()->SetNumberOfComponentsPerPixel(
87  m_NumberOfPrincipalComponentsRequired );
88  break;
89  }
90  case static_cast<int>(Transform::INVERSE):
91  {
92  unsigned int theOutputDimension = 0;
93  if ( m_GivenTransformationMatrix )
94  {
95  theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ?
96  m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols();
97  }
98  else
99  {
100  throw itk::ExceptionObject(__FILE__, __LINE__,
101  "Mixture matrix is required to know the output size",
102  ITK_LOCATION);
103  }
104 
105  this->GetOutput()->SetNumberOfComponentsPerPixel( theOutputDimension );
106 
107  break;
108  }
109  default:
110  throw itk::ExceptionObject(__FILE__, __LINE__,
111  "Class should be templeted with FORWARD or INVERSE only...",
112  ITK_LOCATION );
113  }
114 
115  switch ( static_cast<int>(DirectionOfTransformation) )
116  {
117  case static_cast<int>(Transform::FORWARD):
118  {
119  ForwardGenerateOutputInformation();
120  break;
121  }
122  case static_cast<int>(Transform::INVERSE):
123  {
124  ReverseGenerateOutputInformation();
125  break;
126  }
127  }
128 }
129 
130 template < class TInputImage, class TOutputImage,
131  Transform::TransformDirection TDirectionOfTransformation >
132 void
135 {
136  typename InputImageType::Pointer inputImgPtr
137  = const_cast<InputImageType*>( this->GetInput() );
138 
139  m_PCAFilter->SetInput( inputImgPtr );
140  m_PCAFilter->GetOutput()->UpdateOutputInformation();
141 
142  if ( !m_GivenTransformationMatrix )
143  {
144  GenerateTransformationMatrix();
145  }
146  else if ( !m_IsTransformationForward )
147  {
148  // prevent from multiple inversion in the pipelines
149  m_IsTransformationForward = true;
150  vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
151  m_TransformationMatrix = invertor.pinverse();
152  }
153 
154  if ( m_TransformationMatrix.GetVnlMatrix().empty() )
155  {
156  throw itk::ExceptionObject( __FILE__, __LINE__,
157  "Empty transformation matrix",
158  ITK_LOCATION);
159  }
160 
161  m_TransformFilter->SetInput( m_PCAFilter->GetOutput() );
162  m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
163 }
164 
165 template < class TInputImage, class TOutputImage,
166  Transform::TransformDirection TDirectionOfTransformation >
167 void
170 {
171  if ( !m_GivenTransformationMatrix )
172  {
173  throw itk::ExceptionObject( __FILE__, __LINE__,
174  "No Transformation matrix given",
175  ITK_LOCATION );
176  }
177 
178  if ( m_TransformationMatrix.GetVnlMatrix().empty() )
179  {
180  throw itk::ExceptionObject( __FILE__, __LINE__,
181  "Empty transformation matrix",
182  ITK_LOCATION);
183  }
184 
185  if ( m_IsTransformationForward )
186  {
187  // prevent from multiple inversion in the pipelines
188  m_IsTransformationForward = false;
189  vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
190  m_TransformationMatrix = invertor.pinverse();
191  }
192 
193  m_TransformFilter->SetInput( this->GetInput() );
194  m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
195 
196  /*
197  * PCA filter may throw exception if
198  * the mean, stdDev and transformation matrix
199  * have not been given at this point
200  */
201  m_PCAFilter->SetInput( m_TransformFilter->GetOutput() );
202 }
203 
204 
205 template < class TInputImage, class TOutputImage,
206  Transform::TransformDirection TDirectionOfTransformation >
207 void
210 {
211  switch ( static_cast<int>(DirectionOfTransformation) )
212  {
213  case static_cast<int>(Transform::FORWARD):
214  return ForwardGenerateData();
215  case static_cast<int>(Transform::INVERSE):
216  return ReverseGenerateData();
217  default:
218  throw itk::ExceptionObject(__FILE__, __LINE__,
219  "Class should be templated with FORWARD or INVERSE only...",
220  ITK_LOCATION );
221  }
222 }
223 
224 template < class TInputImage, class TOutputImage,
225  Transform::TransformDirection TDirectionOfTransformation >
226 void
229 {
230  m_TransformFilter->GraftOutput( this->GetOutput() );
231  m_TransformFilter->Update();
232 
233  this->GraftOutput( m_TransformFilter->GetOutput() );
234 }
235 
236 template < class TInputImage, class TOutputImage,
237  Transform::TransformDirection TDirectionOfTransformation >
238 void
241 {
242  m_PCAFilter->GraftOutput( this->GetOutput() );
243  m_PCAFilter->Update();
244  this->GraftOutput( m_PCAFilter->GetOutput() );
245 }
246 
247 template < class TInputImage, class TOutputImage,
248  Transform::TransformDirection TDirectionOfTransformation >
249 void
252 {
253  itk::ProgressReporter reporter ( this, 0, GetNumberOfIterations(), GetNumberOfIterations() );
254 
255  double convergence = itk::NumericTraits<double>::max();
256  unsigned int iteration = 0;
257 
258  const unsigned int size = this->GetNumberOfPrincipalComponentsRequired();
259 
260  // transformation matrix
261  InternalMatrixType W ( size, size, vnl_matrix_identity );
262 
263  while ( iteration++ < GetNumberOfIterations()
264  && convergence > GetConvergenceThreshold() )
265  {
266  InternalMatrixType W_old ( W );
267 
268  typename InputImageType::Pointer img = const_cast<InputImageType*>( m_PCAFilter->GetOutput() );
269  TransformFilterPointerType transformer = TransformFilterType::New();
270  if ( !W.is_identity() )
271  {
272  transformer->SetInput( GetPCAFilter()->GetOutput() );
273  transformer->SetMatrix( W );
274  transformer->Update();
275  img = const_cast<InputImageType*>( transformer->GetOutput() );
276  }
277 
278  for ( unsigned int band = 0; band < size; band++ )
279  {
280  otbMsgDebugMacro( << "Iteration " << iteration << ", bande " << band
281  << ", convergence " << convergence );
282 
283  InternalOptimizerPointerType optimizer = InternalOptimizerType::New();
284  optimizer->SetInput( 0, m_PCAFilter->GetOutput() );
285  optimizer->SetInput( 1, img );
286  optimizer->SetW( W );
287  optimizer->SetNonLinearity( this->GetNonLinearity(),
288  this->GetNonLinearityDerivative() );
289  optimizer->SetCurrentBandForLoop( band );
290 
291  MeanEstimatorFilterPointerType estimator = MeanEstimatorFilterType::New();
292  estimator->SetInput( optimizer->GetOutput() );
293 
294  // Here we have a pipeline of two persistent filters, we have to manually
295  // call Reset() and Synthetize () on the first one (optimizer).
296  optimizer->Reset();
297  estimator->Update();
298  optimizer->Synthetize();
299 
300  double norm = 0.;
301  for ( unsigned int bd = 0; bd < size; bd++ )
302  {
303  W(bd, band) -= m_Mu * ( estimator->GetMean()[bd]
304  - optimizer->GetBeta() * W(bd, band) )
305  / optimizer->GetDen();
306  norm += std::pow( W(bd, band), 2. );
307  }
308  for ( unsigned int bd = 0; bd < size; bd++ )
309  W(bd, band) /= std::sqrt( norm );
310  }
311 
312  // Decorrelation of the W vectors
313  InternalMatrixType W_tmp = W * W.transpose();
314  vnl_svd< MatrixElementType > solver ( W_tmp );
315  InternalMatrixType valP = solver.W();
316  for ( unsigned int i = 0; i < valP.rows(); ++i )
317  valP(i, i) = 1. / std::sqrt( static_cast<double>( valP(i, i) ) ); // Watch for 0 or neg
318  InternalMatrixType transf = solver.U();
319  W_tmp = transf * valP * transf.transpose();
320  W = W_tmp * W;
321 
322  // Convergence evaluation
323  convergence = 0.;
324  for ( unsigned int i = 0; i < W.rows(); ++i )
325  for ( unsigned int j = 0; j < W.cols(); ++j )
326  convergence += std::abs( W(i, j) - W_old(i, j) );
327 
328  reporter.CompletedPixel();
329  } // end of while loop
330 
331  this->m_TransformationMatrix = W;
332 
333  otbMsgDebugMacro( << "Final convergence " << convergence
334  << " after " << iteration << " iterations" );
335 }
336 
337 } // end of namespace otb
338 
339 #endif
340 
341 
virtual void GenerateTransformationMatrix()
void GenerateOutputInformation() override
static ITK_CONSTEXPR_FUNC T max(const T &)
#define otbMsgDebugMacro(x)
Definition: otbMacro.h:64
TInputImage InputImageType
MatrixType::InternalMatrixType InternalMatrixType