#! /usr/bin/env python
# Copyright 2015 Martin C. Frith
# SPDX-License-Identifier: GPL-3.0-or-later

# References:
# [Fri19] How sequence alignment scores correspond to probability models,
#         MC Frith, Bioinformatics, 2019.

from __future__ import division, print_function

import gzip
import math
import optparse
import os
import random
import signal
import subprocess
import sys
import tempfile

proteinAlphabet20 = "ACDEFGHIKLMNPQRSTVWY"

def myOpen(fileName):  # faster than fileinput
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")  # xxx dubious for Python2
    return open(fileName)

def rootOfIncreasingFunction(func, lowerBound, upperBound, args):
    # Find x such that func(x, *args) == 0
    gap = upperBound - lowerBound
    while True:
        gap *= 0.5
        mid = lowerBound + gap
        if mid <= lowerBound:
            return mid
        if func(mid, *args) < 0:
            lowerBound = mid

def rootOfDecreasingFunction(func, lowerBound, upperBound, args):
    # Find x such that func(x, *args) == 0
    gap = upperBound - lowerBound
    while True:
        gap *= 0.5
        mid = lowerBound + gap
        if mid <= lowerBound:
            return mid
        if func(mid, *args) > 0:
            lowerBound = mid

def homogeneousLetterFreqs(scale, matScores):
    # Solve the simultaneous equations in Section 2.1 of [Fri19]
    expMat = [[math.exp(j / scale) for j in i] for i in matScores]
    m = [row[:] + [1.0] for row in expMat]  # augmented matrix
    n = len(expMat)
    for k in range(n):
        iMax = k
        for i in range(k, n):
            if abs(m[i][k]) > abs(m[iMax][k]):
                iMax = i
        if iMax > k:
            m[k], m[iMax] = m[iMax], m[k]
        if abs(m[k][k]) <= 0:
            raise ArithmeticError("singular matrix")
        for i in range(n):
            if i != k:
                mul = m[i][k] / m[k][k]
                for j in range(k + 1, n + 1):
                    m[i][j] -= m[k][j] * mul
    return [m[k][n] / m[k][k] for k in range(n)]

def randomSample(things, sampleSize):
    """Randomly get sampleSize things (or all if fewer)."""
    reservoir = []  # "reservoir sampling" algorithm
    for i, x in enumerate(things):
        if i < sampleSize:
            reservoir.append(x)
        else:
            r = random.randrange(i + 1)
            if r < sampleSize:
                reservoir[r] = x
    return reservoir

def writeWords(outFile, words):
    print(*words, file=outFile)

def seqInput(fileNames):
    if not fileNames:
        fileNames = ["-"]
    for name in fileNames:
        f = myOpen(name)
        seqType = 0
        for line in f:
            if seqType == 0:
                if line[0] == ">":
                    seqType = 1
                    seq = []
                elif line[0] == "@":
                    seqType = 2
                    lineType = 1
            elif seqType == 1:  # fasta
                if line[0] == ">":
                    yield "".join(seq), ""
                    seq = []
                else:
                    seq.append(line.rstrip())
            elif seqType == 2:  # fastq
                if lineType == 1:
                    seq = line.rstrip()
                elif lineType == 3:
                    yield seq, line.rstrip()
                lineType = (lineType + 1) % 4
        if seqType == 1: yield "".join(seq), ""
        f.close()

def isGoodChunk(chunk):
    for i in chunk:
        for j in i[3]:
            if j not in "Nn":
                return True
    return False

def chunkInput(opts, sequences):
    chunkCount = 0
    chunk = []
    wantedLength = opts.sample_length
    for i, x in enumerate(sequences):
        seq, qual = x
        if all(i in "Nn" for i in seq): continue
        seqLength = len(seq)
        beg = 0
        while beg < seqLength:
            length = min(wantedLength, seqLength - beg)
            end = beg + length
            segment = i, beg, end, seq[beg:end], qual[beg:end]
            chunk.append(segment)
            wantedLength -= length
            if not wantedLength:
                if isGoodChunk(chunk):
                    yield chunk
                    chunkCount += 1
                chunk = []
                wantedLength = opts.sample_length
            beg = end
    if chunk and chunkCount < opts.sample_number:
        yield chunk

def writeSegment(outfile, segment):
    if not segment: return
    i, beg, end, seq, qual = segment
    name = str(i) + ":" + str(beg)
    if qual:
        outfile.write("@" + name + "\n")
        outfile.write(seq)
        outfile.write("\n+\n")
        outfile.write(qual)
    else:
        outfile.write(">" + name + "\n")
        outfile.write(seq)
    outfile.write("\n")

def getSeqSample(opts, queryFiles, outfile):
    sequences = seqInput(queryFiles)
    chunks = chunkInput(opts, sequences)
    sample = randomSample(chunks, opts.sample_number)
    sample.sort()
    x = None
    for chunk in sample:
        for y in chunk:
            if x and y[0] == x[0] and y[1] == x[2]:
                x = x[0], x[1], y[2], x[3] + y[3], x[4] + y[4]
            else:
                writeSegment(outfile, x)
                x = y
    writeSegment(outfile, x)

def scaleFromHeader(lines):
    for line in lines:
        for i in line.split():
            if i.startswith("t="):
                return float(i[2:])
    raise Exception("couldn't read the scale")

def countsFromLastOutput(lines, opts):
    nTransitions = 5
    tranCounts = [1.0] * nTransitions  # +1 pseudocounts
    tranCounts[1] = 2.0  # deletes: opens + extensions, so 2 pseudocounts
    tranCounts[2] = 2.0  # inserts: opens + extensions, so 2 pseudocounts
    countMatrix = None
    alignments = 0  # no pseudocount here
    for line in lines:
        if line[0] == "s":
            strand = line.split()[4]  # slow?
        if line[0] == "c":
            counts = [float(i) for i in line.split()[1:]]
            if not countMatrix:
                matrixSize = len(counts) - nTransitions
                nCols = int(math.sqrt(matrixSize))
                nRows = matrixSize // nCols
                countMatrix = [[1.0] * nCols for i in range(nRows)]
            identities = sum(counts[i * nCols + i] for i in range(nRows))
            alignmentLength = sum(counts[matrixSize + i] for i in range(3))
            if 100 * identities > opts.pid * alignmentLength:
                continue
            for i in range(nRows):
                for j in range(nCols):
                    if strand == "+" or opts.S != "1":
                        countMatrix[i][j]       += counts[i * nCols + j]
                    else:
                        countMatrix[-1-i][-1-j] += counts[i * nCols + j]
            for i in range(nTransitions):
                tranCounts[i] += counts[matrixSize + i]
            alignments += 1
    if not alignments:
        raise Exception("no alignments")
    return countMatrix, tranCounts + [alignments]

def scoreFromProb(scale, prob):
    if prob > 0: logProb = math.log(prob)
    else:        logProb = -800  # exp(-800) is exactly zero, on my computer
    return int(round(scale * logProb))

def costFromProb(scale, prob):
    return -scoreFromProb(scale, prob)

def guessAlphabet(matrixSize):
    if matrixSize ==  4: return "ACGT"
    if matrixSize == 20: return proteinAlphabet20
    raise Exception("can't handle unusual alphabets")

def writeMatrixHead(outFile, prefix, alphabet, formatString):
    writeWords(outFile, [prefix + " "] + [formatString % k for k in alphabet])

def writeMatrixBody(outFile, prefix, alphabet, matrix, formatString):
    for i, j in zip(alphabet, matrix):
        writeWords(outFile, [prefix + i] + [formatString % k for k in j])

def writeCountMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14.12g")

def writeProbMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%-14s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14g")

def writeScoreMatrix(outFile, matrix, prefix):
    alphabet = guessAlphabet(len(matrix))
    writeMatrixHead(outFile, prefix, alphabet, "%6s")
    writeMatrixBody(outFile, prefix, alphabet, matrix, "%6s")

def matProbsFromCounts(counts, opts):
    r = range(len(counts))
    if opts.revsym:  # add complement (reverse strand) substitutions
        counts = [[counts[i][j] + counts[-1-i][-1-j] for j in r] for i in r]
    if opts.matsym:  # symmetrize the substitution matrix
        counts = [[counts[i][j] + counts[j][i] for j in r] for i in r]
    identities = sum(counts[i][i] for i in r)
    total = sum(map(sum, counts))
    probs = [[j / total for j in i] for i in counts]
    print("# substitution percent identity: %g" % (100 * identities / total))
    print()
    print("# count matrix "
          "(query letters = columns, reference letters = rows):")
    writeCountMatrix(sys.stdout, counts, "# ")
    print()
    print("# probability matrix "
          "(query letters = columns, reference letters = rows):")
    writeProbMatrix(sys.stdout, probs, "# ")
    print()
    return probs

def probImbalance(endProb, matchProb, firstDelProb, delExtendProb,
                  firstInsProb, insExtendProb):
    # (RHS - LHS) of Equation (12) in [Fri19]
    d = firstDelProb / (endProb - delExtendProb)
    i = firstInsProb / (endProb - insExtendProb)
    return 1 - matchProb / (endProb * endProb) - d - i

def balancedEndProb(*args):
    matchProb, firstDelProb, delExtendProb, firstInsProb, insExtendProb = args
    lowerBound = max(delExtendProb, insExtendProb)
    upperBound = 1.0
    return rootOfIncreasingFunction(probImbalance,
                                    lowerBound, upperBound, args)

def gapProbsFromCounts(counts, opts):
    matches, deletes, inserts, delOpens, insOpens, alignments = counts
    gaps = deletes + inserts
    gapOpens = delOpens + insOpens
    denominator = matches + gapOpens + (alignments + 1)  # +1 pseudocount
    matchProb = matches / denominator
    if opts.gapsym:
        delOpenProb = gapOpens / denominator / 2
        insOpenProb = gapOpens / denominator / 2
        delGrowProb = (gaps - gapOpens) / gaps
        insGrowProb = (gaps - gapOpens) / gaps
    else:
        delOpenProb = delOpens / denominator
        insOpenProb = insOpens / denominator
        delGrowProb = (deletes - delOpens) / deletes
        insGrowProb = (inserts - insOpens) / inserts
    print("# aligned letter pairs: %.12g" % matches)
    print("# deletes: %.12g" % deletes)
    print("# inserts: %.12g" % inserts)
    print("# delOpens: %.12g" % delOpens)
    print("# insOpens: %.12g" % insOpens)
    print("# alignments:", alignments)
    print("# mean delete size: %g" % (deletes / delOpens))
    print("# mean insert size: %g" % (inserts / insOpens))
    print("# matchProb: %g" % matchProb)
    print("# delOpenProb: %g" % delOpenProb)
    print("# insOpenProb: %g" % insOpenProb)
    print("# delExtendProb: %g" % delGrowProb)
    print("# insExtendProb: %g" % insGrowProb)
    print()
    return matchProb, (delOpenProb, delGrowProb), (insOpenProb, insGrowProb)

def gapRatiosFromProbs(matchProb, delProbs, insProbs):
    delOpenProb, delGrowProb = delProbs
    insOpenProb, insGrowProb = insProbs

    delCloseProb = 1 - delGrowProb
    firstDelProb = delOpenProb * delCloseProb

    insCloseProb = 1 - insGrowProb
    firstInsProb = insOpenProb * insCloseProb

    endProb = balancedEndProb(matchProb, firstDelProb, delGrowProb,
                              firstInsProb, insGrowProb)
    # probably, endProb is negligibly less than 1

    matchRatio = matchProb / (endProb * endProb)

    firstDelRatio = firstDelProb / endProb
    delGrowRatio = delGrowProb / endProb
    delRatios = firstDelRatio, delGrowRatio

    firstInsRatio = firstInsProb / endProb
    insGrowRatio = insGrowProb / endProb
    insRatios = firstInsRatio, insGrowRatio

    return matchRatio, delRatios, insRatios

def scoreFromLetterProbs(scale, matchRatio, pairProb, rowProb, colProb):
    # Equation (4) in [Fri19]
    probRatio = pairProb / (rowProb * colProb)
    return scoreFromProb(scale, matchRatio * probRatio)

def matScoresFromProbs(scale, matchRatio, matProbs, rowProbs, colProbs):
    return [[scoreFromLetterProbs(scale, matchRatio, matProbs[i][j], x, y)
             for j, y in enumerate(colProbs)] for i, x in enumerate(rowProbs)]

def gapCostsFromProbRatios(scale, firstGapRatio, gapExtendRatio):
    # The next addition gets the alignment parameter from the path
    # parameters, as in Supplementary section 3.1 of [Fri19]:
    gapExtendRatio += firstGapRatio
    firstGapCost = max(costFromProb(scale, firstGapRatio), 1)
    gapExtendCost = max(costFromProb(scale, gapExtendRatio), 1)
    return firstGapCost, gapExtendCost

def imbalanceFromGap(scale, firstGapCost, gapExtendCost):
    firstGapRatio = math.exp(-firstGapCost / scale)
    gapExtendRatio = math.exp(-gapExtendCost / scale)
    # The next subtraction gets the path parameter from the alignment
    # parameters, as in Supplementary section 3.1 of [Fri19]:
    gapExtendRatio -= firstGapRatio
    return firstGapRatio / (1 - gapExtendRatio)

def scoreImbalance(scale, matScores, delCosts, insCosts):
    # C' - 1, where C' is defined in Equation (13) of [Fri19]
    d = imbalanceFromGap(scale, *delCosts)
    i = imbalanceFromGap(scale, *insCosts)
    return 1 / sum(homogeneousLetterFreqs(scale, matScores)) + d + i - 1

def balancedScale(imbalanceFunc, nearScale, args):
    # Find a scale, near nearScale, with balanced length probability
    bump = 1.000001
    rootFinders = rootOfDecreasingFunction, rootOfIncreasingFunction
    value = imbalanceFunc(nearScale, *args)
    if abs(value) <= 0:
        return nearScale
    oldLower = oldUpper = nearScale
    while oldUpper < 2 * nearScale:  # xxx ???
        newLower = oldLower / bump
        lowerValue = imbalanceFunc(newLower, *args)
        if (lowerValue < 0) != (value < 0):
            finder = rootFinders[value > 0]
            return finder(imbalanceFunc, newLower, oldLower, args)
        oldLower = newLower
        newUpper = oldUpper * bump
        upperValue = imbalanceFunc(newUpper, *args)
        if (upperValue < 0) != (value < 0):
            finder = rootFinders[value < 0]
            return finder(imbalanceFunc, oldUpper, newUpper, args)
        oldUpper = newUpper
    return 0.0

def scoresAndScale(originalScale, matParams, delRatios, insRatios):
    while True:
        matScores = matScoresFromProbs(originalScale, *matParams)
        delCosts = gapCostsFromProbRatios(originalScale, *delRatios)
        insCosts = gapCostsFromProbRatios(originalScale, *insRatios)
        args = matScores, delCosts, insCosts
        scale = balancedScale(scoreImbalance, originalScale, args)
        if scale > 0:
            rowFreqs = homogeneousLetterFreqs(scale, zip(*matScores))
            colFreqs = homogeneousLetterFreqs(scale, matScores)
            if all(i >= 0 for i in rowFreqs + colFreqs):
                return matScores, delCosts, insCosts, scale
        print("# the integer-rounded scores are too inaccurate: "
              "doubling the scale")
        originalScale *= 2

def writeGapCosts(delCosts, insCosts, isLastFormat, outFile):
    delInit, delGrow = delCosts
    insInit, insGrow = insCosts
    delOpen = delInit - delGrow
    insOpen = insInit - insGrow
    if isLastFormat:
        print("#last -a", delOpen, file=outFile)
        print("#last -A", insOpen, file=outFile)
        print("#last -b", delGrow, file=outFile)
        print("#last -B", insGrow, file=outFile)
    else:
        print("# delExistCost:", delOpen, file=outFile)
        print("# insExistCost:", insOpen, file=outFile)
        print("# delExtendCost:", delGrow, file=outFile)
        print("# insExtendCost:", insGrow, file=outFile)

def tryToMakeChildProgramsFindable():
    d = os.path.dirname(__file__)
    e = os.path.join(d, os.pardir, "src")
    # put them first, to avoid getting older versions of LAST:
    os.environ["PATH"] = d + os.pathsep + e + os.pathsep + os.environ["PATH"]

def readLastalProgName(lastdbIndexName):
    bitsPerInt = "32"
    with open(lastdbIndexName + ".prj") as f:
        for line in f:
            if line.startswith("integersize="):
                bitsPerInt = line.split("=")[1].strip()
    return "lastal8" if bitsPerInt == "64" else "lastal"

def fixedLastalArgs(opts, lastalProgName):
    x = [lastalProgName, "-j7"]
    if opts.D: x.append("-D" + opts.D)
    if opts.E: x.append("-E" + opts.E)
    if opts.s: x.append("-s" + opts.s)
    if opts.S: x.append("-S" + opts.S)
    if opts.C: x.append("-C" + opts.C)
    if opts.T: x.append("-T" + opts.T)
    if opts.m: x.append("-m" + opts.m)
    if opts.k: x.append("-k" + opts.k)
    if opts.P: x.append("-P" + opts.P)
    if opts.X: x.append("-X" + opts.X)
    if opts.Q: x.append("-Q" + opts.Q)
    if opts.verbose: x.append("-" + "v" * opts.verbose)
    return x

def process(args, inStream):
    return subprocess.Popen(args, stdin=inStream, stdout=subprocess.PIPE,
                            universal_newlines=True)

def versionFromLastal():
    args = ["lastal", "--version"]
    proc = process(args, None)
    return proc.stdout.read().split()[1]

def lastSplitProcess(opts, proc):
    splitArgs = ["last-split", "-n", "-m0.01"]  # xxx ???
    proc = process(splitArgs, proc.stdout)
    if opts.postmask:
        maskArgs = ["last-postmask"]
        proc = process(maskArgs, proc.stdout)
    return proc

def doTraining(opts, args):
    tryToMakeChildProgramsFindable()
    lastalProgName = readLastalProgName(args[0])
    scaleIncrease = 20  # while training, up-scale the scores by this amount
    lastalVersion = versionFromLastal()

    lastalArgs = fixedLastalArgs(opts, lastalProgName)
    if opts.r: lastalArgs.append("-r" + opts.r)
    if opts.q: lastalArgs.append("-q" + opts.q)
    if opts.p: lastalArgs.append("-p" + opts.p)
    if opts.a: lastalArgs.append("-a" + opts.a)
    if opts.b: lastalArgs.append("-b" + opts.b)
    if opts.A: lastalArgs.append("-A" + opts.A)
    if opts.B: lastalArgs.append("-B" + opts.B)
    lastalArgs += args
    proc = process(lastalArgs, None)
    proc = lastSplitProcess(opts, proc)

    if opts.scale:
        externalScale = opts.scale / math.log(2)
    else:
        externalScale = scaleFromHeader(proc.stdout)

    internalScale = externalScale * scaleIncrease
    oldParameters = []

    print("# lastal version:", lastalVersion)
    print("# maximum percent identity:", opts.pid)
    print("# scale of score parameters:", externalScale)
    print("# scale used while training:", internalScale)
    print()

    while True:
        print("#", *lastalArgs)
        print()
        sys.stdout.flush()
        matCounts, gapCounts = countsFromLastOutput(proc.stdout, opts)
        gapProbs = gapProbsFromCounts(gapCounts, opts)
        matProbs = matProbsFromCounts(matCounts, opts)
        matchRatio, delRatios, insRatios = gapRatiosFromProbs(*gapProbs)
        rowProbs = [sum(i) for i in matProbs]
        colProbs = [sum(i) for i in zip(*matProbs)]
        matParams = matchRatio, matProbs, rowProbs, colProbs
        sas = scoresAndScale(internalScale, matParams, delRatios, insRatios)
        matScores, delCosts, insCosts, scale = sas
        writeGapCosts(delCosts, insCosts, False, None)
        print()
        print("# score matrix "
              "(query letters = columns, reference letters = rows):")
        writeScoreMatrix(sys.stdout, matScores, "# ")
        print()
        parameters = delCosts, insCosts, matScores
        if parameters in oldParameters: break
        oldParameters.append(parameters)
        lastalArgs = fixedLastalArgs(opts, lastalProgName)
        lastalArgs.append("-t{0:.6}".format(scale))
        lastalArgs.append("-p-")
        lastalArgs += args
        proc = process(lastalArgs, subprocess.PIPE)
        writeGapCosts(delCosts, insCosts, True, proc.stdin)
        writeScoreMatrix(proc.stdin, matScores, "")
        proc.stdin.close()
        proc = lastSplitProcess(opts, proc)

    sas = scoresAndScale(externalScale, matParams, delRatios, insRatios)
    matScores, delCosts, insCosts, scale = sas
    if opts.X: print("#last -X", opts.X)
    if opts.Q: print("#last -Q", opts.Q)
    print("#last -t{0:.6}".format(scale))
    writeGapCosts(delCosts, insCosts, True, None)
    if opts.s: print("#last -s", opts.s)
    if opts.S: print("#last -S", opts.S)
    print("# score matrix "
          "(query letters = columns, reference letters = rows):")
    writeScoreMatrix(sys.stdout, matScores, "")

def lastTrain(opts, args):
    if opts.sample_number:
        random.seed(math.pi)
        refName = args[0]
        queryFiles = args[1:]
        try:
            with tempfile.NamedTemporaryFile("w", delete=False) as f:
                getSeqSample(opts, queryFiles, f)
            doTraining(opts, [refName, f.name])
        finally:
            os.remove(f.name)
    else:
        doTraining(opts, args)

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    usage = "%prog [options] lastdb-name sequence-file(s)"
    description = "Try to find suitable score parameters for aligning the given sequences."
    op = optparse.OptionParser(usage=usage, description=description)
    op.add_option("-v", "--verbose", action="count",
                  help="show more details of intermediate steps")

    og = optparse.OptionGroup(op, "Training options")
    og.add_option("--revsym", action="store_true",
                  help="force reverse-complement symmetry")
    og.add_option("--matsym", action="store_true",
                  help="force symmetric substitution matrix")
    og.add_option("--gapsym", action="store_true",
                  help="force insertion/deletion symmetry")
    og.add_option("--pid", type="float", default=100, help=
                  "skip alignments with > PID% identity (default: %default)")
    og.add_option("--postmask", type="int", metavar="NUMBER", default=1, help=
                  "skip mostly-lowercase alignments (default=%default)")
    og.add_option("--sample-number", type="int", default=500, metavar="N",
                  help="number of random sequence samples (default: %default)")
    og.add_option("--sample-length", type="int", default=2000, metavar="L",
                  help="length of each sample (default: %default)")
    og.add_option("--scale", type="float", metavar="S",
                  help="output scores in units of 1/S bits")
    op.add_option_group(og)

    og = optparse.OptionGroup(op, "Initial parameter options")
    og.add_option("-r", metavar="SCORE",
                  help="match score (default: 6 if Q>=1, else 5)")
    og.add_option("-q", metavar="COST",
                  help="mismatch cost (default: 18 if Q>=1, else 5)")
    og.add_option("-p", metavar="NAME", help="match/mismatch score matrix")
    og.add_option("-a", metavar="COST",
                  help="gap existence cost (default: 21 if Q>=1, else 15)")
    og.add_option("-b", metavar="COST",
                  help="gap extension cost (default: 9 if Q>=1, else 3)")
    og.add_option("-A", metavar="COST", help="insertion existence cost")
    og.add_option("-B", metavar="COST", help="insertion extension cost")
    op.add_option_group(og)

    og = optparse.OptionGroup(op, "Alignment options")
    og.add_option("-D", metavar="LENGTH",
                  help="query letters per random alignment (default: 1e6)")
    og.add_option("-E", metavar="EG2",
                  help="maximum expected alignments per square giga")
    og.add_option("-s", metavar="STRAND", help=
                  "0=reverse, 1=forward, 2=both (default: 2 if DNA, else 1)")
    og.add_option("-S", metavar="NUMBER", default="1", help=
                  "score matrix applies to forward strand of: " +
                  "0=reference, 1=query (default: %default)")
    og.add_option("-C", metavar="COUNT", help=
                  "omit gapless alignments in COUNT others with > score-per-length")
    og.add_option("-T", metavar="NUMBER",
                  help="type of alignment: 0=local, 1=overlap (default: 0)")
    og.add_option("-m", metavar="COUNT", help=
                  "maximum initial matches per query position (default: 10)")
    og.add_option("-k", metavar="STEP", help="use initial matches starting at "
                  "every STEP-th position in each query (default: 1)")
    og.add_option("-P", metavar="THREADS",
                  help="number of parallel threads")
    og.add_option("-X", metavar="NUMBER", help="N/X is ambiguous in: "
                  "0=neither sequence, 1=reference, 2=query, 3=both "
                  "(default=0)")
    og.add_option("-Q", metavar="NAME",
                  help="input format: fastx, sanger (default=fasta)")
    op.add_option_group(og)

    (opts, args) = op.parse_args()
    if len(args) < 1:
        op.error("I need a lastdb index and query sequences")
    if not opts.sample_number and (len(args) < 2 or "-" in args):
        op.error("sorry, can't use stdin when --sample-number=0")
    if not opts.p and (not opts.Q or opts.Q in ("0", "fastx", "keep")):
        if not opts.r: opts.r = "5"
        if not opts.q: opts.q = "5"
        if not opts.a: opts.a = "15"
        if not opts.b: opts.b = "3"

    try: lastTrain(opts, args)
    except KeyboardInterrupt: pass  # avoid silly error message
    except Exception as e:
        prog = os.path.basename(sys.argv[0])
        sys.exit(prog + ": error: " + str(e))
