#! /usr/bin/env python3
# Author: Martin C. Frith 2021
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import print_function

import argparse
import gzip
import signal
import sys

def openFile(fileName):
    if fileName == "-":
        return sys.stdin
    if fileName.endswith(".gz"):
        return gzip.open(fileName, "rt")
    return open(fileName)

def alignmentInput(lines):
    ranges = []
    alnLines = []
    for line in lines:
        fields = line.split()
        if line[0] == "#":
            print(line, end="")
        elif fields:
            alnLines.append(line)
            if line[0] == "s":
                seqName = fields[1]
                beg = int(fields[2])
                span = int(fields[3])
                end = beg + span
                if fields[4] == "-":
                    end = int(fields[5]) - beg
                    beg = end - span
                r = seqName, beg, end
                ranges.append(r)
        else:
            if alnLines:
                yield ranges, alnLines
            ranges = []
            alnLines = []
    if alnLines:
        yield ranges, alnLines

def isNearInAllSeqs(iRanges, jRanges, maxDistance):
    for i, j in zip(iRanges, jRanges):
        iSeq, iBeg, iEnd = i
        jSeq, jBeg, jEnd = j
        if iSeq != jSeq:
            return False
        if jBeg > iEnd + maxDistance or iBeg > jEnd + maxDistance:
            return False
    return True

def linkOneAln(args, alns, i):
    maxRankDifference = args.tween + 1
    jIndexBeg = max(i - maxRankDifference, 0)
    jIndexEnd = min(i + maxRankDifference + 1, len(alns))

    iRanges, iLines, iRank, iLinks = alns[i]
    iSeq, iBeg, iEnd = iRanges[0]

    j = i
    while j > jIndexBeg:
        j -= 1
        jRanges, jLines, jRank, jLinks = alns[j]
        jSeq, jBeg, jEnd = jRanges[0]
        if jSeq < iSeq or jEnd + args.distance < iBeg:
            break
        if isNearInAllSeqs(iRanges, jRanges, args.distance):
            if abs(jRank - iRank) <= maxRankDifference:
                iLinks.append(j)

    j = i + 1
    while j < jIndexEnd:
        jRanges, jLines, jRank, jLinks = alns[j]
        jSeq, jBeg, jEnd = jRanges[0]
        if iSeq < jSeq or iEnd + args.distance < jBeg:
            break
        if isNearInAllSeqs(iRanges, jRanges, args.distance):
            if abs(jRank - iRank) <= maxRankDifference:
                iLinks.append(j)
        j += 1

def connectedComponent(alns, isMarked, i):
    stack = [i]
    isMarked[i] = True
    while stack:
        j = stack.pop()
        yield j
        for k in alns[j][3]:
            if not isMarked[k]:
                stack.append(k)
                isMarked[k] = True

def connectedComponents(alns):
    isMarked = [False for i in alns]
    for i in range(len(alns)):
        if not isMarked[i]:
            yield sorted(connectedComponent(alns, isMarked, i))

def sortKey1(aln):
    ranges, lines = aln
    return ranges[1]

def main(args):
    alns = sorted(alignmentInput(openFile(args.infile)), key=sortKey1)
    alns = sorted(aln + (rank, []) for rank, aln in enumerate(alns))
    for i, x in enumerate(alns):
        linkOneAln(args, alns, i)
    for c in connectedComponents(alns):
        if len(c) >= args.count:
            for i in c:
                ranges, lines, rank, links = alns[i]
                print(*lines, sep="")

if __name__ == "__main__":
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
    descr = "Omit any alignment not near (in both sequences) to other alignments."
    ap = argparse.ArgumentParser(description=descr, formatter_class=
                                 argparse.ArgumentDefaultsHelpFormatter)
    ap.add_argument("-c", "--count", type=int, default=3, metavar="C",
                    help="minimum number of linked alignments")
    ap.add_argument("-d", "--distance", type=float, default=1000000,
                    metavar="D", help="maximum distance")
    ap.add_argument("-t", "--tween", type=float, default=5, metavar="T",
                    help="maximum number of alignments in between")
    ap.add_argument("infile", help="file of sequence alignments in MAF format")
    args = ap.parse_args()
    main(args)
