#!/usr/bin/python3

import glob
import os
import string
import subprocess
import sys
import unittest
import dataclasses
import errno
import time
import re
import socket
import tempfile
import binascii

import importlib.machinery
import importlib.util

import parent
import testlib
import testvm

sys.dont_write_bytecode = True
os.environ['PYTHONUNBUFFERED'] = '1'


@dataclasses.dataclass
class Test:
    test_id: int
    command: list
    timeout: int
    nondestructive: bool = False
    process: subprocess.Popen = None
    retries: int = 0
    output: bytes = b""

def test_name(test):
    return "{0} {1} {2}{3}".format(test.test_id, test.command[0], test.command[-1], " [ND]" if test.nondestructive else "")


def print_test(test, print_tap=True):
    for line in test.output.splitlines(keepends=True):
        sys.stdout.buffer.write(line)
    sys.stdout.flush()

    if not print_tap:
        return

    if test.process.returncode == 0:
        print("ok " +  test_name(test))
    elif test.process.returncode == 77 or b"# SKIP " in test.output:
        # If the test was skipped, add the last line (which contains the reason
        # for the skip) to the result
        print("ok {0} {1}".format(test_name(test),
                                  test.output.splitlines()[-1].strip().decode() if test.process.returncode == 77 else ""))
    else:
        print("not ok " + test_name(test))
    sys.stdout.flush()

def finish_test(opts, test):
    """Returns if a test should retry or not

    Call test-policy on the test's output, print if needed.

    Return (retry_reason, exit_code). retry_reason can be None or a string.
    """
    if test.process.returncode in [0, 77]:
        print_test(test, not opts.list)
        return None, 0

    if not opts.thorough:
        cmd = ["tests-policy", testvm.DEFAULT_IMAGE]
        try:
            test.output += ("not ok " + test_name(test)).encode()
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            changed = proc.communicate(test.output)[0]
            if proc.returncode == 0:
                if test.output != changed:
                    changed += b"\n"
                test.output = changed
        except OSError as ex:
            if ex.errno != errno.ENOENT:
                sys.stderr.write("Couldn't run tests-policy: {0}\n".format(str(ex)))

        if b"# SKIP" in test.output:
            print_test(test, print_tap=False)
            return None, 0

    # do we get a specific retry reason from tests-policy?
    m = re.search(b"\s*# RETRY (.*)$", test.output, re.MULTILINE)
    if m:
        retry_reason = m.group(1)
        # remove it from test output; we must not print it after the 3rd time, and going to print it separately
        test.output = re.sub(b"\s*# RETRY .*\\n", b"", test.output)
    else:
        # HACK: many tests are unstable, always retry them 3 times
        retry_reason = b"be robust against unstable tests"

    if test.retries < 2:
        test.retries += 1
        test.output += b" # RETRY %i (%s)\n" % (test.retries, retry_reason)
        print_test(test, print_tap=opts.thorough)
        return retry_reason, 0
    else:
        print_test(test, print_tap=opts.thorough)
    return None, 1

def check_valid(filename):
    name = os.path.basename(filename)
    allowed = string.ascii_letters + string.digits + '-_'
    if not all(c in allowed for c in name):
        return None
    return name.replace("-", "_")

def build_command(filename, test, opts):
    cmd = [filename]
    if opts.trace:
        cmd.append("-t")
    if opts.verbosity:
        cmd.append("-v")
    if not opts.fetch:
        cmd.append("--nonet")
    if opts.list:
        cmd.append("-l")
    cmd.append(test)
    return cmd

def run(opts, image):
    # Build the list of tests we'll parallellize and the ones we'll run serially
    test_loader = unittest.TestLoader()
    parallel_tests = []
    serial_tests = []
    test_id = 1
    result = 0
    jobs = 1 if opts.list else opts.jobs
    start_time = time.time()
    serial_tests_duration = 0

    for filename in glob.glob(os.path.join(os.path.dirname(__file__), "check-*")):
        name = check_valid(filename)
        if not name or not os.path.isfile(filename):
            continue
        loader = importlib.machinery.SourceFileLoader(name, filename)
        module = importlib.util.module_from_spec(importlib.util.spec_from_loader(loader.name, loader))
        loader.exec_module(module)
        for test_suite in test_loader.loadTestsFromModule(module):
            for test in test_suite:
                test_method = getattr(test.__class__, test._testMethodName)
                test_str = "{0}.{1}".format(test.__class__.__name__, test._testMethodName)
                # most tests should take much less than 10mins, so default to that;
                # longer tests can be annotated with @timeout(seconds)
                # check the test function first, fall back to the class'es timeout
                test_timeout = getattr(test_method, "__timeout", getattr(test, "__timeout", 600))
                if opts.tests and not any([t in test_str for t in opts.tests]):
                    continue
                nd = getattr(test_method, "_testlib__non_destructive", False)
                test = Test(test_id, build_command(filename, test_str, opts), test_timeout, nd)
                if nd:
                    serial_tests.append(test)
                else:
                    if not opts.nondestructive:
                        parallel_tests.append(test)
                test_id += 1

    # sort serial tests by class/test name, to avoid spurious errors where failures depend on the order of execution
    # but let's make sure we always test them both ways around; hash the image name, which is robust, reproducible, and provides
    # an even distribution of both directions
    serial_tests.sort(key=lambda t: t.command[-1], reverse=bool(binascii.crc32(image.encode()) & 1))

    print("1..{0}".format(len(parallel_tests) + len(serial_tests)))
    sys.stdout.flush()

    if serial_tests and not opts.list:
        testlib.MachineCase.get_global_machine()
        for test in serial_tests:
            test.command.insert(-2, "--machine")
            test.command.insert(-2, "{0}:{1}".format(testlib.MachineCase.get_global_machine().ssh_address,
                                                 testlib.MachineCase.get_global_machine().ssh_port))
            test.command.insert(-2, "--browser")
            test.command.insert(-2, "{0}:{1}".format(testlib.MachineCase.get_global_machine().web_address,
                                                 testlib.MachineCase.get_global_machine().web_port))

    running_tests = []
    serial_test = None
    serial_tests_len = len(serial_tests)
    while serial_tests or parallel_tests or running_tests:
        made_progress = False
        if len(running_tests) < jobs:
            test = None
            if serial_tests and serial_test is None:
                test = serial_tests.pop(0)
                serial_test = test
                serial_test_start = time.time()
            elif parallel_tests:
                test = parallel_tests.pop(0)

            if test:
                made_progress = True
                test.outfile = tempfile.TemporaryFile()
                test.process = subprocess.Popen(["timeout", str(test.timeout)] + test.command,
                                                stdout=test.outfile, stderr=subprocess.STDOUT)
                running_tests.append(test)


        for test in running_tests.copy():
            poll_result = test.process.poll()
            if poll_result is not None:
                made_progress = True
                test.outfile.flush()
                test.outfile.seek(0)
                test.output = test.outfile.read()
                test.outfile.close()
                running_tests.remove(test)
                retry_reason, test_result = finish_test(opts, test)
                result += test_result

                if test is serial_test:
                    serial_tests_duration += (time.time() - serial_test_start)

                    # sometimes our global machine gets messed up; also, tests that time out don't run cleanup handlers
                    # restart it to avoid an unbounded number of test retries and follow-up errors
                    if poll_result == 124 or (retry_reason and b"test harness" in retry_reason):
                        # try hard to keep the test output consistent
                        testlib.MachineCase.kill_global_machine()
                        testlib.MachineCase.get_global_machine()

                # run again if needed
                if retry_reason:
                    test.output = None
                    test.process = None
                    if test is serial_test:
                        serial_tests.insert(0, test)
                    else:
                        parallel_tests.insert(0, test)

                if test is serial_test:
                    serial_test = None

        # Sleep if we didn't make progress
        if not made_progress and not opts.list:
            time.sleep(0.5)

    if not opts.list:
        testlib.MachineCase.kill_global_machine()

        duration = int(time.time() - start_time)
        hostname = socket.gethostname().split(".")[0]
        details = "[{0}s on {1}, {2} serial tests took {3}s]".format(duration, hostname, serial_tests_len, int(serial_tests_duration))
        print()
        if result > 0:
            print("# {0} TESTS FAILED {1}".format(result, details))
        else:
            print("# TESTS PASSED {0}".format(details))

    return result


def main():
    parser = testlib.arg_parser(enable_sit=False)
    parser.add_argument('--publish', action='store', help="Unused")
    parser.add_argument('-j', '--jobs', type=int,
                        default=os.environ.get("TEST_JOBS", 1), help="Number of concurrent jobs")
    parser.add_argument('--thorough', action='store_true',
                        help='Thorough mode, no skipping known issues')
    parser.add_argument('-n', '--nondestructive', action='store_true',
                        help='Only consider @nondestructive tests')
    opts = parser.parse_args()

    image = testvm.DEFAULT_IMAGE
    revision = os.environ.get("TEST_REVISION")
    test_browser = os.environ.get("TEST_BROWSER", "chromium")
    if not revision:
        revision = subprocess.check_output(["git", "rev-parse", "HEAD"],
                                           universal_newlines=True).strip()

    # Tell any subprocesses what we are testing
    os.environ["TEST_REVISION"] = revision
    os.environ["TEST_BROWSER"] = test_browser
    testvm.DEFAULT_IMAGE = image
    os.environ["TEST_OS"] = image

    return run(opts, image)


if __name__ == '__main__':
    sys.exit(main())
