/*
 * otbGPUSAMFilter.txx
 *
 *  Created on: Apr 26, 2010
 *      Author: christop
 */

#ifndef __otbGPUSAMFilter_txx
#define __otbGPUSAMFilter_txx

#include "otbGPUSAMFilter.h"

extern "C" void samProcessing(float* pix, float* sam, int numBands, const float* ref, int imageWidth, int imageHeight);

namespace otb
{

//Note, we don't use the GenerateInputRequestedRegion as the block alignment has
//to be valid in the thread. A solution could be to overload the ThreaderCallback
//with a splitter that preserve block alignment.
//This need to be investigated

/*-------------------------------------------------------
 * ThreadedGenerateData
 --------------------------------------------------------*/
template <class TInputPointSet, class TOutputImage>
void
GPUSAMFilter<TInputPointSet, TOutputImage>
::ThreadedGenerateData(
    const   OutputImageRegionType&     outputRegionForThread,
    int   threadId)
{
  typename InputImageType::Pointer inputPtr = dynamic_cast<InputImageType *> (this->itk::ProcessObject::GetInput(0));
  typename OutputImageType::Pointer outputPtr = dynamic_cast<OutputImageType *> (this->itk::ProcessObject::GetOutput(0));

  bool padding = false;
  m_ExtendedRegion = PadBlockRegion(outputRegionForThread);
  if (m_ExtendedRegion != this->GetOutput()->GetRequestedRegion())
    {
    padding = true;
    }

  if (padding)
    {
    m_OutputIntermediatePtr = OutputImageType::New();
    m_OutputIntermediatePtr->CopyInformation(outputPtr);
    m_OutputIntermediatePtr->SetBufferedRegion(m_ExtendedRegion);
    m_OutputIntermediatePtr->Allocate();

    m_InputIntermediatePtr = InputImageType::New();
    m_InputIntermediatePtr->CopyInformation(inputPtr);
    m_InputIntermediatePtr->SetBufferedRegion(m_ExtendedRegion);
    m_InputIntermediatePtr->Allocate();

    InputImageConstIterator it = InputImageConstIterator(inputPtr, outputRegionForThread);
    InputImageIterator itOut = InputImageIterator(m_InputIntermediatePtr, outputRegionForThread);
    for (it.GoToBegin(), itOut.GoToBegin(); !it.IsAtEnd(); ++it, ++itOut)
      {
      itOut.Set(it.Get());
      }
    }
  else
    {
    outputPtr->SetBufferedRegion(outputRegionForThread);
    outputPtr->Allocate();
    m_OutputIntermediatePtr = outputPtr;
    m_InputIntermediatePtr = inputPtr;
    }

  //Superclass::GenerateData();
  //Do processing here
  this->DoProcessing();

  //handle the output

  if (padding)
    {
    outputPtr->SetBufferedRegion(outputRegionForThread);
    outputPtr->Allocate();
    OutputImageConstIterator it = OutputImageConstIterator(m_OutputIntermediatePtr, outputRegionForThread);
    OutputImageIterator itOut = OutputImageIterator(outputPtr, outputRegionForThread);
    for (it.GoToBegin(), itOut.GoToBegin(); !it.IsAtEnd(); ++it, ++itOut)
      {
      itOut.Set(it.Get());
      }

    }
  else
    {
    outputPtr = m_OutputIntermediatePtr;
    }
}

template<class TInputPointSet, class TOutputImage>
typename GPUSAMFilter<TInputPointSet, TOutputImage>::RegionType
GPUSAMFilter<TInputPointSet, TOutputImage>
::PadBlockRegion(RegionType region)
{
  typename RegionType::SizeType size = region.GetSize();
//  std::cout << "Size before padding: " << size << std::endl;
  if ((size[0] % m_BlkSize[0]) != 0)
    size[0] = size[0] + m_BlkSize[0] - (size[0] % m_BlkSize[0]);
  if ((size[1] % m_BlkSize[1]) != 0)
    size[1] = size[1] + m_BlkSize[1] - (size[1] % m_BlkSize[1]);
//  std::cout << "Size after padding: " << size << std::endl;
  region.SetSize(size);
  return region;
}

template<class TInputPointSet, class TOutputImage>
void GPUSAMFilter<TInputPointSet, TOutputImage>
::DoProcessing()
{


  float * sam = m_OutputIntermediatePtr->GetBufferPointer();
  int numBands = m_InputIntermediatePtr->GetNumberOfComponentsPerPixel();

  int imageWidth = m_OutputIntermediatePtr->GetBufferedRegion().GetSize()[0];
  int imageHeight = m_OutputIntermediatePtr->GetBufferedRegion().GetSize()[1];

  float * pix = m_InputIntermediatePtr->GetBufferPointer();

//  std::cout << m_InputIntermediatePtr->GetBufferedRegion() <<std::endl;
//  std::cout << m_OutputIntermediatePtr->GetBufferedRegion() <<std::endl;
  assert( m_OutputIntermediatePtr->GetBufferedRegion().GetSize()[0] ==  m_InputIntermediatePtr->GetBufferedRegion().GetSize()[0]);
  assert( m_OutputIntermediatePtr->GetBufferedRegion().GetSize()[1] ==  m_InputIntermediatePtr->GetBufferedRegion().GetSize()[1]);
  samProcessing(pix, sam, numBands, this->GetReferencePixel().GetDataPointer(), imageWidth, imageHeight);

}
#endif

}
