#!/usr/bin/python3

# Copyright 2022 Julian Gilbey <jdg@debian.org>
# License: MIT License, as found in LICENSE.txt

"""Determine whether every test passes in at least one pytest run.

Some of the tests in the Spyder test suite are flaky: they succeed most
of the time but occasionally fail.  Given the size of the test suite, this
means that a pytest run will often fail because a single test failed.
Using @flaky on every test may or may not help, but it would be quite
invasive.  Furthermore, as some of the tests are dependent on others, it
is sometimes the case that @flaky does not help, whereas running the whole
test suite again succeeds.

This solution is different: we run the command given on the command line
and collect the output, for example:

    run_pytest.py xvfb-run -a -s "-screen 0 1024x768x24 +extension GLX" \
        python3 runtests.py --run-slow

We note all the tests that result in errors or failures.  We run pytest at
most MAX_PYTEST_RUNS, exit with success if every test has succeeded at least
once and otherwise exit with failure.
"""

import re
import subprocess
import sys
from enum import Flag, auto
from typing import IO, cast

from _pytest.config import ExitCode

MAX_PYTEST_RUNS = 5

PYTEST_FAILURE_MSG = "Pytest run failed (stderr or other error)"
SUCCESS_MSG = "*** ALL TESTS RUN PASSED/XFAILED AT LEAST ONCE AFTER %s ***"
FAILURE_MSG = "*** SOME TESTS FAILED/ERRORED EVERY RUN, ABORTING ***"
RETRY_MSG = "*** RETRYING: THESE TESTS HAVE NOT PASSED AFTER %s: ***"
RETRY_INCOMPLETE_MSG = (
    "*** RETRYING: PYTEST HAS NOT COMPLETED NORMALLY AFTER %s ***"
)


class Status(Flag):
    """Status of pytest run"""

    NOERR = 0
    ERREXIT = auto()
    STDERR = auto()
    UNFINISHED = auto()


def main():
    """Run the tests at most MAX_PYTEST_RUNS times and exit.

    The exit status will be 0 if every test passed at least once and
    non-zero otherwise.
    """
    cmd = sys.argv[1:]
    failures = set(PYTEST_FAILURE_MSG)
    successes = set()
    completed_run = False

    for run in range(MAX_PYTEST_RUNS):
        new_successes, new_failures = get_test_results(cmd)
        print()  # clear output line

        # On the first run, note the failures, but on successive runs,
        # remove successful tests from the list of failures
        if completed_run:
            failures -= new_successes
        else:
            if PYTEST_FAILURE_MSG in new_failures:
                # We have still not had a completed run; we do record
                # successful tests
                successes |= new_successes
            else:
                # This is the first completed run, so we use this as the
                # starting point for unsuccessful tests, dropping any
                # previously successful tests.  Note that this removes
                # PYTEST_FAILURE_MSG from failures.
                failures = set(new_failures)
                failures -= successes
                completed_run = True

        if not failures:
            print(SUCCESS_MSG % show_runs(run + 1))
            sys.exit()

        if completed_run:
            print(RETRY_MSG % show_runs(run + 1))
            for test in sorted(failures):
                print(test)
        else:
            print(RETRY_INCOMPLETE_MSG % show_runs(run + 1))

    print(FAILURE_MSG)
    sys.exit(1)


def show_runs(run_num: int) -> tuple[int, str]:
    """Show the number of runs in the form '%d RUN(S)"""
    return "%d RUN%s" % (run_num, ("" if run_num == 1 else "S"))


def get_test_results(cmd: list[str]) -> tuple[set[str], set[str]]:
    """Run the tests and return all successful and unsuccessful tests"""
    output, status = run_tests(cmd)
    successes, failures = get_test_results_from_output(output)
    if status & (Status.STDERR | Status.UNFINISHED):
        failures.add(PYTEST_FAILURE_MSG)
    elif status & Status.ERREXIT:
        # We only add a note about this if we have not recorded any failing
        # tests, as a failing test always produces a non-zero exit status
        if not failures:
            failures.add(PYTEST_FAILURE_MSG)

    return successes, failures


def tee_output_from_process(proc: subprocess.Popen) -> tuple[list[str], list[str]]:
    """Print the stdout and stderr, returning stdout and stderr as lines"""
    output = []
    errs = []
    for line in cast(IO[str], proc.stdout):
        print(line, end="")
        output.append(line)
    for line in cast(IO[str], proc.stderr):
        # See https://github.com/lxc/lxc/issues/4128
        # This is bizarre, but appears harmless
        if line == "sysconf(NPROCESSORS_CONF) failed: Operation not permitted\n":
            continue
        print(f"<stderr>{line}", end="")
        errs.append(line)

    return output, errs


def run_tests(cmd: list[str]) -> tuple[list[str], Status]:
    """Run the tests and return the stdout and process status"""
    with subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
    ) as process:
        output, errs = tee_output_from_process(process)

        # Check that we haven't missed a failure, for example a segfault
        status = Status.NOERR
        retcode = process.poll()
        if retcode is None:
            # The process has not finished, but process.stdout is empty
            # so we kill the process and add a sentinel to the failures
            process.kill()
            status |= Status.UNFINISHED
        if errs:
            status |= Status.STDERR
        if retcode not in [ExitCode.OK, ExitCode.NO_TESTS_COLLECTED]:
            status |= Status.ERREXIT

    return output, status


# Note that we ignore lines which say "SKIPPED"
status_re = re.compile(r"^(.*) (ERROR|FAILED|PASSED|XFAIL|XPASS) +\[ [0-9 ][0-9]%\]")


def get_test_results_from_output(output: list[str]) -> tuple[set[str], set[str]]:
    """Extract the list of passing and failing tests from the lines of output"""
    successes: set[str] = set()
    failures: set[str] = set()
    for line in output:
        if status := status_re.match(line):
            if status.group(2) in ["PASSED", "XFAIL"]:
                successes.add(status.group(1))
            else:
                failures.add(status.group(1))
    return successes, failures


if __name__ == "__main__":
    main()
