/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2021 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Vertint    ml2pl           Model to pressure level interpolation
      Vertint    ml2hl           Model to height level interpolation
*/

#include <cdi.h>

#include "cdo_options.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "field_vinterp.h"
#include "stdnametable.h"
#include "constants.h"
#include "util_string.h"
#include "const.h"
#include "cdo_zaxis.h"
#include "param_conversion.h"
#include "vertint_util.h"

static void
invert_vct(Varray<double> &vct)
{
  Varray<double> vctbuf = vct;
  const auto nvct = vct.size();
  for (size_t i = 0; i < nvct / 2; i++)
    {
      vct[nvct / 2 - 1 - i] = vctbuf[i];
      vct[nvct - 1 - i] = vctbuf[i + nvct / 2];
    }
}

static bool
zaxis_is_hybrid(const int zaxistype)
{
  return (zaxistype == ZAXIS_HYBRID || zaxistype == ZAXIS_HYBRID_HALF);
}

static void
change_hybrid_zaxis(int vlistID1, int vlistID2, int nvct, double *vct, int zaxisID2, int nhlevf, int nhlevh)
{
  const auto nzaxis = vlistNzaxis(vlistID1);
  for (int iz = 0; iz < nzaxis; ++iz)
    {
      const auto zaxisID = vlistZaxis(vlistID1, iz);
      const auto nlevel = zaxisInqSize(zaxisID);
      const auto zaxistype = zaxisInqType(zaxisID);

      if (zaxis_is_hybrid(zaxistype) && (nlevel == nhlevh || nlevel == nhlevf))
        {
          const auto nvct2 = zaxisInqVctSize(zaxisID);
          if (nvct2 == nvct && memcmp(vct, zaxisInqVctPtr(zaxisID), nvct * sizeof(double)) == 0)
            vlistChangeZaxisIndex(vlistID2, iz, zaxisID2);
        }
    }
}

void *
Vertintml(void *process)
{
  ModelMode mode(ModelMode::UNDEF);
  enum
  {
    func_pl,
    func_hl
  };
  enum
  {
    type_lin,
    type_log
  };
  bool sgeopot_needed = false;
  int sgeopotID = -1, geopotID = -1, tempID = -1, psID = -1, lnpsID = -1, gheightID = -1;
  char paramstr[32];
  gribcode_t gribcodes;

  cdo_initialize(process);

  // clang-format off
                         cdo_operator_add("ml2pl",     func_pl, type_lin, "pressure levels in pascal");
  const auto ML2PLX    = cdo_operator_add("ml2plx",    func_pl, type_lin, "pressure levels in pascal");
                         cdo_operator_add("ml2hl",     func_hl, type_lin, "height levels in meter");
  const auto ML2HLX    = cdo_operator_add("ml2hlx",    func_hl, type_lin, "height levels in meter");
                         cdo_operator_add("ml2pl_lp",  func_pl, type_log, "pressure levels in pascal");
  const auto ML2PLX_LP = cdo_operator_add("ml2plx_lp", func_pl, type_log, "pressure levels in pascal");
                         cdo_operator_add("ml2hl_lp",  func_hl, type_log, "height levels in meter");
  const auto ML2HLX_LP = cdo_operator_add("ml2hlx_lp", func_hl, type_log, "height levels in meter");
  // clang-format on

  const auto operatorID = cdo_operator_id();
  const auto useHeightLevel = (cdo_operator_f1(operatorID) == func_hl);
  const auto useLogType = (cdo_operator_f2(operatorID) == type_log);

  auto extrapolate = (operatorID == ML2PLX || operatorID == ML2HLX || operatorID == ML2PLX_LP || operatorID == ML2HLX_LP);
  if (extrapolate == false) extrapolate = getenv_extrapolate();

  operator_input_arg(cdo_operator_enter(operatorID));

  std::vector<double> plev;
  if (cdo_operator_argc() == 1 && cdo_operator_argv(0) == "default")
    {
      if (useHeightLevel)
        plev = { 10, 50, 100, 500, 1000, 5000, 10000, 15000, 20000, 25000, 30000 };
      else
        plev
            = { 100000, 92500, 85000, 70000, 60000, 50000, 40000, 30000, 25000, 20000, 15000, 10000, 7000, 5000, 3000, 2000, 1000 };
    }
  else
    {
      plev = cdo_argv_to_flt(cdo_get_oper_argv());
    }

  int nplev = plev.size();

  const auto streamID1 = cdo_open_read(0);

  const auto vlistID1 = cdo_stream_inq_vlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  const auto gridsize = vlist_check_gridsize(vlistID1);

  const auto zaxisIDp = zaxisCreate(useHeightLevel ? ZAXIS_HEIGHT : ZAXIS_PRESSURE, nplev);
  zaxisDefLevels(zaxisIDp, plev.data());

  int nvct = 0;
  int zaxisIDh = -1;
  int nhlev = 0, nhlevf = 0, nhlevh = 0;
  Varray<double> vct;
  vlist_read_vct(vlistID1, &zaxisIDh, &nvct, &nhlev, &nhlevf, &nhlevh, vct);

  change_hybrid_zaxis(vlistID1, vlistID2, nvct, vct.data(), zaxisIDp, nhlevf, nhlevh);

  int psvarID = -1;
  bool linvertvct = false;
  if (!vct.empty() && nvct && nvct % 2 == 0)
    {
      psvarID = vlist_get_psvarid(vlistID1, zaxisIDh);

      int i;
      for (i = nvct / 2 + 1; i < nvct; i++)
        if (vct[i] > vct[i - 1]) break;
      if (i == nvct) linvertvct = true;
    }

  if (Options::cdoVerbose) cdo_print("linvertvct = %d", static_cast<int>(linvertvct));

  if (linvertvct) invert_vct(vct);

  VarList varList1;
  varListInit(varList1, vlistID1);
  varListSetUniqueMemtype(varList1);
  const auto memtype = varList1[0].memType;

  VarList varList2;
  varListInit(varList2, vlistID2);
  varListSetMemtype(varList2, memtype);

  const auto nvars = vlistNvars(vlistID1);

  std::vector<bool> vars(nvars), varinterp(nvars);
  std::vector<std::vector<size_t>> varnmiss(nvars);
  Field3DVector vardata1(nvars), vardata2(nvars);

  const auto maxlev = nhlevh > nplev ? nhlevh : nplev;

  std::vector<size_t> pnmiss;
  if (!extrapolate) pnmiss.resize(nplev);

  // check levels
  if (zaxisIDh != -1)
    {
      const auto nlev = zaxisInqSize(zaxisIDh);
      if (nlev != nhlev) cdo_abort("Internal error, wrong number of hybrid level!");
    }

  std::vector<int> vertIndex;

  Field3D fullPress, halfPress;
  if (zaxisIDh != -1 && gridsize > 0)
    {
      vertIndex.resize(gridsize * nplev);

      CdoVar var3Df, var3Dh;
      var3Df.gridsize = gridsize;
      var3Df.nlevels = nhlevf;
      var3Df.memType = memtype;
      fullPress.init(var3Df);

      var3Dh.gridsize = gridsize;
      var3Dh.nlevels = nhlevh;
      var3Dh.memType = memtype;
      halfPress.init(var3Dh);
    }
  else
    cdo_warning("No 3D variable with hybrid sigma pressure coordinate found!");

  if (useHeightLevel)
    {
      std::vector<double> phlev(nplev);
      height_to_pressure(phlev.data(), plev.data(), nplev);

      if (Options::cdoVerbose)
        for (int i = 0; i < nplev; ++i) cdo_print("level=%d  height=%g  pressure=%g", i + 1, plev[i], phlev[i]);

      plev = phlev;
    }

  if (useLogType)
    for (int k = 0; k < nplev; k++) plev[k] = std::log(plev[k]);

  bool useTable = false;
  for (int varID = 0; varID < nvars; varID++)
    {
      const auto tableNum = tableInqNum(vlistInqVarTable(vlistID1, varID));
      if (tableNum > 0 && tableNum != 255)
        {
          useTable = true;
          break;
        }
    }

  if (Options::cdoVerbose && useTable) cdo_print("Using code tables!");

  for (int varID = 0; varID < nvars; varID++)
    {
      const auto gridID = varList1[varID].gridID;
      const auto zaxisID = varList1[varID].zaxisID;
      const auto zaxistype = zaxisInqType(zaxisID);
      const auto nlevels = varList1[varID].nlevels;
      const auto instNum = institutInqCenter(vlistInqVarInstitut(vlistID1, varID));
      const auto tableNum = tableInqNum(vlistInqVarTable(vlistID1, varID));

      auto code = varList1[varID].code;

      const auto param = varList1[varID].param;
      cdiParamToString(param, paramstr, sizeof(paramstr));
      int pnum, pcat, pdis;
      cdiDecodeParam(param, &pnum, &pcat, &pdis);
      if (pdis >= 0 && pdis < 255) code = -1;

      if (useTable)
        {
          if (tableNum == 2)
            {
              mode = ModelMode::WMO;
              wmo_gribcodes(&gribcodes);
            }
          else if (tableNum == 128 || tableNum == 0)
            {
              mode = ModelMode::ECHAM;
              echam_gribcodes(&gribcodes);
            }
          //  KNMI: HIRLAM model version 7.2 uses tableNum=1    (LAMH_D11*)
          //  KNMI: HARMONIE model version 36 uses tableNum=1   (grib*) (opreational NWP version)
          //  KNMI: HARMONIE model version 38 uses tableNum=253 (grib,grib_md) and tableNum=1 (grib_sfx) (research version)
          else if (tableNum == 1 || tableNum == 253)
            {
              mode = ModelMode::HIRLAM;
              hirlam_harmonie_gribcodes(&gribcodes);
            }
        }
      else
        {
          mode = ModelMode::ECHAM;
          echam_gribcodes(&gribcodes);
        }

      if (Options::cdoVerbose)
        {
          cdo_print("Mode=%d  Center=%d  TableNum=%d  Code=%d  Param=%s  Varname=%s  varID=%d", mode, instNum, tableNum, code,
                    paramstr, varList1[varID].name, varID);
        }

      if (code <= 0 || code == 255)
        {
          char varname[CDI_MAX_NAME];
          vlistInqVarName(vlistID1, varID, varname);
          cstr_to_lower_case(varname);

          char stdname[CDI_MAX_NAME];
          int length = CDI_MAX_NAME;
          cdiInqKeyString(vlistID1, varID, CDI_KEY_STDNAME, stdname, &length);
          cstr_to_lower_case(stdname);

          code = echamcode_from_stdname(stdname);
          if (code == -1)
            {
              //                                       ECHAM                         ECMWF
              // clang-format off
              if      (sgeopotID == -1 && (cdo_cmpstr(varname, "geosp") || cdo_cmpstr(varname, "z"))) code = gribcodes.geopot;
              else if (tempID == -1    && (cdo_cmpstr(varname, "st")    || cdo_cmpstr(varname, "t"))) code = gribcodes.temp;
              else if (psID == -1      && (cdo_cmpstr(varname, "aps")   || cdo_cmpstr(varname, "sp"))) code = gribcodes.ps;
              else if (lnpsID == -1    && (cdo_cmpstr(varname, "lsp")   || cdo_cmpstr(varname, "lnsp"))) code = gribcodes.lsp;
              else if (geopotID == -1  &&  cdo_cmpstr(stdname, "geopotential_full")) code = gribcodes.geopot;
              // else if (cdo_cmpstr(varname, "geopoth")) code = 156;
              // clang-format on
            }
        }

      if (mode == ModelMode::ECHAM)
        {
          // clang-format off
          if      (code == gribcodes.geopot  && nlevels == 1) sgeopotID = varID;
          else if (code == gribcodes.geopot  && nlevels == nhlevf) geopotID = varID;
          else if (code == gribcodes.temp    && nlevels == nhlevf) tempID = varID;
          else if (code == gribcodes.ps      && nlevels == 1) psID = varID;
          else if (code == gribcodes.lsp     && nlevels == 1) lnpsID = varID;
          else if (code == gribcodes.gheight && nlevels == nhlevf) gheightID = varID;
          // clang-format on
        }
      else if (mode == ModelMode::WMO || mode == ModelMode::HIRLAM)
        {
          // clang-format off
          if      (code == gribcodes.geopot && nlevels == 1) sgeopotID = varID;
          else if (code == gribcodes.geopot && nlevels == nhlevf) geopotID = varID;
          else if (code == gribcodes.temp   && nlevels == nhlevf) tempID = varID;
          else if (code == gribcodes.ps     && nlevels == 1) psID = varID;
          // clang-format on
        }

      if (gridInqType(gridID) == GRID_SPECTRAL && zaxis_is_hybrid(zaxistype))
        cdo_abort("Spectral data on model level unsupported!");

      if (gridInqType(gridID) == GRID_SPECTRAL) cdo_abort("Spectral data unsupported!");

      if (varID == gheightID) varList1[varID].nlevels = nlevels + 1;
      vardata1[varID].init(varList1[varID]);
      if (varID == gheightID) varList1[varID].nlevels = nlevels;

      // varinterp[varID] = ( zaxis_is_hybrid(zaxistype) && zaxisIDh != -1 && nlevels == nhlev );
      varinterp[varID]
          = (zaxisID == zaxisIDh || (zaxis_is_hybrid(zaxistype) && zaxisIDh != -1 && (nlevels == nhlevh || nlevels == nhlevf)));

      if (varinterp[varID])
        {
          varnmiss[varID].resize(maxlev, 0);
          vardata2[varID].init(varList2[varID]);
        }
      else
        {
          varnmiss[varID].resize(nlevels);
          if (zaxis_is_hybrid(zaxistype) && zaxisIDh != -1 && nlevels > 1)
            cdo_warning("Parameter %d has wrong number of levels, skipped! (param=%s nlevel=%d)", varID + 1, varList1[varID].name,
                        nlevels);
        }
    }

  if (Options::cdoVerbose)
    {
      cdo_print("Found:");
      // clang-format off
      if (-1 != tempID)    cdo_print("  %s -> %s", var_stdname(air_temperature), varList1[tempID].name);
      if (-1 != psID)      cdo_print("  %s -> %s", var_stdname(surface_air_pressure), varList1[psID].name);
      if (-1 != lnpsID)    cdo_print("  LOG(%s) -> %s", var_stdname(surface_air_pressure), varList1[lnpsID].name);
      if (-1 != sgeopotID) cdo_print("  %s -> %s", var_stdname(surface_geopotential), varList1[sgeopotID].name);
      if (-1 != geopotID)  cdo_print("  %s -> %s", var_stdname(geopotential), varList1[geopotID].name);
      if (-1 != gheightID) cdo_print("  %s -> %s", var_stdname(geopotential_height), varList1[gheightID].name);
      // clang-format on
    }

  if (tempID != -1 || gheightID != -1) sgeopot_needed = true;

  if (zaxisIDh != -1 && gheightID != -1 && tempID == -1)
    cdo_abort("%s not found, needed for vertical interpolation of %s!", var_stdname(air_temperature),
              var_stdname(geopotential_height));

  auto presID = lnpsID;
  if (psvarID != -1) presID = psvarID;

  if (zaxisIDh != -1 && presID == -1)
    {
      if (psID == -1)
        cdo_abort("%s not found!", var_stdname(surface_air_pressure));
      else
        presID = psID;
    }

  if (Options::cdoVerbose)
    {
      if (presID == lnpsID)
        cdo_print("using LOG(%s) from %s", var_stdname(surface_air_pressure), varList1[presID].name);
      else
        cdo_print("using %s from %s", var_stdname(surface_air_pressure), varList1[presID].name);
    }

  Field psProg;
  psProg.init(varList1[presID]);

  Field sgeopot;
  if (zaxisIDh != -1 && sgeopot_needed)
    {
      sgeopot.init(varList1[presID]);
      if (sgeopotID == -1)
        {
          if (extrapolate)
            {
              if (geopotID == -1)
                cdo_warning("%s not found - set to zero!", var_stdname(surface_geopotential));
              else
                cdo_print("%s not found - using bottom layer of %s!", var_stdname(surface_geopotential), var_stdname(geopotential));
            }
          field_fill(sgeopot, 0.0);
        }
    }

  // check VCT
  if (zaxisIDh != -1)
    {
      double suma = 0.0, sumb = 0.0;
      for (int i = 0; i < nhlevh; ++i) suma += vct[i];
      for (int i = 0; i < nhlevh; ++i) sumb += vct[i + nhlevh];
      if (!(suma > 0.0 || sumb > 0.0)) cdo_warning("VCT is empty!");
    }

  const auto streamID2 = cdo_open_write(1);
  cdo_def_vlist(streamID2, vlistID2);

  int tsID = 0;
  while (true)
    {
      const auto nrecs = cdo_stream_inq_timestep(streamID1, tsID);
      if (nrecs == 0) break;

      for (int varID = 0; varID < nvars; ++varID) vars[varID] = false;

      taxisCopyTimestep(taxisID2, taxisID1);
      cdo_def_timestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          cdo_inq_record(streamID1, &varID, &levelID);

          const auto zaxisID = varList1[varID].zaxisID;
          const auto nlevels = varList1[varID].nlevels;
          if (linvertvct && zaxisIDh != -1 && zaxisID == zaxisIDh) levelID = nlevels - 1 - levelID;

          cdo_read_record(streamID1, vardata1[varID], levelID, &varnmiss[varID][levelID]);

          vars[varID] = true;
        }

      if (zaxisIDh != -1)
        {
          if (sgeopot_needed)
            {
              if (sgeopotID != -1)
                field_copy(vardata1[sgeopotID], sgeopot);
              else if (geopotID != -1)
                field_copy(vardata1[geopotID], nhlevf - 1, sgeopot);

              // check range of surface geopot
              if (extrapolate && (sgeopotID != -1 || geopotID != -1))
                {
                  const auto mm = field_min_max(sgeopot);
                  if (mm.min < MIN_FIS || mm.max > MAX_FIS)
                    cdo_warning("Surface geopotential out of range (min=%g max=%g) [timestep:%d]!", mm.min, mm.max, tsID + 1);
                  if (gridsize > 1 && mm.min >= 0.0 && mm.max <= 9000.0 && IS_NOT_EQUAL(mm.min, mm.max))
                    cdo_warning("Surface geopotential has an unexpected range (min=%g max=%g) [timestep:%d]!", mm.min, mm.max,
                                tsID + 1);
                }
            }

          if (presID == lnpsID)
            field_transform(vardata1[lnpsID], psProg, op_exp);
          else if (presID != -1)
            field_copy(vardata1[presID], psProg);

          // check range of psProg
          const auto mm = field_min_max(psProg);
          if (mm.min < MIN_PS || mm.max > MAX_PS)
            cdo_warning("Surface pressure out of range (min=%g max=%g) [timestep:%d]!", mm.min, mm.max, tsID + 1);

          if (memtype == MemType::Float)
            vct_to_hybrid_pressure(fullPress.vec_f.data(), halfPress.vec_f.data(), vct.data(), psProg.vec_f.data(), nhlevf,
                                   gridsize);
          else
            vct_to_hybrid_pressure(fullPress.vec_d.data(), halfPress.vec_d.data(), vct.data(), psProg.vec_d.data(), nhlevf,
                                   gridsize);

          if (useLogType)
            {
              field_transform(psProg, psProg, op_log);
              field_transform(halfPress, halfPress, op_log);
              field_transform(fullPress, fullPress, op_log);
            }

          gen_vert_index(vertIndex, plev, fullPress, gridsize);

          if (!extrapolate) gen_vert_index_mv(vertIndex, plev, gridsize, psProg, pnmiss);
        }

      for (int varID = 0; varID < nvars; varID++)
        {
          if (vars[varID])
            {
              if (tsID > 0 && varList1[varID].timetype == TIME_CONSTANT) continue;

              if (varinterp[varID])
                {
                  const auto nlevels = varList1[varID].nlevels;
                  if (nlevels != nhlevf && nlevels != nhlevh)
                    cdo_abort("Number of hybrid level differ from full/half level (param=%s)!", varList1[varID].name);

                  for (int levelID = 0; levelID < nlevels; levelID++)
                    {
                      if (varnmiss[varID][levelID]) cdo_abort("Missing values unsupported for this operator!");
                    }

                  if (varID == tempID)
                    {
                      if (nlevels == nhlevh) cdo_abort("Temperature on half level unsupported!");

                      if (useLogType && extrapolate) cdo_abort("Log. extrapolation of temperature unsupported!");

                      vertical_interp_T(nlevels, fullPress, halfPress, vardata1[varID], vardata2[varID], sgeopot, vertIndex, plev,
                                        gridsize);
                    }
                  else if (varID == gheightID)
                    {
                      if (memtype == MemType::Float)
                        for (size_t i = 0; i < gridsize; ++i)
                          vardata1[varID].vec_f[gridsize * nlevels + i] = sgeopot.vec_f[i] / PlanetGrav;
                      else
                        for (size_t i = 0; i < gridsize; ++i)
                          vardata1[varID].vec_d[gridsize * nlevels + i] = sgeopot.vec_d[i] / PlanetGrav;

                      vertical_interp_Z(nlevels, fullPress, halfPress, vardata1[varID], vardata2[varID], vardata1[tempID], sgeopot,
                                        vertIndex, plev, gridsize);
                    }
                  else
                    {
                      vertical_interp_X(nlevels, fullPress, halfPress, vardata1[varID], vardata2[varID], vertIndex, plev, gridsize);
                    }

                  if (!extrapolate) varray_copy(nplev, pnmiss, varnmiss[varID]);
                }

              for (int levelID = 0; levelID < varList2[varID].nlevels; levelID++)
                {
                  cdo_def_record(streamID2, varID, levelID);
                  cdo_write_record(streamID2, varinterp[varID] ? vardata2[varID] : vardata1[varID], levelID,
                                   varnmiss[varID][levelID]);
                }
            }
        }

      tsID++;
    }

  cdo_stream_close(streamID2);
  cdo_stream_close(streamID1);

  cdo_finish();

  return nullptr;
}
