Orfeo Toolbox  4.0
otbMeanShiftSmoothingImageFilter.txx
Go to the documentation of this file.
1 /*=========================================================================
2 
3  Program: ORFEO Toolbox
4  Language: C++
5  Date: $Date$
6  Version: $Revision$
7 
8 
9  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
10  See OTBCopyright.txt for details.
11 
12 
13  This software is distributed WITHOUT ANY WARRANTY; without even
14  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
15  PURPOSE. See the above copyright notices for more information.
16 
17 =========================================================================*/
18 
19 #ifndef __otbMeanShiftSmoothingImageFilter_txx
20 #define __otbMeanShiftSmoothingImageFilter_txx
21 
24 #include "itkImageRegionIterator.h"
26 #include "otbMacro.h"
27 
28 #include "itkProgressReporter.h"
29 
30 
31 namespace otb
32 {
33 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
35  m_RangeBandwidth(16.), m_RangeBandwidthRamp(0), m_SpatialBandwidth(3)
36  // , m_SpatialRadius(???)
37  , m_Threshold(1e-3), m_MaxIterationNumber(10)
38  // , m_Kernel(...)
39  // , m_NumberOfComponentsPerPixel(...)
40  // , m_JointImage(0)
41  // , m_ModeTable(0)
42  , m_ModeSearch(true)
43 #if 0
44  , m_BucketOptimization(false)
45 #endif
46 {
48  this->SetNthOutput(0, OutputImageType::New());
50  this->SetNthOutput(2, OutputIterationImageType::New());
52  m_GlobalShift.Fill(0);
53 }
54 
55 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
57 {
58 }
59 
60 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
63 {
64  return static_cast<const OutputSpatialImageType *> (this->itk::ProcessObject::GetOutput(1));
65 }
66 
67 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
70 {
71  return static_cast<OutputSpatialImageType *> (this->itk::ProcessObject::GetOutput(1));
72 }
73 
74 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
77 {
78  return static_cast<const OutputImageType *> (this->itk::ProcessObject::GetOutput(0));
79 }
80 
81 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
84 {
85  return static_cast<OutputImageType *> (this->itk::ProcessObject::GetOutput(0));
86 }
87 
88 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
91 {
92  return static_cast<OutputIterationImageType *> (this->itk::ProcessObject::GetOutput(2));
93 }
94 
95 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
98 {
99  return static_cast<OutputIterationImageType *> (this->itk::ProcessObject::GetOutput(2));
100 }
101 
102 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
105 {
106  return static_cast<OutputLabelImageType *> (this->itk::ProcessObject::GetOutput(3));
107 }
108 
109 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
112 {
113  return static_cast<OutputLabelImageType *> (this->itk::ProcessObject::GetOutput(3));
114 }
115 
116 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
118 {
119  typename OutputSpatialImageType::Pointer spatialOutputPtr = this->GetSpatialOutput();
120  typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
121  typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
122  typename OutputLabelImageType::Pointer labelOutputPtr = this->GetLabelOutput();
123 
124  spatialOutputPtr->SetBufferedRegion(spatialOutputPtr->GetRequestedRegion());
125  spatialOutputPtr->Allocate();
126 
127  rangeOutputPtr->SetBufferedRegion(rangeOutputPtr->GetRequestedRegion());
128  rangeOutputPtr->Allocate();
129 
130  iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
131  iterationOutputPtr->Allocate();
132 
133  labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
134  labelOutputPtr->Allocate();
135 }
136 
137 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
139 {
140  Superclass::GenerateOutputInformation();
141 
142  m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
143 
144  if (this->GetSpatialOutput())
145  {
146  this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice
147  }
148  if (this->GetRangeOutput())
149  {
150  this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
151  }
152 }
153 
154 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
156 {
157  // Call superclass implementation
158  Superclass::GenerateInputRequestedRegion();
159 
160  // Retrieve input pointers
161  InputImagePointerType inPtr = const_cast<TInputImage *> (this->GetInput());
162  OutputImagePointerType outRangePtr = this->GetRangeOutput();
163 
164  // Check pointers before using them
165  if (!inPtr || !outRangePtr)
166  {
167  return;
168  }
169 
170  // Retrieve requested region (TODO: check if we need to handle
171  // region for outHDispPtr)
172  RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();
173 
174  // Pad by the appropriate radius
175  RegionType inputRequestedRegion = outputRequestedRegion;
176 
177  // Initializes the spatial radius from kernel bandwidth
178  m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
179 
180  InputSizeType margin;
181 
182  for(unsigned int comp = 0; comp < ImageDimension; ++comp)
183  {
184  margin[comp] = (m_MaxIterationNumber+1) * m_SpatialRadius[comp];
185  }
186 
187  inputRequestedRegion.PadByRadius(margin);
188 
189  // Crop the input requested region at the input's largest possible region
190  if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
191  {
192  inPtr->SetRequestedRegion(inputRequestedRegion);
193  return;
194  }
195  else
196  {
197  // Couldn't crop the region (requested region is outside the largest
198  // possible region). Throw an exception.
199 
200  // store what we tried to request (prior to trying to crop)
201  inPtr->SetRequestedRegion(inputRequestedRegion);
202 
203  // build an exception
204  itk::InvalidRequestedRegionError e(__FILE__, __LINE__);
205  e.SetLocation(ITK_LOCATION);
206  e.SetDescription("Requested region is (at least partially) outside the largest possible region.");
207  e.SetDataObject(inPtr);
208  throw e;
209  }
210 
211 }
212 
213 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
215 {
216  // typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
217  // typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
218 
219  OutputSpatialImagePointerType outSpatialPtr = this->GetSpatialOutput();
220  OutputImagePointerType outRangePtr = this->GetRangeOutput();
221  typename InputImageType::ConstPointer inputPtr = this->GetInput();
222  typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
223  typename OutputSpatialImageType::Pointer spatialOutput = this->GetSpatialOutput();
224 
225  //InputIndexType index;
226 
227 
228  m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
229 
230  m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
231 
232  // Allocate output images
233  this->AllocateOutputs();
234 
235  // Initialize output images to zero
236  iterationOutput->FillBuffer(0);
237  OutputSpatialPixelType zero(spatialOutput->GetNumberOfComponentsPerPixel());
238  zero.Fill(0);
239  spatialOutput->FillBuffer(zero);
240 
241  // m_JointImage is the input data expressed in the joint spatial-range
242  // domain, i.e. spatial coordinates are concatenated to the range values.
243  // Moreover, pixel components in this image are normalized by their respective
244  // (spatial or range) bandwith.
247  JointImageFunctorType;
248 
249  typename JointImageFunctorType::Pointer jointImageFunctor = JointImageFunctorType::New();
250 
251  jointImageFunctor->SetInput(inputPtr);
252  jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_GlobalShift);
253  jointImageFunctor->GetOutput()->SetRequestedRegion(this->GetInput()->GetBufferedRegion());
254  jointImageFunctor->Update();
255  m_JointImage = jointImageFunctor->GetOutput();
256 
257 #if 0
258  if (m_BucketOptimization)
259  {
260  // Create bucket image
261  // Note: because values in the input m_JointImage are normalized, the
262  // rangeRadius argument is just 1
263  m_BucketImage = BucketImageType(static_cast<typename RealVectorImageType::ConstPointer> (m_JointImage),
264  m_JointImage->GetRequestedRegion(), m_Kernel.GetRadius(m_SpatialBandwidth), 1,
265  ImageDimension);
266  }
267 #endif
268  /*
269  // Allocate the joint domain image
270  m_JointImage = RealVectorImageType::New();
271  m_JointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
272  m_JointImage->SetRegions(inputPtr->GetRequestedRegion());
273  m_JointImage->Allocate();
274 
275  InputIteratorWithIndexType inputIt(inputPtr, inputPtr->GetRequestedRegion());
276  JointImageIteratorType jointIt(m_JointImage, inputPtr->GetRequestedRegion());
277 
278  // Initialize the joint image with scaled values
279  inputIt.GoToBegin();
280  jointIt.GoToBegin();
281 
282  while (!inputIt.IsAtEnd())
283  {
284  typename InputImageType::PixelType const& inputPixel = inputIt.Get();
285  index = inputIt.GetIndex();
286 
287  RealVector & jointPixel = jointIt.Get();
288  for(unsigned int comp = 0; comp < ImageDimension; comp++)
289  {
290  jointPixel[comp] = index[comp] / m_SpatialBandwidth;
291  }
292  for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
293  {
294  jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
295  }
296  // jointIt.Set(jointPixel);
297 
298  ++inputIt;
299  ++jointIt;
300  }
301  */
302 
303  //TODO don't create mode table iterator when ModeSearch is set to false
304  m_ModeTable = ModeTableImageType::New();
305  m_ModeTable->SetRegions(inputPtr->GetRequestedRegion());
306  m_ModeTable->Allocate();
307  m_ModeTable->FillBuffer(0);
308 
309  if (m_ModeSearch)
310  {
311  // Image to store the status at each pixel:
312  // 0 : no mode has been found yet
313  // 1 : a mode has been assigned to this pixel
314  // 2 : a mode will be assigned to this pixel
315 
316 
317  // Initialize counters for mode (also used for mode labeling)
318  // Most significant bits of label counters are used to identify the thread
319  // Id.
320  unsigned int numThreads;
321 
322  numThreads = this->GetNumberOfThreads();
323  m_ThreadIdNumberOfBits = -1;
324  unsigned int n = numThreads;
325  while (n != 0)
326  {
327  n >>= 1;
328  m_ThreadIdNumberOfBits++;
329  }
330  if (m_ThreadIdNumberOfBits == 0) m_ThreadIdNumberOfBits = 1; // minimum 1 bit
331  m_NumLabels.SetSize(numThreads);
332  for (unsigned int i = 0; i < numThreads; i++)
333  {
334  m_NumLabels[i] = static_cast<LabelType> (i) << (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits);
335  }
336 
337  }
338 
339 }
340 
341 // Calculates the mean shift vector at the position given by jointPixel
342 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
344  const typename RealVectorImageType::Pointer jointImage,
345  const RealVector& jointPixel,
346  const OutputRegionType& outputRegion,
347  const RealVector & bandwidth,
348  RealVector& meanShiftVector)
349 {
350  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
351 
352  InputIndexType inputIndex;
353  InputIndexType regionIndex;
354  InputSizeType regionSize;
355 
356  assert(meanShiftVector.GetSize() == jointDimension);
357  meanShiftVector.Fill(0);
358 
359  // Calculates current pixel neighborhood region, restricted to the output image region
360  for (unsigned int comp = 0; comp < ImageDimension; ++comp)
361  {
362  inputIndex[comp] = vcl_floor(jointPixel[comp] + 0.5) - m_GlobalShift[comp];
363 
364  regionIndex[comp] = vcl_max(static_cast<long int> (outputRegion.GetIndex().GetElement(comp)),
365  static_cast<long int> (inputIndex[comp] - m_SpatialRadius[comp] - 1));
366  const long int indexRight = vcl_min(
367  static_cast<long int> (outputRegion.GetIndex().GetElement(comp)
368  + outputRegion.GetSize().GetElement(comp) - 1),
369  static_cast<long int> (inputIndex[comp] + m_SpatialRadius[comp] + 1));
370 
371  regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int> (regionIndex[comp]) + 1);
372  }
373 
374  RegionType neighborhoodRegion;
375  neighborhoodRegion.SetIndex(regionIndex);
376  neighborhoodRegion.SetSize(regionSize);
377 
378  RealType weightSum = 0;
379  RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel), shifts(ImageDimension + m_NumberOfComponentsPerPixel);
380 
381  // An iterator on the neighborhood of the current pixel (in joint
382  // spatial-range domain)
384  //itk::ImageRegionConstIterator<RealVectorImageType> it(jointImage, neighborhoodRegion);
385 
386  it.GoToBegin();
387  while (!it.IsAtEnd())
388  {
389  jointNeighbor = it.Get();
390 
391  // Compute the squared norm of the difference
392  // This is the L2 norm, TODO: replace by the templated norm
393  RealType norm2 = 0;
394  for (unsigned int comp = 0; comp < jointDimension; comp++)
395  {
396  shifts[comp] = jointNeighbor[comp] - jointPixel[comp];
397  double d = shifts[comp]/bandwidth[comp];
398  norm2 += d*d;
399  }
400 
401  // Compute pixel weight from kernel
402  const RealType weight = m_Kernel(norm2);
403  /*
404  // The following code is an alternative way to compute norm2 and weight
405  // It separates the norms of spatial and range elements
406  RealType spatialNorm2;
407  RealType rangeNorm2;
408  spatialNorm2 = 0;
409  for (unsigned int comp = 0; comp < ImageDimension; comp++)
410  {
411  RealType d;
412  d = jointNeighbor[comp] - jointPixel[comp];
413  spatialNorm2 += d*d;
414  }
415 
416  if(spatialNorm2 >= 1.0)
417  {
418  weight = 0;
419  }
420  else
421  {
422  rangeNorm2 = 0;
423  for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
424  {
425  RealType d;
426  d = jointNeighbor[ImageDimension + comp] - jointPixel[ImageDimension + comp];
427  rangeNorm2 += d*d;
428  }
429 
430  weight = (rangeNorm2 <= 1.0)? 1.0 : 0.0;
431  }
432  */
433 
434  // Update sum of weights
435  weightSum += weight;
436 
437  // Update mean shift vector
438  for (unsigned int comp = 0; comp < jointDimension; comp++)
439  {
440  meanShiftVector[comp] += weight * shifts[comp];
441  }
442 
443  ++it;
444  }
445 
446  if (weightSum > 0)
447  {
448  for (unsigned int comp = 0; comp < jointDimension; comp++)
449  {
450  meanShiftVector[comp] = meanShiftVector[comp] / weightSum;
451  }
452  }
453 }
454 
455 #if 0
456 // Calculates the mean shift vector at the position given by jointPixel
457 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
459  const RealVector& jointPixel,
460  RealVector& meanShiftVector)
461 {
462  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
463 
464  RealType weightSum = 0;
465 
466  for (unsigned int comp = 0; comp < jointDimension; comp++)
467  {
468  meanShiftVector[comp] = 0;
469  }
470 
471  RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
472 
473  InputIndexType index;
474  for (unsigned int dim = 0; dim < ImageDimension; ++dim)
475  {
476  index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
477  }
478 
479  const std::vector<unsigned int>
480  neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(
481  m_BucketImage.BucketIndexToBucketListIndex(
482  m_BucketImage.GetBucketIndex(
483  jointPixel,
484  index)));
485 
486  unsigned int numNeighbors = m_BucketImage.GetNumberOfNeighborBuckets();
487  for (unsigned int neighborIndex = 0; neighborIndex < numNeighbors; ++neighborIndex)
488  {
489  const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets[neighborIndex]);
490  if (bucket.empty()) continue;
491  typename BucketImageType::BucketType::const_iterator it = bucket.begin();
492  while (it != bucket.end())
493  {
494  jointNeighbor.SetData(const_cast<RealType*> (*it));
495 
496  // Compute the squared norm of the difference
497  // This is the L2 norm, TODO: replace by the templated norm
498  RealType norm2 = 0;
499  for (unsigned int comp = 0; comp < jointDimension; comp++)
500  {
501  const RealType d = jointNeighbor[comp] - jointPixel[comp];
502  norm2 += d * d;
503  }
504 
505  // Compute pixel weight from kernel
506  const RealType weight = m_Kernel(norm2);
507 
508  // Update sum of weights
509  weightSum += weight;
510 
511  // Update mean shift vector
512  for (unsigned int comp = 0; comp < jointDimension; comp++)
513  {
514  meanShiftVector[comp] += weight * jointNeighbor[comp];
515  }
516 
517  ++it;
518  }
519  }
520 
521  if (weightSum > 0)
522  {
523  for (unsigned int comp = 0; comp < jointDimension; comp++)
524  {
525  meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
526  }
527  }
528 }
529 #endif
530 
531 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
533 ::ThreadedGenerateData(const OutputRegionType& outputRegionForThread, itk::ThreadIdType threadId)
534 {
535  // at the first iteration
536 
537 
538  // Retrieve output images pointers
539  typename OutputSpatialImageType::Pointer spatialOutput = this->GetSpatialOutput();
540  typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
541  typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
542  typename OutputLabelImageType::Pointer labelOutput = this->GetLabelOutput();
543 
544  // Get input image pointer
545  typename InputImageType::ConstPointer input = this->GetInput();
546 
547  // defines input and output iterators
548  typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
549  typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
550  typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
551  typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
552 
553  const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
554 
555  typename OutputImageType::PixelType rangePixel(m_NumberOfComponentsPerPixel);
556  typename OutputSpatialImageType::PixelType spatialPixel(ImageDimension);
557 
558  RealVector jointPixel;
559 
560  RealVector bandwidth(jointDimension);
561  for (unsigned int comp = 0; comp < ImageDimension; comp++)
562  bandwidth[comp] = m_SpatialBandwidth;
563 
564  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
565 
566  RegionType const& requestedRegion = input->GetRequestedRegion();
567 
569  JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
570 
571  OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
572  OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
573  OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
574  OutputLabelIteratorType labelIt(labelOutput, outputRegionForThread);
575 
576  typedef itk::ImageRegionIterator<ModeTableImageType> ModeTableImageIteratorType;
577  ModeTableImageIteratorType modeTableIt(m_ModeTable, outputRegionForThread);
578 
579  jointIt.GoToBegin();
580  rangeIt.GoToBegin();
581  spatialIt.GoToBegin();
582  iterationIt.GoToBegin();
583  modeTableIt.GoToBegin();
584  labelIt.GoToBegin();
585 
586  unsigned int iteration = 0;
587 
588  // Mean shift vector, updating the joint pixel at each iteration
589  RealVector meanShiftVector(jointDimension);
590 
591  // Variables used by mode search optimization
592  // List of indices where the current pixel passes through
593  std::vector<InputIndexType> pointList;
594  if (m_ModeSearch) pointList.resize(m_MaxIterationNumber);
595  // Number of times an already processed candidate pixel is encountered, resulting in no
596  // further computation (Used for statistics only)
597  unsigned int numBreaks = 0;
598  // index of the current pixel updated during the mean shift loop
599  InputIndexType modeCandidate;
600 
601  for (; !jointIt.IsAtEnd(); ++jointIt, ++rangeIt, ++spatialIt, ++iterationIt, ++modeTableIt, ++labelIt, progress.CompletedPixel())
602  {
603 
604  // if pixel has been already processed (by mode search optimization), skip
605  typename ModeTableImageType::InternalPixelType const& currentPixelMode = modeTableIt.Get();
606  if (m_ModeSearch && currentPixelMode == 1)
607  {
608  numBreaks++;
609  continue;
610  }
611 
612  bool hasConverged = false;
613 
614  // get input pixel in the joint spatial-range domain (with components
615  // normalized by bandwith)
616  jointPixel = jointIt.Get(); // Pixel in the joint spatial-range domain
617 
618  for (unsigned int comp = ImageDimension; comp < jointDimension; comp++)
619  bandwidth[comp] = m_RangeBandwidthRamp*jointPixel[comp]+m_RangeBandwidth;
620 
621  // index of the currently processed output pixel
622  InputIndexType currentIndex = jointIt.GetIndex();
623 
624  // Number of points currently in the pointList
625  unsigned int pointCount = 0; // Note: used only in mode search optimization
626  iteration = 0;
627  while ((iteration < m_MaxIterationNumber) && (!hasConverged))
628  {
629 
630  if (m_ModeSearch)
631  {
632  // Find index of the pixel closest to the current jointPixel (not normalized by bandwidth)
633  for (unsigned int comp = 0; comp < ImageDimension; comp++)
634  {
635  modeCandidate[comp] = vcl_floor(jointPixel[comp] - m_GlobalShift[comp] + 0.5);
636  }
637  // Check status of candidate mode
638 
639  // If pixel candidate has status 0 (no mode assigned) or 1 (mode assigned)
640  // but not 2 (pixel in current search path), and pixel has actually moved
641  // from its initial position, and pixel candidate is inside the output
642  // region, then perform optimization tasks
643  if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2
644  && outputRegionForThread.IsInside(modeCandidate))
645  {
646  // Obtain the data point to see if it close to jointPixel
647  RealType diff = 0;
648  RealVector const& candidatePixel = m_JointImage->GetPixel(modeCandidate);
649  for (unsigned int comp = ImageDimension; comp < jointDimension; comp++)
650  {
651  const RealType d = (candidatePixel[comp] - jointPixel[comp])/bandwidth[comp];
652  diff += d * d;
653  }
654 
655  if (diff < 0.5) // Spectral value is close enough
656  {
657  // If no mode has been associated to the candidate pixel then
658  // associate it to the upcoming mode
659  if (m_ModeTable->GetPixel(modeCandidate) == 0)
660  {
661  // Add the candidate to the list of pixels that will be assigned the
662  // finally calculated mode value
663  pointList[pointCount++] = modeCandidate;
664  m_ModeTable->SetPixel(modeCandidate, 2);
665  }
666  else // == 1
667  {
668  // The candidate pixel has already been assigned to a mode
669  // Assign the same value
670  rangePixel = rangeOutput->GetPixel(modeCandidate);
671  for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
672  {
673  jointPixel[ImageDimension + comp] = rangePixel[comp];
674  }
675  // Update the mode table because pixel will be assigned just now
676  modeTableIt.Set(2); // m_ModeTable->SetPixel(currentIndex, 2);
677  // bypass further calculation
678  numBreaks++;
679  break;
680  }
681  }
682 
683  }
684  } // end if (m_ModeSearch)
685 
686  //Calculate meanShiftVector
687 #if 0
688  if (m_BucketOptimization)
689  {
690  this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
691  }
692  else
693  {
694 #endif
695  this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, bandwidth, meanShiftVector);
696 
697 #if 0
698  }
699 #endif
700 
701  // Compute mean shift vector squared norm (not normalized by bandwidth)
702  // and add mean shift vector to current joint pixel
703  double meanShiftVectorSqNorm = 0;
704  for (unsigned int comp = 0; comp < jointDimension; comp++)
705  {
706  const double v = meanShiftVector[comp];
707  meanShiftVectorSqNorm += v * v;
708  jointPixel[comp] += meanShiftVector[comp];
709  }
710 
711  //TODO replace SSD Test with templated metric
712  hasConverged = meanShiftVectorSqNorm < m_Threshold;
713  iteration++;
714  }
715 
716  for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
717  {
718  rangePixel[comp] = jointPixel[ImageDimension + comp];
719  }
720 
721  for (unsigned int comp = 0; comp < ImageDimension; comp++)
722  {
723  spatialPixel[comp] = jointPixel[comp] - currentIndex[comp] - m_GlobalShift[comp];
724  }
725 
726  rangeIt.Set(rangePixel);
727  spatialIt.Set(spatialPixel);
728 
729  const typename OutputIterationImageType::PixelType iterationPixel = iteration;
730  iterationIt.Set(iterationPixel);
731 
732  if (m_ModeSearch)
733  {
734  // Update the mode table now that the current pixel has been assigned
735  modeTableIt.Set(1); // m_ModeTable->SetPixel(currentIndex, 1);
736 
737  // If the loop exited with hasConverged or too many iterations, then we have a new mode
738  LabelType label;
739  if (hasConverged || iteration == m_MaxIterationNumber)
740  {
741  m_NumLabels[threadId]++;
742  label = m_NumLabels[threadId];
743  }
744  else // the loop exited through a break. Use the already assigned mode label
745  {
746  label = labelOutput->GetPixel(modeCandidate);
747  }
748  labelIt.Set(label);
749 
750  // Also assign all points in the list to the same mode
751  for (unsigned int i = 0; i < pointCount; i++)
752  {
753  rangeOutput->SetPixel(pointList[i], rangePixel);
754  m_ModeTable->SetPixel(pointList[i], 1);
755  labelOutput->SetPixel(pointList[i], label);
756  }
757  }
758  else // if ModeSearch is not set LabelOutput can't be generated
759  {
760  LabelType labelZero = 0;
761  labelIt.Set(labelZero);
762  }
763 
764  }
765  // std::cout << "numBreaks: " << numBreaks << " Break ratio: " << numBreaks / (RealType)outputRegionForThread.GetNumberOfPixels() << std::endl;
766 }
767 
768 /* after threaded convergence test */
769 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
771 {
772  typename OutputLabelImageType::Pointer labelOutput = this->GetLabelOutput();
773  typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
774  OutputLabelIteratorType labelIt(labelOutput, labelOutput->GetRequestedRegion());
775 
776  // Reassign mode labels
777  // Note: Labels are only computed when mode search optimization is enabled
778  if (m_ModeSearch)
779  {
780  // New labels will be consecutive. The following vector contains the new
781  // start label for each thread.
783  newLabelOffset.SetSize(this->GetNumberOfThreads());
784  newLabelOffset[0] = 0;
785  for (itk::ThreadIdType i = 1; i < this->GetNumberOfThreads(); i++)
786  {
787  // Retrieve the number of labels in the thread by removing the threadId
788  // from the most significant bits
789  LabelType localNumLabel = m_NumLabels[i - 1] & ((static_cast<LabelType> (1) << (sizeof(LabelType) * 8
790  - m_ThreadIdNumberOfBits)) - static_cast<LabelType> (1));
791  newLabelOffset[i] = localNumLabel + newLabelOffset[i - 1];
792  }
793 
794  labelIt.GoToBegin();
795 
796  while (!labelIt.IsAtEnd())
797  {
798  LabelType const label = labelIt.Get();
799 
800  // Get threadId from most significant bits
801  const itk::ThreadIdType threadId = label >> (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits);
802 
803  // Relabeling
804  // First get the label number by removing the threadId bits
805  // Then add the label offset specific to the threadId
806  LabelType newLabel = label & ((static_cast<LabelType> (1) << (sizeof(LabelType) * 8 - m_ThreadIdNumberOfBits))
807  - static_cast<LabelType> (1));
808  newLabel += newLabelOffset[threadId];
809 
810  labelIt.Set(newLabel);
811  ++labelIt;
812  }
813  }
814 
815 }
816 
817 template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
819  std::ostream& os,
820  itk::Indent indent) const
821 {
822  Superclass::PrintSelf(os, indent);
823  os << indent << "Spatial bandwidth: " << m_SpatialBandwidth << std::endl;
824  os << indent << "Range bandwidth: " << m_RangeBandwidth << std::endl;
825 }
826 
827 } // end namespace otb
828 
829 #endif

Generated at Sat Mar 8 2014 16:07:59 for Orfeo Toolbox with doxygen 1.8.3.1