/*
 * Copyright (C) 2005-2024 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef otbGenericInterpolateImageFunction_hxx
#define otbGenericInterpolateImageFunction_hxx
#include "otbGenericInterpolateImageFunction.h"
#include "vnl/vnl_math.h"

namespace otb
{

/** Constructor */
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::GenericInterpolateImageFunction()
{
  m_WindowSize = 1;
  this->SetRadius(1);
  m_OffsetTable             = nullptr;
  m_WeightOffsetTable       = nullptr;
  m_TablesHaveBeenGenerated = false;
  m_NormalizeWeight         = false;
}

/** Destructor */
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::~GenericInterpolateImageFunction()
{
  this->ResetOffsetTable();
}

/** Delete every tables. */
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::ResetOffsetTable()
{
  // Clear the offset table
  if (m_OffsetTable != nullptr)
  {
    delete[] m_OffsetTable;
    m_OffsetTable = nullptr;
  }

  // Clear the weights tales
  if (m_WeightOffsetTable != nullptr)
  {
    for (unsigned int i = 0; i < m_OffsetTableSize; ++i)
    {
      delete[] m_WeightOffsetTable[i];
    }
    delete[] m_WeightOffsetTable;
    m_WeightOffsetTable = nullptr;
  }
}

template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::SetRadius(unsigned int rad)
{
  // m_Radius = rad;
  this->GetFunction().SetRadius(rad);
  m_WindowSize = rad << 1;
  this->Modified();
}

template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::Modified() const
{
  Superclass::Modified();
  m_TablesHaveBeenGenerated = false;
}

/** Initialize used tables*/
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::InitializeTables()
{
  // Compute the offset table size
  m_OffsetTableSize = 1;
  for (unsigned dim = 0; dim < ImageDimension; ++dim)
  {
    m_OffsetTableSize *= m_WindowSize;
  }

  // Allocate the offset table
  m_OffsetTable = new unsigned int[m_OffsetTableSize];

  // Allocate the weights tables
  m_WeightOffsetTable = new unsigned int*[m_OffsetTableSize];
  for (unsigned int i = 0; i < m_OffsetTableSize; ++i)
  {
    m_WeightOffsetTable[i] = new unsigned int[ImageDimension];
  }
}

/** Fill the weight offset table*/
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::FillWeightOffsetTable()
{
  // Initialize the neighborhood
  SizeType radius;
  radius.Fill(this->GetRadius().front());
  if (this->GetInputImage() != nullptr)
  {
    IteratorType it = IteratorType(radius, this->GetInputImage(), this->GetInputImage()->GetBufferedRegion());
    // Compute the offset tables (we ignore all the zero indices
    // in the neighborhood)
    unsigned int iOffset = 0;
    int          empty   = static_cast<int>(this->GetRadius().front());

    for (unsigned int iPos = 0; iPos < it.Size(); ++iPos)
    {
      // Get the offset (index)
      typename IteratorType::OffsetType off = it.GetOffset(iPos);

      // Check if the offset has zero weights
      bool nonzero = true;
      for (unsigned int dim = 0; dim < ImageDimension; ++dim)
      {
        if (off[dim] == -empty)
        {
          nonzero = false;
          break;
        }
      }
      // Only use offsets with non-zero indices
      if (nonzero)
      {
        // Set the offset index
        m_OffsetTable[iOffset] = iPos;

        // Set the weight table indices
        for (unsigned int dim = 0; dim < ImageDimension; ++dim)
        {
          m_WeightOffsetTable[iOffset][dim] = off[dim] + this->GetRadius().front() - 1;
        }
        // Increment the index
        iOffset++;
      }
    }
  }
  else
  {
    itkExceptionMacro(<< "An input has to be set");
  }
}

/** Initialize tables: need to be call explicitly */
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::Initialize()
{
  // Delete existing tables
  this->ResetOffsetTable();
  // Tables initialization
  this->InitializeTables();
  // fill the weight table
  this->FillWeightOffsetTable();
  m_TablesHaveBeenGenerated = true;
}

/** Evaluate at image index position */
template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
typename GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::OutputType
GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::EvaluateAtContinuousIndex(const ContinuousIndexType& index) const
{
  if (!m_TablesHaveBeenGenerated)
  {
    itkExceptionMacro(<< "The Interpolation functor need to be explicitly intanciated with the method Initialize()");
  }

  // unsigned int dim;
  IndexType baseIndex;
  double    distance[ImageDimension];

  // Compute the integer index based on the continuous one by
  // 'flooring' the index
  for (unsigned int dim = 0; dim < ImageDimension; ++dim)
  {
    // The following "if" block is equivalent to the following line without
    // having to call floor.
    //    baseIndex[dim] = (long) std::floor(index[dim] );
    if (index[dim] >= 0.0)
    {
      baseIndex[dim] = (long)index[dim];
    }
    else
    {
      long tIndex = (long)index[dim];
      if (double(tIndex) != index[dim])
      {
        tIndex--;
      }
      baseIndex[dim] = tIndex;
    }
    distance[dim] = index[dim] - double(baseIndex[dim]);
  }

  // Position the neighborhood at the index of interest
  SizeType radius;
  radius.Fill(this->GetRadius().front());
  IteratorType nit = IteratorType(radius, this->GetInputImage(), this->GetInputImage()->GetBufferedRegion());
  nit.SetLocation(baseIndex);

  const unsigned int twiceRadius = static_cast<const unsigned int>(2 * this->GetRadius().front());
  /*  double xWeight[ImageDimension][ twiceRadius]; */
  std::vector<std::vector<double>> xWeight;
  xWeight.resize(ImageDimension);
  for (unsigned int cpt = 0; cpt < xWeight.size(); ++cpt)
  {
    xWeight[cpt].resize(twiceRadius);
  }

  for (unsigned int dim = 0; dim < ImageDimension; ++dim)
  {
    // x is the offset, hence the parameter of the kernel
    double x = distance[dim] + this->GetRadius().front();

    // If distance is zero, i.e. the index falls precisely on the
    // pixel boundary, the weights form a delta function.
    /*
    if(distance[dim] == 0.0)
    {
    for( unsigned int i = 0; i < m_WindowSize; ++i)
      {
    xWeight[dim][i] = static_cast<int>(i) == (static_cast<int>(this->GetRadius()) - 1) ? 1. : 0.;
      }
    }
    else
    {
    */
    // i is the relative offset in dimension dim.
    for (unsigned int i = 0; i < m_WindowSize; ++i)
    {
      // Increment the offset, taking it through the range
      // (dist + rad - 1, ..., dist - rad), i.e. all x
      // such that std::abs(x) <= rad
      x -= 1.0;
      // Compute the weight for this m
      xWeight[dim][i] = m_Function(x);
    }
    //}
  }
  if (m_NormalizeWeight == true)
  {
    for (unsigned int dim = 0; dim < ImageDimension; ++dim)
    {
      double sum = 0.;
      // Compute the weights sum
      for (unsigned int i = 0; i < m_WindowSize; ++i)
      {
        sum += xWeight[dim][i];
      }
      if (sum != 1.)
      {
        // Normalize the weights
        for (unsigned int i = 0; i < m_WindowSize; ++i)
        {
          xWeight[dim][i] = xWeight[dim][i] / sum;
        }
      }
    }
  }

  // Iterate over the neighborhood, taking the correct set
  // of weights in each dimension
  RealType xPixelValue;
  itk::NumericTraits<RealType>::SetLength(xPixelValue, this->GetInputImage()->GetNumberOfComponentsPerPixel());
  xPixelValue = static_cast<RealType>(0.0);

  for (unsigned int j = 0; j < m_OffsetTableSize; ++j)
  {
    // Get the offset for this neighbor
    unsigned int off = m_OffsetTable[j];

    // Get the intensity value at the pixel
    RealType xVal = nit.GetPixel(off);

    // Multiply the intensity by each of the weights. Gotta hope
    // that the compiler will unwrap this loop and pipeline this!
    for (unsigned int dim = 0; dim < ImageDimension; ++dim)
    {
      xVal *= xWeight[dim][m_WeightOffsetTable[j][dim]];
    }

    // Increment the pixel value
    xPixelValue += xVal;
  }

  // Return the interpolated value
  return static_cast<OutputType>(xPixelValue);
}

template <class TInputImage, class TFunction, class TBoundaryCondition, class TCoordRep>
void GenericInterpolateImageFunction<TInputImage, TFunction, TBoundaryCondition, TCoordRep>::PrintSelf(std::ostream& os, itk::Indent indent) const
{
  Superclass::PrintSelf(os, indent);
}

} // namespace otb

#endif
