#!/bin/bash

#
# load parameters
#
source config.in

#
# compile the generator of small mults
#
${host_compile} -c mults.f90 
${host_compile} -c multrec_gen.f90 
${host_compile} mults.o multrec_gen.o lib_gen.f90 -o lib_gen.x

#
# directory for the library
#
mkdir -p lib

#
# generate the generic caller
#

maxsize=-1
numsize=0

for myn in ${dims_small}
do
  numsize=$((numsize+1))
  maxsize=`echo "$myn $maxsize" | awk '{if ($1>$2) { print $1 } else { print $2 } }'`
done

#
# generate a translation array for the jump table
#
count=0
eles="(/0"
for i in `seq 1 $maxsize`
do

  found=0
  for myn in ${dims_small}
  do
      if [ "$myn" == "$i" ]; then
         found=1
      fi
  done
  if [ "$found" == 1 ]; then
     count=$((count+1))
     ele=$count
  else
     ele=0 
  fi
  eles="$eles,$ele"
done
eles="$eles/)"

if [[ "$data_type" == "1" ]]; then
    filetype="d"
    gemm="DGEMM"
    strdat="REAL(KIND=KIND(0.0D0))"
fi
if [[ "$data_type" == "2" ]]; then
    filetype="s"
    gemm="SGEMM"
    strdat="REAL(KIND=KIND(0.0))"
fi
if [[ "$data_type" == "3" ]]; then
    filetype="z"
    gemm="ZGEMM"
    strdat="COMPLEX(KIND=KIND(0.0D0))"
fi
if [[ "$data_type" == "4" ]]; then
    filetype="c"
    gemm="CGEMM"
    strdat="COMPLEX(KIND=KIND(0.0))"
fi



if [[ "$transpose_flavor" == "1" ]]; then
    filetrans="nn"
    ta="N"
    tb="N"
    decl="A(M,K), B(K,N)"
    lds="LDA=M ; LDB=K"
fi
if [[ "$transpose_flavor" == "2" ]]; then
    filetrans="tn"
    ta="T"
    tb="N"
    decl="A(K,M), B(K,N)"
    lds="LDA=K ; LDB=K"
fi
if [[ "$transpose_flavor" == "3" ]]; then
    filetrans="nt"
    ta="N"
    tb="T"
    decl="A(M,K), B(N,K)"
    lds="LDA=M ; LDB=N"
fi
if [[ "$transpose_flavor" == "4" ]]; then
    filetrans="tt"
    ta="T"
    tb="T"
    decl="A(K,M), B(N,K)"
    lds="LDA=K ; LDB=N"
fi

filetrans="${filetype}${filetrans}"
file="smm_${filetrans}.f90"

cd lib

rm -f ${file}
printf "SUBROUTINE smm_${filetrans}(M,N,K,A,B,C)\n INTEGER :: M,N,K,LDA,LDB\n ${strdat} :: C(M,N), ${decl}\n" >> ${file}
printf " INTEGER, PARAMETER :: indx(0:$maxsize)=&\n $eles\n" >> ${file}
printf " INTEGER :: im,in,ik,itot\n" >> ${file}
printf " $strdat, PARAMETER :: one=1\n" >> ${file}
printf " ${lds}\n" >> ${file}
printf " IF (M<=$maxsize) THEN\n   im=indx(M)\n ELSE\n   im=0\n ENDIF\n" >> ${file}
printf " IF (N<=$maxsize) THEN\n   in=indx(N)\n ELSE\n   in=0\n ENDIF\n" >> ${file}
printf " IF (K<=$maxsize) THEN\n   ik=indx(K)\n ELSE\n   ik=0\n ENDIF\n" >> ${file}
printf " itot=(ik*($numsize+1)+in)*($numsize+1)+im\n" >> ${file}

count=0
printf " SELECT CASE(itot)\n" >> ${file}
for myk in 0 ${dims_small}
do
for myn in 0 ${dims_small}
do
for mym in 0 ${dims_small}
do
printf " CASE($count)\n " >> ${file}
prod=$((myk*myn*mym))
if [[ "$prod" == "0" ]]; then
  printf '   GOTO 999\n' >> ${file}
else
  printf "   CALL smm_${filetrans}_${mym}_${myn}_${myk}(A,B,C)\n" >> ${file}
fi
count=$((count+1))
done
done
done
printf " END SELECT\n" >> ${file}
printf " RETURN\n" >> ${file}
printf "999 CONTINUE \n CALL ${gemm}('%s','%s',M,N,K,one,A,LDA,B,LDB,one,C,M)\n" $ta $tb >> ${file}

printf "END SUBROUTINE smm_${filetrans}" >> ${file}

cd ..

#
# generate list of loop bounds to generate for the small library
#
postfixes=""
for m in ${dims_small}  ; do
for n in ${dims_small}  ; do
for k in ${dims_small}  ; do
 postfixes="$postfixes ${m}_${n}_${k}"
done
done
done

#
# for easy parallelism go via a Makefile
#
rm -f Makefile.lib

(
#
# a two stage approach, first compile in parallel, once done,
# execute in parallel
#
printf "all: archive\n"

printf "driver:\n"
printf "\t cd lib/ ; $target_compile -c ${file} \n\n"

printf "archive: compile driver \n"
printf "\t cd lib/ ; ar -r libsmm_${filetrans}.a smm_${filetrans}*.o \n\n"

printf "compile: "
for pf in $postfixes ; do
    printf "comp_${pf} "
done
printf "\n\n"

#
# all compile rules
#
for m in ${dims_small}  ; do
for n in ${dims_small}  ; do
for k in ${dims_small}  ; do
    printf "comp_${m}_${n}_${k}: \n"
    printf "\t ./lib_gen.x ${m} ${n} ${k} ${transpose_flavor} ${data_type} > lib/smm_${filetrans}_${m}_${n}_${k}.f90\n"
    printf "\t cd lib/ ; $target_compile -c smm_${filetrans}_${m}_${n}_${k}.f90 \n\n"
done ; done ; done

) > Makefile.lib

#
# execute makefile compiling all variants and executing them
#

make -j $tasks -f Makefile.lib all

#
# a final test program. Checking correctness and final performance comparison
#

cat << EOF > test_smm_${filetrans}.f90
MODULE WTF
  INTERFACE MYRAND
    MODULE PROCEDURE SMYRAND, DMYRAND, CMYRAND, ZMYRAND
  END INTERFACE
CONTAINS
  SUBROUTINE DMYRAND(A)
    REAL(KIND=KIND(0.0D0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq
    CALL RANDOM_NUMBER(Aeq)
    A=Aeq
  END SUBROUTINE
  SUBROUTINE SMYRAND(A)
    REAL(KIND=KIND(0.0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq
    CALL RANDOM_NUMBER(Aeq)
    A=Aeq
  END SUBROUTINE
  SUBROUTINE CMYRAND(A)
    COMPLEX(KIND=KIND(0.0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq,Beq
    CALL RANDOM_NUMBER(Aeq)
    CALL RANDOM_NUMBER(Beq)
    A=CMPLX(Aeq,Beq,KIND=KIND(0.0))
  END SUBROUTINE
  SUBROUTINE ZMYRAND(A)
    COMPLEX(KIND=KIND(0.0D0)), DIMENSION(:,:) :: A
    REAL(KIND=KIND(0.0)), DIMENSION(SIZE(A,1),SIZE(A,2)) :: Aeq,Beq
    CALL RANDOM_NUMBER(Aeq)
    CALL RANDOM_NUMBER(Beq)
    A=CMPLX(Aeq,Beq,KIND=KIND(0.0D0))
  END SUBROUTINE
END MODULE
SUBROUTINE testit(M,N,K)
  USE WTF
  IMPLICIT NONE
  INTEGER :: M,N,K

  $strdat :: C1(M,N), C2(M,N)
  $strdat :: ${decl}
  $strdat, PARAMETER :: one=1
  INTEGER :: i,LDA,LDB

  REAL(KIND=KIND(0.0D0)) :: flops,gflop
  REAL :: t1,t2,t3,t4
  INTEGER :: Niter

  flops=2*REAL(M,KIND=KIND(0.0D0))*N*K
  gflop=1000.0D0*1000.0D0*1000.0D0
  ! assume we would like to do 5 Gflop for testing a subroutine
  Niter=MAX(1,CEILING(MIN(10000000.0D0,5*gflop/flops)))
  ${lds}

  DO i=1,10
     CALL MYRAND(A)
     CALL MYRAND(B)
     CALL MYRAND(C1)
     C2=C1

     CALL ${gemm}("$ta","$tb",M,N,K,one,A,LDA,B,LDB,one,C1,M) 
     CALL smm_${filetrans}(M,N,K,A,B,C2)

     IF (MAXVAL(ABS(C2-C1))>100*EPSILON(REAL(1.0,KIND=KIND(A(1,1))))) THEN
        write(6,*) "Matrix size",M,N,K
        write(6,*) "A=",A
        write(6,*) "B=",B
        write(6,*) "C1=",C1
        write(6,*) "C2=",C2
        write(6,*) "BLAS and smm yield different results : possible compiler bug... do not use the library ;-)"
        STOP
     ENDIF
  ENDDO

  A=0; B=0; C1=0 ; C2=0
 
  CALL CPU_TIME(t1) 
  DO i=1,Niter
     CALL ${gemm}("$ta","$tb",M,N,K,one,A,LDA,B,LDB,one,C1,M) 
  ENDDO
  CALL CPU_TIME(t2) 

  CALL CPU_TIME(t3)
  DO i=1,Niter
     CALL smm_${filetrans}(M,N,K,A,B,C2)
  ENDDO
  CALL CPU_TIME(t4)

  WRITE(6,'(A,I5,I5,I5,A,F6.3,A,F6.3,A,F12.3,A)') "Matrix size ",M,N,K, &
        " smm: ",Niter*flops/(t4-t3)/gflop," Gflops. Linked blas: ",Niter*flops/(t2-t1)/gflop,&
        " Gflops. Performance ratio: ",((t2-t1)/(t4-t3))*100,"%"

END SUBROUTINE 

PROGRAM tester
  IMPLICIT NONE

EOF

for m in ${dims_small}  ; do
for n in ${dims_small}  ; do
for k in ${dims_small}  ; do
  echo "   CALL testit(${m},${n},${k})" >> test_smm_${filetrans}.f90
done ; done ; done

cat << EOF >> test_smm_${filetrans}.f90
  ! checking 'peak' performance (and a size likely outside of the library)
  CALL testit(1000,1000,1000)
END PROGRAM
EOF

#
# compile the benchmarking and testing program for the smm library
#
${target_compile} test_smm_${filetrans}.f90 -o test_smm_${filetrans}.x -Llib -lsmm_${filetrans} ${blas_linking}

#
# run and compile the final test
#
./test_smm_${filetrans}.x | tee ./test_smm_${filetrans}.out

#
# We're done... protect the user from bad compilers
#
grep 'BLAS and smm yield different results' ./test_smm_${filetrans}.out >& /dev/null
if [ "$?" == "0" ]; then
   echo "Library is miscompiled ... removing lib/libsmm_${filetrans}.a"
   rm -f lib/libsmm_${filetrans}.a
else
   pathhere=`pwd`
   echo "Done... check performance looking at test_smm_${filetrans}.out"
   echo "Final library can be linked as -L${pathhere}/lib -lsmm_${filetrans}"
fi

