/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#ifndef __otbSVMModelEstimator_h
#define __otbSVMModelEstimator_h

#include "itkObject.h"
#include "itkObjectFactory.h"

// Extern declaration of the Cuda method
extern "C"
void SVMTrain(float *mexalpha,float* beta,float*y,float *x ,float C, float kernelwidth, int m, int n, float StoppingCrit);

namespace otb
{

/** \class 
 * \brief

 *
 * \ingroup ClassificationFilters
 */
template <class TInputListSample, class TTrainingListSample>
class ITK_EXPORT CuSVMListSampleEstimator 
: public itk::Object
{
public:
  /** Standard class typedefs. */
  typedef CuSVMListSampleEstimator                               Self;
  typedef itk::Object                                            Superclass;
  typedef itk::SmartPointer<Self>                                Pointer;
  typedef itk::SmartPointer<const Self>                          ConstPointer;

  /** Samples list typedef */
  typedef TInputListSample                                       ListSampleType;
  typedef typename ListSampleType::MeasurementVectorType                             SampleType;
  typedef typename ListSampleType::Pointer                       ListSamplePointerType;
  typedef typename ListSampleType::MeasurementVectorType         MeasurementVectorType;

  typedef TTrainingListSample                                    TrainingListSampleType;
  typedef typename TrainingListSampleType::Pointer               TrainingListSamplePointerType;
  typedef typename TrainingListSampleType::MeasurementVectorType LabelMeasurementType;

  /** Method for creation through the object factory. */
  itkNewMacro(Self);

  itkSetObjectMacro(Samples,ListSampleType);
  itkSetObjectMacro(Labels,TrainingListSampleType);

  void SetMinimum(const SampleType & min)
  {
    m_Minimum = min;
  }

  void SetMaximum(const SampleType & max)
  {
    m_Maximum = max;
  }

  float * GetAlphas()
  {
    return m_Alphas;
  }

  float * GetSupportVectors()
  {
    return m_SupportVectors;
  }

  float GetBeta()
  {
    return m_Beta;
  }

  itkSetMacro(KernelWidth,unsigned int);
  itkSetMacro(C,float);
  itkSetMacro(StoppingThreshold,float);

  itkGetMacro(NumberOfSupportVectors,unsigned int);

  // Solve SVM
  void Compute()
  {
    // Retrieve the number of measurements
    unsigned int numberOfMeasurements = m_Samples->Size();

    // Check for some measurements
    if(numberOfMeasurements == 0)
      {
      itkExceptionMacro(<<"The training samples list is empty");
      }

    // Check for correct number of labels
    if(m_Labels->Size() != numberOfMeasurements)
      {
      itkExceptionMacro(<<"Training samples list and labels list must have the same size.");
      }
    
    // Retrieve the size of the samples
    unsigned int measurementSize = m_Samples->Begin().GetMeasurementVector().Size();

    // Build inputs for Cuda
    float * x;
    float * y;
    float * alpha;

    // Allocate memory
    x = new float[numberOfMeasurements*measurementSize];
    y = new float[numberOfMeasurements];
    alpha = new float[numberOfMeasurements];
    
    // Copy data
    typename ListSampleType::ConstIterator lsIt = m_Samples->Begin();
    typename TrainingListSampleType::ConstIterator labIt = m_Labels->Begin();

    unsigned int sampleId = 0;

    while(lsIt != m_Samples->End() && labIt != m_Labels->End())
      {
      for(unsigned int i = 0; i<measurementSize;++i)
	{
	x[sampleId + i * numberOfMeasurements] = (lsIt.GetMeasurementVector()[i]-m_Minimum[i])/(m_Maximum[i]-m_Minimum[i]);
	}
      
      if(labIt.GetMeasurementVector()[0] == 1)
	{
	y[sampleId] = -1;
	}
      else
	{
	y[sampleId] = 1;
	}
      
	alpha[sampleId] = 0;
	
	++ sampleId;
	++lsIt;
	++labIt;
      }

    // Trigger cuda SVM optimizer
    SVMTrain(alpha,&m_Beta,y,x,m_C,m_KernelWidth,numberOfMeasurements,measurementSize,m_StoppingThreshold);

    // Clear previous memory
    if(m_SupportVectors != NULL)
      {
      delete [] m_SupportVectors;
      }
    if(m_Alphas != NULL)
      {
      delete [] m_Alphas;
      }
    
    // Compute number of support vectors and number of positive vectors
    int numSVs = 0;
    int numPosSVs = 0;

    for(unsigned int i = 0; i < numberOfMeasurements; ++i)
      {
      if(alpha[i]!=0)
	{
	++numSVs;
	alpha[i]*=y[i];
	
	if(y[i]>0)
	  {
	  ++numPosSVs;
	  }
	}
      }

    // Build alphas and support vectors
    m_Alphas = new float[numSVs];
    m_SupportVectors = new float[numSVs*measurementSize];

    unsigned int posSvIndex=0;
    unsigned int negSvIndex=0;


    for(unsigned int i = 0; i < numberOfMeasurements; ++i)
      {
      if(alpha[i]!=0)
	{
	if(y[i]>0)
	  {
	  m_Alphas[posSvIndex] = alpha[i];
	  
	  for(unsigned int j = 0;j<measurementSize;++j)
	    {
	    m_SupportVectors[posSvIndex+j*numSVs]=x[i+j*numberOfMeasurements];
	    }
	  ++posSvIndex;
	  }
	else
	  {
	  m_Alphas[numPosSVs+negSvIndex] = alpha[i];

	  for(unsigned int j = 0;j<measurementSize;++j)
	    {
	    m_SupportVectors[negSvIndex+numPosSVs+j*numSVs]=x[i+j*numberOfMeasurements];
	    }
	  ++negSvIndex;
	  }
	}
      }

    m_NumberOfSupportVectors = numSVs;

    std::cout<<"Number of support vectors: "<<m_NumberOfSupportVectors<<std::endl;
    
    // Free memory
    delete [] x;
    delete [] y;
    delete [] alpha;
  }

protected:
  /** Constructor */
  CuSVMListSampleEstimator() : m_C(1.0), 
                               m_KernelWidth(1.0),
                               m_StoppingThreshold(0.001), 
                               m_SupportVectors(NULL),
                               m_Alphas(NULL),
                               m_Beta(0)
    {}

  /** Destructor */
  virtual ~CuSVMListSampleEstimator()
  {
  // Clear previous memory
    if(m_SupportVectors != NULL)
      {
      delete [] m_SupportVectors;
      }
    if(m_Alphas != NULL)
      {
      delete [] m_Alphas;
      }  
  }

  /** PrintSelf method */
  virtual void PrintSelf(std::ostream& os, itk::Indent indent) const
  {
    // Call superclass implementation
    Superclass::PrintSelf(os,indent);
  }

private:
  CuSVMListSampleEstimator(const Self &); //purposely not implemented
  void operator =(const Self&); //purposely not implemented

  // Training samples
  ListSamplePointerType m_Samples;

  // Labels
  TrainingListSamplePointerType m_Labels;

  // Regularization parameter
  float m_C;

  // Kernel width
  float m_KernelWidth;

  // Stopping Threshold
  float m_StoppingThreshold;

  // Support vectors
  float * m_SupportVectors;

  // Alphas
  float * m_Alphas;

  // Beta
  float m_Beta;

  // Number of SVs
  unsigned int m_NumberOfSupportVectors;

  SampleType m_Minimum;
  SampleType m_Maximum;

}; // class CuSVMListSampleEstimator

} // namespace otb

#endif
