#include "field_trend.h"
#include "compare.h"
#include "cimdOmp.h"
#include "arithmetic.h"

template <typename T>
static void
calc_trend_sum(FieldVector3D &work, bool hasMissvals, size_t len, const Varray<T> &varray, T missval, double zj, int varID,
               int levelID)
{
  auto &sumj = work[0][varID][levelID].vec_d;
  auto &sumjj = work[1][varID][levelID].vec_d;
  auto &sumjx = work[2][varID][levelID].vec_d;
  auto &sumx = work[3][varID][levelID].vec_d;
  auto &zn = work[4][varID][levelID].vec_d;

  auto trend_sum = [&](auto i, double value) {
    sumj[i] += zj;
    sumjj[i] += zj * zj;
    sumjx[i] += zj * value;
    sumx[i] += value;
    zn[i]++;
  };

  auto trend_sum_mv = [&](auto i, T value, auto is_NE) {
    if (is_NE(value, missval)) trend_sum(i, value);
  };

  if (hasMissvals)
    {
      if (std::isnan(missval))
#ifdef _OPENMP
#pragma omp parallel for if (len > cdoMinLoopSize) default(shared) schedule(static)
#endif
        for (size_t i = 0; i < len; ++i) { trend_sum_mv(i, varray[i], dbl_is_not_equal); }
      else
#ifdef _OPENMP
#pragma omp parallel for if (len > cdoMinLoopSize) default(shared) schedule(static)
#endif
        for (size_t i = 0; i < len; ++i) { trend_sum_mv(i, varray[i], is_not_equal); }
    }
  else
    {
#ifdef HAVE_OPENMP4
#pragma omp parallel for simd if (len > cdoMinLoopSize) default(shared) schedule(static)
#endif
      for (size_t i = 0; i < len; ++i) { trend_sum(i, varray[i]); }
    }
}

void
calc_trend_sum(FieldVector3D &work, const Field &field, double zj, int varID, int levelID)
{
  auto hasMissvals = (field.numMissVals > 0);
  if (field.memType == MemType::Float)
    calc_trend_sum(work, hasMissvals, field.size, field.vec_f, static_cast<float>(field.missval), zj, varID, levelID);
  else
    calc_trend_sum(work, hasMissvals, field.size, field.vec_d, field.missval, zj, varID, levelID);
}

template <typename T>
static void
sub_trend(double zj, Varray<T> &v1, const Varray<double> &v2, const Varray<double> &v3, size_t len, double mv)
{
  auto missval1 = mv;
  auto missval2 = mv;

  auto sub_kernel = [&](auto i, auto is_EQ) { return SUBM(v1[i], ADDM(v2[i], MULM(v3[i], zj))); };

  if (std::isnan(missval1))
    {
#ifdef _OPENMP
#pragma omp parallel for if (len > cdoMinLoopSize) default(shared) schedule(static)
#endif
      for (size_t i = 0; i < len; ++i) { v1[i] = sub_kernel(i, dbl_is_equal); }
    }
  else
    {
#ifdef _OPENMP
#pragma omp parallel for if (len > cdoMinLoopSize) default(shared) schedule(static)
#endif
      for (size_t i = 0; i < len; ++i) { v1[i] = sub_kernel(i, is_equal); }
    }
}

void
sub_trend(double zj, Field &field1, const Field &field2, const Field &field3)
{
  if (field1.memType == MemType::Float)
    sub_trend(zj, field1.vec_f, field2.vec_d, field3.vec_d, field1.size, field1.missval);
  else
    sub_trend(zj, field1.vec_d, field2.vec_d, field3.vec_d, field1.size, field1.missval);
}

void
calc_trend_param(const FieldVector3D &work, Field &paramA, Field &paramB, int varID, int levelID)
{
  auto gridsize = paramA.size;
  auto missval1 = paramA.missval;
  auto missval2 = paramA.missval;

  const auto &sumj = work[0][varID][levelID].vec_d;
  const auto &sumjj = work[1][varID][levelID].vec_d;
  const auto &sumjx = work[2][varID][levelID].vec_d;
  const auto &sumx = work[3][varID][levelID].vec_d;
  const auto &zn = work[4][varID][levelID].vec_d;

  auto trend_kernel = [&](auto i, auto is_EQ) {
    auto temp1 = SUBM(sumjx[i], DIVM(MULM(sumj[i], sumx[i]), zn[i]));
    auto temp2 = SUBM(sumjj[i], DIVM(MULM(sumj[i], sumj[i]), zn[i]));
    auto temp3 = DIVM(temp1, temp2);

    paramA.vec_d[i] = SUBM(DIVM(sumx[i], zn[i]), MULM(DIVM(sumj[i], zn[i]), temp3));
    paramB.vec_d[i] = temp3;
  };

  if (std::isnan(missval1))
    for (size_t i = 0; i < gridsize; ++i) trend_kernel(i, dbl_is_equal);
  else
    for (size_t i = 0; i < gridsize; ++i) trend_kernel(i, is_equal);
}
