19 #ifndef __otbMeanShiftSmoothingImageFilter_txx
20 #define __otbMeanShiftSmoothingImageFilter_txx
33 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
35 m_RangeBandwidth(16.), m_SpatialBandwidth(3)
37 , m_Threshold(1e-3), m_MaxIterationNumber(10)
44 , m_BucketOptimization(false)
54 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
59 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
66 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
73 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
80 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
87 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
94 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
101 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
108 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
115 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
119 typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
120 typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
123 spatialOutputPtr->SetBufferedRegion(spatialOutputPtr->GetRequestedRegion());
124 spatialOutputPtr->Allocate();
126 rangeOutputPtr->SetBufferedRegion(rangeOutputPtr->GetRequestedRegion());
127 rangeOutputPtr->Allocate();
129 iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
130 iterationOutputPtr->Allocate();
132 labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
133 labelOutputPtr->Allocate();
136 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
139 Superclass::GenerateOutputInformation();
141 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
143 if (this->GetSpatialOutput())
145 this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension);
147 if (this->GetRangeOutput())
149 this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
153 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
157 Superclass::GenerateInputRequestedRegion();
164 if (!inPtr || !outRangePtr)
171 RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();
174 RegionType inputRequestedRegion = outputRequestedRegion;
177 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
179 inputRequestedRegion.PadByRadius(m_SpatialRadius);
182 if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
184 inPtr->SetRequestedRegion(inputRequestedRegion);
193 inPtr->SetRequestedRegion(inputRequestedRegion);
198 e.
SetDescription(
"Requested region is (at least partially) outside the largest possible region.");
205 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
213 typename InputImageType::ConstPointer inputPtr = this->GetInput();
214 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
220 m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
222 m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
225 this->AllocateOutputs();
228 iterationOutput->FillBuffer(0);
231 spatialOutput->FillBuffer(zero);
239 JointImageFunctorType;
241 typename JointImageFunctorType::Pointer jointImageFunctor = JointImageFunctorType::New();
243 jointImageFunctor->SetInput(inputPtr);
244 jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_SpatialBandwidth,
246 jointImageFunctor->Update();
247 m_JointImage = jointImageFunctor->GetOutput();
250 if (m_BucketOptimization)
255 m_BucketImage = BucketImageType(static_cast<typename RealVectorImageType::ConstPointer> (m_JointImage),
256 m_JointImage->GetRequestedRegion(), m_Kernel.GetRadius(m_SpatialBandwidth), 1,
296 m_ModeTable = ModeTableImageType::New();
297 m_ModeTable->SetRegions(inputPtr->GetRequestedRegion());
298 m_ModeTable->Allocate();
299 m_ModeTable->FillBuffer(0);
312 unsigned int numThreads;
314 numThreads = this->GetNumberOfThreads();
315 m_ThreadIdNumberOfBits = -1;
316 unsigned int n = numThreads;
320 m_ThreadIdNumberOfBits++;
322 if (m_ThreadIdNumberOfBits == 0) m_ThreadIdNumberOfBits = 1;
323 m_NumLabels.SetSize(numThreads);
324 for (
unsigned int i = 0; i < numThreads; i++)
326 m_NumLabels[i] =
static_cast<LabelType> (i) << (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
334 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
341 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
347 assert(meanShiftVector.
GetSize() == jointDimension);
348 meanShiftVector.
Fill(0);
351 for (
unsigned int comp = 0; comp < ImageDimension; ++comp)
353 inputIndex[comp] = jointPixel[comp] * m_SpatialBandwidth;
355 regionIndex[comp] = vcl_max(static_cast<long int> (outputRegion.GetIndex().GetElement(comp)),
356 static_cast<long int> (inputIndex[comp] - m_SpatialRadius[comp]));
357 const long int indexRight = vcl_min(
358 static_cast<long int> (outputRegion.GetIndex().GetElement(comp)
359 + outputRegion.GetSize().GetElement(comp) - 1),
360 static_cast<long int> (inputIndex[comp] + m_SpatialRadius[comp]));
362 regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int> (regionIndex[comp]) + 1);
366 neighborhoodRegion.SetIndex(regionIndex);
367 neighborhoodRegion.SetSize(regionSize);
370 RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
385 for (
unsigned int comp = 0; comp < jointDimension; comp++)
387 const RealType d = jointNeighbor[comp] - jointPixel[comp];
392 const RealType weight = m_Kernel(norm2);
428 for (
unsigned int comp = 0; comp < jointDimension; comp++)
430 meanShiftVector[comp] += weight * jointNeighbor[comp];
438 for (
unsigned int comp = 0; comp < jointDimension; comp++)
440 meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
447 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
449 const RealVector& jointPixel,
450 RealVector& meanShiftVector)
452 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
454 RealType weightSum = 0;
456 for (
unsigned int comp = 0; comp < jointDimension; comp++)
458 meanShiftVector[comp] = 0;
461 RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
463 InputIndexType index;
464 for (
unsigned int dim = 0; dim < ImageDimension; ++dim)
466 index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
469 const std::vector<unsigned int>
470 neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(
471 m_BucketImage.BucketIndexToBucketListIndex(
472 m_BucketImage.GetBucketIndex(
476 unsigned int numNeighbors = m_BucketImage.GetNumberOfNeighborBuckets();
477 for (
unsigned int neighborIndex = 0; neighborIndex < numNeighbors; ++neighborIndex)
479 const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets[neighborIndex]);
480 if (bucket.empty())
continue;
481 typename BucketImageType::BucketType::const_iterator it = bucket.begin();
482 while (it != bucket.end())
484 jointNeighbor.SetData(const_cast<RealType*> (*it));
489 for (
unsigned int comp = 0; comp < jointDimension; comp++)
491 const RealType d = jointNeighbor[comp] - jointPixel[comp];
496 const RealType weight = m_Kernel(norm2);
502 for (
unsigned int comp = 0; comp < jointDimension; comp++)
504 meanShiftVector[comp] += weight * jointNeighbor[comp];
513 for (
unsigned int comp = 0; comp < jointDimension; comp++)
515 meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
521 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
531 typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
532 typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
536 typename InputImageType::ConstPointer input = this->GetInput();
544 const unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
546 typename OutputImageType::PixelType rangePixel(m_NumberOfComponentsPerPixel);
552 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
553 bandwidth[comp] = m_SpatialBandwidth;
554 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
555 bandwidth[comp] = m_RangeBandwidth;
559 RegionType const& requestedRegion = input->GetRequestedRegion();
562 JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
564 OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
565 OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
566 OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
567 OutputLabelIteratorType labelIt(labelOutput, outputRegionForThread);
570 ModeTableImageIteratorType modeTableIt(m_ModeTable, outputRegionForThread);
574 spatialIt.GoToBegin();
575 iterationIt.GoToBegin();
576 modeTableIt.GoToBegin();
579 unsigned int iteration = 0;
586 std::vector<InputIndexType> pointList;
587 if (m_ModeSearch) pointList.resize(m_MaxIterationNumber);
590 unsigned int numBreaks = 0;
594 for (; !jointIt.IsAtEnd(); ++jointIt, ++rangeIt, ++spatialIt, ++iterationIt, ++modeTableIt, ++labelIt, progress.CompletedPixel())
599 if (m_ModeSearch && currentPixelMode == 1)
605 bool hasConverged =
false;
609 jointPixel = jointIt.Get();
615 unsigned int pointCount = 0;
617 while ((iteration < m_MaxIterationNumber) && (!hasConverged))
623 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
625 modeCandidate[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
633 if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2
634 && outputRegionForThread.IsInside(modeCandidate))
638 RealVector const& candidatePixel = m_JointImage->GetPixel(modeCandidate);
639 for (
unsigned int comp = ImageDimension; comp < jointDimension; comp++)
641 const RealType d = candidatePixel[comp] - jointPixel[comp];
649 if (m_ModeTable->GetPixel(modeCandidate) == 0)
653 pointList[pointCount++] = modeCandidate;
654 m_ModeTable->SetPixel(modeCandidate, 2);
660 rangePixel = rangeOutput->GetPixel(modeCandidate);
661 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
663 jointPixel[ImageDimension + comp] = rangePixel[comp] / m_RangeBandwidth;
678 if (m_BucketOptimization)
680 this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
685 this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
693 double meanShiftVectorSqNorm = 0;
694 for (
unsigned int comp = 0; comp < jointDimension; comp++)
696 const double v = meanShiftVector[comp] * bandwidth[comp];
697 meanShiftVectorSqNorm += v * v;
698 jointPixel[comp] += meanShiftVector[comp];
702 hasConverged = meanShiftVectorSqNorm < m_Threshold;
706 for (
unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
708 rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth;
711 for (
unsigned int comp = 0; comp < ImageDimension; comp++)
713 spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp];
716 rangeIt.Set(rangePixel);
717 spatialIt.Set(spatialPixel);
719 const typename OutputIterationImageType::PixelType iterationPixel = iteration;
720 iterationIt.Set(iterationPixel);
729 if (hasConverged || iteration == m_MaxIterationNumber)
731 m_NumLabels[threadId]++;
732 label = m_NumLabels[threadId];
736 label = labelOutput->GetPixel(modeCandidate);
741 for (
unsigned int i = 0; i < pointCount; i++)
743 rangeOutput->SetPixel(pointList[i], rangePixel);
744 m_ModeTable->SetPixel(pointList[i], 1);
745 labelOutput->SetPixel(pointList[i], label);
751 labelIt.Set(labelZero);
759 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
764 OutputLabelIteratorType labelIt(labelOutput, labelOutput->GetRequestedRegion());
773 newLabelOffset.
SetSize(this->GetNumberOfThreads());
774 newLabelOffset[0] = 0;
775 for (
int i = 1; i < this->GetNumberOfThreads(); i++)
780 - m_ThreadIdNumberOfBits)) -
static_cast<LabelType> (1));
781 newLabelOffset[i] = localNumLabel + newLabelOffset[i - 1];
786 while (!labelIt.IsAtEnd())
791 const unsigned int threadId = label >> (
sizeof(
LabelType) * 8 - m_ThreadIdNumberOfBits);
798 newLabel += newLabelOffset[threadId];
800 labelIt.Set(newLabel);
807 template<
class TInputImage,
class TOutputImage,
class TKernel,
class TOutputIterationImage>
812 Superclass::PrintSelf(os, indent);
813 os << indent <<
"Spatial bandwidth: " << m_SpatialBandwidth << std::endl;
814 os << indent <<
"Range bandwidth: " << m_RangeBandwidth << std::endl;