18 #ifndef __otbFastICAImageFilter_txx
19 #define __otbFastICAImageFilter_txx
26 #include "itkNumericTraits.h"
29 #include <vnl/vnl_matrix.h>
30 #include <vnl/algo/vnl_matrix_inverse.h>
31 #include <vnl/algo/vnl_generalized_eigensystem.h>
36 template <
class TInputImage,
class TOutputImage,
41 this->SetNumberOfRequiredInputs(1);
43 m_NumberOfPrincipalComponentsRequired = 0;
45 m_GivenTransformationMatrix =
false;
46 m_IsTransformationForward =
true;
48 m_NumberOfIterations = 50;
49 m_ConvergenceThreshold = 1E-4;
50 m_ContrastFunction = &vcl_tanh;
53 m_PCAFilter = PCAFilterType::New();
54 m_PCAFilter->SetUseNormalization(
true);
56 m_TransformFilter = TransformFilterType::New();
59 template <
class TInputImage,
class TOutputImage,
66 Superclass::GenerateOutputInformation();
68 switch ( static_cast<int>(DirectionOfTransformation) )
72 if ( m_NumberOfPrincipalComponentsRequired == 0
73 || m_NumberOfPrincipalComponentsRequired
74 > this->GetInput()->GetNumberOfComponentsPerPixel() )
76 m_NumberOfPrincipalComponentsRequired =
77 this->GetInput()->GetNumberOfComponentsPerPixel();
80 this->GetOutput()->SetNumberOfComponentsPerPixel(
81 m_NumberOfPrincipalComponentsRequired );
86 unsigned int theOutputDimension = 0;
87 if ( m_GivenTransformationMatrix )
89 theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ?
90 m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols();
95 "Mixture matrix is required to know the output size",
99 this->GetOutput()->SetNumberOfComponentsPerPixel( theOutputDimension );
105 "Class should be templeted with FORWARD or INVERSE only...",
110 template <
class TInputImage,
class TOutputImage,
116 switch ( static_cast<int>(DirectionOfTransformation) )
119 return ForwardGenerateData();
121 return ReverseGenerateData();
124 "Class should be templated with FORWARD or INVERSE only...",
129 template <
class TInputImage,
class TOutputImage,
135 typename InputImageType::Pointer inputImgPtr
138 m_PCAFilter->SetInput( inputImgPtr );
139 m_PCAFilter->Update();
141 if ( !m_GivenTransformationMatrix )
143 GenerateTransformationMatrix();
145 else if ( !m_IsTransformationForward )
148 m_IsTransformationForward =
true;
149 vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
150 m_TransformationMatrix = invertor.pinverse();
153 if ( m_TransformationMatrix.GetVnlMatrix().empty() )
156 "Empty transformation matrix",
160 m_TransformFilter->SetInput( m_PCAFilter->GetOutput() );
161 m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
162 m_TransformFilter->GraftOutput( this->GetOutput() );
163 m_TransformFilter->Update();
165 this->GraftOutput( m_TransformFilter->GetOutput() );
168 template <
class TInputImage,
class TOutputImage,
174 if ( !m_GivenTransformationMatrix )
177 "No Transformation matrix given",
181 if ( m_TransformationMatrix.GetVnlMatrix().empty() )
184 "Empty transformation matrix",
188 if ( m_IsTransformationForward )
191 m_IsTransformationForward =
false;
192 vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
193 m_TransformationMatrix = invertor.pinverse();
196 m_TransformFilter->SetInput( this->GetInput() );
197 m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
204 m_PCAFilter->SetInput( m_TransformFilter->GetOutput() );
205 m_PCAFilter->GraftOutput( this->GetOutput() );
206 m_PCAFilter->Update();
208 this->GraftOutput( m_PCAFilter->GetOutput() );
211 template <
class TInputImage,
class TOutputImage,
219 double convergence = itk::NumericTraits<double>::max();
220 unsigned int iteration = 0;
222 const unsigned int size = this->GetInput()->GetNumberOfComponentsPerPixel();
227 while ( iteration++ < GetNumberOfIterations()
228 && convergence > GetConvergenceThreshold() )
232 typename InputImageType::Pointer img =
const_cast<InputImageType*
>( this->GetInput() );
234 if ( !W.is_identity() )
236 transformer->SetInput( GetPCAFilter()->GetOutput() );
237 transformer->SetMatrix( W );
238 transformer->Update();
242 for (
unsigned int band = 0; band < size; band++ )
245 <<
", convergence " << convergence );
248 optimizer->SetInput( 0, m_PCAFilter->GetOutput() );
249 optimizer->SetInput( 1, img );
250 optimizer->SetW( W );
251 optimizer->SetContrastFunction( this->GetContrastFunction() );
252 optimizer->SetCurrentBandForLoop( band );
255 estimator->SetInput( optimizer->GetOutput() );
259 for (
unsigned int bd = 0; bd < size; bd++ )
261 W(band, bd) -= m_Mu * ( estimator->GetMean()[bd]
262 - optimizer->GetBeta() * W(band, bd) / optimizer->GetDen() );
263 norm += vcl_pow( W(band, bd), 2. );
265 for (
unsigned int bd = 0; bd < size; bd++ )
266 W(band, bd) /= vcl_sqrt( norm );
271 vnl_svd< MatrixElementType > solver ( W_tmp );
273 for (
unsigned int i = 0; i < valP.rows(); ++i )
274 valP(i, i) = 1. / vcl_sqrt( static_cast<double>( valP(i, i) ) );
276 W_tmp = transf * valP * transf.transpose();
281 for (
unsigned int i = 0; i < W.rows(); ++i )
282 for (
unsigned int j = 0; j < W.cols(); ++j )
283 convergence += vcl_abs( W(i, j) - W_old(i, j) );
288 if ( size != this->GetNumberOfPrincipalComponentsRequired() )
290 this->m_TransformationMatrix = W.get_n_rows( 0, this->GetNumberOfPrincipalComponentsRequired() );
294 this->m_TransformationMatrix = W;
298 <<
" after " << iteration <<
" iterations" );