/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef __otbCuSVMImageClassificationFilter_h
#define __otbCuSVMImageClassificationFilter_h

#include "itkInPlaceImageFilter.h"
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionIterator.h"

extern "C"
void GPUPredictWrapper (int m, int n, int k, float kernelwidth, const float *Test, const float *Svs, float * alphas,float *prediction, float beta,float isregression);

namespace otb
{
/** \class CuSVMImageClassificationFilter
 *  \brief 
 *
 * \sa SVMClassifier
 * \ingroup Streamed
 * \ingroup Threaded
 */
template <class TInputImage, class TOutputImage>
class ITK_EXPORT CuSVMImageClassificationFilter
  : public itk::InPlaceImageFilter<TInputImage, TOutputImage>
{
public:
  /** Standard typedefs */
  typedef CuSVMImageClassificationFilter                       Self;
  typedef itk::InPlaceImageFilter<TInputImage, TOutputImage> Superclass;
  typedef itk::SmartPointer<Self>                            Pointer;
  typedef itk::SmartPointer<const Self>                      ConstPointer;

  /** Type macro */
  itkNewMacro(Self);

  /** Creation through object factory macro */
  itkTypeMacro(CuSVMImageClassificationFilter, InPlaceImageFilter);

  /** The max dimension of the sample to classify.
   *  This filter internally uses itk::FixedArray as input for the classifier,
   *  so the max sample size has to be fixed at compilation time.
   */

  typedef TInputImage                                InputImageType;
  typedef typename InputImageType::ConstPointer      InputImageConstPointerType;
  typedef typename InputImageType::InternalPixelType ValueType;
  typedef typename InputImageType::PixelType                  SampleType;

  typedef TOutputImage                         OutputImageType;
  typedef typename OutputImageType::Pointer    OutputImagePointerType;
  typedef typename OutputImageType::RegionType OutputImageRegionType;
  typedef typename OutputImageType::PixelType  LabelType;


  void SetMinimum(const SampleType & min)
  {
    m_Minimum = min;
  }

  void SetMaximum(const SampleType & max)
  {
    m_Maximum = max;
  }

  void SetSupportVectors(float * svs)
    {
    m_SupportVectors = svs;
    }

  void SetAlphas(float * alphas)
    {
    m_Alphas = alphas;
    }

  void SetBeta(const float & beta)
    {
    m_Beta = beta;
    }

  itkSetMacro(KernelWidth,unsigned int);
  itkSetMacro(NumberOfSupportVectors,unsigned int);

protected:
  /** Constructor */
  CuSVMImageClassificationFilter() : m_SupportVectors(NULL),m_Alphas(NULL),m_Beta(0.0),m_KernelWidth(1.), m_NumberOfSupportVectors(0)
    {
    this->SetInPlace(false);
    }

  /** Destructor */
  virtual ~CuSVMImageClassificationFilter() {}

  /** Generate Data */
  virtual void GenerateData()
  {
    OutputImagePointerType     outputPtr = this->GetOutput();
    InputImageConstPointerType inputPtr  = this->GetInput();

    OutputImageRegionType outputRegion = outputPtr->GetRequestedRegion();
    
    outputPtr->SetBufferedRegion(outputRegion);
    outputPtr->Allocate();

    itk::ImageRegionConstIterator<InputImageType> inIt(inputPtr,outputRegion);
    inIt.GoToBegin();

    unsigned int sampleSize = inIt.Get().Size();

    float * tests = new float[sampleSize * outputRegion.GetNumberOfPixels()];
    float * predictions = new float[outputRegion.GetNumberOfPixels()];

    unsigned int sampleIndex = 0;

    while(!inIt.IsAtEnd())
      {
      for(unsigned int i = 0;i<sampleSize;++i)
	{
	tests[sampleIndex+i*outputRegion.GetNumberOfPixels()] = (inIt.Get()[i] - m_Minimum[i])/(m_Maximum[i]-m_Minimum[i]);
	}
      predictions[sampleIndex] = 0;

      ++sampleIndex;
      ++inIt;
      }
    
    // Call GPUs
    GPUPredictWrapper(outputRegion.GetNumberOfPixels(),m_NumberOfSupportVectors,sampleSize,m_KernelWidth,tests,m_SupportVectors,m_Alphas,predictions,m_Beta,0);

    itk::ImageRegionIterator<OutputImageType> outIt(outputPtr,outputRegion);
    outIt.GoToBegin();
    sampleIndex = 0;

    while(!outIt.IsAtEnd())
      {
      if(predictions[sampleIndex]<0)
	{
	outIt.Set(1);
	}
      else
	{
	outIt.Set(2);
	}
      ++sampleIndex;
      ++outIt;
      }

  }

  /**PrintSelf method */
  virtual void PrintSelf(std::ostream& os, itk::Indent indent) const
  {
    Superclass::PrintSelf(os,indent);
  }

private:
  CuSVMImageClassificationFilter(const Self &); //purposely not implemented
  void operator =(const Self&); //purposely not implemented

  float * m_SupportVectors;
  float * m_Alphas;
  float m_Beta;
  unsigned int m_KernelWidth;
  unsigned int m_NumberOfSupportVectors;
  SampleType m_Minimum;
  SampleType m_Maximum;

};
} // End namespace otb


#endif
