#!/usr/bin/python3
#
# This test uses lxd to create a container for each supported Ubuntu release,
# and test if sshuttle works from the local testbed (using sshhuttle under test)
# to the remote containter.
#
# This also tests the reverse, by creating a container matching the testbed's release,
# and connecting from each supported Ubuntu release's container. Note that the reverse
# direction tests *do not* test the sshuttle under test by this autopkgtest, since
# on a "remote" system sshuttle is not involved at all, and does not even need to be
# installed; the reverse direction test primarily tests if anything *else* has changed
# that breaks sshuttle (most likely, changes in python).

import apt_pkg
import functools
import ipaddress
import json
import os
import re
import sys
import subprocess
import tempfile
import time
import unittest

from aptsources.distro import get_distro
from contextlib import suppress
from distro_info import UbuntuDistroInfo
from pathlib import Path


DISTROINFO = UbuntuDistroInfo()
VALID_RELEASES = set(DISTROINFO.supported_esm() + DISTROINFO.supported() + [DISTROINFO.devel()])
TESTBED = {}
RELEASES = []

# really silly that users need to call this...especially just to use version_compare
apt_pkg.init_system()

@functools.lru_cache
def get_arch():
    return run_cmd('dpkg --print-architecture').stdout.strip()

def is_expected_failure(src, dst, python):
    if not python and is_expected_failure_nopy(src, dst):
        return True
    if python == 'python2' and is_expected_failure_py2(src, dst):
        return True
    if python == 'python3' and is_expected_failure_py3(src, dst):
        return True

    # xenial requires 'netstat' installed in remote, LP: #1896299
    if src.release == 'xenial' and apt_pkg.version_compare(src.sshuttle_version, '0.76-1ubuntu1') <= 0:
        return True

    # otherwise, we don't expect failure
    return False

def is_expected_failure_nopy(src, dst):
    # failure due to regression in patch to detect python command
    # should be fixed in version after this; LP: #1897961
    if src.release == 'xenial' and apt_pkg.version_compare(src.sshuttle_version, '0.76-1ubuntu1.1') <= 0:
        return True

def is_expected_failure_py2(src, dst):
    # failure due to regression from initial fix for py3.8 fix
    # should be fixed in version after this; LP: #1873368
    if src.release == 'focal' and apt_pkg.version_compare(src.sshuttle_version, '0.78.5-1ubuntu1') <= 0:
        return True

def is_expected_failure_py3(src, dst):
    # expected failure: trusty -> any
    # since trusty is now ESM only, this isn't expected to be fixed
    if src.release == 'trusty':
        return True

    # failure with py3.8 (or later) target, which is default py3 in focal (or later)
    if DISTROINFO.version(dst.release) >= DISTROINFO.version('focal'):
        # should be fixed in version after each of these; LP: #1873368
        if src.release == 'xenial' and apt_pkg.version_compare(src.sshuttle_version, '0.76-1ubuntu1') <= 0:
            return True
        if src.release == 'bionic' and apt_pkg.version_compare(src.sshuttle_version, '0.78.3-1ubuntu1') <= 0:
            return True
        if src.release == 'focal' and apt_pkg.version_compare(src.sshuttle_version, '0.78.5-1ubuntu1') <= 0:
            return True

    # otherwise, we don't expect failure
    return False

def set_releases(releases):
    invalid_releases = list(set(releases) - VALID_RELEASES)
    if invalid_releases:
        print(f'ignoring invalid release(s): {", ".join(invalid_releases)}')
    valid_releases = list(set(releases) & VALID_RELEASES)
    if valid_releases:
        print(f'limiting remote release(s) to: {", ".join(valid_releases)}')
        RELEASES.clear()
        RELEASES.extend(valid_releases)

def load_tests(loader, standard_tests, pattern):
    suite = unittest.TestSuite()
    for release in sorted(RELEASES or VALID_RELEASES):
        cls = type(f'SshuttleTest_{release}', (SshuttleTest,),
                   {'release': release})
        suite.addTests(loader.loadTestsFromTestCase(cls))
    return suite

def setUpModule():
    subprocess.run(['lxd', 'init', '--auto'], check=True)

    init_lxd()
    init_private_addrs()
    init_ssh_config()
    init_base_test_class()

def tearDownModule():
    remove_ssh_config()
    remove_private_subnets()
    del SshuttleTest.reverse_remote

def init_lxd():
    network_json = json.loads(subprocess.check_output('lxc network list --format json'.split(), encoding='utf-8'))
    managed_networks = list(filter(lambda n: n.get('managed'), network_json))
    if len(managed_networks) != 1:
        raise unittest.SkipTest(f'Expected only 1 lxd-managed network, found {len(managed_networks)}:\n{managed_networks}')
    lxdbr = managed_networks[0]

    TESTBED['lxdbr_name'] = lxdbr.get('name')
    TESTBED['lxdbr_addr'] = lxdbr.get('config').get('ipv4.address').partition('/')[0]
    TESTBED['lxdbr_domain'] = lxdbr.get('config').get('dns.domain', 'lxd')

    subprocess.run(['resolvectl', 'dns', TESTBED['lxdbr_name'], TESTBED['lxdbr_addr']])
    subprocess.run(['resolvectl', 'domain', TESTBED['lxdbr_name'], f'~{TESTBED["lxdbr_domain"]}'])

def init_private_addrs():
    TESTBED['remote_private_addr'] = None
    TESTBED['reverse_remote_private_addr'] = None

    b = 255
    while b > 0:
        b -= 1
        addr = f'10.{b}.0.1'
        if 'via' in subprocess.run(['ip', 'r', 'get', addr], encoding='utf-8',
                                   stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout:
            if not TESTBED['remote_private_addr']:
                TESTBED['remote_private_addr'] = addr
                continue
            if not TESTBED['reverse_remote_private_addr']:
                TESTBED['reverse_remote_private_addr'] = addr
                break
    else:
        raise Exception('Could not find any 10.* subnet to use for private addresses')

    TESTBED['remote_private_subnet'] = private_addr_to_subnet(TESTBED['remote_private_addr'])
    TESTBED['reverse_remote_private_subnet'] = private_addr_to_subnet(TESTBED['reverse_remote_private_addr'])

    # Force the private addrs unreachable so we don't try to reach them out our normal gateway
    subprocess.run(['ip', 'r', 'add', TESTBED['remote_private_subnet'], 'dev', 'lo'])
    subprocess.run(['ip', 'r', 'add', TESTBED['reverse_remote_private_subnet'], 'dev', 'lo'])

def remove_private_subnets():
    subprocess.run(['ip', 'r', 'del', TESTBED['remote_private_subnet'], 'dev', 'lo'])
    subprocess.run(['ip', 'r', 'del', TESTBED['reverse_remote_private_subnet'], 'dev', 'lo'])

def init_ssh_config():
    id_rsa = Path('/root/.ssh/id_rsa')
    if not id_rsa.exists():
        subprocess.run(['ssh-keygen', '-f', str(id_rsa), '-P', ''], check=True)
    TESTBED['ssh_key'] = id_rsa.with_suffix('.pub').read_text(encoding='utf-8')

    hosts = ' '.join([f'*.{TESTBED["lxdbr_domain"]}',
                      TESTBED['lxdbr_addr'],
                      TESTBED['remote_private_addr'],
                      TESTBED['reverse_remote_private_addr']])
    TESTBED['ssh_config'] = '\n'.join(['Host ' + hosts,
                                       '  StrictHostKeyChecking no',
                                       '  UserKnownHostsFile /dev/null',
                                       '  ConnectTimeout 10',
                                       '  ConnectionAttempts 18',
                                       ''])
    config = Path('/root/.ssh/config')
    if config.exists():
        content = config.read_text(encoding='utf-8') or ''
        if content and not content.endswith('\n'):
            content += '\n'
    else:
        content = ''
    content += TESTBED['ssh_config']
    config.write_text(content, encoding='utf-8')

def remove_ssh_config():
    config = Path('/root/.ssh/config')
    content = config.read_text(encoding='utf-8')
    config.write_text(content.replace(TESTBED['ssh_config'], ''), encoding='utf-8')

def init_base_test_class():
    cls = SshuttleTest
    cls.testbed_ssh_key = TESTBED['ssh_key']
    cls.testbed_ssh_config = TESTBED['ssh_config']

    cls.release = get_distro().codename
    reverse_remote = Remote(f'reverse-remote-{cls.release}', cls.release)
    reverse_remote.add_ssh_key(cls.testbed_ssh_key)
    reverse_remote.add_ssh_config(cls.testbed_ssh_config)
    reverse_remote.private_addr = TESTBED["reverse_remote_private_addr"]
    reverse_remote.private_subnet = TESTBED["reverse_remote_private_subnet"]
    reverse_remote.add_start_cmd(f'ip a add {reverse_remote.private_subnet} dev lo')
    reverse_remote.snapshot_create()
    cls.reverse_remote = reverse_remote

def private_addr_to_subnet(addr):
    return ipaddress.ip_network(f'{addr}/24', strict=False).with_prefixlen

def run_cmd(cmd, **kwargs):
    if type(cmd) == str:
        cmd = cmd.split()
    kwargs.setdefault('stdout', subprocess.PIPE)
    kwargs.setdefault('stderr', subprocess.STDOUT)
    kwargs.setdefault('encoding', 'utf-8')
    return subprocess.run(cmd, **kwargs)


class Remote(object):
    def __init__(self, name, release):
        self.name = name
        self.fqdn = f'{name}.{TESTBED["lxdbr_domain"]}'
        self.release = release
        self._start_cmds = []

        cmd = f'lxc delete --force {self.name}'
        self.log(cmd)
        run_cmd(cmd)

        image = f'ubuntu-daily:{release}'
        cmd = f'lxc launch --quiet {image} {self.name}'
        self.log(cmd)
        result = run_cmd(cmd)
        if result.returncode != 0:
            raise Exception(f'Could not launch {self.name}: {result.stdout}')

        self._wait_for_networking()
        self._create_ssh_key()
        self._add_local_ppas()
        self._add_proposed()
        self._apt_update_upgrade()
        self._install_sshuttle()
        self._install_python()
        self.stop(force=False)

    def log(self, msg):
        print(f'{self.name}: {msg}')

    def _wait_for_networking(self):
        self.log(f'Waiting for {self.name} to finish starting')
        for sec in range(120):
            if 'via' in self.lxc_exec('ip r show default').stdout:
                break
            time.sleep(0.5)
        else:
            raise Exception(f'Timed out waiting for remote {self.name} networking')

    def _create_ssh_key(self):
        self.log('creating ssh key')
        self.lxc_exec(['ssh-keygen', '-f', '/root/.ssh/id_rsa', '-P', ''])
        self._ssh_key = self.lxc_exec('cat /root/.ssh/id_rsa.pub').stdout

    def _add_local_ppas(self):
        paths = list(Path('/etc/apt/sources.list.d').glob('*.list'))
        paths.append(Path('/etc/apt/sources.list'))
        ppas = []
        for path in paths:
            for line in path.read_text(encoding='utf-8').splitlines():
                match = re.match(r'^deb .*ppa.launchpad.net/(?P<team>\w+)/(?P<ppa>\w+)/ubuntu', line)
                if match:
                    ppas.append(f'ppa:{match.group("team")}/{match.group("ppa")}')
        for ppa in ppas:
            self.log(f'adding PPA {ppa}')
            self.lxc_exec(['add-apt-repository', '-y', ppa])

    def _add_proposed(self):
        with tempfile.TemporaryDirectory() as d:
            f = Path(d) / 'tempfile'
            self.lxc_file_pull('/etc/apt/sources.list', str(f))
            for line in f.read_text(encoding='utf-8').splitlines():
                match = re.match(rf'^deb (?P<uri>\S+) {self.release} main.*', line)
                if match:
                    uri = match.group('uri')
                    components = 'man universe restricted multiverse'
                    proposed_line = f'deb {uri} {self.release}-proposed {components}'
                    self.log(f'adding {self.release}-proposed using {uri}')
                    self.lxc_exec(['add-apt-repository', '-y', proposed_line])
                    return

    def _apt_update_upgrade(self):
        self.log('upgrading packages')
        self.lxc_apt('update')
        self.lxc_apt('upgrade -y')

    def _install_sshuttle(self):
        self.log('installing sshuttle')
        result_install = self.lxc_apt('install -y sshuttle')
        result_which = self.lxc_exec('which sshuttle')
        if result_which.returncode != 0:
            err = result_install.stdout + result_which.stdout
            raise Exception(f'could not install sshuttle: {err}')
        self.sshuttle_version = self.lxc_exec('dpkg-query -f ${Version} -W sshuttle').stdout

    def _install_python(self):
        self.log('installing python')
        self.lxc_apt('install -y python')
        for python in ['python2', 'python3']:
            result_install = self.lxc_apt(['install', '-y', python])
            result_which = self.lxc_exec(['which', python])
            if result_which.returncode != 0:
                err = result_install.stdout + result_which.stdout
                raise Exception(f'could not install {python}: {err}')

    def snapshot_create(self, name='default'):
        self.log(f'creating snapshot: {name}')
        self.stop(force=False)
        subprocess.run(['lxc', 'snapshot', self.name, name], check=True)

    def snapshot_restore(self, name='default', start=True):
        self.log(f'restoring snapshot: {name}')
        self.stop()
        subprocess.run(['lxc', 'restore', self.name, name], check=True)
        if start:
            self.start()

    def snapshot_update(self, name='default'):
        self.log(f'updating snapshot: {name}')
        subprocess.run(['lxc', 'delete', '--force', f'{self.name}/{name}'], check=True)
        self.snapshot_create(name)

    @functools.cached_property
    def ssh_key(self):
        return self._ssh_key

    def add_start_cmd(self, cmd):
        self.log(f'adding start cmd: {cmd}')
        self._start_cmds.append(cmd)

    def add_file_content(self, path, content):
        with tempfile.TemporaryDirectory() as d:
            localfile = Path(d) / Path(path).name
            self.lxc_file_pull(path, str(localfile))
            existing_content = localfile.read_text(encoding='utf-8') or ''
            if content not in existing_content:
                if existing_content and not existing_content.endswith('\n'):
                    existing_content += '\n'
                existing_content += content
                localfile.write_text(existing_content)
                self.lxc_file_push(str(localfile), path)

    def add_ssh_key(self, key):
        self.log(f'adding ssh key: {key.strip()}')
        self.add_file_content('/root/.ssh/authorized_keys', key)

    def add_ssh_config(self, config):
        self.log('adding ssh config')
        self.add_file_content('/root/.ssh/config', config)

    def lxc_exec(self, cmd, **kwargs):
        if type(cmd) == str:
            cmd = cmd.split()
        return run_cmd(['lxc', 'exec', self.name, '--'] + cmd, **kwargs)

    def lxc_apt(self, cmd, **kwargs):
        if type(cmd) == str:
            cmd = cmd.split()
        return run_cmd(['lxc', 'exec', self.name, '--env', 'DEBIAN_FRONTEND=noninteractive', '--', 'apt'] + cmd, **kwargs)

    def lxc_file_pull(self, remote, local, fail_if_missing=False):
        remote = f'{self.name}{remote}'
        self.log(f'{local} <- {remote}')
        try:
            run_cmd(['lxc', 'file', 'pull', remote, local], check=True,
                    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            if fail_if_missing:
                raise
            localpath = Path(local)
            if localpath.is_dir():
                localpath = localpath / Path(remote).name
            localpath.touch()
            self.log(f'remote file missing, created empty file {localpath}')

    def lxc_file_push(self, local, remote):
        remote = f'{self.name}{remote}'
        self.log(f'{local} -> {remote}')
        run_cmd(['lxc', 'file', 'push', local, remote], check=True)

    @property
    def json(self):
        listjson = run_cmd('lxc list --format json').stdout
        filtered = list(filter(lambda i: i['name'] == self.name, json.loads(listjson)))
        if len(filtered) != 1:
            raise Exception(f'Expected only 1 lxc list entry for {self.name}, found {len(filtered)}:\n{listjson}')
        return filtered[0]

    @property
    def is_running(self):
        return self.json['status'] == 'Running'

    def start(self):
        if not self.is_running:
            cmd = f'lxc start {self.name}'
            self.log(cmd)
            result = run_cmd(cmd, check=True)
            if result.stdout:
                self.log(result.stdout)
            self._wait_for_networking()
            for cmd in self._start_cmds:
                self.lxc_exec(cmd)

    def stop(self, force=True):
        cmd = 'lxc stop'
        if force:
            cmd += ' --force'
        cmd += f' {self.name}'
        self.log(cmd)
        result = run_cmd(cmd)
        if result.stdout:
            self.log(result.stdout)

    def __del__(self):
        run_cmd(['lxc', 'delete', '--force', self.name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


class VerboseAssertionError(AssertionError):
    __logs = []

    def __init__(self, *args):
        logs = list(args) + self.read_log()
        super(VerboseAssertionError, self).__init__('\n'.join(logs))

    @classmethod
    def add_log(cls, msg):
        cls.__logs.append(str(msg))

    @classmethod
    def read_log(cls):
        log = cls.__logs
        cls.__logs = []
        return log

    @classmethod
    def clear_log(cls):
        cls.read_log()


class SshuttleTest(unittest.TestCase):
    release = None
    failureException = VerboseAssertionError

    @classmethod
    def is_arch_supported(cls):
        if cls.release == 'trusty':
            return get_arch() == 'amd64'
        return True

    @classmethod
    def setUpClass(cls):
        # note that some of the cls attrs used here are set by setUpModule()

        # this is set by the subclass, and required
        assert(cls.release)

        if not cls.is_arch_supported():
            raise unittest.SkipTest(f'Release {cls.release} not available for {get_arch()}')

        remote = Remote(f'remote-{cls.release}', cls.release)
        remote.add_ssh_key(cls.testbed_ssh_key)
        remote.add_ssh_config(cls.testbed_ssh_config)
        remote.private_addr = TESTBED["remote_private_addr"]
        remote.private_subnet = TESTBED["remote_private_subnet"]
        remote.add_start_cmd(f'ip a add {remote.private_subnet} dev lo')
        remote.snapshot_create()
        cls.remote = remote

        cls.reverse_remote.snapshot_restore()
        cls.reverse_remote.add_ssh_key(cls.remote.ssh_key)
        cls.reverse_remote.snapshot_update()

    @classmethod
    def tearDownClass(cls):
        del cls.remote

    def setUp(self):
        self.name = f'testbed-{self.release}'
        self.reverse_remote.snapshot_restore()
        self.remote.snapshot_restore()
        self.sshuttle_process = None
        self.sshuttle_log = tempfile.NamedTemporaryFile()
        self.failureException.clear_log()

    def tearDown(self):
        self.sshuttle_stop()
        self.sshuttle_log.close()
        self.remote.stop()
        self.reverse_remote.stop()

    @functools.cached_property
    def sshuttle_version(self):
        return run_cmd('dpkg-query -f ${Version} -W sshuttle').stdout

    def sshuttle_start(self, dst, python):
        sshuttle_cmd = 'sshuttle'
        if python:
            sshuttle_cmd += f' --python {python}'
        sshuttle_cmd += f' -r {dst.fqdn} {dst.private_subnet}'
        if dst is self.reverse_remote:
            sshuttle_cmd = f'lxc exec {self.remote.name} -- {sshuttle_cmd}'
        print(f'running: {sshuttle_cmd}')
        self.sshuttle_process = subprocess.Popen(sshuttle_cmd.split(), encoding='utf-8',
                                                 stdout=self.sshuttle_log, stderr=self.sshuttle_log)
        print('waiting for sshuttle to start...', end='', flush=True)
        for sec in range(300):
            if self.sshuttle_process.poll() is not None:
                print('sshuttle failed :-(', flush=True)
                break
            if 'client: Connected.' in Path(self.sshuttle_log.name).read_text(encoding='utf-8'):
                print('started', flush=True)
                break
            time.sleep(1)
            print('.', end='', flush=True)
        else:
            print("WARNING: timed out waiting for sshuttle to start, the test may fail")
        if self.sshuttle_process.poll() is not None:
            self.fail('sshuttle process failed to start')

    def sshuttle_stop(self):
        if self.sshuttle_process and self.sshuttle_process.poll() is None:
            print('stopping sshuttle...')
            self.sshuttle_process.terminate()
            with suppress(subprocess.TimeoutExpired):
                self.sshuttle_process.communicate(timeout=30)
                print('sshuttle stopped')
                self.sshuttle_process = None
                return

            print('sshuttle did not respond, killing sshuttle...')
            self.sshuttle_process.kill()
            with suppress(subprocess.TimeoutExpired):
                self.sshuttle_process.communicate(timeout=30)
                print('sshuttle stopped')
                self.sshuttle_process = None
                return

            self.fail('sshuttle subprocess refused to stop')

    def ssh_to(self, remote):
        ssh_cmd = ['ssh', '-v', remote.private_addr, '--', 'sh', '-c', '"echo connected to $(hostname)"'] 
        if remote is self.reverse_remote:
            ssh_cmd = ['lxc', 'exec', self.remote.name, '--'] + ssh_cmd
        print(f'running: {ssh_cmd}')
        result = run_cmd(ssh_cmd)
        self.failureException.add_log(result.stdout)
        if result.returncode == 0:
            print('ssh connected')
            return True
        else:
            if result.stdout:
                # just print the last line here; this might be an expected failure,
                # and for unexpected failures the full output will be included in the failure detail
                print(result.stdout.splitlines()[-1])
            print('ssh failed')

    def test_local_to_remote_nopy(self):
        self._test_to_remote(self, self.remote, None)

    def test_local_to_remote_py2(self):
        self._test_to_remote(self, self.remote, 'python2')

    def test_local_to_remote_py3(self):
        self._test_to_remote(self, self.remote, 'python3')

    def test_remote_to_reverse_remote_nopy(self):
        self._test_to_remote(self.remote, self.reverse_remote, None)

    def test_remote_to_reverse_remote_py2(self):
        self._test_to_remote(self.remote, self.reverse_remote, 'python2')

    def test_remote_to_reverse_remote_py3(self):
        self._test_to_remote(self.remote, self.reverse_remote, 'python3')

    def _test_to_remote(self, src, dst, python):
        self.failureException.add_log(f'Test detail: {src.name} sshuttle {src.sshuttle_version} to {dst.name} {python if python else ""}')
        print('this ssh connection should timeout:')
        self.assertFalse(self.ssh_to(dst))
        try:
            self.sshuttle_start(dst, python)
            print('this ssh connection should not timeout:')
            self.assertTrue(self.ssh_to(dst))
        except AssertionError:
            if is_expected_failure(src, dst, python):
                self.skipTest('This is an expected failure, ignoring test failure')
            else:
                self.failureException.add_log(Path(self.sshuttle_log.name).read_text(encoding='utf-8'))
                raise


if __name__ == '__main__':
    if len(sys.argv) > 1:
        set_releases(sys.argv[1:])
        del sys.argv[1:]
    unittest.main(verbosity=2)
