/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

#include "process_int.h"
#include "cdo_wtime.h"
#include "remap.h"
#include "remap_store_link.h"
#include "cdo_options.h"
#include "progress.h"
#include "cimdOmp.h"

// bilinear interpolation

static inline void
limit_dphi_bounds(double &dphi)
{
  if (dphi > 3.0 * PIH) dphi -= PI2;
  if (dphi < -3.0 * PIH) dphi += PI2;
}

bool
remap_find_weights(const LonLatPoint &llpoint, const double (&src_lons)[4], const double (&src_lats)[4], double *ig, double *jg)
{
  constexpr double converge = 1.e-10;  // Convergence criterion
  extern long remap_max_iter;

  // Iterate to find iw,jw for bilinear approximation

  // some latitude  differences
  const auto dth1 = src_lats[1] - src_lats[0];
  const auto dth2 = src_lats[3] - src_lats[0];
  const auto dth3 = src_lats[2] - src_lats[1] - dth2;

  // some longitude differences
  auto dph1 = src_lons[1] - src_lons[0];
  auto dph2 = src_lons[3] - src_lons[0];
  auto dph3 = src_lons[2] - src_lons[1];

  limit_dphi_bounds(dph1);
  limit_dphi_bounds(dph2);
  limit_dphi_bounds(dph3);

  dph3 = dph3 - dph2;

  // current guess for bilinear coordinate
  double iguess = 0.5;
  double jguess = 0.5;

  long iter = 0;  // iteration counters
  for (iter = 0; iter < remap_max_iter; ++iter)
    {
      const auto dthp = llpoint.lat - src_lats[0] - dth1 * iguess - dth2 * jguess - dth3 * iguess * jguess;
      auto dphp = llpoint.lon - src_lons[0];

      limit_dphi_bounds(dphp);

      dphp = dphp - dph1 * iguess - dph2 * jguess - dph3 * iguess * jguess;

      const auto mat1 = dth1 + dth3 * jguess;
      const auto mat2 = dth2 + dth3 * iguess;
      const auto mat3 = dph1 + dph3 * jguess;
      const auto mat4 = dph2 + dph3 * iguess;

      const auto determinant = mat1 * mat4 - mat2 * mat3;

      const auto deli = (dthp * mat4 - dphp * mat2) / determinant;
      const auto delj = (dphp * mat1 - dthp * mat3) / determinant;

      if (std::fabs(deli) < converge && std::fabs(delj) < converge) break;

      iguess += deli;
      jguess += delj;
    }

  *ig = iguess;
  *jg = jguess;

  return (iter < remap_max_iter);
}

static void
bilinearSetWeights(double iw, double jw, double (&weights)[4])
{
  // clang-format off
  weights[0] = (1.0-iw) * (1.0-jw);
  weights[1] =      iw  * (1.0-jw);
  weights[2] =      iw  *      jw;
  weights[3] = (1.0-iw) *      jw;
  // clang-format on
}

int
num_src_points(const Varray<short> &mask, const size_t (&src_add)[4], double (&src_lats)[4])
{
  int icount = 0;

  for (int n = 0; n < 4; ++n)
    {
      if (mask[src_add[n]] != 0)
        icount++;
      else
        src_lats[n] = 0.;
    }

  return icount;
}

static void
renormalizeWeights(const double (&src_lats)[4], double (&weights)[4])
{
  double sum_weights = 0.0;  // sum of weights for normalization
  for (unsigned n = 0; n < 4; ++n) sum_weights += std::fabs(src_lats[n]);
  for (unsigned n = 0; n < 4; ++n) weights[n] = std::fabs(src_lats[n]) / sum_weights;
}

static void
bilinearWarning()
{
  static bool lwarn = true;

  if (Options::cdoVerbose || lwarn)
    {
      lwarn = false;
      cdo_warning("Bilinear interpolation failed for some grid points - used a distance-weighted average instead!");
    }
}

// This routine computes the weights for a bilinear interpolation.
void
remap_bilinear_weights(RemapSearch &rsearch, RemapVars &rv)
{
  auto src_grid = rsearch.srcGrid;
  auto tgt_grid = rsearch.tgtGrid;

  if (Options::cdoVerbose) cdo_print("Called %s()", __func__);

  if (src_grid->rank != 2) cdo_abort("Can't do bilinear interpolation when source grid rank != 2");

  auto start = Options::cdoVerbose ? cdo_get_wtime() : 0.0;

  progress::init();

  // Compute mappings from source to target grid

  auto tgt_grid_size = tgt_grid->size;

  std::vector<WeightLinks> weightLinks(tgt_grid_size);
  weight_links_alloc(4, tgt_grid_size, weightLinks);

  auto findex = 0.0;

  // Loop over destination grid

#ifdef _OPENMP
#pragma omp parallel for default(none) schedule(static) shared(findex, rsearch, weightLinks, tgt_grid_size, src_grid, tgt_grid, rv)
#endif
  for (size_t tgt_cell_add = 0; tgt_cell_add < tgt_grid_size; ++tgt_cell_add)
    {
#ifdef _OPENMP
#pragma omp atomic
#endif
      findex++;
      if (cdo_omp_get_thread_num() == 0) progress::update(0, 1, findex / tgt_grid_size);

      weightLinks[tgt_cell_add].nlinks = 0;

      if (!tgt_grid->mask[tgt_cell_add]) continue;

      const auto llpoint = remapgrid_get_lonlat(tgt_grid, tgt_cell_add);

      double src_lats[4];  //  latitudes  of four bilinear corners
      double src_lons[4];  //  longitudes of four bilinear corners
      double weights[4];   //  bilinear weights for four corners
      size_t src_add[4];   //  address for the four source points

      // Find nearest square of grid points on source grid
      auto search_result = remap_search_square(rsearch, llpoint, src_add, src_lats, src_lons);

      // Check to see if points are mask points
      if (search_result > 0)
        {
          for (unsigned n = 0; n < 4; ++n)
            if (!src_grid->mask[src_add[n]]) search_result = 0;
        }

      // If point found, find local iw,jw coordinates for weights
      if (search_result > 0)
        {
          tgt_grid->cell_frac[tgt_cell_add] = 1.0;

          double iw = 0.0, jw = 0.0;  // current guess for bilinear coordinate
          if (remap_find_weights(llpoint, src_lons, src_lats, &iw, &jw))
            {
              // Successfully found iw,jw - compute weights
              bilinearSetWeights(iw, jw, weights);
              store_weightlinks(0, 4, src_add, weights, tgt_cell_add, weightLinks);
            }
          else
            {
              bilinearWarning();
              search_result = -1;
            }
        }

      /*
        Search for bilinear failed - use a distance-weighted average instead
        (this is typically near the pole) Distance was stored in src_lats!
      */
      if (search_result < 0)
        {
          if (num_src_points(src_grid->mask, src_add, src_lats) > 0)
            {
              tgt_grid->cell_frac[tgt_cell_add] = 1.0;
              renormalizeWeights(src_lats, weights);
              store_weightlinks(0, 4, src_add, weights, tgt_cell_add, weightLinks);
            }
        }
    }

  progress::update(0, 1, 1);

  weight_links_to_remap_links(0, tgt_grid_size, weightLinks, rv);

  if (Options::cdoVerbose) cdo_print("%s: %.2f seconds", __func__, cdo_get_wtime() - start);
}  // remap_bilinear_weights

template <typename T>
static inline T
bilinearRemap(const Varray<T> &src_array, const double (&weights)[4], const size_t (&src_add)[4])
{
  // *tgt_point = 0.0;
  // for (unsigned n = 0; n < 4; ++n) *tgt_point += src_array[src_add[n]]*weights[n];
  return src_array[src_add[0]] * weights[0] + src_array[src_add[1]] * weights[1] + src_array[src_add[2]] * weights[2]
         + src_array[src_add[3]] * weights[3];
}

// This routine computes and apply the weights for a bilinear interpolation.
template <typename T>
static void
remap_bilinear(RemapSearch &rsearch, const Varray<T> &src_array, Varray<T> &tgt_array, T missval)
{
  auto src_grid = rsearch.srcGrid;
  auto tgt_grid = rsearch.tgtGrid;

  if (Options::cdoVerbose) cdo_print("Called %s()", __func__);

  if (src_grid->rank != 2) cdo_abort("Can't do bilinear interpolation when source grid rank != 2");

  auto start = Options::cdoVerbose ? cdo_get_wtime() : 0.0;

  progress::init();

  auto tgt_grid_size = tgt_grid->size;
  auto src_grid_size = src_grid->size;

  Varray<short> src_grid_mask(src_grid_size);
#ifdef _OPENMP
#pragma omp parallel for default(none) schedule(static) shared(src_grid_size, src_array, src_grid_mask, missval)
#endif
  for (size_t i = 0; i < src_grid_size; ++i) src_grid_mask[i] = !DBL_IS_EQUAL(src_array[i], missval);

  // Compute mappings from source to target grid

  auto findex = 0.0;

  // Loop over destination grid

#ifdef _OPENMP
#pragma omp parallel for default(none) schedule(static) \
    shared(findex, rsearch, tgt_grid_size, src_grid, tgt_grid, src_array, tgt_array, missval, src_grid_mask)
#endif
  for (size_t tgt_cell_add = 0; tgt_cell_add < tgt_grid_size; ++tgt_cell_add)
    {
#ifdef _OPENMP
#pragma omp atomic
#endif
      findex++;
      if (cdo_omp_get_thread_num() == 0) progress::update(0, 1, findex / tgt_grid_size);

      tgt_array[tgt_cell_add] = missval;

      if (!tgt_grid->mask[tgt_cell_add]) continue;

      const auto llpoint = remapgrid_get_lonlat(tgt_grid, tgt_cell_add);

      double src_lats[4];  //  latitudes  of four bilinear corners
      double src_lons[4];  //  longitudes of four bilinear corners
      double weights[4];   //  bilinear weights for four corners
      size_t src_add[4];   //  address for the four source points

      // Find nearest square of grid points on source grid
      auto search_result = remap_search_square(rsearch, llpoint, src_add, src_lats, src_lons);

      // Check to see if points are mask points
      if (search_result > 0)
        {
          for (unsigned n = 0; n < 4; ++n)
            if (src_grid_mask[src_add[n]] == 0) search_result = 0;
        }

      // If point found, find local iw,jw coordinates for weights
      if (search_result > 0)
        {
          tgt_grid->cell_frac[tgt_cell_add] = 1.0;

          double iw = 0.0, jw = 0.0;  // current guess for bilinear coordinate
          if (remap_find_weights(llpoint, src_lons, src_lats, &iw, &jw))
            {
              // Successfully found iw,jw - compute weights
              bilinearSetWeights(iw, jw, weights);
              sort_weights_n4(src_add, weights);
              tgt_array[tgt_cell_add] = bilinearRemap(src_array, weights, src_add);
            }
          else
            {
              bilinearWarning();
              search_result = -1;
            }
        }

      /*
        Search for bilinear failed - use a distance-weighted average instead
        (this is typically near the pole) Distance was stored in src_lats!
      */
      if (search_result < 0)
        {
          if (num_src_points(src_grid_mask, src_add, src_lats) > 0)
            {
              tgt_grid->cell_frac[tgt_cell_add] = 1.0;
              renormalizeWeights(src_lats, weights);
              sort_weights_n4(src_add, weights);
              tgt_array[tgt_cell_add] = bilinearRemap(src_array, weights, src_add);
            }
        }
    }

  progress::update(0, 1, 1);

  if (Options::cdoVerbose) cdo_print("%s: %.2f seconds", __func__, cdo_get_wtime() - start);
}  // remap_bilinear

void
remap_bilinear(RemapSearch &rsearch, const Field &field1, Field &field2)
{
  if (field1.memType == MemType::Float)
    remap_bilinear(rsearch, field1.vec_f, field2.vec_f, (float)field1.missval);
  else
    remap_bilinear(rsearch, field1.vec_d, field2.vec_d, field1.missval);
}
