/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef __otbGPUFineCorrelationImageFilter_txx
#define __otbGPUFineCorrelationImageFilter_txx

#include "otbGPUFineCorrelationImageFilter.h"

#include "itkMatrix.h"
#include "itkVector.h"

#include "itkLinearInterpolateImageFunction.h"
#include "itkImageRegionIteratorWithIndex.h"

extern "C" void fineCorrelationProcessing(float * fixedD, float * movingD, float * corrD,
                                          int fixWidth, int movWidth,
                                          int fixStride, int movStride,
                                          float* corr, int nBlocks,
                                          int patchRadius, int searchRadius);

extern "C" void loadIntoGPU(float* fixed, float* moving, float** fixedD, float** movingD, float** corrD,
                            int fixedWidth, int fixedHeight, int movingWidth, int movingHeight,
                            int nBlocks, int patchRadius);

extern "C" void freeGPU(float* fixedD, float* movingD, float* corrD);

namespace otb
{

/**
 * Constructor
 */
template <class TInputImage, class T0utputCorrelation, class TOutputDeformationField>
GPUFineCorrelationImageFilter<TInputImage,T0utputCorrelation,TOutputDeformationField>
::GPUFineCorrelationImageFilter()
{
  this->SetNumberOfThreads(1);
}

template <class TInputImage, class TOutputCorrelation, class TOutputDeformationField>
void
GPUFineCorrelationImageFilter<TInputImage,TOutputCorrelation,TOutputDeformationField>
::ThreadedGenerateData( const OutputImageRegionType &outputRegionForThread,
                        int threadId)
{
  // Get the image pointers
  const TInputImage * fixedPtr = this->GetFixedInput();
  const TInputImage * movingPtr = this->GetMovingInput();
  TOutputCorrelation * outputPtr = this->GetOutput();
  TOutputDeformationField * outputDfPtr = this->GetOutputDeformationField();

  // Iterators
  NeighborhoodIteratorType fixedIt(this->GetRadius(), fixedPtr, outputRegionForThread);
  NeighborhoodIteratorType movingIt(this->GetSearchRadius() + this->GetRadius(), movingPtr, outputRegionForThread);
  itk::ImageRegionIteratorWithIndex<TOutputCorrelation> outputIt(outputPtr, outputRegionForThread);
  itk::ImageRegionIterator<TOutputDeformationField> outputDfIt(outputDfPtr, outputRegionForThread);

  // Boundary conditions
  itk::ZeroFluxNeumannBoundaryCondition<TInputImage> fixedNbc;
  itk::ZeroFluxNeumannBoundaryCondition<TInputImage> movingNbc;
  fixedIt.OverrideBoundaryCondition(&fixedNbc);
  movingIt.OverrideBoundaryCondition(&movingNbc);

  // support progress methods/callbacks
  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());

  // Correlation neighborhood
//  NeighborhoodType correlationMap;
  typename NeighborhoodType::PixelType* correlationMap = new typename NeighborhoodType::PixelType[(2*this->GetSearchRadius()[0]+1)
                                                     *(2*this->GetSearchRadius()[1]+1)];
//  correlationMap.SetRadius(this->GetSearchRadius());

  //Setup for GPU
  int nBlocks = 2*outputRegionForThread.GetSize()[0];
//  int nBlocks = 1;
  std::cout << "\nMemory used in GPU for correlation: " << nBlocks*this->GetSearchRadius()[0]*this->GetSearchRadius()[1]*4./1000000. << " Mbytes" << std::endl;


  typename InputImageType::Pointer fixedIntermediatePtr;
  fixedIntermediatePtr = InputImageType::New();
  fixedIntermediatePtr->CopyInformation(fixedPtr);
  InputImageRegionType extendedFixedRegion = outputRegionForThread;
  extendedFixedRegion.PadByRadius( this->GetRadius() );
  fixedIntermediatePtr->SetBufferedRegion(extendedFixedRegion);
  fixedIntermediatePtr->Allocate();
  fixedIntermediatePtr->FillBuffer(0);

  ImageConstIterator fixedItTmp = ImageConstIterator(fixedPtr, outputRegionForThread);
  ImageIterator fixedItInterm = ImageIterator(fixedIntermediatePtr, outputRegionForThread);
  for (fixedItTmp.GoToBegin(), fixedItInterm.GoToBegin(); !fixedItTmp.IsAtEnd(); ++fixedItTmp, ++fixedItInterm)
  {
    fixedItInterm.Set(fixedItTmp.Get());
  }

  typename InputImageType::Pointer movingIntermediatePtr;
  movingIntermediatePtr = InputImageType::New();
  movingIntermediatePtr->CopyInformation(movingPtr);
  InputImageRegionType extendedMovingRegion = movingPtr->GetRequestedRegion();
  extendedMovingRegion.PadByRadius( this->GetRadius() + this->GetSearchRadius());
  movingIntermediatePtr->SetBufferedRegion(extendedMovingRegion);
  movingIntermediatePtr->Allocate();
  movingIntermediatePtr->FillBuffer(0);
  ImageConstIterator movingItTmp = ImageConstIterator(movingPtr, outputRegionForThread);
  ImageIterator movingItInterm = ImageIterator(movingIntermediatePtr, outputRegionForThread);
  for (movingItTmp.GoToBegin(), movingItInterm.GoToBegin(); !movingItTmp.IsAtEnd(); ++movingItTmp, ++movingItInterm)
  {
    movingItInterm.Set(movingItTmp.Get());
  }

  float* fixedD = NULL;
  float* movingD = NULL;
  float* corrD = NULL;
  int fixedWidth = fixedIntermediatePtr->GetBufferedRegion().GetSize()[0];
  int fixedHeight = fixedIntermediatePtr ->GetBufferedRegion().GetSize()[1];
  int movingWidth = movingIntermediatePtr->GetBufferedRegion().GetSize()[0];
  int movingHeight = movingIntermediatePtr ->GetBufferedRegion().GetSize()[1];

  loadIntoGPU(fixedIntermediatePtr->GetBufferPointer(),
              movingIntermediatePtr->GetBufferPointer(),
              &fixedD, &movingD, &corrD,
              fixedWidth, fixedHeight, movingWidth, movingHeight, nBlocks,
              this->GetSearchRadius()[0]);//FIXME assuming square search

  int fixStride = this->GetRadius()[0] + this->GetRadius()[1]*fixedWidth;
  int movStride = this->GetRadius()[0] + this->GetSearchRadius()[0]
                 + (this->GetRadius()[1]+this->GetSearchRadius()[1])*movingWidth;

  // GoToBegin
  fixedIt.GoToBegin();
  movingIt.GoToBegin();
  outputIt.GoToBegin();
  outputDfIt.GoToBegin();

  // Offset
  OffsetType offset;

  // Correl, max correl, maxPosition
  float maxCorrel;
  int iMax,jMax;
  OffsetType maxOffset;


  int count = 0;
  int searchWindowSizeX =  2*this->GetSearchRadius()[0]+1;
  int searchWindowSizeY =  2*this->GetSearchRadius()[1]+1;

  float * correl = new float[searchWindowSizeX*searchWindowSizeY*nBlocks];


  // Walk the images
  while (!fixedIt.IsAtEnd() && !movingIt.IsAtEnd() && !outputIt.IsAtEnd() && !outputDfIt.IsAtEnd())
    {
    // Initialize
    maxCorrel = -1.0;
//    maxOffset.Fill(0);
    typename NeighborhoodIteratorType::IndexType index = fixedIt.GetIndex();
    DeformationValueType deformationValue;

    int fixX = index[0];
    int fixY = index[1];

    if ((count % nBlocks) == 0)
      {
      fineCorrelationProcessing(fixedD, movingD, corrD,
                              fixedWidth, movingWidth,
//                              fixStride+index[0]+index[1]*fixedWidth,
//                              movStride+index[0]+index[1]*movingWidth,
                              index[0]+index[1]*fixedWidth,
                              index[0]+index[1]*movingWidth,
                              correl, nBlocks,
                              this->GetRadius()[0], this->GetSearchRadius()[0]);//FIXME assuming square search and patch
      }


    // Compute the correlation at each location
    int correlIndex = 0;
    int radiusX = this->GetSearchRadius()[0];
    int radiusY = this->GetSearchRadius()[1];
    for (int j = -radiusY; j <= (int) radiusY; ++j)
      {
//      std::cout << '\n';
      for (int i = -radiusX; i <= (int) radiusX; ++i)
        {
        // Update offset
        offset[0] = i;
        offset[1] = j;

        int correlIndexTmp = i+radiusX + (j+radiusY+ (count % nBlocks)*searchWindowSizeY)*searchWindowSizeX;
        assert(correlIndexTmp  < searchWindowSizeX * searchWindowSizeY * nBlocks);
        // Check for maximum
        correlationMap[correlIndex] = correl[correlIndexTmp ];
//        std::cout << correl[correlIndex] << ' ';
        if (correlationMap[correlIndex] > maxCorrel)
          {
          maxCorrel = correlationMap[correlIndex];
          maxOffset = offset;
          }
        ++correlIndex;
        }
//        std::cout << std::endl;
      }
//      std::cout << maxCorrel << " at " << maxOffset << "\n";
    // Perform LSQR QUADFIT refinement
    if (this->GetRefinementMode() == LSQR_QUADFIT)
      {
      //FIXME adaptation of the correlationMap structure required
      maxCorrel = this->RefineLocation(correlationMap, maxOffset, deformationValue);
      }
    // Perform SUBPIXEL refinement
    else if (this->GetRefinementMode() == SUBPIXEL)
      {
      // Get the neighborhood
      NeighborhoodType fixedN = fixedIt.GetNeighborhood();
      NeighborhoodType movingN = movingIt.GetNeighborhood();

      // Compute the interpolated fine grid
      NeighborhoodType subPixelN = this->ComputeSubPixelNeighborhood(fixedIt.GetIndex() + maxOffset,
                                                                     this->GetSubPixelPrecision());

      // The fine offset
      OffsetType fineOffset, maxFineOffset;
      maxFineOffset.Fill(0);

      // Compute the correlation at each fine location
      for (int i = -(int) this->GetSubPixelPrecision() + 1; i < (int) this->GetSubPixelPrecision(); ++i)
        {
        for (int j = -(int) this->GetSubPixelPrecision() + 1; j < (int) this->GetSubPixelPrecision(); ++j)
          {
          // Update the fine offset
          fineOffset[0] = i;
          fineOffset[1] = j;

          // Compute the correlation
          double correl = this->Correlation(fixedN, subPixelN, fineOffset, this->GetSubPixelPrecision());

          // If correlation is better
          if (correl > maxCorrel)
            {
            // Update values
            maxCorrel = correl;
            maxFineOffset = fineOffset;
            }
          }
        }

      // Finally, update deformation values
      deformationValue[0] = maxOffset[0] + (double) maxFineOffset[0] / (double) this->GetSubPixelPrecision();
      deformationValue[1] = maxOffset[1] + (double) maxFineOffset[1] / (double) this->GetSubPixelPrecision();

      }
    // Default and COARSE case
    else
      {
      deformationValue[0] = maxOffset[0];
      deformationValue[1] = maxOffset[1];
      }

    // Store the offset and the correlation value
    outputIt.Set(maxCorrel);
    outputDfIt.Set(deformationValue);

    // Update iterators
    ++fixedIt;
    ++movingIt;
    ++outputIt;
    ++outputDfIt;

    // Update progress
    progress.CompletedPixel();
    ++count;
    }
  delete [] correl;
  freeGPU(fixedD, movingD, corrD);
}

} // end namespace otb

#endif
