/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Filter    highpass
      Filter    lowpass
      Filter    bandpass
*/

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#ifdef HAVE_LIBFFTW3
#include <fftw3.h>
#endif

#include <cdi.h>

#include "process_int.h"
#include "cdo_vlist.h"
#include "param_conversion.h"
#include "statistic.h"
#include "cdo_options.h"
#include "datetime.h"
#include "cimdOmp.h"


static void
create_fmasc(int nts, double fdata, double fmin, double fmax, std::vector<int> &fmasc)
{
  const auto dimin = nts * fmin / fdata;
  const auto dimax = nts * fmax / fdata;

  const int imin = (dimin < 0) ? 0 : (int) std::floor(dimin);
  const int imax = (std::ceil(dimax) > nts / 2) ? nts / 2 : (int) std::ceil(dimax);

  if (imin < 0 || imin >= nts) cdoAbort("Parameter fmin=%g: timestep %d out of bounds (1-%d)!", fmin, imin + 1, nts);
  if (imax < 0 || imax >= nts) cdoAbort("Parameter fmax=%g: timestep %d out of bounds (1-%d)!", fmax, imax + 1, nts);

  fmasc[imin] = 1;
  for (int i = imin + 1; i <= imax; i++) fmasc[i] = fmasc[nts - i] = 1;
}

#ifdef HAVE_LIBFFTW3
static void
filter_fftw(int nts, const std::vector<int> &fmasc, fftw_complex *fft_out, fftw_plan *p_T2S, fftw_plan *p_S2T)
{
  fftw_execute(*p_T2S);

  for (int i = 0; i < nts; i++)
    if (!fmasc[i])
      {
        fft_out[i][0] = 0;
        fft_out[i][1] = 0;
      }

  fftw_execute(*p_S2T);

  return;
}
#endif

static void
filter_intrinsic(int nts, const std::vector<int> &fmasc, double *real, double *imag)
{
  const bool isPower2 = ((nts & (nts - 1)) == 0);

  Varray<double> work_r, work_i;

  if (!isPower2) work_r.resize(nts);
  if (!isPower2) work_i.resize(nts);

  if (isPower2)
    cdo::fft(real, imag, nts, 1);
  else
    cdo::ft_r(real, imag, nts, 1, work_r.data(), work_i.data());

  for (int i = 0; i < nts; i++)
    if (!fmasc[i]) real[i] = imag[i] = 0;

  if (isPower2)
    cdo::fft(real, imag, nts, -1);
  else
    cdo::ft_r(real, imag, nts, -1, work_r.data(), work_i.data());

  return;
}

void *
Filter(void *process)
{
  enum
  {
    BANDPASS,
    HIGHPASS,
    LOWPASS
  };
  const char *tunits[] = { "second", "minute", "hour", "day", "month", "year" };
  const int iunits[] = { 31536000, 525600, 8760, 365, 12, 1 };
  int nrecs;
  int year0, month0, day0;
  double fdata = 0;
  TimeIncrement timeIncr0 = { 0, TimeUnit::SECONDS };
  DateTimeList dtlist;
  struct FourierMemory
  {
    Varray<double> real;
    Varray<double> imag;
#ifdef HAVE_LIBFFTW3
    fftw_complex *in_fft;
    fftw_complex *out_fft;
    fftw_plan p_T2S;
    fftw_plan p_S2T;
#endif
  };

  cdoInitialize(process);

  cdoOperatorAdd("bandpass", BANDPASS, 0, nullptr);
  cdoOperatorAdd("highpass", HIGHPASS, 0, nullptr);
  cdoOperatorAdd("lowpass", LOWPASS, 0, nullptr);

  const auto operatorID = cdoOperatorID();
  const auto operfunc = cdoOperatorF1(operatorID);

  bool use_fftw = false;
  if (Options::Use_FFTW)
    {
#ifdef HAVE_LIBFFTW3
      if (Options::cdoVerbose) cdoPrint("Using fftw3 lib");
      use_fftw = true;
#else
      if (Options::cdoVerbose) cdoPrint("LIBFFTW3 support not compiled in!");
#endif
    }

  if (Options::cdoVerbose && !use_fftw) cdoPrint("Using intrinsic FFT function!");

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto calendar = taxisInqCalendar(taxisID1);

  VarList varList;
  varListInit(varList, vlistID1);

  const auto nvars = vlistNvars(vlistID1);
  FieldVector3D vars;

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      constexpr size_t NALLOC_INC = 1024;
      if ((size_t) tsID >= vars.size()) vars.resize(vars.size() + NALLOC_INC);

      dtlist.taxisInqTimestep(taxisID1, tsID);

      fieldsFromVlist(vlistID1, vars[tsID]);

      for (int recID = 0; recID < nrecs; recID++)
        {
          size_t nmiss;
          int varID, levelID;
          cdoInqRecord(streamID1, &varID, &levelID);
          const auto gridsize = varList[varID].gridsize;
          vars[tsID][varID][levelID].resize(gridsize);
          cdoReadRecord(streamID1, vars[tsID][varID][levelID].vec_d.data(), &nmiss);
          vars[tsID][varID][levelID].nmiss = nmiss;
          if (nmiss) cdoAbort("Missing value support for operators in module Filter not added yet!");
        }

      // get and check time increment
      if (tsID > 0)
        {
          const auto vdate0 = dtlist.getVdate(tsID - 1);
          const auto vtime0 = dtlist.getVtime(tsID - 1);
          const auto vdate = dtlist.getVdate(tsID);
          const auto vtime = dtlist.getVtime(tsID);

          cdiDecodeDate(vdate0, &year0, &month0, &day0);
          int year, month, day;
          cdiDecodeDate(vdate, &year, &month, &day);

          const auto juldate0 = julianDateEncode(calendar, vdate0, vtime0);
          const auto juldate = julianDateEncode(calendar, vdate, vtime);
          const auto jdelta = julianDateToSeconds(julianDateSub(juldate, juldate0));

          const auto timeIncr = getTimeIncrement(jdelta, vdate0, vdate);

          if (tsID == 1)
            {
              timeIncr0 = timeIncr;
              if (timeIncr.period == 0) cdoAbort("Time step must be different from zero!");
              if (Options::cdoVerbose) cdoPrint("Time step %lld %s", timeIncr.period, tunits[(int) timeIncr.unit]);
              fdata = 1. * iunits[(int) timeIncr.unit] / timeIncr.period;
            }

          if (calendar != CALENDAR_360DAYS && calendar != CALENDAR_365DAYS && calendar != CALENDAR_366DAYS
              && timeIncr0.unit < TimeUnit::MONTHS && month == 2 && day == 29 && (day0 != day || month0 != month || year0 != year))
            {
              cdoWarning("Filtering of multi-year times series doesn't works properly with a standard calendar.");
              cdoWarning("  Please delete the day %i-02-29 (cdo del29feb)", year);
            }

          if (timeIncr.period != timeIncr0.period || timeIncr.unit != timeIncr0.unit)
            cdoWarning("Time increment in step %d (%lld%s) differs from step 1 (%lld%s)!", tsID + 1, timeIncr.period,
                       tunits[(int) timeIncr.unit], timeIncr0.period, tunits[(int) timeIncr0.unit]);
        }

      tsID++;
    }

  const auto nts = tsID;
  if (nts <= 1) cdoAbort("Number of time steps <= 1!");

  std::vector<FourierMemory> ompmem(Threading::ompNumThreads);

  if (use_fftw)
    {
#ifdef HAVE_LIBFFTW3
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          ompmem[i].in_fft = fftw_alloc_complex(nts);
          ompmem[i].out_fft = fftw_alloc_complex(nts);
          ompmem[i].p_T2S = fftw_plan_dft_1d(nts, ompmem[i].in_fft, ompmem[i].out_fft, 1, FFTW_ESTIMATE);
          ompmem[i].p_S2T = fftw_plan_dft_1d(nts, ompmem[i].out_fft, ompmem[i].in_fft, -1, FFTW_ESTIMATE);
        }
#endif
    }
  else
    {
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          ompmem[i].real.resize(nts);
          ompmem[i].imag.resize(nts);
        }
    }

  double fmin = 0, fmax = 0;
  switch (operfunc)
    {
    case BANDPASS:
      {
        operatorInputArg("lower and upper bound of frequency band");
        operatorCheckArgc(2);
        fmin = parameter2double(cdoOperatorArgv(0));
        fmax = parameter2double(cdoOperatorArgv(1));
        break;
      }
    case HIGHPASS:
      {
        operatorInputArg("lower bound of frequency pass");
        operatorCheckArgc(1);
        fmin = parameter2double(cdoOperatorArgv(0));
        fmax = fdata;
        break;
      }
    case LOWPASS:
      {
        operatorInputArg("upper bound of frequency pass");
        operatorCheckArgc(1);
        fmin = 0;
        fmax = parameter2double(cdoOperatorArgv(0));
        break;
      }
    }

  if (Options::cdoVerbose) cdoPrint("fmin=%g  fmax=%g", fmin, fmax);

  std::vector<int> fmasc(nts, 0);
  create_fmasc(nts, fdata, fmin, fmax, fmasc);

  for (int varID = 0; varID < nvars; varID++)
    {
      const auto gridsize = varList[varID].gridsize;
      for (int levelID = 0; levelID < varList[varID].nlevels; levelID++)
        {
          if (use_fftw)
            {
#ifdef HAVE_LIBFFTW3
#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
              for (size_t i = 0; i < gridsize; i++)
                {
                  const auto ompthID = cdo_omp_get_thread_num();

                  for (int t = 0; t < nts; t++)
                    {
                      ompmem[ompthID].in_fft[t][0] = vars[t][varID][levelID].vec_d[i];
                      ompmem[ompthID].in_fft[t][1] = 0;
                    }

                  filter_fftw(nts, fmasc, ompmem[ompthID].out_fft, &ompmem[ompthID].p_T2S, &ompmem[ompthID].p_S2T);

                  for (int t = 0; t < nts; t++) vars[t][varID][levelID].vec_d[i] = ompmem[ompthID].in_fft[t][0] / nts;
                }
#endif
            }
          else
            {
#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
              for (size_t i = 0; i < gridsize; i++)
                {
                  const auto ompthID = cdo_omp_get_thread_num();

                  for (int t = 0; t < nts; t++) ompmem[ompthID].real[t] = vars[t][varID][levelID].vec_d[i];

                  varrayFill(ompmem[ompthID].imag, 0.0);

                  filter_intrinsic(nts, fmasc, ompmem[ompthID].real.data(), ompmem[ompthID].imag.data());

                  for (int t = 0; t < nts; t++) vars[t][varID][levelID].vec_d[i] = ompmem[ompthID].real[t];
                }
            }
        }
    }

#ifdef HAVE_LIBFFTW3
  if (use_fftw)
    {
      for (int i = 0; i < Threading::ompNumThreads; i++)
        {
          fftw_free(ompmem[i].in_fft);
          fftw_free(ompmem[i].out_fft);
          fftw_destroy_plan(ompmem[i].p_T2S);
          fftw_destroy_plan(ompmem[i].p_S2T);
        }
      fftw_cleanup();
    }
#endif

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  for (tsID = 0; tsID < nts; tsID++)
    {
      dtlist.taxisDefTimestep(taxisID2, tsID);
      cdoDefTimestep(streamID2, tsID);

      for (int varID = 0; varID < nvars; varID++)
        {
          const auto nlevels = varList[varID].nlevels;
          for (int levelID = 0; levelID < nlevels; levelID++)
            {
              if (!vars[tsID][varID][levelID].empty())
                {
                  const auto nmiss = vars[tsID][varID][levelID].nmiss;
                  cdoDefRecord(streamID2, varID, levelID);
                  cdoWriteRecord(streamID2, vars[tsID][varID][levelID].vec_d.data(), nmiss);
                }
            }
        }
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
