#!/usr/bin/python3
#
# This test uses lxd to create a container for each supported Ubuntu release,
# and test if sshuttle works from the local testbed (using sshhuttle under test)
# to the remote containter.
#
# This also tests the reverse, by creating a container matching the testbed's release,
# and connecting from each supported Ubuntu release's container. Note that the reverse
# direction tests *do not* test the sshuttle under test by this autopkgtest, since
# on a "remote" system sshuttle is not involved at all, and does not even need to be
# installed; the reverse direction test primarily tests if anything *else* has changed
# that breaks sshuttle (most likely, changes in python).

import apt_pkg
import ipaddress
import json
import os
import re
import sys
import subprocess
import tempfile
import time
import unittest

from aptsources.distro import get_distro
from contextlib import suppress
from distro_info import UbuntuDistroInfo
from pathlib import Path


DISTROINFO = UbuntuDistroInfo()
VALID_RELEASES = set(DISTROINFO.supported_esm() + DISTROINFO.supported() + [DISTROINFO.devel()])
TESTBED = {}
RELEASES = []

# really silly that users need to call this...especially just to use version_compare
apt_pkg.init_system()

def is_expected_failure(src_release, dst_release, python, sshuttle_version):
    if python == 'python3' and is_expected_failure_py3(src_release, dst_release, sshuttle_version):
        return True

    # otherwise, we don't expect failure
    return False

def is_expected_failure_py3(src_release, dst_release, sshuttle_version):
    # expected failure: trusty -> any
    # since trusty is now ESM only, this isn't expected to be fixed
    if src_release == 'trusty':
        return True

    # failure with py3.8 (or later) target, which is default py3 in focal (or later)
    if DISTROINFO.version(dst_release) >= DISTROINFO.version('focal'):
        # should be fixed in version after each of these; LP: #1873368
        if (src_release == 'xenial' and apt_pkg.version_compare(sshuttle_version, '0.76-1ubuntu1') <= 0):
            return True
        if (src_release == 'bionic' and apt_pkg.version_compare(sshuttle_version, '0.78.3-1ubuntu1') <= 0):
            return True
        if (src_release == 'focal' and apt_pkg.version_compare(sshuttle_version, '0.78.5-1ubuntu1') <= 0):
            return True

    # otherwise, we don't expect failure
    return False

def set_releases(releases):
    invalid_releases = list(set(releases) - VALID_RELEASES)
    if invalid_releases:
        print(f'ignoring invalid release(s): {", ".join(invalid_releases)}')
    valid_releases = list(set(releases) & VALID_RELEASES)
    if valid_releases:
        print(f'limiting remote release(s) to: {", ".join(valid_releases)}')
        RELEASES.clear()
        RELEASES.extend(valid_releases)

def load_tests(loader, standard_tests, pattern):
    suite = unittest.TestSuite()
    for release in sorted(RELEASES or VALID_RELEASES):
        cls = type(f'SshuttleTest_{release}', (SshuttleTest,),
                   {'RELEASE': release})
        suite.addTests(loader.loadTestsFromTestCase(cls))
    return suite

def setUpModule():
    subprocess.run(['lxd', 'init', '--auto'], check=True)

    init_lxd()
    init_private_addrs()
    init_ssh_config()
    init_base_test_class()

def tearDownModule():
    remove_ssh_config()
    remove_private_subnets()
    del SshuttleTest.reverse_remote

def init_lxd():
    network_json = json.loads(subprocess.check_output('lxc network list --format json'.split(), encoding='utf-8'))
    managed_networks = list(filter(lambda n: n.get('managed'), network_json))
    if len(managed_networks) != 1:
        raise unittest.SkipTest(f'Expected only 1 lxd-managed network, found {len(managed_networks)}:\n{managed_networks}')
    lxdbr = managed_networks[0]

    TESTBED['lxdbr_name'] = lxdbr.get('name')
    TESTBED['lxdbr_addr'] = lxdbr.get('config').get('ipv4.address').partition('/')[0]
    TESTBED['lxdbr_domain'] = lxdbr.get('config').get('dns.domain', 'lxd')

    subprocess.run(['resolvectl', 'dns', TESTBED['lxdbr_name'], TESTBED['lxdbr_addr']])
    subprocess.run(['resolvectl', 'domain', TESTBED['lxdbr_name'], f'~{TESTBED["lxdbr_domain"]}'])

def init_private_addrs():
    TESTBED['remote_private_addr'] = None
    TESTBED['reverse_remote_private_addr'] = None

    b = 255
    while b > 0:
        b -= 1
        addr = f'10.{b}.0.1'
        if 'via' in subprocess.run(['ip', 'r', 'get', addr], encoding='utf-8',
                                   stdout=subprocess.PIPE, stderr=subprocess.PIPE).stdout:
            if not TESTBED['remote_private_addr']:
                TESTBED['remote_private_addr'] = addr
                continue
            if not TESTBED['reverse_remote_private_addr']:
                TESTBED['reverse_remote_private_addr'] = addr
                break
    else:
        raise Exception('Could not find any 10.* subnet to use for private addresses')

    TESTBED['remote_private_subnet'] = private_addr_to_subnet(TESTBED['remote_private_addr'])
    TESTBED['reverse_remote_private_subnet'] = private_addr_to_subnet(TESTBED['reverse_remote_private_addr'])

    # Force the private addrs unreachable so we don't try to reach them out our normal gateway
    subprocess.run(['ip', 'r', 'add', TESTBED['remote_private_subnet'], 'dev', 'lo'])
    subprocess.run(['ip', 'r', 'add', TESTBED['reverse_remote_private_subnet'], 'dev', 'lo'])

def remove_private_subnets():
    subprocess.run(['ip', 'r', 'del', TESTBED['remote_private_subnet'], 'dev', 'lo'])
    subprocess.run(['ip', 'r', 'del', TESTBED['reverse_remote_private_subnet'], 'dev', 'lo'])

def init_ssh_config():
    id_rsa = Path('/root/.ssh/id_rsa')
    if not id_rsa.exists():
        subprocess.run(['ssh-keygen', '-f', str(id_rsa), '-P', ''], check=True)
    TESTBED['ssh_key'] = id_rsa.with_suffix('.pub').read_text(encoding='utf-8')

    hosts = ' '.join([f'*.{TESTBED["lxdbr_domain"]}',
                      TESTBED['lxdbr_addr'],
                      TESTBED['remote_private_addr'],
                      TESTBED['reverse_remote_private_addr']])
    TESTBED['ssh_config'] = '\n'.join(['Host ' + hosts,
                                       '  StrictHostKeyChecking no',
                                       '  UserKnownHostsFile /dev/null',
                                       '  ConnectTimeout 5',
                                       ''])
    config = Path('/root/.ssh/config')
    if config.exists():
        content = config.read_text(encoding='utf-8') or ''
        if content and not content.endswith('\n'):
            content += '\n'
    else:
        content = ''
    content += TESTBED['ssh_config']
    config.write_text(content, encoding='utf-8')

def remove_ssh_config():
    config = Path('/root/.ssh/config')
    content = config.read_text(encoding='utf-8')
    config.write_text(content.replace(TESTBED['ssh_config'], ''), encoding='utf-8')

def init_base_test_class():
    cls = SshuttleTest
    cls.testbed_ssh_key = TESTBED['ssh_key']
    cls.testbed_ssh_config = TESTBED['ssh_config']

    cls.release = get_distro().codename
    reverse_remote = Remote(f'reverse-remote-{cls.release}', cls.release)
    reverse_remote.add_ssh_key(cls.testbed_ssh_key)
    reverse_remote.add_ssh_config(cls.testbed_ssh_config)
    reverse_remote.private_addr = TESTBED["reverse_remote_private_addr"]
    reverse_remote.private_subnet = TESTBED["reverse_remote_private_subnet"]
    reverse_remote.add_start_cmd(f'ip a add {reverse_remote.private_subnet} dev lo')
    reverse_remote.snapshot_create()
    cls.reverse_remote = reverse_remote

def private_addr_to_subnet(addr):
    return ipaddress.ip_network(f'{addr}/24', strict=False).with_prefixlen


class Remote(object):
    def __init__(self, name, release):
        self.name = name
        self.fqdn = f'{name}.{TESTBED["lxdbr_domain"]}'
        self.release = release
        self._start_cmds = []

        image = f'ubuntu-daily:{release}'
        subprocess.run(['lxc', 'delete', '-f', self.name], stderr=subprocess.DEVNULL)
        subprocess.run(['lxc', 'launch', image, self.name], check=True)

        self._wait_for_networking()
        self._create_ssh_key()
        self._add_local_ppas()
        self._add_proposed()
        self._apt_update_upgrade()
        self._install_sshuttle()
        self._install_python()
        self.stop()

    def log(self, msg):
        print(f'{self.name}: {msg}')

    def _wait_for_networking(self):
        self.log(f'Waiting for {self.name} to finish starting')
        for sec in range(120):
            if 'via' in self.lxc_exec('ip r show default'.split()).stdout:
                break
            time.sleep(0.5)
        else:
            raise Exception(f'Timed out waiting for remote {self.name} networking')

    def _create_ssh_key(self):
        self.log('creating ssh key')
        self.lxc_exec(['ssh-keygen', '-f', '/root/.ssh/id_rsa', '-P', ''])
        self._ssh_key = self.lxc_exec('cat /root/.ssh/id_rsa.pub'.split()).stdout

    def _add_local_ppas(self):
        paths = list(Path('/etc/apt/sources.list.d').glob('*.list'))
        paths.append(Path('/etc/apt/sources.list'))
        ppas = []
        for path in paths:
            for line in path.read_text(encoding='utf-8').splitlines():
                match = re.match(r'^deb .*ppa.launchpad.net/(?P<team>\w+)/(?P<ppa>\w+)/ubuntu', line)
                if match:
                    ppas.append(f'ppa:{match.group("team")}/{match.group("ppa")}')
        for ppa in ppas:
            self.log(f'adding PPA {ppa}')
            self.lxc_exec(['add-apt-repository', '-y', ppa])

    def _add_proposed(self):
        with tempfile.TemporaryDirectory() as d:
            f = Path(d) / 'tempfile'
            self.lxc_file_pull('/etc/apt/sources.list', str(f))
            for line in f.read_text(encoding='utf-8').splitlines():
                match = re.match(rf'^deb (?P<uri>\S+) {self.release} main.*', line)
                if match:
                    uri = match.group('uri')
                    components = 'man universe restricted multiverse'
                    proposed_line = f'deb {uri} {self.release}-proposed {components}'
                    self.log(f'adding {self.release}-proposed using {uri}')
                    self.lxc_exec(['add-apt-repository', '-y', proposed_line])
                    return

    def _apt_update_upgrade(self):
        self.log('upgrading packages')
        self.lxc_exec('apt update'.split())
        self.lxc_exec('apt upgrade -y'.split())

    def _install_sshuttle(self):
        self.log('installing sshuttle')
        self.lxc_exec(['apt', 'install', '-y', 'sshuttle'])
        if self.lxc_exec(['which', 'sshuttle']).returncode != 0:
            raise Exception('could not install sshuttle')
        self.sshuttle_version = self.lxc_exec('dpkg-query -f ${Version} -W sshuttle'.split()).stdout

    def _install_python(self):
        self.log('installing python')
        for python in ['python', 'python2', 'python3']:
            self.lxc_exec(['apt', 'install', '-y', python])
        for python in ['python2', 'python3']:
            if self.lxc_exec(['which', python]).returncode != 0:
                raise Exception(f'could not install {python}')

    def snapshot_create(self, name='default', stop=True):
        self.log(f'creating snapshot: {name}')
        if stop:
            self.stop()
        subprocess.run(['lxc', 'snapshot', self.name, name], check=True)

    def snapshot_restore(self, name='default', start=True):
        self.log(f'restoring snapshot: {name}')
        self.stop()
        subprocess.run(['lxc', 'restore', self.name, name], check=True)
        if start:
            self.start()

    def snapshot_update(self, name='default', stop=True):
        self.log(f'updating snapshot: {name}')
        subprocess.run(['lxc', 'delete', f'{self.name}/{name}'], check=True)
        self.snapshot_create(name=name, stop=stop)

    @property
    def ssh_key(self):
        return self._ssh_key

    def add_start_cmd(self, cmd):
        self.log(f'adding start cmd: {cmd}')
        self._start_cmds.append(cmd)

    def add_file_content(self, path, content):
        with tempfile.TemporaryDirectory() as d:
            localfile = Path(d) / Path(path).name
            self.lxc_file_pull(path, str(localfile))
            existing_content = localfile.read_text(encoding='utf-8') or ''
            if content not in existing_content:
                if existing_content and not existing_content.endswith('\n'):
                    existing_content += '\n'
                existing_content += content
                localfile.write_text(existing_content)
                self.lxc_file_push(str(localfile), path)

    def add_ssh_key(self, key):
        self.log(f'adding ssh key: {key.strip()}')
        self.add_file_content('/root/.ssh/authorized_keys', key)

    def add_ssh_config(self, config):
        self.log('adding ssh config')
        self.add_file_content('/root/.ssh/config', config)

    def lxc_exec(self, cmd, **kwargs):
        kwargs.setdefault('stdout', subprocess.PIPE)
        kwargs.setdefault('stderr', subprocess.PIPE)
        kwargs.setdefault('encoding', 'utf-8')
        return subprocess.run(['lxc', 'exec', self.name, '--'] + cmd, **kwargs)

    def lxc_file_pull(self, remote, local, fail_if_missing=False):
        remote = f'{self.name}{remote}'
        self.log(f'{local} <- {remote}')
        try:
            subprocess.run(['lxc', 'file', 'pull', remote, local], check=True, encoding='utf-8',
                           stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            if fail_if_missing:
                raise
            localpath = Path(local)
            if localpath.is_dir():
                localpath = localpath / Path(remote).name
            localpath.touch()
            self.log(f'remote file missing, created empty file {localpath}')

    def lxc_file_push(self, local, remote):
        remote = f'{self.name}{remote}'
        self.log(f'{local} -> {remote}')
        subprocess.run(['lxc', 'file', 'push', local, remote], check=True)

    @property
    def json(self):
        listjson = subprocess.run(['lxc', 'list', '--format', 'json'], encoding='utf-8',
                                  stdout=subprocess.PIPE).stdout
        filtered = list(filter(lambda i: i['name'] == self.name, json.loads(listjson)))
        if len(filtered) != 1:
            raise Exception(f'Expected only 1 lxc list entry for {self.name}, found {len(filtered)}:\n{listjson}')
        return filtered[0]

    @property
    def is_running(self):
        return self.json['status'] != 'Stopped'

    def start(self):
        if not self.is_running:
            subprocess.run(['lxc', 'start', self.name])
            self._wait_for_networking()
            for cmd in self._start_cmds:
                self.lxc_exec(cmd.split())

    def stop(self):
        if self.is_running:
            subprocess.run(['lxc', 'stop', self.name])

    def __del__(self):
        subprocess.run(['lxc', 'delete', '-f', self.name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


class VerboseAssertionError(AssertionError):
    __logs = []

    def __init__(self, *args):
        logs = list(args) + self.read_log()
        super(VerboseAssertionError, self).__init__('\n'.join(logs))

    @classmethod
    def add_log(cls, msg):
        cls.__logs.append(str(msg))

    @classmethod
    def read_log(cls):
        log = cls.__logs
        cls.__logs = []
        return log

    @classmethod
    def clear_log(cls):
        cls.read_log()


class SshuttleTest(unittest.TestCase):
    RELEASE = None
    failureException = VerboseAssertionError

    @classmethod
    def setUpClass(cls):
        # note that some of the cls attrs used here are set by setUpModule()

        # this is set by the subclass, and required
        assert(cls.RELEASE)

        remote = Remote(f'remote-{cls.RELEASE}', cls.RELEASE)
        remote.add_ssh_key(cls.testbed_ssh_key)
        remote.add_ssh_config(cls.testbed_ssh_config)
        remote.private_addr = TESTBED["remote_private_addr"]
        remote.private_subnet = TESTBED["remote_private_subnet"]
        remote.add_start_cmd(f'ip a add {remote.private_subnet} dev lo')
        remote.snapshot_create()
        cls.remote = remote

        cls.reverse_remote.snapshot_restore()
        cls.reverse_remote.add_ssh_key(cls.remote.ssh_key)
        cls.reverse_remote.snapshot_update()

    @classmethod
    def tearDownClass(cls):
        del cls.remote

    def setUp(self):
        self.reverse_remote.snapshot_restore()
        self.remote.snapshot_restore()
        self.sshuttle_process = None
        self.sshuttle_log = tempfile.NamedTemporaryFile()
        self.failureException.clear_log()

    def tearDown(self):
        self.sshuttle_stop()
        self.sshuttle_log.close()
        self.remote.stop()
        self.reverse_remote.stop()

    def add_failure_log_test_detail(self, src, dst, python, sshuttle_version):
        self.failureException.add_log(f'Test detail: {src} sshuttle {sshuttle_version} to {dst} {python}')

    @property
    def sshuttle_version(self):
        cmd = 'dpkg-query -f ${Version} -W sshuttle'
        return subprocess.run(cmd.split(), encoding='utf-8', stdout=subprocess.PIPE).stdout

    def sshuttle_start(self, src, dst, python, remote):
        sshuttle_cmd = f'sshuttle --python {python} -r {dst} {remote.private_subnet}'
        if remote is self.reverse_remote:
            sshuttle_cmd = f'lxc exec {self.remote.name} -- {sshuttle_cmd}'
        print(f'running: {sshuttle_cmd}')
        self.sshuttle_process = subprocess.Popen(sshuttle_cmd.split(), encoding='utf-8',
                                                 stdout=self.sshuttle_log, stderr=self.sshuttle_log)
        print('waiting for sshuttle to start...', end='', flush=True)
        for sec in range(60):
            if self.sshuttle_process.poll() is not None:
                print('sshuttle failed :-(')
                break
            if 'client: Connected.' in Path(self.sshuttle_log.name).read_text(encoding='utf-8'):
                print('started')
                break
            time.sleep(1)
            print('.', end='', flush=True)
        else:
            print("WARNING: timed out waiting for sshuttle to start, the test may fail")
        self.failureException.add_log(Path(self.sshuttle_log.name).read_text(encoding='utf-8'))
        if self.sshuttle_process.poll() is not None:
            self.fail('sshuttle process failed to start')

    def sshuttle_stop(self):
        if self.sshuttle_process and self.sshuttle_process.poll() is None:
            print('stopping local sshuttle')
            self.sshuttle_process.terminate()
            try:
                self.sshuttle_process.communicate(timeout=30)
            except subprocess.TimeoutExpired:
                self.sshuttle_process.kill()
                self.sshuttle_process.communicate(timeout=30)

    def ssh_local_to_remote(self):
        print(f'local {self.release} ssh to remote {self.remote.release} {self.remote.private_addr}')
        try:
            subprocess.run(['ssh', self.remote.private_addr, 'true'], check=True)
            return True
        except subprocess.CalledProcessError:
            return False

    def test_local_to_remote_py2(self):
        self._test_local_to_remote('python2')

    def test_local_to_remote_py3(self):
        self._test_local_to_remote('python3')

    def _test_local_to_remote(self, python):
        self.add_failure_log_test_detail(self.release, self.remote.release, python, self.sshuttle_version)
        try:
            print('this ssh connection should timeout:')
            self.assertFalse(self.ssh_local_to_remote())
            self.sshuttle_start('local', self.remote.fqdn, python, self.remote)
            print('this ssh connection should not timeout:')
            self.assertTrue(self.ssh_local_to_remote())
        except AssertionError:
            if is_expected_failure(self.release, self.remote.release, python, self.sshuttle_version):
                self.skipTest('This is an expected failure, ignoring test failure')
            else:
                raise

    def ssh_remote_to_reverse_remote(self):
        print(f'remote {self.remote.release} ssh to remote {self.reverse_remote.release} {self.reverse_remote.private_addr}')
        try:
            self.remote.lxc_exec(['ssh', self.reverse_remote.private_addr, 'true'],
                                 stdout=sys.stdout, stderr=sys.stderr, check=True)
            return True
        except subprocess.CalledProcessError:
            return False

    def test_remote_to_reverse_remote_py2(self):
        self._test_remote_to_reverse_remote('python2')

    def test_remote_to_reverse_remote_py3(self):
        self._test_remote_to_reverse_remote('python3')

    def _test_remote_to_reverse_remote(self, python):
        self.add_failure_log_test_detail(self.remote.release, self.reverse_remote.release, python, self.remote.sshuttle_version)
        try:
            print('this ssh connection should timeout:')
            self.assertFalse(self.ssh_remote_to_reverse_remote())
            self.sshuttle_start(self.remote.name, self.reverse_remote.fqdn, python, self.reverse_remote)
            print('this ssh connection should not timeout:')
            self.assertTrue(self.ssh_remote_to_reverse_remote())
        except AssertionError:
            if is_expected_failure(self.remote.release, self.reverse_remote.release, python, self.remote.sshuttle_version):
                self.skipTest('This is an expected failure, ignoring test failure')
            else:
                raise


if __name__ == '__main__':
    if len(sys.argv) > 1:
        set_releases(sys.argv[1:])
        del sys.argv[1:]
    unittest.main(verbosity=2)
