/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Smoothstat       smooth9             running 9-point-average
*/

#include <cdi.h>

#include "process_int.h"
#include "param_conversion.h"
#include "cdo_wtime.h"
#include <mpim_grid.h>
#include "gridreference.h"
#include "constants.h"  // planet radius
#include "pmlist.h"
#include "cdo_options.h"
#include "progress.h"
#include "cimdOmp.h"
#include "grid_point_search.h"

enum
{
  FORM_LINEAR
};

static const char *Form[] = { "linear" };

struct smoothpoint_t
{
  size_t maxpoints;
  int form;
  double radius;
  double weight0;
  double weightR;
};

static void
smooth(int gridID, double missval, const double *restrict array1, double *restrict array2, size_t *nmiss, smoothpoint_t spoint)
{
  auto gridID0 = gridID;
  auto gridsize = gridInqSize(gridID);
  auto numNeighbors = spoint.maxpoints;
  if (numNeighbors > gridsize) numNeighbors = gridsize;

  std::vector<uint8_t> mask(gridsize);
  for (size_t i = 0; i < gridsize; ++i) mask[i] = !DBL_IS_EQUAL(array1[i], missval);

  auto gridtype = gridInqType(gridID);

  if (gridtype == GRID_GME) gridID = gridToUnstructured(gridID, 0);

  if (gridtype != GRID_UNSTRUCTURED && gridtype != GRID_CURVILINEAR) gridID = gridToCurvilinear(gridID, 0);

  if (gridtype == GRID_UNSTRUCTURED)
    {
      if (gridInqYvals(gridID, nullptr) == 0 || gridInqXvals(gridID, nullptr) == 0)
        {
          int number = 0;
          cdiInqKeyInt(gridID, CDI_GLOBAL, CDI_KEY_NUMBEROFGRIDUSED, &number);
          if (number > 0)
            {
              gridID = referenceToGrid(gridID);
              if (gridID == -1) cdoAbort("Reference to source grid not found!");
            }
        }
    }

  if (gridInqYvals(gridID, nullptr) == 0 || gridInqXvals(gridID, nullptr) == 0) cdoAbort("Cell center coordinates missing!");

  Varray<double> xvals(gridsize), yvals(gridsize);
  gridInqXvals(gridID, &xvals[0]);
  gridInqYvals(gridID, &yvals[0]);

  // Convert lat/lon units if required
  cdo_grid_to_radian(gridID, CDI_XAXIS, gridsize, &xvals[0], "grid center lon");
  cdo_grid_to_radian(gridID, CDI_YAXIS, gridsize, &yvals[0], "grid center lat");

  std::vector<knnWeightsType> knnWeights;
  for (int i = 0; i < Threading::ompNumThreads; ++i) knnWeights.push_back(knnWeightsType(numNeighbors));

  double start = Options::cdoVerbose ? cdo_get_wtime() : 0;

  bool xIsCyclic = false;
  size_t dims[2] = { gridsize, 0 };
  GridPointSearch gps;
  gridPointSearchCreate(gps, xIsCyclic, dims, gridsize, &xvals[0], &yvals[0]);

  gps.searchRadius = spoint.radius;

  if (Options::cdoVerbose) cdoPrint("Point search created: %.2f seconds", cdo_get_wtime() - start);

  if (Options::cdoVerbose) progress::init();

  start = Options::cdoVerbose ? cdo_get_wtime() : 0;

  size_t naddsMin = gridsize, naddsMax = 0;
  size_t nmissx = 0;
  double findex = 0;

#ifdef HAVE_OPENMP4
#pragma omp parallel for schedule(dynamic) default(none) reduction(+ : nmissx) reduction(min : naddsMin) reduction(max : naddsMax) \
  shared(findex, Options::cdoVerbose, knnWeights, spoint, mask, array1, array2, xvals, yvals, gps, gridsize, missval)
#endif
  for (size_t i = 0; i < gridsize; ++i)
    {
      const auto ompthID = cdo_omp_get_thread_num();

#ifdef _OPENMP
#pragma omp atomic
#endif
      findex++;
      if (Options::cdoVerbose && cdo_omp_get_thread_num() == 0) progress::update(0, 1, findex / gridsize);

      gridSearchPointSmooth(gps, xvals[i], yvals[i], knnWeights[ompthID]);

      // Compute weights based on inverse distance if mask is false, eliminate those points
      const auto nadds = knnWeights[ompthID].computeWeights(mask, spoint.radius, spoint.weight0, spoint.weightR);
      if (nadds < naddsMin) naddsMin = nadds;
      if (nadds > naddsMax) naddsMax = nadds;
      if (nadds)
        {
          array2[i] = knnWeights[ompthID].arrayWeightsSum(array1);
        }
      else
        {
          nmissx++;
          array2[i] = missval;
        }
    }

  *nmiss = nmissx;

  progress::update(0, 1, 1);

  if (Options::cdoVerbose) cdoPrint("Point search nearest: %.2f seconds", cdo_get_wtime() - start);
  if (Options::cdoVerbose) cdoPrint("Min/Max points found: %zu/%zu", naddsMin, naddsMax);

  gridPointSearchDelete(gps);

  if (gridID0 != gridID) gridDestroy(gridID);
}

static inline void
smooth9_sum(size_t ij, uint8_t *mask, double sfac, const double *restrict array, double *avg, double *divavg)
{
  if (mask[ij])
    {
      *avg += sfac * array[ij];
      *divavg += sfac;
    }
}

static void
smooth9(int gridID, double missval, const double *restrict array1, double *restrict array2, size_t *nmiss)
{
  auto gridsize = gridInqSize(gridID);
  auto nlon = gridInqXsize(gridID);
  auto nlat = gridInqYsize(gridID);
  auto grid_is_cyclic = gridIsCircular(gridID);

  std::vector<uint8_t> vmask(gridsize);
  auto mask = vmask.data();

  for (size_t i = 0; i < gridsize; ++i) mask[i] = !DBL_IS_EQUAL(missval, array1[i]);

  *nmiss = 0;
  for (size_t i = 0; i < nlat; i++)
    {
      for (size_t j = 0; j < nlon; j++)
        {
          double avg = 0;
          double divavg = 0;

          if ((i == 0) || (j == 0) || (i == (nlat - 1)) || (j == (nlon - 1)))
            {
              size_t ij = j + nlon * i;
              if (mask[ij])
                {
                  avg += array1[ij];
                  divavg += 1;
                  /* upper left corner */
                  if ((i != 0) && (j != 0))
                    smooth9_sum(((i - 1) * nlon) + j - 1, mask, 0.3, array1, &avg, &divavg);
                  else if (i != 0 && grid_is_cyclic)
                    smooth9_sum((i - 1) * nlon + j - 1 + nlon, mask, 0.3, array1, &avg, &divavg);

                  /* upper cell */
                  if (i != 0) smooth9_sum(((i - 1) * nlon) + j, mask, 0.5, array1, &avg, &divavg);

                  /* upper right corner */
                  if ((i != 0) && (j != (nlon - 1)))
                    smooth9_sum(((i - 1) * nlon) + j + 1, mask, 0.3, array1, &avg, &divavg);
                  else if ((i != 0) && grid_is_cyclic)
                    smooth9_sum((i - 1) * nlon + j + 1 - nlon, mask, 0.3, array1, &avg, &divavg);

                  /* left cell */
                  if (j != 0)
                    smooth9_sum(((i) *nlon) + j - 1, mask, 0.5, array1, &avg, &divavg);
                  else if (grid_is_cyclic)
                    smooth9_sum(i * nlon - 1 + nlon, mask, 0.5, array1, &avg, &divavg);

                  /* right cell */
                  if (j != (nlon - 1))
                    smooth9_sum((i * nlon) + j + 1, mask, 0.5, array1, &avg, &divavg);
                  else if (grid_is_cyclic)
                    smooth9_sum(i * nlon + j + 1 - nlon, mask, 0.5, array1, &avg, &divavg);

                  /* lower left corner */
                  if (mask[ij] && ((i != (nlat - 1)) && (j != 0)))
                    smooth9_sum(((i + 1) * nlon + j - 1), mask, 0.3, array1, &avg, &divavg);
                  else if ((i != (nlat - 1)) && grid_is_cyclic)
                    smooth9_sum((i + 1) * nlon - 1 + nlon, mask, 0.3, array1, &avg, &divavg);

                  /* lower cell */
                  if (i != (nlat - 1)) smooth9_sum(((i + 1) * nlon) + j, mask, 0.5, array1, &avg, &divavg);

                  /* lower right corner */
                  if ((i != (nlat - 1)) && (j != (nlon - 1)))
                    smooth9_sum(((i + 1) * nlon) + j + 1, mask, 0.3, array1, &avg, &divavg);
                  else if ((i != (nlat - 1)) && grid_is_cyclic)
                    smooth9_sum(((i + 1) * nlon) + j + 1 - nlon, mask, 0.3, array1, &avg, &divavg);
                }
            }
          else if (mask[j + nlon * i])
            {
              avg += array1[j + nlon * i];
              divavg += 1;

              smooth9_sum(((i - 1) * nlon) + j - 1, mask, 0.3, array1, &avg, &divavg);
              smooth9_sum(((i - 1) * nlon) + j, mask, 0.5, array1, &avg, &divavg);
              smooth9_sum(((i - 1) * nlon) + j + 1, mask, 0.3, array1, &avg, &divavg);
              smooth9_sum(((i) *nlon) + j - 1, mask, 0.5, array1, &avg, &divavg);
              smooth9_sum((i * nlon) + j + 1, mask, 0.5, array1, &avg, &divavg);
              smooth9_sum(((i + 1) * nlon + j - 1), mask, 0.3, array1, &avg, &divavg);
              smooth9_sum(((i + 1) * nlon) + j, mask, 0.5, array1, &avg, &divavg);
              smooth9_sum(((i + 1) * nlon) + j + 1, mask, 0.3, array1, &avg, &divavg);
            }

          if (std::fabs(divavg) > 0)
            {
              array2[i * nlon + j] = avg / divavg;
            }
          else
            {
              array2[i * nlon + j] = missval;
              (*nmiss)++;
            }
        }
    }
}

static double
radiusInKm(const double radiusInDeg)
{
  return radiusInDeg * (2 * PlanetRadius * M_PI) / (360. * 1000.);
}

static int
convert_form(const std::string &formstr)
{
  int form = FORM_LINEAR;

  if (formstr == "linear")
    form = FORM_LINEAR;
  else
    cdoAbort("form=%s unsupported!", formstr.c_str());

  return form;
}

static void
smoothGetParameter(int *xnsmooth, smoothpoint_t *spoint)
{
  const auto pargc = operatorArgc();
  if (pargc)
    {
      const auto pargv = cdoGetOperArgv();

      KVList kvlist;
      kvlist.name = "SMOOTH";
      if (kvlist.parseArguments(pargc, pargv) != 0) cdoAbort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdoAbort("Too many values for parameter key >%s<!", key.c_str());
          if (kv.nvalues < 1) cdoAbort("Missing value for parameter key >%s<!", key.c_str());
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "nsmooth")   *xnsmooth = parameter2int(value);
          else if (key == "maxpoints") spoint->maxpoints = parameter2sizet(value);
          else if (key == "weight0")   spoint->weight0 = parameter2double(value);
          else if (key == "weightR")   spoint->weightR = parameter2double(value);
          else if (key == "radius")    spoint->radius = radius_str_to_deg(value.c_str());
          else if (key == "form")      spoint->form = convert_form(value);
          else cdoAbort("Invalid parameter key >%s<!", key.c_str());
          // clang-format on
        }
    }
}

void *
Smooth(void *process)
{
  int nrecs;
  int xnsmooth = 1;
  smoothpoint_t spoint;
  spoint.maxpoints = SIZE_MAX;
  spoint.radius = 1;
  spoint.form = FORM_LINEAR;
  spoint.weight0 = 0.25;
  spoint.weightR = 0.25;

  cdoInitialize(process);

  // clang-format off
  const auto SMOOTH  = cdoOperatorAdd("smooth",   0,   0, nullptr);
  const auto SMOOTH9 = cdoOperatorAdd("smooth9",  0,   0, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();

  if (operatorID == SMOOTH) smoothGetParameter(&xnsmooth, &spoint);

  if (spoint.radius < 0 || spoint.radius > 180) cdoAbort("%s=%g out of bounds (0-180 deg)!", "radius", spoint.radius);

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto nvars = vlistNvars(vlistID1);
  std::vector<bool> varIDs(nvars);

  for (int varID = 0; varID < nvars; ++varID)
    {
      const auto gridID = vlistInqVarGrid(vlistID1, varID);
      const auto gridtype = gridInqType(gridID);
      if (gridtype == GRID_GAUSSIAN || gridtype == GRID_LONLAT || gridtype == GRID_CURVILINEAR || gridtype == GRID_PROJECTION
          || (operatorID == SMOOTH9 && gridtype == GRID_GENERIC && gridInqXsize(gridID) > 0 && gridInqYsize(gridID) > 0))
        {
          varIDs[varID] = true;
        }
      else if (operatorID == SMOOTH && gridtype == GRID_UNSTRUCTURED)
        {
          varIDs[varID] = true;
        }
      else
        {
          char varname[CDI_MAX_NAME];
          vlistInqVarName(vlistID1, varID, varname);
          varIDs[varID] = false;
          cdoWarning("Unsupported grid for variable %s", varname);
        }
    }

  auto gridsizemax = vlistGridsizeMax(vlistID1);
  if (gridsizemax < spoint.maxpoints) spoint.maxpoints = gridsizemax;
  if (Options::cdoVerbose && operatorID == SMOOTH)
    cdoPrint("nsmooth = %d, maxpoints = %zu, radius = %gdeg(%gkm), form = %s, weight0 = %g, weightR = %g", xnsmooth,
             spoint.maxpoints, spoint.radius, radiusInKm(spoint.radius), Form[spoint.form], spoint.weight0, spoint.weightR);

  spoint.radius *= DEG2RAD;

  Varray<double> array1(gridsizemax);
  Varray<double> array2(gridsizemax);

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          size_t nmiss;

          cdoInqRecord(streamID1, &varID, &levelID);
          cdoReadRecord(streamID1, array1.data(), &nmiss);

          if (varIDs[varID])
            {
              const auto missval = vlistInqVarMissval(vlistID1, varID);
              const auto gridID = vlistInqVarGrid(vlistID1, varID);

              for (int i = 0; i < xnsmooth; ++i)
                {
                  if (operatorID == SMOOTH)
                    smooth(gridID, missval, array1.data(), array2.data(), &nmiss, spoint);
                  else if (operatorID == SMOOTH9)
                    smooth9(gridID, missval, array1.data(), array2.data(), &nmiss);

                  varrayCopy(gridsizemax, array2, array1);
                }

              cdoDefRecord(streamID2, varID, levelID);
              cdoWriteRecord(streamID2, array2.data(), nmiss);
            }
          else
            {
              cdoDefRecord(streamID2, varID, levelID);
              cdoWriteRecord(streamID2, array1.data(), nmiss);
            }
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
