#! /usr/bin/env python
# Copyright 2015 Martin C. Frith

import gzip
import math, optparse, os, random, signal, subprocess, sys, tempfile

def myOpen(fileName):  # faster than fileinput
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName)
    return open(fileName)

def randomSample(things, sampleSize):
    """Randomly get sampleSize things (or all if fewer)."""
    reservoir = []  # "reservoir sampling" algorithm
    for i, x in enumerate(things):
        if i < sampleSize:
            reservoir.append(x)
        else:
            r = random.randrange(i + 1)
            if r < sampleSize:
                reservoir[r] = x
    return reservoir

def writeWords(outFile, words):
    outFile.write(" ".join(words) + "\n")

def seqInput(fileNames):
    if not fileNames:
        fileNames = ["-"]
    for name in fileNames:
        f = myOpen(name)
        seqType = 0
        for line in f:
            if seqType == 0:
                if line[0] == ">":
                    seqType = 1
                    seq = []
                elif line[0] == "@":
                    seqType = 2
                    lineType = 1
            elif seqType == 1:  # fasta
                if line[0] == ">":
                    yield "".join(seq), ""
                    seq = []
                else:
                    seq.append(line.rstrip())
            elif seqType == 2:  # fastq
                if lineType == 1:
                    seq = line.rstrip()
                elif lineType == 3:
                    yield seq, line.rstrip()
                lineType = (lineType + 1) % 4
        if seqType == 1: yield "".join(seq), ""
        f.close()

def isGoodChunk(chunk):
    for i in chunk:
        for j in i[3]:
            if j not in "Nn":
                return True
    return False

def chunkInput(opts, sequences):
    chunkCount = 0
    chunk = []
    wantedLength = opts.sample_length
    for i, x in enumerate(sequences):
        seq, qual = x
        if all(i in "Nn" for i in seq): continue
        seqLength = len(seq)
        beg = 0
        while beg < seqLength:
            length = min(wantedLength, seqLength - beg)
            end = beg + length
            segment = i, beg, end, seq[beg:end], qual[beg:end]
            chunk.append(segment)
            wantedLength -= length
            if not wantedLength:
                if isGoodChunk(chunk):
                    yield chunk
                    chunkCount += 1
                chunk = []
                wantedLength = opts.sample_length
            beg = end
    if chunk and chunkCount < opts.sample_number:
        yield chunk

def writeSegment(outfile, segment):
    if not segment: return
    i, beg, end, seq, qual = segment
    name = str(i) + ":" + str(beg)
    if qual:
        outfile.write("@" + name + "\n")
        outfile.write(seq)
        outfile.write("\n+\n")
        outfile.write(qual)
    else:
        outfile.write(">" + name + "\n")
        outfile.write(seq)
    outfile.write("\n")

def getSeqSample(opts, queryFiles, outfile):
    sequences = seqInput(queryFiles)
    chunks = chunkInput(opts, sequences)
    sample = randomSample(chunks, opts.sample_number)
    sample.sort()
    x = None
    for chunk in sample:
        for y in chunk:
            if x and y[0] == x[0] and y[1] == x[2]:
                x = x[0], x[1], y[2], x[3] + y[3], x[4] + y[4]
            else:
                writeSegment(outfile, x)
                x = y
    writeSegment(outfile, x)

def versionFromHeader(lines):
    for line in lines:
        words = line.split()
        if "version" in words:
            return words[-1]
    raise Exception("couldn't read the version")

def scaleFromHeader(lines):
    for line in lines:
        for i in line.split():
            if i.startswith("t="):
                return float(i[2:])
    raise Exception("couldn't read the scale")

def scoreMatrixFromHeader(lines):
    matrix = []
    for line in lines:
        w = line.split()
        if len(w) > 2 and len(w[1]) == 1:
            matrix.append(w[1:])
        elif matrix:
            break
    return matrix

def scaledMatrix(matrix, scaleIncrease):
    return matrix[0:1] + [i[0:1] + [int(j) * scaleIncrease for j in i[1:]]
                          for i in matrix[1:]]

def countsFromLastOutput(lines, opts):
    matrix = []
    # use +1 pseudocounts as a kludge to mitigate numerical problems:
    matches = 1.0
    deletes = 2.0  # 1 open + 1 extension
    inserts = 2.0  # 1 open + 1 extension
    delOpens = 1.0
    insOpens = 1.0
    alignments = 0  # no pseudocount here
    for line in lines:
        if line[0] == "s":
            strand = line.split()[4]  # slow?
        if line[0] == "c":
            c = map(float, line.split()[1:])
            if not matrix:
                matrixSize = int(math.sqrt(len(c) - 10))
                matrix = [[1.0] * matrixSize for i in range(matrixSize)]
            identities = sum(c[i * matrixSize + i] for i in range(matrixSize))
            alignmentLength = c[-10] + c[-9] + c[-8]
            if 100 * identities > opts.pid * alignmentLength: continue
            for i in range(matrixSize):
                for j in range(matrixSize):
                    if strand == "+" or opts.S == "0":
                        matrix[i][j]       += c[i * matrixSize + j]
                    else:
                        matrix[-1-i][-1-j] += c[i * matrixSize + j]
            matches += c[-10]
            deletes += c[-9]
            inserts += c[-8]
            delOpens += c[-7]
            insOpens += c[-6]
            alignments += 1
    gapCounts = matches, deletes, inserts, delOpens, insOpens, alignments
    return matrix, gapCounts

def scoreFromProb(scale, prob):
    if prob > 0: logProb = math.log(prob)
    else:        logProb = -800  # exp(-800) is exactly zero, on my computer
    return int(round(scale * logProb))

def costFromProb(scale, prob):
    return -scoreFromProb(scale, prob)

def guessAlphabet(matrixSize):
    if matrixSize ==  4: return "ACGT"
    if matrixSize == 20: return "ACDEFGHIKLMNPQRSTVWY"
    raise Exception("can't handle unusual alphabets")

def matrixWithLetters(matrix):
    alphabet = guessAlphabet(len(matrix))
    return [alphabet] + [[a] + i for a, i in zip(alphabet, matrix)]

def writeMatrixHead(outFile, prefix, alphabet, formatString):
    writeWords(outFile, [prefix + " "] + [formatString % k for k in alphabet])

def writeMatrixBody(outFile, prefix, alphabet, matrix, formatString):
    for i, j in zip(alphabet, matrix):
        writeWords(outFile, [prefix + i] + [formatString % k for k in j])

def writeCountMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14s")

def writeProbMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14g")

def writeScoreMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%6s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%6s")

def writeMatrixWithLetters(outFile, matrix, prefix):
    head = matrix[0]
    tail = matrix[1:]
    left = [i[0] for i in tail]
    body = [i[1:] for i in tail]
    writeMatrixHead(outFile, prefix, head, "%6s")
    writeMatrixBody(outFile, prefix, left, body, "%6s")

def matProbsFromCounts(counts, opts):
    r = range(len(counts))
    if opts.revsym:  # add complement (reverse strand) substitutions
        counts = [[counts[i][j] + counts[-1-i][-1-j] for j in r] for i in r]
    if opts.matsym:  # symmetrize the substitution matrix
        counts = [[counts[i][j] + counts[j][i] for j in r] for i in r]
    identities = sum(counts[i][i] for i in r)
    total = sum(map(sum, counts))
    probs = [[j / total for j in i] for i in counts]

    print "# substitution percent identity: %g" % (100 * identities / total)
    print
    print "# count matrix (query letters = columns, reference letters = rows):"
    writeCountMatrix(sys.stdout, counts, "# ")
    print
    print "# probability matrix (query letters = columns, reference letters = rows):"
    writeProbMatrix(sys.stdout, probs, "# ")
    print

    return probs

def gapProbsFromCounts(counts, opts):
    matches, deletes, inserts, delOpens, insOpens, alignments = counts
    if not alignments: raise Exception("no alignments")
    gaps = deletes + inserts
    gapOpens = delOpens + insOpens
    denominator = matches + gapOpens + (alignments + 1)  # +1 pseudocount
    if opts.gapsym:
        delOpenProb = gapOpens / denominator / 2
        insOpenProb = gapOpens / denominator / 2
        delExtendProb = (gaps - gapOpens) / gaps
        insExtendProb = (gaps - gapOpens) / gaps
    else:
        delOpenProb = delOpens / denominator
        insOpenProb = insOpens / denominator
        delExtendProb = (deletes - delOpens) / deletes
        insExtendProb = (inserts - insOpens) / inserts

    print "# aligned letter pairs:", matches
    print "# deletes:", deletes
    print "# inserts:", inserts
    print "# delOpens:", delOpens
    print "# insOpens:", insOpens
    print "# alignments:", alignments
    print "# mean delete size: %g" % (deletes / delOpens)
    print "# mean insert size: %g" % (inserts / insOpens)
    print "# delOpenProb: %g" % delOpenProb
    print "# insOpenProb: %g" % insOpenProb
    print "# delExtendProb: %g" % delExtendProb
    print "# insExtendProb: %g" % insExtendProb
    print

    delCloseProb = 1 - delExtendProb
    insCloseProb = 1 - insExtendProb
    firstDelProb = delOpenProb * delCloseProb
    firstInsProb = insOpenProb * insCloseProb

    # If we define "an alignment" to mean "a set of indistinguishable
    # paths", then:
    #delExtendProb += firstDelProb
    #insExtendProb += firstInsProb
    # Else, this ensures gap existence cost >= 0:
    delExtendProb = max(delExtendProb, firstDelProb)
    insExtendProb = max(insExtendProb, firstInsProb)

    delExistProb = firstDelProb / delExtendProb
    insExistProb = firstInsProb / insExtendProb

    return delExistProb, insExistProb, delExtendProb, insExtendProb

def scoreFromLetterProbs(scale, pairProb, prob1, prob2):
    probRatio = pairProb / (prob1 * prob2)
    return scoreFromProb(scale, probRatio)

def matScoresFromProbs(scale, probs):
    rowProbs = map(sum, probs)
    colProbs = map(sum, zip(*probs))
    return [[scoreFromLetterProbs(scale, j, x, y) for j, y in zip(i, colProbs)]
            for i, x in zip(probs, rowProbs)]

def gapCostsFromProbs(scale, probs):
    delExistProb, insExistProb, delExtendProb, insExtendProb = probs
    delExistCost = costFromProb(scale, delExistProb)
    insExistCost = costFromProb(scale, insExistProb)
    delExtendCost = costFromProb(scale, delExtendProb)
    insExtendCost = costFromProb(scale, insExtendProb)
    if delExtendCost == 0: delExtendCost = 1
    if insExtendCost == 0: insExtendCost = 1
    return delExistCost, insExistCost, delExtendCost, insExtendCost

def writeLine(out, *things):
    out.write(" ".join(map(str, things)) + "\n")

def writeGapCosts(gapCosts, out):
    delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
    writeLine(out, "#last -a", delExistCost)
    writeLine(out, "#last -A", insExistCost)
    writeLine(out, "#last -b", delExtendCost)
    writeLine(out, "#last -B", insExtendCost)

def printGapCosts(gapCosts):
    delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
    print "# delExistCost:", delExistCost
    print "# insExistCost:", insExistCost
    print "# delExtendCost:", delExtendCost
    print "# insExtendCost:", insExtendCost
    print

def tryToMakeChildProgramsFindable():
    myDir = os.path.dirname(__file__)
    srcDir = os.path.join(myDir, os.pardir, "src")
    # put srcDir first, to avoid getting older versions of LAST:
    os.environ["PATH"] = srcDir + os.pathsep + os.environ["PATH"]

def readLastalProgName(lastdbIndexName):
    bitsPerInt = "32"
    with open(lastdbIndexName + ".prj") as f:
        for line in f:
            if line.startswith("integersize="):
                bitsPerInt = line.split("=")[1].strip()
    return "lastal8" if bitsPerInt == "64" else "lastal"

def fixedLastalArgs(opts, lastalProgName):
    x = [lastalProgName, "-j7"]
    if opts.D: x.append("-D" + opts.D)
    if opts.E: x.append("-E" + opts.E)
    if opts.s: x.append("-s" + opts.s)
    if opts.S: x.append("-S" + opts.S)
    if opts.C: x.append("-C" + opts.C)
    if opts.T: x.append("-T" + opts.T)
    if opts.m: x.append("-m" + opts.m)
    if opts.k: x.append("-k" + opts.k)
    if opts.P: x.append("-P" + opts.P)
    if opts.Q: x.append("-Q" + opts.Q)
    return x

def doTraining(opts, args):
    tryToMakeChildProgramsFindable()
    lastalProgName = readLastalProgName(args[0])
    scaleIncrease = 20  # while training, up-scale the scores by this amount
    x = fixedLastalArgs(opts, lastalProgName)
    if opts.r: x.append("-r" + opts.r)
    if opts.q: x.append("-q" + opts.q)
    if opts.p: x.append("-p" + opts.p)
    if opts.a: x.append("-a" + opts.a)
    if opts.b: x.append("-b" + opts.b)
    if opts.A: x.append("-A" + opts.A)
    if opts.B: x.append("-B" + opts.B)
    x += args
    y = ["last-split", "-n"]
    p = subprocess.Popen(x, stdout=subprocess.PIPE)
    q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)
    lastalVersion = versionFromHeader(q.stdout)
    externalScale = scaleFromHeader(q.stdout)
    internalScale = externalScale * scaleIncrease
    if opts.Q:
        externalMatrix = scoreMatrixFromHeader(q.stdout)
        internalMatrix = scaledMatrix(externalMatrix, scaleIncrease)
    oldParameters = []

    print "# lastal version:", lastalVersion
    print "# maximum percent identity:", opts.pid
    print "# scale of score parameters:", externalScale
    print "# scale used while training:", internalScale
    print

    while True:
        print "#", " ".join(x)
        print
        sys.stdout.flush()
        matCounts, gapCounts = countsFromLastOutput(q.stdout, opts)
        gapProbs = gapProbsFromCounts(gapCounts, opts)
        gapCosts = gapCostsFromProbs(internalScale, gapProbs)
        printGapCosts(gapCosts)
        if opts.Q:
            if gapCosts in oldParameters: break
            oldParameters.append(gapCosts)
        else:
            matProbs = matProbsFromCounts(matCounts, opts)
            matScores = matScoresFromProbs(internalScale, matProbs)
            print "# score matrix (query letters = columns, reference letters = rows):"
            writeScoreMatrix(sys.stdout, matScores, "# ")
            print
            parameters = gapCosts, matScores
            if parameters in oldParameters: break
            oldParameters.append(parameters)
            internalMatrix = matrixWithLetters(matScores)
        x = fixedLastalArgs(opts, lastalProgName)
        x.append("-p-")
        x += args
        p = subprocess.Popen(x, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
        writeGapCosts(gapCosts, p.stdin)
        writeMatrixWithLetters(p.stdin, internalMatrix, "")
        p.stdin.close()
        # in python2.6, the next line must come after p.stdin.close()
        q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)

    gapCosts = gapCostsFromProbs(externalScale, gapProbs)
    writeGapCosts(gapCosts, sys.stdout)
    if opts.s: writeLine(sys.stdout, "#last -s", opts.s)
    if opts.S: writeLine(sys.stdout, "#last -S", opts.S)
    if not opts.Q:
        matScores = matScoresFromProbs(externalScale, matProbs)
        externalMatrix = matrixWithLetters(matScores)
    print "# score matrix (query letters = columns, reference letters = rows):"
    writeMatrixWithLetters(sys.stdout, externalMatrix, "")

def lastTrain(opts, args):
    if opts.sample_number:
        random.seed(math.pi)
        refName = args[0]
        queryFiles = args[1:]
        try:
            with tempfile.NamedTemporaryFile(delete=False) as f:
                getSeqSample(opts, queryFiles, f)
            doTraining(opts, [refName, f.name])
        finally:
            os.remove(f.name)
    else:
        doTraining(opts, args)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = "%prog [options] lastdb-name sequence-file(s)"
    description = "Try to find suitable score parameters for aligning the given sequences."
    op = optparse.OptionParser(usage=usage, description=description)
    og = optparse.OptionGroup(op, "Training options")
    og.add_option("--revsym", action="store_true",
                  help="force reverse-complement symmetry")
    og.add_option("--matsym", action="store_true",
                  help="force symmetric substitution matrix")
    og.add_option("--gapsym", action="store_true",
                  help="force insertion/deletion symmetry")
    og.add_option("--pid", type="float", default=100, help=
                  "skip alignments with > PID% identity (default: %default)")
    og.add_option("--sample-number", type="int", default=500, metavar="N",
                  help="number of random sequence samples (default: %default)")
    og.add_option("--sample-length", type="int", default=2000, metavar="L",
                  help="length of each sample (default: %default)")
    op.add_option_group(og)
    og = optparse.OptionGroup(op, "Initial parameter options")
    og.add_option("-r", metavar="SCORE",
                  help="match score (default: 6 if Q>0, else 5)")
    og.add_option("-q", metavar="COST",
                  help="mismatch cost (default: 18 if Q>0, else 5)")
    og.add_option("-p", metavar="NAME", help="match/mismatch score matrix")
    og.add_option("-a", metavar="COST",
                  help="gap existence cost (default: 21 if Q>0, else 15)")
    og.add_option("-b", metavar="COST",
                  help="gap extension cost (default: 9 if Q>0, else 3)")
    og.add_option("-A", metavar="COST", help="insertion existence cost")
    og.add_option("-B", metavar="COST", help="insertion extension cost")
    op.add_option_group(og)
    og = optparse.OptionGroup(op, "Alignment options")
    og.add_option("-D", metavar="LENGTH",
                  help="query letters per random alignment (default: 1e6)")
    og.add_option("-E", metavar="EG2",
                  help="maximum expected alignments per square giga")
    og.add_option("-s", metavar="STRAND", help=
                  "0=reverse, 1=forward, 2=both (default: 2 if DNA, else 1)")
    og.add_option("-S", metavar="NUMBER", default="1", help=
                  "score matrix applies to forward strand of: " +
                  "0=reference, 1=query (default: %default)")
    og.add_option("-C", metavar="COUNT", help=
                  "omit gapless alignments in COUNT others with > score-per-length")
    og.add_option("-T", metavar="NUMBER",
                  help="type of alignment: 0=local, 1=overlap (default: 0)")
    og.add_option("-m", metavar="COUNT", help=
                  "maximum initial matches per query position (default: 10)")
    og.add_option("-k", metavar="STEP", help="use initial matches starting at "
                  "every STEP-th position in each query (default: 1)")
    og.add_option("-P", metavar="THREADS",
                  help="number of parallel threads")
    og.add_option("-Q", metavar="NUMBER",
                  help="input format: 0=fasta, 1=fastq-sanger")
    op.add_option_group(og)
    (opts, args) = op.parse_args()
    if len(args) < 1:
        op.error("I need a lastdb index and query sequences")
    if not opts.sample_number and (len(args) < 2 or "-" in args):
        op.error("sorry, can't use stdin when --sample-number=0")
    if not opts.p and (not opts.Q or opts.Q == "0"):
        if not opts.r: opts.r = "5"
        if not opts.q: opts.q = "5"
        if not opts.a: opts.a = "15"
        if not opts.b: opts.b = "3"

    try: lastTrain(opts, args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception, e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
