OTB  9.0.0
Orfeo Toolbox
otbMultivariateAlterationDetectorImageFilter.hxx
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2005-2022 Centre National d'Etudes Spatiales (CNES)
3  *
4  * This file is part of Orfeo Toolbox
5  *
6  * https://www.orfeo-toolbox.org/
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  */
20 
21 #ifndef otbMultivariateAlterationDetectorImageFilter_hxx
22 #define otbMultivariateAlterationDetectorImageFilter_hxx
23 
25 #include "otbMath.h"
26 
27 #include "vnl/algo/vnl_matrix_inverse.h"
28 #include "vnl/algo/vnl_generalized_eigensystem.h"
29 
30 #include "itkImageRegionIterator.h"
31 #include "itkProgressReporter.h"
32 
33 namespace otb
34 {
35 template <class TInputImage, class TOutputImage>
37 {
38  this->SetNumberOfRequiredInputs(2);
39  m_CovarianceEstimator = CovarianceEstimatorType::New();
40 }
41 
42 template <class TInputImage, class TOutputImage>
44 {
45  // Process object is not const-correct so the const casting is required.
46  this->SetNthInput(0, const_cast<TInputImage*>(image1));
47 }
48 
49 template <class TInputImage, class TOutputImage>
52 {
53  if (this->GetNumberOfInputs() < 1)
54  {
55  return nullptr;
56  }
57  return static_cast<const TInputImage*>(this->itk::ProcessObject::GetInput(0));
58 }
59 
60 template <class TInputImage, class TOutputImage>
62 {
63  // Process object is not const-correct so the const casting is required.
64  this->SetNthInput(1, const_cast<TInputImage*>(image2));
65 }
66 
67 template <class TInputImage, class TOutputImage>
70 {
71  if (this->GetNumberOfInputs() < 2)
72  {
73  return nullptr;
74  }
75  return static_cast<const TInputImage*>(this->itk::ProcessObject::GetInput(1));
76 }
77 
78 
79 template <class TInputImage, class TOutputImage>
81 {
82  // Call superclass implementation
83  Superclass::GenerateOutputInformation();
84 
85  // Retrieve input images pointers
86  const TInputImage* input1Ptr = this->GetInput1();
87  const TInputImage* input2Ptr = this->GetInput2();
88  TOutputImage* outputPtr = const_cast<TOutputImage*>(this->GetOutput());
89 
90  // Get the number of components for each image
91  unsigned int nbComp1 = input1Ptr->GetNumberOfComponentsPerPixel();
92  unsigned int nbComp2 = input2Ptr->GetNumberOfComponentsPerPixel();
93  unsigned int outNbComp = std::max(nbComp1, nbComp2);
94 
95  outputPtr->SetNumberOfComponentsPerPixel(outNbComp);
96 
97  if (input1Ptr->GetLargestPossibleRegion() != input2Ptr->GetLargestPossibleRegion())
98  {
99  itkExceptionMacro(<< "Input images does not have the same size!");
100  }
101 
102  // First concatenate both images
103  ConcatenateImageFilterPointer concatenateFilter = ConcatenateImageFilterType::New();
104  concatenateFilter->SetInput1(input1Ptr);
105  concatenateFilter->SetInput2(input2Ptr);
106 
107  // The compute covariance matrix
108  // TODO: implement switch off of this computation
109  m_CovarianceEstimator->SetInput(concatenateFilter->GetOutput());
110  m_CovarianceEstimator->Update();
111  m_CovarianceMatrix = m_CovarianceEstimator->GetCovariance();
112  m_MeanValues = m_CovarianceEstimator->GetMean();
113 
114  // Extract sub-matrices of the covariance matrix
115  VnlMatrixType s11 = m_CovarianceMatrix.GetVnlMatrix().extract(nbComp1, nbComp1);
116  VnlMatrixType s22 = m_CovarianceMatrix.GetVnlMatrix().extract(nbComp2, nbComp2, nbComp1, nbComp1);
117  VnlMatrixType s12 = m_CovarianceMatrix.GetVnlMatrix().extract(nbComp1, nbComp2, 0, nbComp1);
118  VnlMatrixType s21 = s12.transpose();
119 
120  // Extract means
121  m_Mean1 = VnlVectorType(nbComp1, 0);
122  m_Mean2 = VnlVectorType(nbComp2, 0);
123 
124  for (unsigned int i = 0; i < nbComp1; ++i)
125  {
126  m_Mean1[i] = m_MeanValues[i];
127  }
128 
129  for (unsigned int i = 0; i < nbComp2; ++i)
130  {
131  m_Mean2[i] = m_MeanValues[nbComp1 + i];
132  }
133 
134  if (nbComp1 == nbComp2)
135  {
136  // Case where nbbands1 == nbbands2
137 
138  VnlMatrixType invs22 = vnl_matrix_inverse<RealType>(s22);
139 
140  // Build the generalized eigensystem
141  VnlMatrixType s12s22is21 = s12 * invs22 * s21;
142 
143  vnl_generalized_eigensystem ges(s12s22is21, s11);
144 
145  m_V1 = ges.V;
146 
147  // Compute canonical correlation matrix
148  m_Rho = ges.D.get_diagonal();
149  m_Rho = m_Rho.apply(&std::sqrt);
150 
151  // We do not need to scale v1 since the
152  // vnl_generalized_eigensystem already gives unit variance
153 
154  VnlMatrixType invstderr1 = s11.apply(&std::sqrt);
155  invstderr1 = invstderr1.apply(&InverseValue);
156  VnlVectorType diag1 = invstderr1.get_diagonal();
157  invstderr1.fill(0);
158  invstderr1.set_diagonal(diag1);
159 
160  VnlMatrixType sign1 = VnlMatrixType(nbComp1, nbComp1, 0);
161 
162  VnlMatrixType aux4 = invstderr1 * s11 * m_V1;
163 
164  VnlVectorType aux5 = VnlVectorType(nbComp1, 0);
165 
166  for (unsigned int i = 0; i < nbComp1; ++i)
167  {
168  aux5 = aux5 + aux4.get_row(i);
169  }
170 
171  sign1.set_diagonal(aux5);
172  sign1 = sign1.apply(&SignOfValue);
173 
174  m_V1 = m_V1 * sign1;
175 
176  m_V2 = invs22 * s21 * m_V1;
177 
178  // Scale v2 for unit variance
179  VnlMatrixType aux1 = m_V2.transpose() * (s22 * m_V2);
180  VnlVectorType aux2 = aux1.get_diagonal();
181  aux2 = aux2.apply(&std::sqrt);
182  aux2 = aux2.apply(&InverseValue);
183  VnlMatrixType aux3 = VnlMatrixType(aux2.size(), aux2.size(), 0);
184  aux3.fill(0);
185  aux3.set_diagonal(aux2);
186  m_V2 = m_V2 * aux3;
187  }
188  else
189  {
190  VnlMatrixType sl(nbComp1 + nbComp2, nbComp1 + nbComp2, 0);
191  VnlMatrixType sr(nbComp1 + nbComp2, nbComp1 + nbComp2, 0);
192 
193  sl.update(s12, 0, nbComp1);
194  sl.update(s21, nbComp1, 0);
195  sr.update(s11, 0, 0);
196  sr.update(s22, nbComp1, nbComp1);
197 
198  vnl_generalized_eigensystem ges(sl, sr);
199 
200  VnlMatrixType V = ges.V;
201 
202  V.fliplr();
203 
204  m_V1 = V.extract(nbComp1, nbComp1);
205  m_V2 = V.extract(nbComp2, nbComp2, nbComp1, 0);
206 
207  m_Rho = ges.D.get_diagonal().flip().extract(std::max(nbComp1, nbComp2), 0);
208 
209  // Scale v1 to get a unit variance
210  VnlMatrixType aux1 = m_V1.transpose() * (s11 * m_V1);
211 
212  VnlVectorType aux2 = aux1.get_diagonal();
213  aux2 = aux2.apply(&std::sqrt);
214  aux2 = aux2.apply(&InverseValue);
215 
216  VnlMatrixType aux3 = VnlMatrixType(aux2.size(), aux2.size(), 0);
217  aux3.set_diagonal(aux2);
218  m_V1 = m_V1 * aux3;
219 
220  VnlMatrixType invstderr1 = s11.apply(&std::sqrt);
221  invstderr1 = invstderr1.apply(&InverseValue);
222  VnlVectorType diag1 = invstderr1.get_diagonal();
223  invstderr1.fill(0);
224  invstderr1.set_diagonal(diag1);
225 
226  VnlMatrixType sign1 = VnlMatrixType(nbComp1, nbComp1, 0);
227 
228  VnlMatrixType aux4 = invstderr1 * s11 * m_V1;
229 
230  VnlVectorType aux5 = VnlVectorType(nbComp1, 0);
231 
232  for (unsigned int i = 0; i < nbComp1; ++i)
233  {
234  aux5 = aux5 + aux4.get_row(i);
235  }
236 
237  sign1.set_diagonal(aux5);
238  sign1 = sign1.apply(&SignOfValue);
239 
240  m_V1 = m_V1 * sign1;
241 
242  // Scale v2 for unit variance
243  aux1 = m_V2.transpose() * (s22 * m_V2);
244  aux2 = aux1.get_diagonal();
245  aux2 = aux2.apply(&std::sqrt);
246  aux2 = aux2.apply(&InverseValue);
247  aux3 = VnlMatrixType(aux2.size(), aux2.size(), 0);
248  aux3.fill(0);
249  aux3.set_diagonal(aux2);
250  m_V2 = m_V2 * aux3;
251 
252  VnlMatrixType sign2 = VnlMatrixType(nbComp2, nbComp2, 0);
253 
254  aux5 = (m_V1.transpose() * s12 * m_V2).transpose().get_diagonal();
255  sign2.set_diagonal(aux5);
256  sign2 = sign2.apply(&SignOfValue);
257  m_V2 = m_V2 * sign2;
258 
259  m_Rho.flip();
260  }
261 }
262 
263 template <class TInputImage, class TOutputImage>
265  itk::ThreadIdType threadId)
266 {
267  // Retrieve input images pointers
268  const TInputImage* input1Ptr = this->GetInput1();
269  const TInputImage* input2Ptr = this->GetInput2();
270  TOutputImage* outputPtr = this->GetOutput();
271 
272 
273  typedef itk::ImageRegionConstIterator<InputImageType> ConstIteratorType;
274  typedef itk::ImageRegionIterator<OutputImageType> IteratorType;
275 
276  IteratorType outIt(outputPtr, outputRegionForThread);
277  ConstIteratorType inIt1(input1Ptr, outputRegionForThread);
278  ConstIteratorType inIt2(input2Ptr, outputRegionForThread);
279 
280  inIt1.GoToBegin();
281  inIt2.GoToBegin();
282  outIt.GoToBegin();
283 
284  // Get the number of components for each image
285  unsigned int nbComp1 = input1Ptr->GetNumberOfComponentsPerPixel();
286  unsigned int nbComp2 = input2Ptr->GetNumberOfComponentsPerPixel();
287  unsigned int outNbComp = outputPtr->GetNumberOfComponentsPerPixel();
288 
289 
290  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
291 
292  while (!inIt1.IsAtEnd() && !inIt2.IsAtEnd() && !outIt.IsAtEnd())
293  {
294  VnlVectorType x1(nbComp1, 0);
295  VnlVectorType x1bis(outNbComp, 0);
296  VnlVectorType x2(nbComp2, 0);
297  VnlVectorType x2bis(outNbComp, 0);
298  VnlVectorType mad(outNbComp, 0);
299 
300  for (unsigned int i = 0; i < nbComp1; ++i)
301  {
302  x1[i] = inIt1.Get()[i];
303  }
304 
305  for (unsigned int i = 0; i < nbComp2; ++i)
306  {
307  x2[i] = inIt2.Get()[i];
308  }
309 
310  VnlVectorType first = (x1 - m_Mean1) * m_V1;
311  VnlVectorType second = (x2 - m_Mean2) * m_V2;
312 
313  for (unsigned int i = 0; i < nbComp1; ++i)
314  {
315  x1bis[i] = first[i];
316  }
317 
318  for (unsigned int i = 0; i < nbComp2; ++i)
319  {
320  x2bis[i] = second[i];
321  }
322 
323  mad = x1bis - x2bis;
324 
325  typename OutputImageType::PixelType outPixel(outNbComp);
326 
327  if (nbComp1 == nbComp2)
328  {
329  for (unsigned int i = 0; i < outNbComp; ++i)
330  {
331  outPixel[i] = mad[i];
332  }
333  }
334  else
335  {
336  for (unsigned int i = 0; i < outNbComp; ++i)
337  {
338  outPixel[i] = mad[outNbComp - i - 1];
339 
340  if (i < outNbComp - std::min(nbComp1, nbComp2))
341  {
342  outPixel[i] *= std::sqrt(2.);
343  }
344  }
345  }
346 
347  outIt.Set(outPixel);
348 
349  ++inIt1;
350  ++inIt2;
351  ++outIt;
352  progress.CompletedPixel();
353  }
354 }
355 }
356 
357 #endif
otb::MultivariateAlterationDetectorImageFilter::MultivariateAlterationDetectorImageFilter
MultivariateAlterationDetectorImageFilter()
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:36
otb::MultivariateAlterationDetectorImageFilter::SetInput1
void SetInput1(const TInputImage *image1)
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:43
otb::MultivariateAlterationDetectorImageFilter::VnlMatrixType
vnl_matrix< RealType > VnlMatrixType
Definition: otbMultivariateAlterationDetectorImageFilter.h:121
otb::MultivariateAlterationDetectorImageFilter::OutputImageRegionType
OutputImageType::RegionType OutputImageRegionType
Definition: otbMultivariateAlterationDetectorImageFilter.h:104
otb::MultivariateAlterationDetectorImageFilter::ThreadedGenerateData
void ThreadedGenerateData(const OutputImageRegionType &outputRegionForThread, itk::ThreadIdType threadId) override
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:264
otbMath.h
otb
The "otb" namespace contains all Orfeo Toolbox (OTB) classes.
Definition: otbJoinContainer.h:32
otb::MultivariateAlterationDetectorImageFilter::VnlVectorType
vnl_vector< RealType > VnlVectorType
Definition: otbMultivariateAlterationDetectorImageFilter.h:120
otb::InverseValue
T InverseValue(const T &value)
Definition: otbMath.h:96
otb::MultivariateAlterationDetectorImageFilter::GetInput1
const TInputImage * GetInput1()
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:51
otb::MultivariateAlterationDetectorImageFilter::SetInput2
void SetInput2(const TInputImage *image2)
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:61
otb::MultivariateAlterationDetectorImageFilter::GenerateOutputInformation
void GenerateOutputInformation() override
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:80
otb::MultivariateAlterationDetectorImageFilter::GetInput2
const TInputImage * GetInput2()
Definition: otbMultivariateAlterationDetectorImageFilter.hxx:69
otb::MultivariateAlterationDetectorImageFilter::InputImageType
TInputImage InputImageType
Definition: otbMultivariateAlterationDetectorImageFilter.h:94
otbMultivariateAlterationDetectorImageFilter.h
otb::MultivariateAlterationDetectorImageFilter::ConcatenateImageFilterPointer
ConcatenateImageFilterType::Pointer ConcatenateImageFilterPointer
Definition: otbMultivariateAlterationDetectorImageFilter.h:111
otb::SignOfValue
T SignOfValue(const T &value)
Definition: otbMath.h:102