#!/usr/bin/python3

import glob
import os
import string
import subprocess
import sys
import unittest
import errno
import time
import re
import socket
import tempfile
import binascii

import importlib.machinery
import importlib.util

import parent
import testlib
import testvm

sys.dont_write_bytecode = True
os.environ['PYTHONUNBUFFERED'] = '1'


class Test:
    process = None
    retries = 0
    retry_when_affected = True
    output = b""

    def __init__(self, test_id, command, timeout, nondestructive, retry_when_affected):
        self.test_id = test_id
        self.command = command
        self.timeout = timeout
        self.nondestructive = nondestructive
        self.serial_machine = None
        self.retry_when_affected = retry_when_affected


def test_name(test):
    return "{0} {1} {2}{3}".format(test.test_id, test.command[0], test.command[-1], " [ND]" if test.nondestructive else "")


def flush_stdout():
    while True:
        try:
            sys.stdout.flush()
            break
        except BlockingIOError:
            time.sleep(0.1)


def print_test(test, print_tap=True):
    for line in test.output.splitlines(keepends=True):
        while line:
            try:
                sys.stdout.buffer.write(line)
                break
            except BlockingIOError as e:
                line = line[e.characters_written:]
                time.sleep(0.1)
    flush_stdout()

    if not print_tap:
        return

    if test.process.returncode == 0:
        print("ok " +  test_name(test))
    elif test.process.returncode == 77 or b"# SKIP " in test.output:
        # If the test was skipped, add the last line (which contains the reason
        # for the skip) to the result
        print("ok {0} {1}".format(test_name(test),
                                  test.output.splitlines()[-1].strip().decode() if test.process.returncode == 77 else ""))
    else:
        print("not ok " + test_name(test))
    flush_stdout()

def finish_test(opts, test, affected_tests):
    """Returns if a test should retry or not

    Call test-policy on the test's output, print if needed.

    Return (retry_reason, exit_code). retry_reason can be None or a string.
    """

    affected = test.command[0] in affected_tests

    # Try affected tests 3 times
    if test.process.returncode == 0 and affected and test.retry_when_affected and test.retries < 2:
        retry_reason = b"test affected tests 3 times"
        test.retries += 1
        test.output += b" # RETRY %i (%s)\n" % (test.retries, retry_reason)
        print_test(test, not opts.list)
        return retry_reason, 0
    elif test.process.returncode in [0, 77]:
        print_test(test, not opts.list)
        return None, 0

    if not opts.thorough:
        cmd = ["tests-policy", testvm.DEFAULT_IMAGE]
        try:
            test.output += ("not ok " + test_name(test)).encode()
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            changed = proc.communicate(test.output)[0]
            if proc.returncode == 0:
                if test.output != changed:
                    changed += b"\n"
                test.output = changed
        except OSError as ex:
            if ex.errno != errno.ENOENT:
                sys.stderr.write("Couldn't run tests-policy: {0}\n".format(str(ex)))

        if b"# SKIP" in test.output:
            print_test(test, print_tap=False)
            return None, 0

    # do we get a specific retry reason from tests-policy?
    m = re.search(b"\s*# RETRY (.*)$", test.output, re.MULTILINE)
    if m:
        retry_reason = m.group(1)
        # remove it from test output; we must not print it after the 3rd time, and going to print it separately
        test.output = re.sub(b"\s*# RETRY .*\\n", b"", test.output)
    elif affected: # Don't retry affected failed tests
        retry_reason = None
    else:
        # HACK: many tests are unstable, always retry them 3 times
        retry_reason = b"be robust against unstable tests"

    unexpected_message = testlib.UNEXPECTED_MESSAGE.encode() in test.output
    if test.retries < 2 and not unexpected_message and retry_reason:
        test.retries += 1
        test.output += b" # RETRY %i (%s)\n" % (test.retries, retry_reason)
        print_test(test, print_tap=opts.thorough)
        return retry_reason, 0
    else:
        print_test(test, print_tap=opts.thorough)
    return None, 1

def check_valid(filename):
    name = os.path.basename(filename)
    allowed = string.ascii_letters + string.digits + '-_'
    if not all(c in allowed for c in name):
        return None
    return name.replace("-", "_")

def build_command(filename, test, opts):
    cmd = [filename]
    if opts.trace:
        cmd.append("-t")
    if opts.verbosity:
        cmd.append("-v")
    if not opts.fetch:
        cmd.append("--nonet")
    if opts.list:
        cmd.append("-l")
    cmd.append(test)
    return cmd

class GlobalMachine:
    def __init__(self, restrict=True):
        self.image = testvm.DEFAULT_IMAGE
        self.network = testvm.VirtNetwork(image=self.image)
        self.networking = self.network.host(restrict=restrict)
        self.machine = testvm.VirtMachine(verbose=True, networking=self.networking, image=self.image)
        if not os.path.exists(self.machine.image_file):
            self.machine.pull(self.machine.image_file)
        self.machine.start()

    def reset(self):
        # It is important to re-use self.networking here, so that the
        # machine keeps its browser and control port.
        self.machine.kill()
        self.machine = testvm.VirtMachine(verbose=True, networking=self.networking, image=self.image)
        self.machine.start()

    def kill(self):
        self.machine.kill()
        self.network.kill()


def run(opts, image):
    # Build the list of tests we'll parallelize and the ones we'll run serially
    test_loader = unittest.TestLoader()
    parallel_tests = []
    serial_tests = []
    test_id = 1
    result = 0
    jobs = 1 if opts.list else opts.jobs
    start_time = time.time()

    # Map of batches of serial tests. Key is the batch name, value is a map with 3 keys:
    # - "working" - None if the machine is idle
    # - "tests"   - Array of tests to run
    # - "time"    - Combined time all tests took
    # - "machine" - A GlobalMachine instance for running the tests
    batch_tests = {}

    # Make sure tests can make relative imports
    sys.path.append(os.path.realpath(opts.test_dir))

    # Build list of affected tests
    # If file from `test-dir` was changed in the PR, all tests from it are considered affected
    changed_tests = []
    cmd = ["git", "diff", "--name-only", "HEAD", "origin/master", opts.test_dir]
    r = subprocess.run(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    if r.returncode == 0:
        changed_tests = [test.decode("utf-8") for test in r.stdout.strip().splitlines()]

    # If more than 3 test files were changed don't consider any of them as affected
    # as it might be a PR that changes more unrelated things.
    if len(changed_tests) > 3:
        changed_tests = []

    seen_classes = {}
    for filename in glob.glob(os.path.join(opts.test_dir, "check-*")):
        name = check_valid(filename)
        if not name or not os.path.isfile(filename):
            continue
        loader = importlib.machinery.SourceFileLoader(name, filename)
        module = importlib.util.module_from_spec(importlib.util.spec_from_loader(loader.name, loader))
        loader.exec_module(module)
        for test_suite in test_loader.loadTestsFromModule(module):
            for test in test_suite:
                # ensure that test classes are unique, so that they can be selected properly
                cls = test.__class__.__name__
                if seen_classes.get(cls) not in [None, filename]:
                    raise ValueError("test class %s in %s already defined in %s" % (cls, filename, seen_classes[cls]))
                seen_classes[cls] = filename

                test_method = getattr(test.__class__, test._testMethodName)
                test_str = "{0}.{1}".format(cls, test._testMethodName)
                # most tests should take much less than 10mins, so default to that;
                # longer tests can be annotated with @timeout(seconds)
                # check the test function first, fall back to the class'es timeout
                test_timeout = getattr(test_method, "__timeout", getattr(test, "__timeout", 600))
                if opts.tests and not any([t in test_str for t in opts.tests]):
                    continue
                if test_str in opts.exclude:
                    continue
                nd = getattr(test_method, "_testlib__non_destructive", False)
                rwa = getattr(test_method, "_testlib__retry_when_affected", True)
                test = Test(test_id, build_command(filename, test_str, opts), test_timeout, nd, rwa)
                if nd:
                    serial_tests.append(test)
                else:
                    if not opts.nondestructive:
                        parallel_tests.append(test)
                test_id += 1

    # sort serial tests by class/test name, to avoid spurious errors where failures depend on the order of execution
    # but let's make sure we always test them both ways around; hash the image name, which is robust, reproducible, and provides
    # an even distribution of both directions
    serial_tests.sort(key=lambda t: t.command[-1], reverse=bool(binascii.crc32(image.encode()) & 1))

    print("1..{0}".format(len(parallel_tests) + len(serial_tests)))
    flush_stdout()

    if serial_tests and opts.list:
        # Just build one batch for listing
        batch_tests[0] = {"working": None, "tests": [], "time": 0}
        for test in serial_tests:
            test.serial_machine = 0
            batch_tests[0]["tests"].append(test)

    if serial_tests and not opts.list:
        if opts.machine:
            batch_tests[0] = {"working": None, "tests": [], "time": 0}
            ssh_address = opts.machine
            web_address = opts.browser

            for test in serial_tests:
                batch_tests[0]["tests"].append(test)
                test.command.insert(-2, "--machine")
                test.command.insert(-2, ssh_address)
                test.command.insert(-2, "--browser")
                test.command.insert(-2, web_address)
                test.serial_machine = 0
        else:
            batch_size = len(serial_tests) // opts.batches

            for i in range(opts.batches):
                m = GlobalMachine(restrict=not opts.enable_network)
                batch_tests[i] = { "working": None, "tests": [], "time": 0, "machine": m }
                ssh_address = "{0}:{1}".format(m.machine.ssh_address,
                                               m.machine.ssh_port)
                web_address = "{0}:{1}".format(m.machine.web_address,
                                               m.machine.web_port)

                if i == opts.batches - 1: # Last machine needs to resolve the rest
                    batch = serial_tests[batch_size * i : ]
                else:
                    batch = serial_tests[batch_size * i : batch_size * i + batch_size]

                for test in batch:
                    batch_tests[i]["tests"].append(test)
                    test.command.insert(-2, "--machine")
                    test.command.insert(-2, ssh_address)
                    test.command.insert(-2, "--browser")
                    test.command.insert(-2, web_address)
                    test.serial_machine = i

    running_tests = []
    serial_tests_len = len(serial_tests)
    serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
    while serial_remaining or parallel_tests or running_tests:
        made_progress = False
        if len(running_tests) < jobs:
            test = None
            # Find if there is parallel machine that is not busy and has some other tests to run
            for batch in batch_tests:
                if batch_tests[batch]["working"] is None and len(batch_tests[batch]["tests"]) > 0:
                    test = batch_tests[batch]["tests"].pop(0)
                    batch_tests[batch]["working"] = True
                    serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
                    batch_tests[batch]["started"] = time.time()
                    break
            else:
                if parallel_tests:
                    test = parallel_tests.pop(0)

            if test:
                made_progress = True
                test.outfile = tempfile.TemporaryFile()
                test.process = subprocess.Popen(["timeout", str(test.timeout)] + test.command,
                                                stdout=test.outfile, stderr=subprocess.STDOUT)
                running_tests.append(test)


        for test in running_tests.copy():
            poll_result = test.process.poll()
            if poll_result is not None:
                made_progress = True
                test.outfile.flush()
                test.outfile.seek(0)
                test.output = test.outfile.read()
                test.outfile.close()
                running_tests.remove(test)
                retry_reason, test_result = finish_test(opts, test, changed_tests)
                result += test_result

                if test.serial_machine is not None:
                    tests_duration = (time.time() - batch_tests[test.serial_machine]["started"])
                    batch_tests[test.serial_machine]["time"] += tests_duration

                    # sometimes our global machine gets messed up; also, tests that time out don't run cleanup handlers
                    # restart it to avoid an unbounded number of test retries and follow-up errors
                    if not opts.machine and (poll_result == 124 or (retry_reason and b"test harness" in retry_reason)):
                        # try hard to keep the test output consistent
                        sys.stderr.write("Restarting global machine %s\n" % test.serial_machine)
                        sys.stderr.flush()
                        batch_tests[test.serial_machine]["machine"].reset()

                # run again if needed
                if retry_reason:
                    test.output = None
                    test.process = None
                    if test.serial_machine is not None:
                        batch_tests[test.serial_machine]["tests"].insert(0, test)
                        serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
                    else:
                        parallel_tests.insert(0, test)

                if test.serial_machine is not None:
                    batch_tests[test.serial_machine]["working"] = None

        # Sleep if we didn't make progress
        if not made_progress and not opts.list:
            time.sleep(0.5)

    if not opts.list:
        for b in batch_tests.values():
            if "machine" in b:
                b["machine"].kill()

        duration = int(time.time() - start_time)
        hostname = socket.gethostname().split(".")[0]

        serial_details = []
        for batch in batch_tests:
            serial_details.append("{0}: {1}s".format(batch, int(batch_tests[batch]["time"])))

        details = "[{0}s on {1}, {2} serial tests: {3}]".format(duration, hostname, serial_tests_len, ", ".join(serial_details))
        print()
        if result > 0:
            print("# {0} TESTS FAILED {1}".format(result, details))
        else:
            print("# TESTS PASSED {0}".format(details))

    return result


def main():
    jobs = int(os.environ.get("TEST_JOBS", 1))
    parser = testlib.arg_parser(enable_sit=False)
    parser.add_argument('-j', '--jobs', type=int,
                        default=jobs, help="Number of concurrent jobs")
    parser.add_argument('--thorough', action='store_true',
                        help='Thorough mode, no skipping known issues')
    parser.add_argument('-n', '--nondestructive', action='store_true',
                        help='Only consider @nondestructive tests')
    parser.add_argument('--machine', metavar="hostname[:port]",
                        default=None, help="Run tests against an already running machine;  implies --nondestructive")
    parser.add_argument('--browser', metavar="hostname[:port]",
                        default=None, help="When using --machine, use this cockpit web address")
    parser.add_argument('--test-dir', default=testvm.TEST_DIR,
                        help="Directory in which to glob check-* files")
    parser.add_argument('--exclude', action="append", default=[], metavar="TestClass.testName",
                        help="Exclude test (exact match only); can be specified multiple times")
    parser.add_argument('-b', '--batches', type=int,
                        default=max(jobs // 2, 1), help="Number of concurrent batches of nondestructive tests")
    opts = parser.parse_args()

    if opts.machine:
        if opts.jobs > 1:
            parser.error("--machine cannot be used with concurrent jobs")
        opts.nondestructive = True

    # Tell any subprocesses what we are testing
    if "TEST_REVISION" not in os.environ:
        r = subprocess.run(["git", "rev-parse", "HEAD"],
                           universal_newlines=True, check=False, stdout=subprocess.PIPE)
        if r.returncode == 0:
            os.environ["TEST_REVISION"] = r.stdout.strip()

    os.environ["TEST_BROWSER"] = os.environ.get("TEST_BROWSER", "chromium")

    image = testvm.DEFAULT_IMAGE
    testvm.DEFAULT_IMAGE = image
    os.environ["TEST_OS"] = image

    return run(opts, image)


if __name__ == '__main__':
    sys.exit(main())
