#include "otbVectorImage.h"
#include "otbImage.h"
#include "otbImageFileReader.h"
#include "otbStreamingImageFileWriter.h"
#include "otbSVMSampleListModelEstimator.h"
#include "otbSVMImageClassificationFilter.h"
#include "itkImageRegionSplitter.h"
#include "otbStreamingTraits.h"
#include "itkImageRegionConstIterator.h"
#include "itkVariableSizeMatrix.h"
#include "otbCommandLineArgumentParser.h"
#include "itkListSample.h"
#include "itkTimeProbe.h"
#include "otbVectorRescaleIntensityImageFilter.h"

//./svmCPU -in ../data/W_Spot.tif -tm ../data/GloucesterROIs.tif  -out testCPUSpot.tif -sl 20 -c 1.0

int main(int argc, char * argv[])
{

  // Parse command line parameters
  typedef otb::CommandLineArgumentParser ParserType;
  ParserType::Pointer parser = ParserType::New();

  parser->SetProgramDescription("Supervised SVM image classification with random training and validation set");
  parser->AddInputImage();
  parser->AddOutputImage();
  parser->AddOption("--TrainingMap","Labeled training map","-tm",1,true);
  parser->AddOption("--StreamingNumberOfLines","Number of lined for each streaming block","-sl",1,true);
  parser->AddOption("--C","C parameter","-c",1,true);
    


  typedef otb::CommandLineArgumentParseResult ParserResultType;
  ParserResultType::Pointer  parseResult = ParserResultType::New();

  try
  {
    parser->ParseCommandLine(argc,argv,parseResult);
  }
  catch ( itk::ExceptionObject & err )
  {

    std::string descriptionException = err.GetDescription();
    if (descriptionException.find("ParseCommandLine(): Help Parser") != std::string::npos)
    {
      return EXIT_SUCCESS;
    }
    if (descriptionException.find("ParseCommandLine(): Version Parser") != std::string::npos)
    {
      return EXIT_SUCCESS;
    }
    return EXIT_FAILURE;
  }

  // initiating random number generation
  srand(time(NULL));

  std::string infname = parseResult->GetInputImage();
  std::string labelfname = parseResult->GetParameterString("--TrainingMap",0);
  std::string outfname = parseResult->GetOutputImage();
  const unsigned int nbLinesForStreaming = parseResult->GetParameterUInt("--StreamingNumberOfLines");
  const float cparam = parseResult->GetParameterFloat("--C");


  typedef float                  PixelType;
  typedef unsigned short LabeledPixelType;

  typedef otb::VectorImage<PixelType,2> ImageType;
  typedef otb::Image<LabeledPixelType,2> LabeledImageType;
  typedef otb::ImageFileReader<ImageType> ImageReaderType;
  typedef otb::ImageFileReader<LabeledImageType> LabeledImageReaderType;
  typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType;

  typedef otb::StreamingTraits<ImageType> StreamingTraitsType;
  typedef itk::ImageRegionSplitter<2>  SplitterType;
  typedef ImageType::RegionType RegionType;

  typedef itk::ImageRegionConstIterator<ImageType> IteratorType;
  typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;

  typedef itk::VariableLengthVector<PixelType>                                    SampleType;
  typedef itk::Statistics::ListSample<SampleType>                                 ListSampleType;
  typedef itk::FixedArray<LabeledPixelType,1>                                     TrainingSampleType;
  typedef itk::Statistics::ListSample<TrainingSampleType>                         TrainingListSampleType;
  typedef otb::SVMSampleListModelEstimator<ListSampleType,TrainingListSampleType> EstimatorType;
  typedef otb::SVMImageClassificationFilter<ImageType,LabeledImageType>         ClassificationFilterType;
  typedef otb::VectorRescaleIntensityImageFilter<ImageType,ImageType>             RescalerType;

  typedef std::map<LabeledPixelType,unsigned int>                                 ClassesMapType;

  typedef itk::VariableSizeMatrix<double>                                         ConfusionMatrixType;

  ImageReaderType::Pointer reader = ImageReaderType::New();
  LabeledImageReaderType::Pointer labeledReader = LabeledImageReaderType::New();

  reader->SetFileName(infname);
  labeledReader->SetFileName(labelfname);

  /*******************************************/
  /*           Sampling data                 */
  /*******************************************/

  std::cout<<"-- SAMPLING DATA --"<<std::endl;
  std::cout<<std::endl;

  // Update input images information
  reader->GenerateOutputInformation();
  labeledReader->GenerateOutputInformation();

  if (reader->GetOutput()->GetLargestPossibleRegion()
      != labeledReader->GetOutput()->GetLargestPossibleRegion()
     )
  {
    std::cerr<<"Label image, mask image and input image have different sizes."<<std::endl;
    return EXIT_FAILURE;
  }

  RegionType largestRegion = reader->GetOutput()->GetLargestPossibleRegion();

  // Setting up local streaming capabilities
  SplitterType::Pointer splitter = SplitterType::New();
  unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(reader->GetOutput(),
                                         largestRegion,
                                         splitter,
                                         otb::SET_BUFFER_NUMBER_OF_LINES,
                                         0,0,nbLinesForStreaming);

  std::cout<<"The images will be streamed into "<<numberOfStreamDivisions<<" parts."<<std::endl;

  // Training sample lists
  ListSampleType::Pointer sampleList = ListSampleType::New();
  TrainingListSampleType::Pointer trainingSampleList = TrainingListSampleType::New();

  // Sample dimension
  unsigned int sampleSize = reader->GetOutput()->GetNumberOfComponentsPerPixel();

  SampleType min(sampleSize),max(sampleSize);
  bool firstTime = true;

  std::cout<<"The following sample size will be used: "<<sampleSize<<std::endl;
  std::cout<<std::endl;
  // local streaming variables
  unsigned int piece = 0;
  RegionType streamingRegion;

  // Information on the different classes

  ClassesMapType classesMap;
  ClassesMapType valClassesMap;
  ClassesMapType totalMap;
  ClassesMapType indexMap;

  unsigned short currentIndex = 0;

  // For each streamed part
  for (piece = 0;
       piece < numberOfStreamDivisions;
       piece++)
  {
    streamingRegion = splitter->GetSplit(piece,numberOfStreamDivisions,largestRegion);

//    std::cout<<"Processing region: "<<streamingRegion<<std::endl;

    reader->GetOutput()->SetRequestedRegion(streamingRegion);
    reader->GetOutput()->PropagateRequestedRegion();
    reader->GetOutput()->UpdateOutputData();

    labeledReader->GetOutput()->SetRequestedRegion(streamingRegion);
    labeledReader->GetOutput()->PropagateRequestedRegion();
    labeledReader->GetOutput()->UpdateOutputData();


    IteratorType it(reader->GetOutput(),streamingRegion);
    LabeledIteratorType labeledIt(labeledReader->GetOutput(),streamingRegion);

    it.GoToBegin();
    labeledIt.GoToBegin();
    
    if(firstTime)
      {
      min = it.Get();
      max = it.Get();
      firstTime = false;
      }

    unsigned int totalSamples = 0;
    unsigned int totalValidationSamples = 0;

    // Loop on the image
    while (!it.IsAtEnd()&&!labeledIt.IsAtEnd())
    {
      // If the current pixel is labeled
      if (labeledIt.Get() !=0)
      {
          SampleType newSample(sampleSize);
          TrainingSampleType newTrainingSample;

          // build the sample
          newSample.Fill(0);
          for (unsigned int i = 0;i<sampleSize;++i)
          {
            newSample[i]=it.Get()[i];

	    if(min[i]>it.Get()[i])
	      {
	      min[i]=it.Get()[i];
	      }

	    if(max[i]<it.Get()[i])
	      {
	      max[i]=it.Get()[i];
	      }
	    
          }

          // build the training sample
          newTrainingSample[0]=labeledIt.Get();

          // Update the the sample lists
          sampleList->PushBack(newSample);
          trainingSampleList->PushBack(newTrainingSample);
          ++totalSamples;
        }
      ++it;
      ++labeledIt;
    }

  }



  // Normalization
  ListSampleType::Pointer nSampleList = ListSampleType::New();

  for(ListSampleType::Iterator lit = sampleList->Begin();lit!=sampleList->End();++lit)
    {
    SampleType sample = lit.GetMeasurementVector();

    for(unsigned int i = 0; i < sampleSize; ++i)
      {
      sample[i] = (sample[i]-min[i])/(max[i]-min[i]);
      }
    nSampleList->PushBack(sample);

    }

  std::cout<<nSampleList->Size()<<" samples added to the training set."<<std::endl;
  std::cout<<std::endl;

  /*******************************************/
  /*           Learning                      */
  /*******************************************/

  std::cout<<"-- LEARNING --"<<std::endl;
  std::cout<<std::endl;

  EstimatorType::Pointer estimator = EstimatorType::New();

  estimator->SetInputSampleList(nSampleList);
  estimator->SetTrainingSampleList(trainingSampleList);
  estimator->ParametersOptimizationOff();
  estimator->SetKernelType(RBF);
  estimator->SetC(cparam);
  
  timespec startClock, endClock;
  time_t startTime = time(NULL);
  clock_t startNdvi = clock();
  clock_gettime(CLOCK_REALTIME, &startClock);

  estimator->Update();

   clock_gettime(CLOCK_REALTIME, &endClock);
  clock_t endNdvi = clock();
  time_t endTime = time(NULL);
  std::cout<<std::endl;
  std::cout << "Time1: " << std::setprecision(15)
            << (endNdvi-startNdvi)/((float) CLOCKS_PER_SEC) << std::endl;
  std::cout << "Time2: " << std::setprecision(15)
            << (endTime-startTime) << std::endl;
  std::cout << "Time3: " << std::setprecision(15)
            << (endClock.tv_sec-startClock.tv_sec) + (endClock.tv_nsec-startClock.tv_nsec)/1000000000. << std::endl;

  std::cout<<"-- CLASSIFICATION --"<<std::endl;
  std::cout<<std::endl;

 RescalerType::Pointer rescaler = RescalerType::New();
  rescaler->SetInput(reader->GetOutput());
  rescaler->AutomaticInputMinMaxComputationOff();
  rescaler->SetInputMinimum(min);
  rescaler->SetInputMaximum(max);

    SampleType outMin = min;
  outMin.Fill(0);
  
  SampleType outMax = max;
  outMax.Fill(1);

  rescaler->SetOutputMinimum(outMin);
  rescaler->SetOutputMaximum(outMax);





  ClassificationFilterType::Pointer classifier = ClassificationFilterType::New();
  classifier->SetInput(rescaler->GetOutput());
  classifier->SetModel(estimator->GetModel());

  WriterType::Pointer writer = WriterType::New();
  writer->SetInput(classifier->GetOutput());
  writer->SetBufferNumberOfLinesDivisions(nbLinesForStreaming);
  writer->SetFileName(outfname);

  timespec startClock2, endClock2;
  time_t startTime2 = time(NULL);
  clock_t startNdvi2 = clock();
  clock_gettime(CLOCK_REALTIME, &startClock2);

  writer->Update();

  clock_gettime(CLOCK_REALTIME, &endClock2);
  clock_t endNdvi2 = clock();
  time_t endTime2 = time(NULL);

  std::cout<<std::endl;

  std::cout << "Time1: " << std::setprecision(15)
            << (endNdvi2-startNdvi2)/((float) CLOCKS_PER_SEC) << std::endl;
  std::cout << "Time2: " << std::setprecision(15)
            << (endTime2-startTime2) << std::endl;
  std::cout << "Time3: " << std::setprecision(15)
            << (endClock2.tv_sec-startClock2.tv_sec) + (endClock2.tv_nsec-startClock2.tv_nsec)/1000000000. << std::endl;


  return EXIT_SUCCESS;
}
