#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include "FragCoords.h"
#include "DLTmath.h"
#include "pdbUtils.h"
#include "pdbStats.h"
#include "CovMat.h"


static double
ComputeE0Frag(const FragCoords *coords1, const FragCoords *coords2)
{   
    int             i;
    double          sum;
    const double   *x2 = (const double *) coords2->x,
                   *y2 = (const double *) coords2->y,
                   *z2 = (const double *) coords2->z;
    const double   *x1 = (const double *) coords1->x,
                   *y1 = (const double *) coords1->y,
                   *z1 = (const double *) coords1->z;
    double          x1i, y1i, z1i, x2i, y2i, z2i;


    sum = 0.0;
    i = coords1->fraglen;
    while(i-- > 0)
    {
        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;

        sum += ((x1i * x1i + x2i * x2i) +
                (y1i * y1i + y2i * y2i) +
                (z1i * z1i + z2i * z2i));
    }

    sum *= 0.5;

    return(sum);
}


static double
ComputeE0(const Coords *coords1, const Coords *coords2)
{   
    int             i;
    double          sum;
    const double   *x2 = (const double *) coords2->x,
                   *y2 = (const double *) coords2->y,
                   *z2 = (const double *) coords2->z;
    const double   *x1 = (const double *) coords1->x,
                   *y1 = (const double *) coords1->y,
                   *z1 = (const double *) coords1->z;
    double          x1i, y1i, z1i, x2i, y2i, z2i;


    sum = 0.0;
    i = coords1->vlen;
    while(i-- > 0)
    {
        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;

        sum += ((x1i * x1i + x2i * x2i) +
                (y1i * y1i + y2i * y2i) +
                (z1i * z1i + z2i * z2i));
    }

    sum *= 0.5;

    return(sum);
}


/* This function assumes that the coordinates have been centered previously!! *
 * Use CenMass(coords_ptr) and ApplyCenter(coords_ptr)                     */
static void
ComputeRFrag(const FragCoords *coords1, const FragCoords *coords2, double **Rmat)
{
    int             i;
    const double   *x2 = (const double *) coords2->x,
                   *y2 = (const double *) coords2->y,
                   *z2 = (const double *) coords2->z;
    const double   *x1 = (const double *) coords1->x,
                   *y1 = (const double *) coords1->y,
                   *z1 = (const double *) coords1->z;
    double          x2i, y2i, z2i, x1i, y1i, z1i;
    double          Rmat00, Rmat01, Rmat02,
                    Rmat10, Rmat11, Rmat12,
                    Rmat20, Rmat21, Rmat22;

    Rmat00 = Rmat01 = Rmat02 = 
    Rmat10 = Rmat11 = Rmat12 =
    Rmat20 = Rmat21 = Rmat22 = 0.0;

    i = coords1->fraglen;
    while(i-- > 0)
    {
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;

        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;

        Rmat00 += x2i * x1i;
        Rmat01 += x2i * y1i;
        Rmat02 += x2i * z1i;
        
        Rmat10 += y2i * x1i;
        Rmat11 += y2i * y1i;
        Rmat12 += y2i * z1i;
        
        Rmat20 += z2i * x1i;
        Rmat21 += z2i * y1i;
        Rmat22 += z2i * z1i;
    }

    Rmat[0][0] = Rmat00;
    Rmat[0][1] = Rmat01;
    Rmat[0][2] = Rmat02;
    Rmat[1][0] = Rmat10;
    Rmat[1][1] = Rmat11;
    Rmat[1][2] = Rmat12;
    Rmat[2][0] = Rmat20;
    Rmat[2][1] = Rmat21;
    Rmat[2][2] = Rmat22;
}


/* This function assumes that the coordinates have been centered previously!! *
 * Use CenMass(coords_ptr) and ApplyCenter(coords_ptr)                     */
static void
ComputeR(const Coords *coords1, const Coords *coords2, double **Rmat)
{
    int             i;
    const double   *x2 = (const double *) coords2->x,
                   *y2 = (const double *) coords2->y,
                   *z2 = (const double *) coords2->z;
    const double   *x1 = (const double *) coords1->x,
                   *y1 = (const double *) coords1->y,
                   *z1 = (const double *) coords1->z;
    double          x2i, y2i, z2i, x1i, y1i, z1i;
    double          Rmat00, Rmat01, Rmat02,
                    Rmat10, Rmat11, Rmat12,
                    Rmat20, Rmat21, Rmat22;

    Rmat00 = Rmat01 = Rmat02 = 
    Rmat10 = Rmat11 = Rmat12 =
    Rmat20 = Rmat21 = Rmat22 = 0.0;

    i = coords1->vlen;
    while(i-- > 0)
    {
        x2i = *x2++;
        y2i = *y2++;
        z2i = *z2++;

        x1i = *x1++;
        y1i = *y1++;
        z1i = *z1++;

        Rmat00 += x2i * x1i;
        Rmat01 += x2i * y1i;
        Rmat02 += x2i * z1i;
        
        Rmat10 += y2i * x1i;
        Rmat11 += y2i * y1i;
        Rmat12 += y2i * z1i;
        
        Rmat20 += z2i * x1i;
        Rmat21 += z2i * y1i;
        Rmat22 += z2i * z1i;
    }

    Rmat[0][0] = Rmat00;
    Rmat[0][1] = Rmat01;
    Rmat[0][2] = Rmat02;
    Rmat[1][0] = Rmat10;
    Rmat[1][1] = Rmat11;
    Rmat[1][2] = Rmat12;
    Rmat[2][0] = Rmat20;
    Rmat[2][1] = Rmat21;
    Rmat[2][2] = Rmat22;
}


static void
Computebs(double **R, double *evals, double **evecs, double **b)
{
    double          norm0 = 1.0 / sqrt(evals[0]);
    double          norm1 = 1.0 / sqrt(evals[1]);

    b[0][0] = norm0 * (evecs[0][0]*R[0][0] + evecs[0][1]*R[1][0] + evecs[0][2]*R[2][0]);
    b[0][1] = norm0 * (evecs[0][0]*R[0][1] + evecs[0][1]*R[1][1] + evecs[0][2]*R[2][1]);
    b[0][2] = norm0 * (evecs[0][0]*R[0][2] + evecs[0][1]*R[1][2] + evecs[0][2]*R[2][2]);

    b[1][0] = norm1 * (evecs[1][0]*R[0][0] + evecs[1][1]*R[1][0] + evecs[1][2]*R[2][0]);
    b[1][1] = norm1 * (evecs[1][0]*R[0][1] + evecs[1][1]*R[1][1] + evecs[1][2]*R[2][1]);
    b[1][2] = norm1 * (evecs[1][0]*R[0][2] + evecs[1][1]*R[1][2] + evecs[1][2]*R[2][2]);

    /* b[2] = b[0] X b[1] cross product */
    b[2][0] = b[0][1]*b[1][2] - b[0][2]*b[1][1];
    b[2][1] = b[0][2]*b[1][0] - b[0][0]*b[1][2];
    b[2][2] = b[0][0]*b[1][1] - b[0][1]*b[1][0];
}


static void
ComputeU(double **b, double **eigenvecs, double **U)
{   
    int         i, j, k;

    memset(&U[0][0], 0, 9 * sizeof(double));

    for (i = 0; i < 3; ++i)
        for (j = 0; j < 3; ++j)
            for (k = 0; k < 3; ++k)
                U[j][i] += eigenvecs[k][i] * b[k][j];
}


/* returns sum of squared residuals, E
   rmsd = sqrt(E/atom_num)  */
double
KabschFrag(const FragCoords *coords1, const FragCoords *coords2, double **U,
           double **R, double **evecs, double **b, double *tmpevec)
{
    int             i;
    double          evals[3];
    double          detU, E, E0, E1;

    E0 = ComputeE0Frag(coords1, coords2);
    ComputeRFrag(coords1, coords2, R);

    /* Jacobi Eigen decomp, Kabsch eigen decomp algorithm - works (2009-06-12) */
    /* This appears to be the fastest, though perhaps not by much over NR */
//    Mat3SqrTrans2(b, (const double **) R); /* premultiply R by its own transpose, put result in b */
//    /* nrot =  */ jacobi3_cyc(b, evals, evecs, 1e-8); /* eigenvectors currently in rows of evecs */
    /* printf("\nnrot = %d", nrot); */
//    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* LAPACK, Kabsch eigen decomp algorithm - works (2009-06-12) */
    Mat3SqrTrans2(b, (const double **) R); /* premultiply R by its own transpose, put result in evecs */
    eigensym((const double **) b, evals, evecs, 3); /* evectors in rows of evecs, dsyev puts evals in ascending order */
    Mat3TransposeIp(evecs);
    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

/* printf("\nevecs:"); */
/* MatPrint(evecs, 3); */

    /* NR, Kabsch eigen decomp algorithm - works (2009-06-12) */
//    Mat3SqrTrans2(evecs, (const double **) R); /* premultiply R by its own transpose, put result in evecs */
//    eigen3(evecs, evals);    /* eigenvectors currently in rowss of evecs */
//    Mat3TransposeIp(evecs);
//    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* a[2] = a[0] X a[1] cross product */
    evecs[2][0] = evecs[0][1]*evecs[1][2] - evecs[0][2]*evecs[1][1];
    evecs[2][1] = evecs[0][2]*evecs[1][0] - evecs[0][0]*evecs[1][2];
    evecs[2][2] = evecs[0][0]*evecs[1][1] - evecs[0][1]*evecs[1][0];

    Computebs(R, evals, evecs, b);
    ComputeU(b, evecs, U);
    detU = Mat3Det((const double **)U);

    if (detU < 0)
    {
        for (i = 0; i < 3; ++i)
            U[2][i] *= -1.0;

        E1 = -sqrt(evals[0]) - sqrt(evals[1]) + sqrt(evals[2]);
    }
    else
        E1 = -sqrt(evals[0]) - sqrt(evals[1]) - sqrt(evals[2]);
//VerifyRotMat(U, 1e-8);
    E = E0 + E1;

    if (E < 0.0)
        E = 0.0;
    else
        E *= 2.0;

    return(E);
}


/* returns sum of squared residuals, E
   rmsd = sqrt(E/atom_num)  */
/* Kabsch eigen decomp algorithm - works (2009-06-12) */
double
Kabsch(const Coords *coords1, const Coords *coords2, double **U,
          double **R, double **evecs, double **b, double *tmpevec)
{
    int             i;
    double          evals[3];
    double          detU, E, E0, E1;

    E0 = ComputeE0(coords1, coords2);
    ComputeR(coords1, coords2, R);

    /* Jacobi Eigen decomp, Kabsch eigen decomp algorithm - works (2009-06-12) */
    Mat3SqrTrans2(b, (const double **) R); /* premultiply R by its own transpose, put result in b */
    /* nrot =  */jacobi3_cyc(b, evals, evecs, 1e-7); /* eigenvectors currently in rows of evecs */
    /* printf("\nnrot = %d", nrot); */
    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* LAPACK, Kabsch eigen decomp algorithm - works (2009-06-12) */
//    Mat3SqrTrans2(b, (const double **) R); /* premultiply R by its own transpose, put result in evecs */
//    eigensym((const double **) b, evals, evecs, 3); /* evectors in rows of evecs, dsyev puts evals in ascending order */
//    Mat3TransposeIp(evecs);
//    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* NR, Kabsch eigen decomp algorithm - works (2009-06-12) */
//    Mat3SqrTrans2(evecs, (const double **) R); /* premultiply R by its own transpose, put result in evecs */
//    eigen3(evecs, evals);    /* eigenvectors currently in rowss of evecs */
//    Mat3TransposeIp(evecs);
//    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* a[2] = a[0] X a[1] cross product */
    evecs[2][0] = evecs[0][1]*evecs[1][2] - evecs[0][2]*evecs[1][1];
    evecs[2][1] = evecs[0][2]*evecs[1][0] - evecs[0][0]*evecs[1][2];
    evecs[2][2] = evecs[0][0]*evecs[1][1] - evecs[0][1]*evecs[1][0];

    Computebs(R, evals, evecs, b);
    ComputeU(b, evecs, U);
    detU = Mat3Det((const double **)U);

    if (detU < 0)
    {
        for (i = 0; i < 3; ++i)
            U[2][i] *= -1.0;

        E1 = -sqrt(evals[0]) - sqrt(evals[1]) + sqrt(evals[2]);
    }
    else
        E1 = -sqrt(evals[0]) - sqrt(evals[1]) - sqrt(evals[2]);

    // VerifyRotMat(U, 1e-8);

    E = E0 + E1;

    if (E < 0.0)
        E = 0.0;
    else
        E *= 2.0;

    return(E);
}


/* returns sum of squared residuals, E
   rmsd = sqrt(E/atom_num)  */
/* Kabsch eigen decomp algorithm - also works (2009-06-12) */
/* Not the cyclic Jacobi algorithm */
double
KabschJacobi(const Coords *coords1, const Coords *coords2, double **U,
             double **R, double **evecs, double **b, double *tmpevec)
{
    double          evals[3];
    double          E, E0, E1;

    E0 = ComputeE0(coords1, coords2);
    ComputeR(coords1, coords2, R);

    Mat3SqrTrans2(b, (const double **) R); /* premultiply R by its own transpose, put result in b */
    /* nrot =  */jacobi3(b, evals, evecs, 1e-6); /* eigenvectors currently in rows of evecs */
    /* printf("\nnrot = %d", nrot); */
    EigenSort3(evecs, evals, tmpevec); /* sort evectors according to bigger evals */

    /* a[2] = a[0] X a[1] cross product */
    evecs[2][0] = evecs[0][1]*evecs[1][2] - evecs[0][2]*evecs[1][1];
    evecs[2][1] = evecs[0][2]*evecs[1][0] - evecs[0][0]*evecs[1][2];
    evecs[2][2] = evecs[0][0]*evecs[1][1] - evecs[0][1]*evecs[1][0];

    Computebs(R, evals, evecs, b);
    ComputeU(b, evecs, U);

    if (Mat3Det((const double **)U) < 0)
    {
        U[2][0] *= -1.0;
        U[2][1] *= -1.0;
        U[2][2] *= -1.0;

        E1 = -sqrt(evals[0]) - sqrt(evals[1]) + sqrt(evals[2]);
    }
    else
        E1 = -sqrt(evals[0]) - sqrt(evals[1]) - sqrt(evals[2]);

    E = E0 + E1;

    if (E < 0.0)
        return(0.0);
    else
        return(2.0 * E);


/*     #include "pdbUtils.c" */
/*     PrintCoords((Coords *) coords1); */
/*     PrintCoords((Coords *) coords2); */
/*     Mat3Print(R); */
/*     Mat3Print(evecs); */
/*     Mat3Print(b); */
/*  */
/*     printf("\ndetU = %f", detU); */
/*     printf("\n"); */
/*     for (i = 0; i < 3; ++i) */
/*         printf("evals[%d] = %8.2f\n", i, evals[i]); */
/*  */
/*     printf("length coords #1 = %8.2d\n", coords1->vlen); */
/*     printf("E0   = %8.2f\n", E0); */
/*     printf("E1   = %8.2f\n", E1); */
/*     printf("MSD  = %8.2f\n", E); */
/*     Mat3Print(U); */
}

