/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2019 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Math       abs             Absolute value
      Math       sqr             Square
      Math       sqrt            Square root
      Math       exp             Exponential
      Math       ln              Natural logarithm
      Math       log10           Base 10 logarithm
      Math       sin             Sine
      Math       cos             Cosine
      Math       tan             Tangent
      Math       asin            Arc sine
      Math       acos            Arc cosine
      Math       atan            Arc tangent
      Math       pow             Power
      Math       reci            Reciprocal
*/

#include <cdi.h>


#include "process_int.h"
#include "param_conversion.h"

void *
Math(void *process)
{
  enum
  {
    ABS,
    FINT,
    FNINT,
    SQR,
    SQRT,
    EXP,
    LN,
    LOG10,
    SIN,
    COS,
    TAN,
    ASIN,
    ACOS,
    ATAN,
    POW,
    RECI,
    NOT,
    CONJ,
    RE,
    IM,
    ARG
  };
  int nrecs;
  int varID, levelID;
  size_t nmiss;
  size_t i;

  cdoInitialize(process);

  // clang-format off
  cdoOperatorAdd("abs",   ABS,   0, nullptr);
  cdoOperatorAdd("int",   FINT,  0, nullptr);
  cdoOperatorAdd("nint",  FNINT, 0, nullptr);
  cdoOperatorAdd("sqr",   SQR,   0, nullptr);
  cdoOperatorAdd("sqrt",  SQRT,  0, nullptr);
  cdoOperatorAdd("exp",   EXP,   0, nullptr);
  cdoOperatorAdd("ln",    LN,    0, nullptr);
  cdoOperatorAdd("log10", LOG10, 0, nullptr);
  cdoOperatorAdd("sin",   SIN,   0, nullptr);
  cdoOperatorAdd("cos",   COS,   0, nullptr);
  cdoOperatorAdd("tan",   TAN,   0, nullptr);
  cdoOperatorAdd("asin",  ASIN,  0, nullptr);
  cdoOperatorAdd("acos",  ACOS,  0, nullptr);
  cdoOperatorAdd("atan",  ATAN,  0, nullptr);
  cdoOperatorAdd("pow",   POW,   0, nullptr);
  cdoOperatorAdd("reci",  RECI,  0, nullptr);
  cdoOperatorAdd("not",   NOT,   0, nullptr);
  cdoOperatorAdd("conj",  CONJ,  0, nullptr);
  cdoOperatorAdd("re",    RE,    0, nullptr);
  cdoOperatorAdd("im",    IM,    0, nullptr);
  cdoOperatorAdd("arg",   ARG,   0, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();
  const auto operfunc = cdoOperatorF1(operatorID);

  double rc = 0;
  if (operfunc == POW)
    {
      operatorInputArg("value");
      rc = parameter2double(operatorArgv()[0]);
    }

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  if (operfunc == RE || operfunc == IM || operfunc == ABS || operfunc == ARG)
    {
      const auto nvars = vlistNvars(vlistID2);
      for (int varID = 0; varID < nvars; ++varID)
        {
          if (vlistInqVarDatatype(vlistID2, varID) == CDI_DATATYPE_CPX32) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT32);
          if (vlistInqVarDatatype(vlistID2, varID) == CDI_DATATYPE_CPX64) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT64);
        }
    }

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  auto gridsizemax = vlistGridsizeMax(vlistID1);
  if (vlistNumber(vlistID1) != CDI_REAL) gridsizemax *= 2;

  std::vector<double> array1(gridsizemax);
  std::vector<double> array2(gridsizemax);

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoReadRecord(streamID1, &array1[0], &nmiss);

          const auto missval1 = vlistInqVarMissval(vlistID1, varID);
          const auto missval2 = missval1;
          const auto n = gridInqSize(vlistInqVarGrid(vlistID1, varID));
          const auto number = vlistInqVarNumber(vlistID1, varID);

          if (number == CDI_REAL)
            {
              switch (operfunc)
                {
                case ABS:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::fabs(array1[i]);
                  break;
                case FINT:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : (int) (array1[i]);
                  break;
                case FNINT:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::round(array1[i]);
                  break;
                case SQR:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : array1[i] * array1[i];
                  break;
                case SQRT:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : SQRTMN(array1[i]);
                  break;
                case EXP:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::exp(array1[i]);
                  break;
                case LN:
                  for (i = 0; i < n; i++)
                    array2[i] = DBL_IS_EQUAL(array1[i], missval1) || array1[i] < 0 ? missval1 : std::log(array1[i]);
                  break;
                case LOG10:
                  for (i = 0; i < n; i++)
                    array2[i] = DBL_IS_EQUAL(array1[i], missval1) || array1[i] < 0 ? missval1 : std::log10(array1[i]);
                  break;
                case SIN:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::sin(array1[i]);
                  break;
                case COS:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::cos(array1[i]);
                  break;
                case TAN:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::tan(array1[i]);
                  break;
                case ASIN:
                  for (i = 0; i < n; i++)
                    array2[i] = DBL_IS_EQUAL(array1[i], missval1) || array1[i] < -1 || array1[i] > 1 ? missval1 : std::asin(array1[i]);
                  break;
                case ACOS:
                  for (i = 0; i < n; i++)
                    array2[i] = DBL_IS_EQUAL(array1[i], missval1) || array1[i] < -1 || array1[i] > 1 ? missval1 : std::acos(array1[i]);
                  break;
                case ATAN:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::atan(array1[i]);
                  break;
                case POW:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::pow(array1[i], rc);
                  break;
                case RECI:
                  for (i = 0; i < n; i++)
                    array2[i] = DBL_IS_EQUAL(array1[i], missval1) || DBL_IS_EQUAL(array1[i], 0.) ? missval1 : 1 / array1[i];
                  break;
                case NOT:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : IS_EQUAL(array1[i], 0);
                  break;
                case RE:
                case ARG:
                  for (i = 0; i < n; i++) array2[i] = array1[i];
                  break;
                default: cdoAbort("Operator not implemented for real data!"); break;
                }

              nmiss = arrayNumMV(n, &array2[0], missval1);
            }
          else
            {
              switch (operfunc)
                {
                case SQR:
                  for (i = 0; i < n; i++)
                    {
                      array2[i * 2] = array1[i * 2] * array1[i * 2] + array1[i * 2 + 1] * array1[i * 2 + 1];
                      array2[i * 2 + 1] = 0;
                    }
                  break;
                case SQRT:
                  for (i = 0; i < n; i++)
                    {
                      double abs = SQRTMN(ADDMN(MULMN(array1[2 * i], array1[2 * i]), MULMN(array1[2 * i + 1], array1[2 * i + 1])));
                      array2[i * 2] = MULMN(1 / std::sqrt(2.), SQRTMN(ADDMN(array1[i * 2], abs)));
                      array2[i * 2 + 1] = MULMN(1 / std::sqrt(2.), DIVMN(array1[2 * i + 1], SQRTMN(ADDMN(array1[2 * i], abs))));
                      ;
                    }
                  break;
                case CONJ:
                  for (i = 0; i < n; i++)
                    {
                      array2[i * 2] = array1[i * 2];
                      array2[i * 2 + 1] = -array1[i * 2 + 1];
                    }
                  break;
                case RE:
                  for (i = 0; i < n; i++) array2[i] = array1[i * 2];
                  break;
                case IM:
                  for (i = 0; i < n; i++) array2[i] = array1[i * 2 + 1];
                  break;
                case ABS:
                  for (i = 0; i < n; i++)
                    array2[i] = SQRTMN(ADDMN(MULMN(array1[2 * i], array1[2 * i]), MULMN(array1[2 * i + 1], array1[2 * i + 1])));
                  break;
                case ARG:
                  for (i = 0; i < n; i++)
                    array2[i] = (DBL_IS_EQUAL(array1[2 * i], missval1) || DBL_IS_EQUAL(array1[2 * i + 1], missval1))
                                    ? missval1
                                    : atan2(array1[2 * i + 1], array1[2 * i]);
                  break;
                default: cdoAbort("Fields with complex numbers are not supported by this operator!"); break;
                }

              nmiss = 0;
            }

          cdoDefRecord(streamID2, varID, levelID);
          cdoWriteRecord(streamID2, &array2[0], nmiss);
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  vlistDestroy(vlistID2);

  cdoFinish();

  return nullptr;
}
