/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2019 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Regres      regres           Regression
*/

#include <cdi.h>


#include "process_int.h"
#include "cdo_vlist.h"
#include "cdo_options.h"
#include "datetime.h"
#include "pmlist.h"
#include "param_conversion.h"

// Same code as Trend!

static void
regresGetParameter(bool &tstepIsEqual)
{
  const auto pargc = operatorArgc();
  if (pargc)
    {
      const auto pargv = operatorArgv();

      KVList kvlist;
      kvlist.name = "TREND";
      if (kvlist.parseArguments(pargc, pargv) != 0) cdoAbort("Parse error!");
      if (Options::cdoVerbose) kvlist.print();

      for (const auto &kv : kvlist)
        {
          const auto &key = kv.key;
          if (kv.nvalues > 1) cdoAbort("Too many values for parameter key >%s<!", key.c_str());
          if (kv.nvalues < 1) cdoAbort("Missing value for parameter key >%s<!", key.c_str());
          const auto &value = kv.values[0];

          // clang-format off
          if      (key == "equal")  tstepIsEqual = parameter2bool(value);
          else cdoAbort("Invalid parameter key >%s<!", key.c_str());
          // clang-format on
        }
    }
}

void *
Regres(void *process)
{
  int nrecs;
  int varID, levelID;
  size_t nmiss;

  cdoInitialize(process);

  auto tstepIsEqual = true;
  regresGetParameter(tstepIsEqual);

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  vlistDefNtsteps(vlistID2, 1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  /*
  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);
  */
  const auto streamID3 = cdoOpenWrite(1);
  cdoDefVlist(streamID3, vlistID2);

  VarList varList;
  varListInit(varList, vlistID1);

  const auto maxrecs = vlistNrecs(vlistID1);
  std::vector<RecordInfo> recList(maxrecs);

  const auto gridsizemax = vlistGridsizeMax(vlistID1);

  Field field1, field2;
  field1.resize(gridsizemax);
  field2.resize(gridsizemax);

  constexpr size_t nwork = 5;
  FieldVector2D work[nwork];
  for (auto &w : work) fieldsFromVlist(vlistID1, w, FIELD_VEC, 0);

  const auto calendar = taxisInqCalendar(taxisID1);

  CheckTimeInc checkTimeInc;
  JulianDate juldate0;
  double deltat1 = 0;
  int64_t vdate = 0;
  int vtime = 0;
  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      vdate = taxisInqVdate(taxisID1);
      vtime = taxisInqVtime(taxisID1);

      if (tstepIsEqual) checkTimeIncrement(tsID, calendar, vdate, vtime, checkTimeInc);
      const auto zj = tstepIsEqual ? (double) tsID : deltaTimeStep0(tsID, calendar, vdate, vtime, juldate0, deltat1);

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);

          if (tsID == 0)
            {
              recList[recID].varID = varID;
              recList[recID].levelID = levelID;
              recList[recID].lconst = vlistInqVarTimetype(vlistID1, varID) == TIME_CONSTANT;
            }

          cdoReadRecord(streamID1, field1.vec.data(), &nmiss);

          const auto gridsize = varList[varID].gridsize;
          const auto missval = varList[varID].missval;

          auto &sumj = work[0][varID][levelID].vec;
          auto &sumjj = work[1][varID][levelID].vec;
          auto &sumjx = work[2][varID][levelID].vec;
          auto &sumx = work[3][varID][levelID].vec;
          auto &zn = work[4][varID][levelID].vec;

          for (size_t i = 0; i < gridsize; i++)
            if (!DBL_IS_EQUAL(field1.vec[i], missval))
              {
                sumj[i] += zj;
                sumjj[i] += zj * zj;
                sumjx[i] += zj * field1.vec[i];
                sumx[i] += field1.vec[i];
                zn[i]++;
              }
        }

      tsID++;
    }

  taxisDefVdate(taxisID2, vdate);
  taxisDefVtime(taxisID2, vtime);
  /* cdoDefTimestep(streamID2, 0); */
  cdoDefTimestep(streamID3, 0);

  for (int recID = 0; recID < maxrecs; recID++)
    {
      const auto varID = recList[recID].varID;
      const auto levelID = recList[recID].levelID;
      const auto gridsize = varList[varID].gridsize;
      const auto missval = varList[varID].missval;
      const auto missval1 = missval;
      const auto missval2 = missval;
      field1.size = gridsize;
      field1.missval = missval;
      field2.size = gridsize;
      field2.missval = missval;

      const auto &sumj = work[0][varID][levelID].vec;
      const auto &sumjj = work[1][varID][levelID].vec;
      const auto &sumjx = work[2][varID][levelID].vec;
      const auto &sumx = work[3][varID][levelID].vec;
      const auto &zn = work[4][varID][levelID].vec;

      for (size_t i = 0; i < gridsize; i++)
        {
          const auto temp1 = SUBMN(sumjx[i], DIVMN(MULMN(sumj[i], sumx[i]), zn[i]));
          const auto temp2 = SUBMN(sumjj[i], DIVMN(MULMN(sumj[i], sumj[i]), zn[i]));

          field2.vec[i] = DIVMN(temp1, temp2);
          field1.vec[i] = SUBMN(DIVMN(sumx[i], zn[i]), MULMN(DIVMN(sumj[i], zn[i]), field2.vec[i]));
        }

      cdoDefRecord(streamID3, varID, levelID);
      cdoWriteRecord(streamID3, field2.vec.data(), fieldNumMiss(field2));
    }

  cdoStreamClose(streamID3);
  cdoStreamClose(streamID1);

  cdoFinish();

  return 0;
}
