//===--- TosaProfileCompliance.cpp - Tosa Profile Compliance Validation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaProfileCompliance.h"
#include "llvm/ADT/StringExtras.h"

using namespace mlir;
using namespace mlir::tosa;

TosaProfileCompliance::TosaProfileCompliance() {
  const TypeInfo boolT = {mlir::IntegerType::getTypeID(), 1};
  const TypeInfo i4T = {mlir::IntegerType::getTypeID(), 4};
  const TypeInfo i8T = {mlir::IntegerType::getTypeID(), 8};
  const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
  const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
  const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
  const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
  const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
  const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
  const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
  const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};

// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
  // End of auto-generated metadata
}

template <>
OperationProfileComplianceMap TosaProfileCompliance::getProfileComplianceMap() {
  return profileComplianceMap;
}

template <>
OperationExtensionComplianceMap
TosaProfileCompliance::getProfileComplianceMap() {
  return extensionComplianceMap;
}

// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
                                                    Value output) {
  for (auto operand : operands)
    addValue(operand);
  addValue(output);
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
  addValue(op.getInput1().front());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
  addValue(op.getInput());
  addValue(op.getInputZp());
  addValue(op.getOutputZp());
  addType(op.getAccType());
  addValue(op.getOutput());
  return success();
}

template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
  addValue(op.getInput());
  addValue(op.getWeight());
  addValue(op.getBias());
  addValue(op.getInputZp());
  addValue(op.getWeightZp());
  addType(op.getAccType());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::Conv3DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::TransposeConv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult
ProfileInfoDepot::populateProfileInfo(tosa::DepthwiseConv2DOp op) {
  return populateProfileInfoConv(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::PadOp op) {
  addValue(op.getInput1());
  addValue(op.getPadConst());
  addValue(op.getOutput());
  return success();
}

template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoDataLayout(T op) {
  addValue(op.getInput1());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ReshapeOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SliceOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TileOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::TransposeOp op) {
  return populateProfileInfoDataLayout(op);
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::GatherOp op) {
  addValue(op.getValues());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ScatterOp op) {
  addValue(op.getValuesIn());
  addValue(op.getInput());
  addValue(op.getValuesOut());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MulOp op) {
  addValue(op.getInput1());
  addValue(op.getInput2());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
  addValue(op.getInput());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
  addValue(op.getInputReal());
  addValue(op.getInputImag());
  addValue(op.getOutputReal());
  addValue(op.getOutputImag());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
  addValue(op.getInputReal());
  addValue(op.getOutputReal());
  addValue(op.getOutputImag());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
  addValue(op.getInput2());
  addValue(op.getInput3());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
  addValue(op.getInput());
  addValue(op.getInputZp());
  addValue(op.getOutputZp());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
  addValue(op.getA());
  addValue(op.getB());
  addValue(op.getAZp());
  addValue(op.getBZp());
  addValue(op.getOutput());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) {
  ::mlir::Attribute attr = op.getInitialValueAttr();
  if (attr == nullptr)
    return failure();

  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
    addType(getElementTypeOrSelf(typedAttr));
    return success();
  }
  return failure();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::IfOp op) {
  addValue(op.getCondition());
  return success();
}

template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::WhileOp op) {
  Block *block = &op.getCondGraph().front();
  Operation *terminator = block->getTerminator();
  addValue(terminator->getOperands().front());
  return success();
}

LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp)                                   \
  if (isa<tosa::tosaOp##Op>(op)) {                                             \
    return populateProfileInfo(cast<tosa::tosaOp##Op>(op));                    \
  }

#define POPULATE_PROFILE_INFO_SKIP(tosaOp)                                     \
  if (isa<tosa::tosaOp##Op>(op))                                               \
    return success();

// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp)                                   \
  if (isa<tosa::tosaOp##Op>(op)) {                                             \
    return populateProfileInfo(op->getOperands(), op->getResult(0));           \
  }

  // Skip irrelevant operands when they are independent and not tied to any
  // specific profile/extension.
  POPULATE_PROFILE_INFO_CUSTOM(AvgPool2d)
  POPULATE_PROFILE_INFO_CUSTOM(TransposeConv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Conv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
  POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
  POPULATE_PROFILE_INFO_CUSTOM(Mul)
  POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
  POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
  POPULATE_PROFILE_INFO_CUSTOM(Concat)
  POPULATE_PROFILE_INFO_CUSTOM(Pad)
  POPULATE_PROFILE_INFO_CUSTOM(Reshape)
  POPULATE_PROFILE_INFO_CUSTOM(Slice)
  POPULATE_PROFILE_INFO_CUSTOM(Tile)
  POPULATE_PROFILE_INFO_CUSTOM(Transpose)
  POPULATE_PROFILE_INFO_CUSTOM(Gather)
  POPULATE_PROFILE_INFO_CUSTOM(Scatter)
  POPULATE_PROFILE_INFO_CUSTOM(Resize)
  POPULATE_PROFILE_INFO_CUSTOM(Select)
  POPULATE_PROFILE_INFO_CUSTOM(Rescale)
  POPULATE_PROFILE_INFO_CUSTOM(MatMul)
  POPULATE_PROFILE_INFO_CUSTOM(Variable)
  POPULATE_PROFILE_INFO_CUSTOM(If)
  POPULATE_PROFILE_INFO_CUSTOM(While)

  // For the most of tosa operators, all operands are profile/extension related
  // and hence are all considered in this profile-based compilance check.
  POPULATE_PROFILE_INFO_COMMON(Cast)
  POPULATE_PROFILE_INFO_COMMON(Const)
  POPULATE_PROFILE_INFO_COMMON(ArgMax)
  POPULATE_PROFILE_INFO_COMMON(Sub)
  POPULATE_PROFILE_INFO_COMMON(Maximum)
  POPULATE_PROFILE_INFO_COMMON(Minimum)
  POPULATE_PROFILE_INFO_COMMON(MaxPool2d)
  POPULATE_PROFILE_INFO_COMMON(Clamp)
  POPULATE_PROFILE_INFO_COMMON(Erf)
  POPULATE_PROFILE_INFO_COMMON(Sigmoid)
  POPULATE_PROFILE_INFO_COMMON(Tanh)
  POPULATE_PROFILE_INFO_COMMON(Add)
  POPULATE_PROFILE_INFO_COMMON(ArithmeticRightShift)
  POPULATE_PROFILE_INFO_COMMON(BitwiseAnd)
  POPULATE_PROFILE_INFO_COMMON(BitwiseNot)
  POPULATE_PROFILE_INFO_COMMON(BitwiseOr)
  POPULATE_PROFILE_INFO_COMMON(BitwiseXor)
  POPULATE_PROFILE_INFO_COMMON(LogicalLeftShift)
  POPULATE_PROFILE_INFO_COMMON(LogicalRightShift)
  POPULATE_PROFILE_INFO_COMMON(LogicalAnd)
  POPULATE_PROFILE_INFO_COMMON(LogicalNot)
  POPULATE_PROFILE_INFO_COMMON(LogicalOr)
  POPULATE_PROFILE_INFO_COMMON(LogicalXor)
  POPULATE_PROFILE_INFO_COMMON(IntDiv)
  POPULATE_PROFILE_INFO_COMMON(Pow)
  POPULATE_PROFILE_INFO_COMMON(Table)
  POPULATE_PROFILE_INFO_COMMON(Abs)
  POPULATE_PROFILE_INFO_COMMON(Ceil)
  POPULATE_PROFILE_INFO_COMMON(Clz)
  POPULATE_PROFILE_INFO_COMMON(Sin)
  POPULATE_PROFILE_INFO_COMMON(Cos)
  POPULATE_PROFILE_INFO_COMMON(Exp)
  POPULATE_PROFILE_INFO_COMMON(Floor)
  POPULATE_PROFILE_INFO_COMMON(Log)
  POPULATE_PROFILE_INFO_COMMON(Negate)
  POPULATE_PROFILE_INFO_COMMON(Reciprocal)
  POPULATE_PROFILE_INFO_COMMON(Rsqrt)
  POPULATE_PROFILE_INFO_COMMON(ReduceAll)
  POPULATE_PROFILE_INFO_COMMON(ReduceAny)
  POPULATE_PROFILE_INFO_COMMON(ReduceMax)
  POPULATE_PROFILE_INFO_COMMON(ReduceMin)
  POPULATE_PROFILE_INFO_COMMON(ReduceProduct)
  POPULATE_PROFILE_INFO_COMMON(ReduceSum)
  POPULATE_PROFILE_INFO_COMMON(Equal)
  POPULATE_PROFILE_INFO_COMMON(GreaterEqual)
  POPULATE_PROFILE_INFO_COMMON(Greater)
  POPULATE_PROFILE_INFO_COMMON(Reverse)
  POPULATE_PROFILE_INFO_COMMON(Identity)
  POPULATE_PROFILE_INFO_COMMON(VariableRead)
  POPULATE_PROFILE_INFO_COMMON(VariableWrite)

  // Type Invariant Extension, a capability extension that is independent
  // of the data type, meaning any compatible type can be used. No type
  // constraint for those operations.
  POPULATE_PROFILE_INFO_SKIP(ConstShape)
  POPULATE_PROFILE_INFO_SKIP(Yield)

  return failure();
}

//===----------------------------------------------------------------------===//
// Tosa Profile And Extension Compliance Checker
//===----------------------------------------------------------------------===//

template <typename T>
LogicalResult TosaProfileCompliance::checkProfileOrExtension(
    Operation *op, const tosa::TargetEnv &targetEnv,
    const SmallVector<ArrayRef<T>> &specRequiredModeSet) {

  // None of profile requirement is set in the specification.
  if (specRequiredModeSet.size() == 0)
    return success();

  auto opName = op->getName().getStringRef().str();
  auto compMap = getProfileComplianceMap<T>();
  auto it = compMap.find(opName);

  if (it == compMap.end()) {
    // Operators such as control-flow and shape ops do not have an operand type
    // restriction. When the profile compliance information of operation is not
    // found, confirm if the target have enabled the profile required from the
    // specification.
    int mode_count = 0;
    for (const auto &cands : specRequiredModeSet) {
      if (targetEnv.allowsAnyOf(cands))
        return success();
      mode_count += cands.size();
    }

    op->emitOpError() << "illegal: requires"
                      << (mode_count > 1 ? " any of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(specRequiredModeSet),
                                    ", ")
                      << "] but not enabled in target\n";

    return failure();
  }

  CheckCondition condition = CheckCondition::invalid;
  // Find the profiles or extensions requirement according to the signature of
  // type of the operand list.
  SmallVector<T> opRequiredMode =
      findMatchedProfile<T>(op, it->second, condition);

  if (opRequiredMode.size() == 0) {
    // No matched restriction found.
    return success();
  }

  if (condition == CheckCondition::allOf &&
      !targetEnv.allowsAllOf(opRequiredMode)) {
    op->emitOpError() << "illegal: requires"
                      << (opRequiredMode.size() > 1 ? " all of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
                      << "] but not enabled in target\n";
    return failure();
  }

  if (condition == CheckCondition::anyOf &&
      !targetEnv.allowsAnyOf(opRequiredMode)) {
    op->emitOpError() << "illegal: requires"
                      << (opRequiredMode.size() > 1 ? " any of " : " ") << "["
                      << llvm::join(stringifyProfile<T>(opRequiredMode), ", ")
                      << "] but not enabled in target\n";
    return failure();
  }

  // Each extension can contain a list of profiles that it works with, usually
  // have the same data type.
  if constexpr (std::is_same_v<T, Extension>) {
    for (const auto &mode : opRequiredMode) {
      SmallVector<Profile> coProfs = getCooperativeProfiles(mode);
      if (!targetEnv.allowsAnyOf(coProfs)) {
        op->emitOpError() << "illegal: requires ["
                          << llvm::join(stringifyProfile<Profile>(coProfs),
                                        ", ")
                          << "] to work with but not enabled in target\n";
        return failure();
      }
    }
  }

  // Ensure the profile inference match the profile knowledge of the
  // specification.
  for (const auto &cands : specRequiredModeSet) {
    for (size_t i = 0; i < opRequiredMode.size(); i++) {
      if (std::find(cands.begin(), cands.end(), opRequiredMode[i]) ==
          cands.end()) {
        op->emitOpError() << "illegal: requires ["
                          << llvm::join(stringifyProfile<T>(opRequiredMode),
                                        ", ")
                          << "] but not included in the profile compliance ["
                          << llvm::join(
                                 stringifyProfile<T>(specRequiredModeSet), ", ")
                          << "]\n";
        return failure();
      }
    }
  }

  return success();
}

LogicalResult
TosaProfileCompliance::checkProfile(Operation *op,
                                    const tosa::TargetEnv &targetEnv) {
  if (auto interface = dyn_cast<tosa::QueryProfileInterface>(op))
    return checkProfileOrExtension<Profile>(op, targetEnv,
                                            interface.getProfiles());

  return success();
}

LogicalResult
TosaProfileCompliance::checkExtension(Operation *op,
                                      const tosa::TargetEnv &targetEnv) {
  if (auto interface = dyn_cast<tosa::QueryExtensionInterface>(op))
    return checkProfileOrExtension<Extension>(op, targetEnv,
                                              interface.getExtensions());

  return success();
}

// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
    Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
    CheckCondition &condition) {
  assert(compInfo.size() != 0 &&
         "profile-based compliance information is empty");

  // Populate the type of profile/extension relevant operands.
  ProfileInfoDepot depot(op);
  SmallVector<TypeInfo> present = depot.getInfo();
  if (present.size() == 0)
    return {};

  for (size_t i = 0; i < compInfo.size(); i++) {
    SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;

    for (SmallVector<TypeInfo> expected : sets) {
      assert(present.size() == expected.size() &&
             "the entries for profile-based compliance do not match between "
             "the generated metadata and the type definition retrieved from "
             " the operation");

      bool is_found = true;
      // Compare the type signature between the given operation and the
      // compliance metadata.
      for (size_t j = 0; j < expected.size(); j++) {
        if (!isSameTypeInfo(present[j], expected[j])) {
          // Verify the next mode set from the list.
          is_found = false;
          break;
        }
      }

      if (is_found == true) {
        condition = compInfo[i].condition;
        return compInfo[i].mode;
      }
    }
  }

  return {};
}

// Debug utilites.
template <typename T>
SmallVector<StringRef>
TosaProfileCompliance::stringifyProfile(ArrayRef<T> profiles) {
  SmallVector<StringRef> debugStrings;
  for (const auto &profile : profiles) {
    if constexpr (std::is_same_v<T, Profile>)
      debugStrings.push_back(tosa::stringifyProfile(profile));
    else
      debugStrings.push_back(tosa::stringifyExtension(profile));
  }
  return debugStrings;
}

template <typename T>
SmallVector<StringRef> TosaProfileCompliance::stringifyProfile(
    const SmallVector<ArrayRef<T>> &profileSet) {
  SmallVector<StringRef> debugStrings;

  for (const auto &profiles : profileSet) {
    auto tempStrings = stringifyProfile<T>(profiles);
    debugStrings.insert(debugStrings.end(), tempStrings.begin(),
                        tempStrings.end());
  }

  return debugStrings;
}
