Orfeo Toolbox  4.0
otbFastICAImageFilter.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 #ifndef __otbFastICAImageFilter_txx
19 #define __otbFastICAImageFilter_txx
20 
21 #include "otbFastICAImageFilter.h"
22 
23 #include "otbMacro.h"
24 
25 
26 #include "itkNumericTraits.h"
27 #include "itkProgressReporter.h"
28 
29 #include <vnl/vnl_matrix.h>
30 #include <vnl/algo/vnl_matrix_inverse.h>
31 #include <vnl/algo/vnl_generalized_eigensystem.h>
32 
33 namespace otb
34 {
35 
36 template < class TInputImage, class TOutputImage,
37  Transform::TransformDirection TDirectionOfTransformation >
40 {
41  this->SetNumberOfRequiredInputs(1);
42 
43  m_NumberOfPrincipalComponentsRequired = 0;
44 
45  m_GivenTransformationMatrix = false;
46  m_IsTransformationForward = true;
47 
48  m_NumberOfIterations = 50;
49  m_ConvergenceThreshold = 1E-4;
50  m_ContrastFunction = &vcl_tanh;
51  m_Mu = 1.;
52 
53  m_PCAFilter = PCAFilterType::New();
54  m_PCAFilter->SetUseNormalization(true);
55 
56  m_TransformFilter = TransformFilterType::New();
57 }
58 
59 template < class TInputImage, class TOutputImage,
60  Transform::TransformDirection TDirectionOfTransformation >
61 void
64 // throw itk::ExceptionObject
65 {
66  Superclass::GenerateOutputInformation();
67 
68  switch ( static_cast<int>(DirectionOfTransformation) )
69  {
70  case static_cast<int>(Transform::FORWARD):
71  {
72  if ( m_NumberOfPrincipalComponentsRequired == 0
73  || m_NumberOfPrincipalComponentsRequired
74  > this->GetInput()->GetNumberOfComponentsPerPixel() )
75  {
76  m_NumberOfPrincipalComponentsRequired =
77  this->GetInput()->GetNumberOfComponentsPerPixel();
78  }
79 
80  this->GetOutput()->SetNumberOfComponentsPerPixel(
81  m_NumberOfPrincipalComponentsRequired );
82  break;
83  }
84  case static_cast<int>(Transform::INVERSE):
85  {
86  unsigned int theOutputDimension = 0;
87  if ( m_GivenTransformationMatrix )
88  {
89  theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ?
90  m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols();
91  }
92  else
93  {
94  throw itk::ExceptionObject(__FILE__, __LINE__,
95  "Mixture matrix is required to know the output size",
96  ITK_LOCATION);
97  }
98 
99  this->GetOutput()->SetNumberOfComponentsPerPixel( theOutputDimension );
100 
101  break;
102  }
103  default:
104  throw itk::ExceptionObject(__FILE__, __LINE__,
105  "Class should be templeted with FORWARD or INVERSE only...",
106  ITK_LOCATION );
107  }
108 }
109 
110 template < class TInputImage, class TOutputImage,
111  Transform::TransformDirection TDirectionOfTransformation >
112 void
115 {
116  switch ( static_cast<int>(DirectionOfTransformation) )
117  {
118  case static_cast<int>(Transform::FORWARD):
119  return ForwardGenerateData();
120  case static_cast<int>(Transform::INVERSE):
121  return ReverseGenerateData();
122  default:
123  throw itk::ExceptionObject(__FILE__, __LINE__,
124  "Class should be templated with FORWARD or INVERSE only...",
125  ITK_LOCATION );
126  }
127 }
128 
129 template < class TInputImage, class TOutputImage,
130  Transform::TransformDirection TDirectionOfTransformation >
131 void
134 {
135  typename InputImageType::Pointer inputImgPtr
136  = const_cast<InputImageType*>( this->GetInput() );
137 
138  m_PCAFilter->SetInput( inputImgPtr );
139  m_PCAFilter->Update();
140 
141  if ( !m_GivenTransformationMatrix )
142  {
143  GenerateTransformationMatrix();
144  }
145  else if ( !m_IsTransformationForward )
146  {
147  // prevent from multiple inversion in the pipelines
148  m_IsTransformationForward = true;
149  vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
150  m_TransformationMatrix = invertor.pinverse();
151  }
152 
153  if ( m_TransformationMatrix.GetVnlMatrix().empty() )
154  {
155  throw itk::ExceptionObject( __FILE__, __LINE__,
156  "Empty transformation matrix",
157  ITK_LOCATION);
158  }
159 
160  m_TransformFilter->SetInput( m_PCAFilter->GetOutput() );
161  m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
162  m_TransformFilter->GraftOutput( this->GetOutput() );
163  m_TransformFilter->Update();
164 
165  this->GraftOutput( m_TransformFilter->GetOutput() );
166 }
167 
168 template < class TInputImage, class TOutputImage,
169  Transform::TransformDirection TDirectionOfTransformation >
170 void
173 {
174  if ( !m_GivenTransformationMatrix )
175  {
176  throw itk::ExceptionObject( __FILE__, __LINE__,
177  "No Transformation matrix given",
178  ITK_LOCATION );
179  }
180 
181  if ( m_TransformationMatrix.GetVnlMatrix().empty() )
182  {
183  throw itk::ExceptionObject( __FILE__, __LINE__,
184  "Empty transformation matrix",
185  ITK_LOCATION);
186  }
187 
188  if ( m_IsTransformationForward )
189  {
190  // prevent from multiple inversion in the pipelines
191  m_IsTransformationForward = false;
192  vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
193  m_TransformationMatrix = invertor.pinverse();
194  }
195 
196  m_TransformFilter->SetInput( this->GetInput() );
197  m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
198 
199  /*
200  * PCA filter may throw exception if
201  * the mean, stdDev and transformation matrix
202  * have not been given at this point
203  */
204  m_PCAFilter->SetInput( m_TransformFilter->GetOutput() );
205  m_PCAFilter->GraftOutput( this->GetOutput() );
206  m_PCAFilter->Update();
207 
208  this->GraftOutput( m_PCAFilter->GetOutput() );
209 }
210 
211 template < class TInputImage, class TOutputImage,
212  Transform::TransformDirection TDirectionOfTransformation >
213 void
216 {
217  itk::ProgressReporter reporter ( this, 0, GetNumberOfIterations(), GetNumberOfIterations() );
218 
219  double convergence = itk::NumericTraits<double>::max();
220  unsigned int iteration = 0;
221 
222  const unsigned int size = this->GetInput()->GetNumberOfComponentsPerPixel();
223 
224  // transformation matrix
225  InternalMatrixType W ( size, size, vnl_matrix_identity );
226 
227  while ( iteration++ < GetNumberOfIterations()
228  && convergence > GetConvergenceThreshold() )
229  {
230  InternalMatrixType W_old ( W );
231 
232  typename InputImageType::Pointer img = const_cast<InputImageType*>( this->GetInput() );
233  TransformFilterPointerType transformer = TransformFilterType::New();
234  if ( !W.is_identity() )
235  {
236  transformer->SetInput( GetPCAFilter()->GetOutput() );
237  transformer->SetMatrix( W );
238  transformer->Update();
239  img = const_cast<InputImageType*>( transformer->GetOutput() );
240  }
241 
242  for ( unsigned int band = 0; band < size; band++ )
243  {
244  otbMsgDebugMacro( << "Iteration " << iteration << ", bande " << band
245  << ", convergence " << convergence );
246 
247  InternalOptimizerPointerType optimizer = InternalOptimizerType::New();
248  optimizer->SetInput( 0, m_PCAFilter->GetOutput() );
249  optimizer->SetInput( 1, img );
250  optimizer->SetW( W );
251  optimizer->SetContrastFunction( this->GetContrastFunction() );
252  optimizer->SetCurrentBandForLoop( band );
253 
254  MeanEstimatorFilterPointerType estimator = MeanEstimatorFilterType::New();
255  estimator->SetInput( optimizer->GetOutput() );
256  estimator->Update();
257 
258  double norm = 0.;
259  for ( unsigned int bd = 0; bd < size; bd++ )
260  {
261  W(band, bd) -= m_Mu * ( estimator->GetMean()[bd]
262  - optimizer->GetBeta() * W(band, bd) / optimizer->GetDen() );
263  norm += vcl_pow( W(band, bd), 2. );
264  }
265  for ( unsigned int bd = 0; bd < size; bd++ )
266  W(band, bd) /= vcl_sqrt( norm );
267  }
268 
269  // Decorrelation of the W vectors
270  InternalMatrixType W_tmp = W * W.transpose();
271  vnl_svd< MatrixElementType > solver ( W_tmp );
272  InternalMatrixType valP = solver.W();
273  for ( unsigned int i = 0; i < valP.rows(); ++i )
274  valP(i, i) = 1. / vcl_sqrt( static_cast<double>( valP(i, i) ) ); // Watch for 0 or neg
275  InternalMatrixType transf = solver.U();
276  W_tmp = transf * valP * transf.transpose();
277  W = W_tmp * W;
278 
279  // Convergence evaluation
280  convergence = 0.;
281  for ( unsigned int i = 0; i < W.rows(); ++i )
282  for ( unsigned int j = 0; j < W.cols(); ++j )
283  convergence += vcl_abs( W(i, j) - W_old(i, j) );
284 
285  reporter.CompletedPixel();
286  } // end of while loop
287 
288  if ( size != this->GetNumberOfPrincipalComponentsRequired() )
289  {
290  this->m_TransformationMatrix = W.get_n_rows( 0, this->GetNumberOfPrincipalComponentsRequired() );
291  }
292  else
293  {
294  this->m_TransformationMatrix = W;
295  }
296 
297  otbMsgDebugMacro( << "Final convergence " << convergence
298  << " after " << iteration << " iterations" );
299 }
300 
301 } // end of namespace otb
302 
303 #endif
304 
305 

Generated at Sat Mar 8 2014 15:55:43 for Orfeo Toolbox with doxygen 1.8.3.1