#!/usr/bin/python3

import glob
import os
import string
import subprocess
import sys
import unittest
import errno
import time
import socket
import tempfile
import binascii

import importlib.machinery
import importlib.util

import parent
parent.ensure_bots()  # NOQA: testvm lives in bots/
import testlib
import testvm

sys.dont_write_bytecode = True
os.environ['PYTHONUNBUFFERED'] = '1'


class Test:
    process = None
    retries = 0
    retry_when_affected = True
    output = b""

    def __init__(self, test_id, command, timeout, nondestructive, retry_when_affected):
        self.test_id = test_id
        self.command = command
        self.timeout = timeout
        self.nondestructive = nondestructive
        self.serial_machine = None
        self.retry_when_affected = retry_when_affected


def test_name(test):
    return "{0} {1} {2}{3}".format(
        test.test_id,
        test.command[0],
        test.command[-1],
        " [ND@{0}]".format(test.serial_machine) if test.nondestructive else "",
    )


def flush_stdout():
    while True:
        try:
            sys.stdout.flush()
            break
        except BlockingIOError:
            time.sleep(0.1)


def print_test(test, print_tap=True, retry_reason="", skip_reason=""):
    for line in test.output.strip().splitlines(keepends=True):
        while line:
            try:
                sys.stdout.buffer.write(line)
                break
            except BlockingIOError as e:
                line = line[e.characters_written:]
                time.sleep(0.1)

    if retry_reason:
        retry_reason = " " + retry_reason
    if skip_reason:
        skip_reason = " " + skip_reason

    if not print_tap:
        print(retry_reason + skip_reason)
        flush_stdout()
        return

    print()  # Tap needs to start on a separate line
    if test.process.returncode == 0:
        print("ok {0}{1}".format(test_name(test), retry_reason))
    elif skip_reason:
        print("ok {0}{1}".format(test_name(test), skip_reason))
    else:
        print("not ok {0}{1}".format(test_name(test), retry_reason))
    flush_stdout()


def finish_test(opts, test, affected_tests):
    """Returns if a test should retry or not

    Call test-failure-policy on the test's output, print if needed.

    Return (retry_reason, exit_code). retry_reason can be None or a string.
    """

    affected = any([test.command[0].endswith(t) for t in affected_tests])
    retry_reason = ""
    skip_reason = ""

    # Try affected tests 3 times
    if test.process.returncode == 0 and affected and test.retry_when_affected and test.retries < 2:
        retry_reason = "test affected tests 3 times"
        test.retries += 1
        print_test(test, not opts.list, "# RETRY {0} ({1})".format(test.retries, retry_reason))
        return retry_reason, 0

    # If test is being skipped pick up the reason
    if test.process.returncode == 77:
        lines = test.output.splitlines()
        skip_reason = lines[-1].strip().decode("utf-8")
        test.output = b"\n".join(lines[:-1])

    if test.process.returncode in [0, 77]:
        print_test(test, not opts.list, skip_reason=skip_reason)
        return None, 0

    if not opts.thorough:
        cmd = ["test-failure-policy", testvm.DEFAULT_IMAGE]
        try:
            proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
            reason = proc.communicate(test.output + ("not ok " + test_name(test)).encode())[0].strip()

            if proc.returncode == 77:
                print_test(test, skip_reason="# SKIP {0}".format(reason.decode("utf-8")))
                return None, 0

            if proc.returncode == 1:
                retry_reason = reason.decode("utf-8")

        except OSError as ex:
            if ex.errno != errno.ENOENT:
                sys.stderr.write("\nCouldn't run test-failure-policy: {0}\n".format(str(ex)))

    # HACK: many tests are unstable, always retry them 3 times unless affected
    if not affected and not retry_reason:
        retry_reason = "be robust against unstable tests"

    has_unexpected_message = testlib.UNEXPECTED_MESSAGE.encode() in test.output
    has_pixel_test_message = testlib.PIXEL_TEST_MESSAGE.encode() in test.output
    if test.retries < 2 and not (has_unexpected_message or has_pixel_test_message) and retry_reason:
        test.retries += 1
        print_test(test, retry_reason="# RETRY {0} ({1})".format(test.retries, retry_reason))
        return retry_reason, 0

    test.output += b"\n"
    print_test(test)
    return None, 1


def check_valid(filename):
    name = os.path.basename(filename)
    allowed = string.ascii_letters + string.digits + '-_'
    if not all(c in allowed for c in name):
        return None
    return name.replace("-", "_")


def build_command(filename, test, opts):
    cmd = [filename]
    if opts.trace:
        cmd.append("-t")
    if opts.verbosity:
        cmd.append("-v")
    if not opts.fetch:
        cmd.append("--nonet")
    if opts.list:
        cmd.append("-l")
    cmd.append(test)
    return cmd


class GlobalMachine:
    def __init__(self, restrict=True, cpus=None, memory_mb=None):
        self.image = testvm.DEFAULT_IMAGE
        self.network = testvm.VirtNetwork(image=self.image)
        self.networking = self.network.host(restrict=restrict)
        self.machine = testvm.VirtMachine(verbose=True, networking=self.networking, image=self.image, cpus=cpus, memory_mb=memory_mb)
        if not os.path.exists(self.machine.image_file):
            self.machine.pull(self.machine.image_file)
        self.machine.start()

    def reset(self):
        # It is important to re-use self.networking here, so that the
        # machine keeps its browser and control port.
        self.machine.kill()
        self.machine = testvm.VirtMachine(verbose=True, networking=self.networking, image=self.image)
        self.machine.start()

    def kill(self):
        self.machine.kill()
        self.network.kill()


def run(opts, image):
    # Build the list of tests we'll parallelize and the ones we'll run serially
    test_loader = unittest.TestLoader()
    parallel_tests = []
    serial_tests = []
    test_id = 1
    result = 0
    jobs = 1 if opts.list else opts.jobs
    start_time = time.time()

    # Map of batches of serial tests. Key is the batch name, value is a map with 3 keys:
    # - "working" - None if the machine is idle
    # - "tests"   - Array of tests to run
    # - "time"    - Combined time all tests took
    # - "machine" - A GlobalMachine instance for running the tests
    batch_tests = {}

    # Make sure tests can make relative imports
    sys.path.append(os.path.realpath(opts.test_dir))

    # Build list of affected tests
    # If file from `test-dir` was changed in the PR, all tests from it are considered affected
    changed_tests = []
    test_files = glob.glob(os.path.join(opts.test_dir, "check-*"))
    if opts.base:
        # Detect affected tests from changed test files
        diff_out = subprocess.check_output(["git", "diff", "--name-only", "origin/" + opts.base, opts.test_dir])
        # Never consider 'test/verify/check-example' to be affected - our tests for tests count on that
        # This file provides only examples, there is no place for it being flaky, no need to retry
        changed_tests = [test.decode("utf-8") for test in diff_out.strip().splitlines() if not test.endswith(b"check-example")]

        # If more than 3 test files were changed don't consider any of them as affected
        # as it might be a PR that changes more unrelated things.
        if len(changed_tests) > 3:
            # If 'test/verify/check-testlib' is affected, keep just that one - our tests for tests count on that
            if "test/verify/check-testlib" in changed_tests:
                changed_tests = ["test/verify/check-testlib"]
            else:
                changed_tests = []

        # Detect affected tests from changed pkg/* subdirectories in cockpit
        # If affected tests get detected from pkg/* changes, don't apply the
        # "only do this for max. 3 check-* changes" (even if the PR also changes ≥ 3 check-*)
        # (this does not apply to other projects)
        diff_out = subprocess.check_output(["git", "diff", "--name-only", "origin/" + opts.base, "--", "pkg/"])
        changed_pkgs = set(pkg.decode("utf-8").split('/')[1] for pkg in diff_out.strip().splitlines())
        changed_tests.extend([test for test in test_files if any(pkg in test for pkg in changed_pkgs)])

    seen_classes = {}
    for filename in test_files:
        name = check_valid(filename)
        if not name or not os.path.isfile(filename):
            continue
        loader = importlib.machinery.SourceFileLoader(name, filename)
        module = importlib.util.module_from_spec(importlib.util.spec_from_loader(loader.name, loader))
        loader.exec_module(module)
        for test_suite in test_loader.loadTestsFromModule(module):
            for test in test_suite:
                # ensure that test classes are unique, so that they can be selected properly
                cls = test.__class__.__name__
                if seen_classes.get(cls) not in [None, filename]:
                    raise ValueError("test class %s in %s already defined in %s" % (cls, filename, seen_classes[cls]))
                seen_classes[cls] = filename

                test_method = getattr(test.__class__, test._testMethodName)
                test_str = "{0}.{1}".format(cls, test._testMethodName)
                # most tests should take much less than 10mins, so default to that;
                # longer tests can be annotated with @timeout(seconds)
                # check the test function first, fall back to the class'es timeout
                test_timeout = getattr(test_method, "__timeout", getattr(test, "__timeout", 600))
                if opts.tests and not any([t in test_str for t in opts.tests]):
                    continue
                if test_str in opts.exclude:
                    continue
                nd = getattr(test_method, "_testlib__non_destructive", False)
                rwa = getattr(test_method, "_testlib__retry_when_affected", True)
                test = Test(test_id, build_command(filename, test_str, opts), test_timeout, nd, rwa)
                if nd:
                    serial_tests.append(test)
                else:
                    if not opts.nondestructive:
                        parallel_tests.append(test)
                test_id += 1

    # sort serial tests by class/test name, to avoid spurious errors where failures depend on the order of execution
    # but let's make sure we always test them both ways around; hash the image name, which is robust, reproducible, and provides
    # an even distribution of both directions
    serial_tests.sort(key=lambda t: t.command[-1], reverse=bool(binascii.crc32(image.encode()) & 1))

    print("1..{0}".format(len(parallel_tests) + len(serial_tests)))
    flush_stdout()

    if serial_tests and opts.list:
        # Just build one batch for listing
        batch_tests[0] = {"working": None, "tests": [], "time": 0}
        for test in serial_tests:
            test.serial_machine = 0
            batch_tests[0]["tests"].append(test)

    if serial_tests and not opts.list:
        if opts.machine:
            batch_tests[0] = {"working": None, "tests": [], "time": 0}
            ssh_address = opts.machine
            web_address = opts.browser

            for test in serial_tests:
                batch_tests[0]["tests"].append(test)
                test.command.insert(-2, "--machine")
                test.command.insert(-2, ssh_address)
                test.command.insert(-2, "--browser")
                test.command.insert(-2, web_address)
                test.serial_machine = 0
        else:
            batch_size = len(serial_tests) // opts.batches

            for i in range(opts.batches):
                m = GlobalMachine(restrict=not opts.enable_network, cpus=opts.nondestructive_cpus, memory_mb=opts.nondestructive_memory_mb)
                batch_tests[i] = {"working": None, "tests": [], "time": 0, "machine": m}
                ssh_address = "{0}:{1}".format(m.machine.ssh_address,
                                               m.machine.ssh_port)
                web_address = "{0}:{1}".format(m.machine.web_address,
                                               m.machine.web_port)

                if i == opts.batches - 1:  # Last machine needs to resolve the rest
                    batch = serial_tests[batch_size * i:]
                else:
                    batch = serial_tests[batch_size * i: batch_size * i + batch_size]

                for test in batch:
                    batch_tests[i]["tests"].append(test)
                    test.command.insert(-2, "--machine")
                    test.command.insert(-2, ssh_address)
                    test.command.insert(-2, "--browser")
                    test.command.insert(-2, web_address)
                    test.serial_machine = i

    running_tests = []
    serial_tests_len = len(serial_tests)
    serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
    while serial_remaining or parallel_tests or running_tests:
        made_progress = False
        if len(running_tests) < jobs:
            test = None
            # Find if there is parallel machine that is not busy and has some other tests to run
            for batch in batch_tests:
                if batch_tests[batch]["working"] is None and len(batch_tests[batch]["tests"]) > 0:
                    test = batch_tests[batch]["tests"].pop(0)
                    batch_tests[batch]["working"] = True
                    serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
                    batch_tests[batch]["started"] = time.time()
                    break
            else:
                if parallel_tests:
                    test = parallel_tests.pop(0)

            if test:
                made_progress = True
                test.outfile = tempfile.TemporaryFile()
                test.process = subprocess.Popen(["timeout", str(test.timeout)] + test.command,
                                                stdout=test.outfile, stderr=subprocess.STDOUT)
                running_tests.append(test)

        for test in running_tests.copy():
            poll_result = test.process.poll()
            if poll_result is not None:
                made_progress = True
                test.outfile.flush()
                test.outfile.seek(0)
                test.output = test.outfile.read()
                test.outfile.close()
                running_tests.remove(test)
                retry_reason, test_result = finish_test(opts, test, changed_tests)
                result += test_result

                if test.serial_machine is not None:
                    tests_duration = (time.time() - batch_tests[test.serial_machine]["started"])
                    batch_tests[test.serial_machine]["time"] += tests_duration

                    # sometimes our global machine gets messed up; also, tests that time out don't run cleanup handlers
                    # restart it to avoid an unbounded number of test retries and follow-up errors
                    if not opts.machine and (poll_result == 124 or (retry_reason and "test harness" in retry_reason)):
                        # try hard to keep the test output consistent
                        sys.stderr.write("\nRestarting global machine %s\n" % test.serial_machine)
                        sys.stderr.flush()
                        batch_tests[test.serial_machine]["machine"].reset()

                # run again if needed
                if retry_reason:
                    test.output = None
                    test.process = None
                    if test.serial_machine is not None:
                        batch_tests[test.serial_machine]["tests"].insert(0, test)
                        serial_remaining = sum([len(batch_tests[x]["tests"]) for x in batch_tests])
                    else:
                        parallel_tests.insert(0, test)

                if test.serial_machine is not None:
                    batch_tests[test.serial_machine]["working"] = None

        # Sleep if we didn't make progress
        if not made_progress and not opts.list:
            time.sleep(0.5)

    if not opts.list:
        for b in batch_tests.values():
            if "machine" in b:
                b["machine"].kill()

        duration = int(time.time() - start_time)
        hostname = socket.gethostname().split(".")[0]

        serial_details = []
        for batch in batch_tests:
            serial_details.append("{0}: {1}s".format(batch, int(batch_tests[batch]["time"])))

        details = "[{0}s on {1}, {2} serial tests: {3}]".format(duration, hostname, serial_tests_len, ", ".join(serial_details))
        print()
        if result > 0:
            print("# {0} TESTS FAILED {1}".format(result, details))
        else:
            print("# TESTS PASSED {0}".format(details))
        flush_stdout()

    return result


def main():
    jobs = int(os.environ.get("TEST_JOBS", 1))
    parser = testlib.arg_parser(enable_sit=False)
    parser.add_argument('-j', '--jobs', type=int,
                        default=jobs, help="Number of concurrent jobs")
    parser.add_argument('--thorough', action='store_true',
                        help='Thorough mode, no skipping known issues')
    parser.add_argument('-n', '--nondestructive', action='store_true',
                        help='Only consider @nondestructive tests')
    parser.add_argument('--machine', metavar="hostname[:port]",
                        default=None, help="Run tests against an already running machine;  implies --nondestructive")
    parser.add_argument('--browser', metavar="hostname[:port]",
                        default=None, help="When using --machine, use this cockpit web address")
    parser.add_argument('--test-dir', default=os.environ.get("TEST_DIR", testvm.TEST_DIR),
                        help="Directory in which to glob check-* files; default: %(default)s")
    parser.add_argument('--exclude', action="append", default=[], metavar="TestClass.testName",
                        help="Exclude test (exact match only); can be specified multiple times")
    parser.add_argument('-b', '--batches', type=int,
                        default=max(jobs // 2, 1), help="Number of concurrent batches of nondestructive tests")
    parser.add_argument('--nondestructive-cpus', type=int, default=None,
                        help="Number of CPUs for nondestructive test global machines")
    parser.add_argument('--nondestructive-memory-mb', type=int, default=None,
                        help="RAM size for nondestructive test global machines")
    parser.add_argument('--base', default=os.environ.get("BASE_BRANCH"),
                        help="Retry affected tests compared to given base branch; default: %(default)s")
    opts = parser.parse_args()

    if opts.machine:
        if opts.jobs > 1:
            parser.error("--machine cannot be used with concurrent jobs")
        opts.nondestructive = True

    # Tell any subprocesses what we are testing
    if "TEST_REVISION" not in os.environ:
        r = subprocess.run(["git", "rev-parse", "HEAD"],
                           universal_newlines=True, check=False, stdout=subprocess.PIPE)
        if r.returncode == 0:
            os.environ["TEST_REVISION"] = r.stdout.strip()

    os.environ["TEST_BROWSER"] = os.environ.get("TEST_BROWSER", "chromium")

    image = testvm.DEFAULT_IMAGE
    testvm.DEFAULT_IMAGE = image
    os.environ["TEST_OS"] = image

    return run(opts, image)


if __name__ == '__main__':
    sys.exit(main())
