#!/usr/bin/python3
# Let's stick to core python3 modules
import argparse
import gettext
import hashlib
import http.client
import io
import json
import os
import shutil
import socket
import subprocess
import sys
import tarfile
import tempfile
import urllib.request

# External dependencies:
# - gnupg
# - xz (or pxz)

_ = gettext.gettext
gettext.textdomain("lxd")


class FriendlyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


def find_on_path(command):
    """Is command on the executable search path?"""

    if 'PATH' not in os.environ:
        return False
    path = os.environ['PATH']
    for element in path.split(os.pathsep):
        if not element:
            continue
        filename = os.path.join(element, command)
        if os.path.isfile(filename) and os.access(filename, os.X_OK):
            return True
    return False


class UnixHTTPConnection(http.client.HTTPConnection):
    def __init__(self, path):
        http.client.HTTPConnection.__init__(self, 'localhost')
        self.path = path

    def connect(self):
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(self.path)
        self.sock = sock


class LXD(object):
    def __init__(self, path):
        self.lxd = UnixHTTPConnection(path)

    def rest_call(self, path, data=None, method="GET"):
        if method == "GET" and data:
            self.lxd.request(
                method,
                "%s?%s" % "&".join(["%s=%s" % (key, value)
                                    for key, value in data.items()]))
        else:
            self.lxd.request(method, path, data)

        r = self.lxd.getresponse()
        d = json.loads(r.read().decode("utf-8"))
        return r.status, d

    def aliases_create(self, name, target):
        data = json.dumps({"target": target,
                           "name": name})

        status, data = self.rest_call("/1.0/images/aliases", data, "POST")

        if status != 200:
            raise Exception("Failed to create alias: %s" % name)

    def aliases_remove(self, name):
        status, data = self.rest_call("/1.0/images/aliases/%s" % name,
                                      method="DELETE")

        if status != 200:
            raise Exception("Failed to remove alias: %s" % name)

    def aliases_list(self):
        status, data = self.rest_call("/1.0/images/aliases")

        return [alias.split("/1.0/images/aliases/")[-1]
                for alias in data['metadata']]

    def images_list(self):
        status, data = self.rest_call("/1.0/images")

        return [image.split("/1.0/images/")[-1] for image in data['metadata']]

    def images_upload(self, path, filename):
        status, data = self.rest_call("/1.0/images", open(path, "rb"), "POST")

        if status != 200:
            raise Exception("Failed to upload the image: %s" % status)

        return data['metadata']


class Busybox(object):
    workdir = None

    def __init__(self):
        # Create our workdir
        self.workdir = tempfile.mkdtemp()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def create_tarball(self):
        xz = "pxz" if find_on_path("pxz") else "xz"

        destination_tar = os.path.join(self.workdir, "busybox.tar")
        target_tarball = tarfile.open(destination_tar, "w:")

        metadata = {'architecture': os.uname()[4],
                    'creation_date': int(os.stat("/bin/busybox").st_ctime),
                    'properties': {
                        'os': "Busybox",
                        'architecture': os.uname()[4],
                        'description': "Busybox %s" % os.uname()[4],
                        'name': "busybox-%s" % os.uname()[4]
                        },
                    }

        # Add busybox
        with open("/bin/busybox", "rb") as fd:
            busybox_file = tarfile.TarInfo()
            busybox_file.size = os.stat("/bin/busybox").st_size
            busybox_file.name = "rootfs/bin/busybox"
            busybox_file.mode = 0o755
            target_tarball.addfile(busybox_file, fd)

        # Add symlinks
        busybox = subprocess.Popen(["/bin/busybox", "--list-full"],
                                   stdout=subprocess.PIPE,
                                   universal_newlines=True)
        busybox.wait()

        for path in busybox.stdout.read().split("\n"):
            if not path.strip():
                continue

            symlink_file = tarfile.TarInfo()
            symlink_file.name = "rootfs/%s" % path.strip()
            symlink_file.type = tarfile.SYMTYPE
            symlink_file.linkname = "/bin/busybox"
            target_tarball.addfile(symlink_file)

        # Add directories
        for path in ("dev", "mnt", "proc", "root", "sys", "tmp"):
            directory_file = tarfile.TarInfo()
            directory_file.name = "rootfs/%s" % path
            directory_file.type = tarfile.DIRTYPE
            target_tarball.addfile(directory_file)

        # Add the metadata file
        metadata_yaml = json.dumps(metadata, sort_keys=True,
                                   indent=4, separators=(',', ': '),
                                   ensure_ascii=False).encode('utf-8') + b"\n"

        metadata_file = tarfile.TarInfo()
        metadata_file.size = len(metadata_yaml)
        metadata_file.name = "metadata.yaml"
        target_tarball.addfile(metadata_file,
                               io.BytesIO(metadata_yaml))

        target_tarball.close()

        # Compress the tarball
        r = subprocess.call([xz, "-9", destination_tar])
        if r:
            raise Exception("Failed to compress: %s" % path)

        return destination_tar + ".xz"


class LXCImages(object):
    workdir = None
    images = []

    def __init__(self,
                 server="https://images.linuxcontainers.org",
                 gpgkey="0xBAEFF88C22F6E216"):

        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        self.gpgdir = "%s/gpg" % self.workdir
        os.mkdir(self.gpgdir, 0o700)

        # Set variables
        self.server = server
        self.gpgkey = gpgkey

        # Get ready to work with this server
        self.gpg_update()
        self.index_update()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def __grab_and_validate(self, url, dest):
        # Main file
        urllib.request.urlretrieve(url, dest)

        # Signature
        urllib.request.urlretrieve(url + ".asc",
                                   dest + ".asc")

        # Verify the signature
        self.gpg_verify(dest + ".asc")

    def gpg_update(self):
        print(_("Downloading the GPG key for %s" % self.server))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--keyserver", "hkp://pool.sks-keyservers.net",
                 "--recv-keys", self.gpgkey],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("Failed to retrieve the GPG key")

    def gpg_verify(self, path):
        print(_("Validating the GPG signature of %s" % path))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--verify", path],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("GPG signature verification failed for: %s" % path)

    def index_update(self):
        print(_("Downloading the image list for %s" % self.server))
        index = "%s/meta/1.0/index-user" % self.server
        index_path = os.path.join(self.workdir, "index.json")

        self.__grab_and_validate(index, index_path)

        images = []
        with open(index_path, "r") as fd:
            for line in fd:
                fields = line.strip().split(";")

                images.append({"distribution": fields[0],
                               "release": fields[1],
                               "architecture": fields[2],
                               "variant": fields[3],
                               "version": fields[4],
                               "path": fields[5]})

        self.images = images

    def image_exists(self, distribution, release,
                     architecture, variant="default"):
        for image in self.images:
            if image['distribution'] == distribution and \
                    image['release'] == release and \
                    image['architecture'] == architecture and \
                    image['variant'] == variant:
                return image['version']

        return False

    def image_download(self, distribution, release,
                       architecture, variant="default"):
        for image in self.images:
            if image['distribution'] == distribution and \
                    image['release'] == release and \
                    image['architecture'] == architecture and \
                    image['variant'] == variant:
                break
        else:
            raise Exception("Requested image doesn't exist")

        path = os.path.join(self.workdir,
                            "%s-%s-%s-%s-%s.tar.xz" %
                            (image['distribution'], image['release'],
                             image['architecture'], image['variant'],
                             image['version']))
        url = "%s%slxd.tar.xz" % (self.server, image['path'])

        print(_("Downloading the image: %s" % url))
        self.__grab_and_validate(url, path)
        return path


if __name__ == "__main__":
    if "LXD_DIR" in os.environ:
        lxd_socket = os.path.join(os.environ['LXD_DIR'], "unix.socket")
    else:
        lxd_socket = "/var/lib/lxd/unix.socket"

    if not os.path.exists(lxd_socket):
        print(_("LXD isn't running."))
        sys.exit(1)
    lxd = LXD(lxd_socket)

    def setup_alias(parser, args, fingerprint):
        existing = lxd.aliases_list()

        for alias in args.alias:
            if alias in existing:
                lxd.aliases_remove(alias)
            lxd.aliases_create(alias, fingerprint)
            print(_("Setup alias: %s" % alias))

    def import_busybox(parser, args):
        busybox = Busybox()
        path = busybox.create_tarball()

        with open(path, "rb") as fd:
            fingerprint = hashlib.sha256(fd.read()).hexdigest()

        if fingerprint in lxd.images_list():
            parser.exit(1, _("This image is already in the store.\n"))

        r = lxd.images_upload(path, path.split("/")[-1])
        print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    def import_lxc(parser, args):
        lxc = LXCImages()

        if not lxc.image_exists(args.distribution, args.release,
                                args.architecture, args.variant):
            parser.exit(1, _("Requested image doesn't exist.\n"))

        path = lxc.image_download(args.distribution, args.release,
                                  args.architecture, args.variant)

        with open(path, "rb") as fd:
            fingerprint = hashlib.sha256(fd.read()).hexdigest()

        if fingerprint in lxd.images_list():
            parser.exit(1, _("This image is already in the store.\n"))

        r = lxd.images_upload(path, path.split("/")[-1])
        print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    parser = FriendlyParser(
        description=_("LXD: image store helper"),
        formatter_class=argparse.RawTextHelpFormatter,
        epilog=_("""Examples:
 To import the latest Ubuntu 14.04 LTS 64bit image with some aliases:
    %s import lxc ubuntu trusty amd64 --alias ubuntu --alias ubuntu/trusty

 To import a basic busybox image:
    %s import busybox --alias busybox
""" % (sys.argv[0], sys.argv[0])))

    parser_subparsers = parser.add_subparsers(dest="action")
    parser_subparsers.required = True

    # Image import
    parser_import = parser_subparsers.add_parser(
        "import", help=_("Import images"))
    parser_import_subparsers = parser_import.add_subparsers(dest="source")
    parser_import_subparsers.required = True

    ## Busybox
    parser_import_busybox = parser_import_subparsers.add_parser(
        "busybox", help=_("Busybox image"))
    parser_import_busybox.add_argument("--alias", action="append", default=[],
                                       help=_("Aliases for the image"))
    parser_import_busybox.set_defaults(func=import_busybox)

    ## LXC
    parser_import_lxc = parser_import_subparsers.add_parser(
        "lxc", help=_("LXC images"))
    parser_import_lxc.add_argument("distribution", help=_("Distribution"))
    parser_import_lxc.add_argument("release", help=_("Release"))
    parser_import_lxc.add_argument("architecture", help=_("Architecture"))
    parser_import_lxc.add_argument("variant", help=_("Variant"),
                                   default="default", nargs="?")
    parser_import_lxc.add_argument("--alias", action="append", default=[],
                                   help=_("Aliases for the image"))
    parser_import_lxc.set_defaults(func=import_lxc)

    # Call the function
    args = parser.parse_args()
    args.func(parser, args)
