/*--------------------------------------------------------------------*/
/*    Copyright 2005 Sandia Corporation.                              */
/*    Under the terms of Contract DE-AC04-94AL85000, there is a       */
/*    non-exclusive license for use of this work by or on behalf      */
/*    of the U.S. Government.  Export of this program may require     */
/*    a license from the United States Government.                    */
/*--------------------------------------------------------------------*/

#ifndef _fei_MatrixTraits_Epetra_h_
#define _fei_MatrixTraits_Epetra_h_

//
//IMPORTANT NOTE: Make sure that wherever this file is included from, it
//appears BEFORE any include of fei_Vector_Impl.hpp or fei_Matrix_Impl.hpp !!!
//
#include <snl_fei_MatrixTraits.hpp>
#include <snl_fei_BlockMatrixTraits.hpp>
#include <fei_VectorTraits_Epetra.hpp>
#include <fei_Include_Trilinos.hpp>
#include <fei_Vector_Impl.hpp>

namespace snl_fei {
  /** Declare an Epetra_CrsMatrix specialization of the
      snl_fei::MatrixTraits struct.

      This allows Epetra_CrsMatrix to be used as the template parameter
      of the snl_fei::Matrix class.
  */
  template<>
  struct MatrixTraits<Epetra_CrsMatrix> {
    static const char* typeName()
      { return("Epetra_CrsMatrix"); }

    static int setValues(Epetra_CrsMatrix* mat, double scalar)
      {
        return( mat->PutScalar(scalar) );
      }

    static int getNumLocalRows(Epetra_CrsMatrix* mat, int& numRows)
    {
      numRows = mat->NumMyRows();
      return(0);
    }

    static int getRowLength(Epetra_CrsMatrix* mat, int row, int& length)
      {
	length = mat->NumGlobalEntries(row);
	if (length < 0) return(-1);
        return( 0 );
      }

    static int copyOutRow(Epetra_CrsMatrix* mat,
                      int row, int len, double* coefs, int* indices)
      {
        int dummy;
        return(mat->ExtractGlobalRowCopy(row, len, dummy, coefs, indices));
      }

    static int putValuesIn(Epetra_CrsMatrix* mat,
                     int numRows, const int* rows,
                     int numCols, const int* cols,
                     const double* const* values,
                           bool sum_into)
      {
        if (sum_into) {
          for(int i=0; i<numRows; ++i) {
            int err = mat->SumIntoGlobalValues(rows[i], numCols,
                                               (double*)values[i],
                                               (int*)cols);
            if (err != 0) {
              return(err);
            }
          }
        }
        else {
          for(int i=0; i<numRows; ++i) {
            int err = mat->ReplaceGlobalValues(rows[i], numCols,
                                               (double*)values[i],
                                               (int*)cols);
            if (err != 0) {
              return(err);
            }
          }
        }
        return(0);
      }

    static int globalAssemble(Epetra_CrsMatrix* mat)
    {
      if (!mat->Filled()) {
	int err = mat->FillComplete();
	if (err != 0) {
	  FEI_CERR << "MatrixTraits<Epetra_CrsMatrix>::globalAssemble"
		   << " ERROR in mat->FillComplete" << FEI_ENDL;
	  return(-1);
	}
      }

      if (!mat->StorageOptimized()) {
	mat->OptimizeStorage();
      }

      return( 0 );
    }

    static int matvec(Epetra_CrsMatrix* mat,
		      fei::Vector* x,
		      fei::Vector* y)
    {
      fei::Vector_Impl<Epetra_MultiVector>* evx =
	dynamic_cast<fei::Vector_Impl<Epetra_MultiVector>* >(x);
      fei::Vector_Impl<Epetra_MultiVector>* evy =
	dynamic_cast<fei::Vector_Impl<Epetra_MultiVector>* >(y);

      if (evx == NULL || evy == NULL) {
	return(-1);
      }

      Epetra_MultiVector* ex = evx->getUnderlyingVector();
      Epetra_MultiVector* ey = evy->getUnderlyingVector();

      return( mat->Multiply(false, *ex, *ey) );
    }

  };//struct MatrixTraits<Epetra_CrsMatrix>

  /** Declare an Epetra_VbrMatrix specialization of the
      snl_fei::BlockMatrixTraits struct.

      This allows Epetra_VbrMatrix to be used as the template parameter
      for the snl_fei::Matrix class.
  */
  template<>
  struct BlockMatrixTraits<Epetra_VbrMatrix> {
    static const char* typeName()
      { return("Epetra_VbrMatrix"); }

    static int putScalar(Epetra_VbrMatrix* mat, double scalar)
      {
        return( mat->PutScalar(scalar) );
      }

    static int getRowLength(Epetra_VbrMatrix* mat, int row, int& length)
      {
	length = mat->NumGlobalBlockEntries(row);
        return(0);
      }

    static int getPointRowLength(Epetra_VbrMatrix* mat, int row, int& length)
    {
      const Epetra_Map& map = mat->RowMatrixRowMap();
      int minLocalRow = map.MinMyGID();
      int localRow = row - minLocalRow;
      int error = mat->NumMyRowEntries(localRow, length);
      return(error);
    }

    static int copyOutRow(Epetra_VbrMatrix* mat,
			  int row, int numBlkCols,
			  int rowDim,
			  int* blkCols,
			  int* colDims,
			  double* coefs,
			  int coefsLen,
			  int& blkRowLength)
      {
	int checkRowDim;
	int error = mat->BeginExtractGlobalBlockRowCopy(row, numBlkCols,
							checkRowDim,
							blkRowLength,
							blkCols, colDims);
	if (error != 0 || checkRowDim != rowDim || blkRowLength != numBlkCols) {
	  return(error);
	}

	int offset = 0;
	for(int i=0; i<numBlkCols; ++i) {
	  if (offset >= coefsLen) {
	    cerr << "BlockMatrixTraits::copyOutRow ran off end of coefs array."
		 << endl;
	    return(-2);
	  }
	  int numValues = rowDim*colDims[i];
	  error = mat->ExtractEntryCopy(numValues, &(coefs[offset]),
					rowDim, false);
	  if (error != 0) {
	    return(error);
	  }
	  offset += numValues;
	}

        return(0);
      }

    static int copyOutPointRow(Epetra_VbrMatrix* mat,
			       int firstLocalOffset,
			       int row,
			       int len,
			       double* coefs,
			       int* indices,
			       int& rowLength)
      {
	int error = mat->ExtractMyRowCopy(row-firstLocalOffset,
					  len, rowLength,
					  coefs, indices);

	const Epetra_Map& colmap = mat->RowMatrixColMap();
	for(int i=0; i<len; ++i) {
	  indices[i] = colmap.GID(indices[i]);
	}

        return(error);
      }

    static int sumIn(Epetra_VbrMatrix* mat,
		     int blockRow,
		     int rowDim,
		     int numBlockCols,
		     const int* blockCols,
		     const int* colDims,
		     int LDA,
		     const double* values)
    {
      int err, *nc_cols = const_cast<int*>(blockCols);
      double* nc_values = const_cast<double*>(values);

      err = mat->BeginSumIntoGlobalValues(blockRow, numBlockCols, nc_cols);
      if (err != 0) return(err);

      int voffset = 0;
      for(int j=0; j<numBlockCols; ++j) {
	err = mat->SubmitBlockEntry(&(nc_values[voffset]), LDA,
				    rowDim, colDims[j]);
	if (err != 0) return(err);

	voffset += colDims[j]*LDA;
      }

      err = mat->EndSubmitEntries();
      if (err != 0) return(err);

      return(0);
    }

    static int copyIn(Epetra_VbrMatrix* mat,
		      int blockRow,
		      int rowDim,
		      int numBlockCols,
		      const int* blockCols,
		      const int* colDims,
		      int LDA,
		      const double* values)
    {
      int* nc_cols = const_cast<int*>(blockCols);
      double* nc_values = const_cast<double*>(values);

      int err = mat->BeginReplaceGlobalValues(blockRow, numBlockCols, nc_cols);
      if (err != 0) return(err);

      int voffset = 0;
      for(int j=0; j<numBlockCols; ++j) {
	err = mat->SubmitBlockEntry(&(nc_values[voffset]), LDA,
				    rowDim, colDims[j]);
	if (err != 0) return(err);

	voffset += colDims[j]*LDA;
      }

      err = mat->EndSubmitEntries();
      if (err != 0) return(err);

      return(0);
    }

    static int sumIn(Epetra_VbrMatrix* mat,
		     int row,
		     int rowDim,
                     int numCols,
		     const int* cols,
		     const int* LDAs,
		     const int* colDims,
                     const double* const* values)
      {
	int* nc_cols = const_cast<int*>(cols);
	double** nc_values = const_cast<double**>(values);
	int err = mat->BeginSumIntoGlobalValues(row,numCols,nc_cols);
	if (err != 0) return(err);
	for(int i=0; i<numCols; ++i) {
	  err = mat->SubmitBlockEntry(nc_values[i], LDAs[i], rowDim, colDims[i]);
	  if (err != 0) return(err);
	}
	err = mat->EndSubmitEntries();

	return(err);
      }

    static int copyIn(Epetra_VbrMatrix* mat,
		      int row,
		      int rowDim,
		      int numCols,
		      const int* cols,
		      const int* LDAs,
		      const int* colDims,
                      const double* const* values)
      {
	int* nc_cols = const_cast<int*>(cols);
	double** nc_values = const_cast<double**>(values);
	int err = mat->BeginReplaceGlobalValues(row, numCols, nc_cols);
	if (err != 0) return(err);
	for(int i=0; i<numCols; ++i) {
	  err = mat->SubmitBlockEntry(nc_values[i], LDAs[i], rowDim, colDims[i]);
	  if (err != 0) return(err);
	}
	err = mat->EndSubmitEntries();

	return(err);
      }

    static int globalAssemble(Epetra_VbrMatrix* mat)
    {
      const Epetra_Map& map = mat->RowMatrixRowMap();
      (void)map;
      return( mat->FillComplete() );
    }
  };//struct BlockMatrixTraits<Epetra_VbrMatrix>
}//namespace snl_fei
#endif // _MatrixTraits_Epetra_h_
