#!/usr/bin/env python3
# Let's stick to core python3 modules
import argparse
import gettext
import hashlib
import http.client
import io
import json
import os
import shutil
import socket
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
import uuid

# External dependencies:
# - gnupg
# - xz (or pxz)

_ = gettext.gettext
gettext.textdomain("lxd")


class FriendlyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


def find_on_path(command):
    """Is command on the executable search path?"""

    if 'PATH' not in os.environ:
        return False
    path = os.environ['PATH']
    for element in path.split(os.pathsep):
        if not element:
            continue
        filename = os.path.join(element, command)
        if os.path.isfile(filename) and os.access(filename, os.X_OK):
            return True
    return False


def report_download(blocks_read, block_size, total_size):
    size_read = blocks_read * block_size
    percent = size_read/total_size*100
    if percent > 100:
        return

    print("Progress: %.0f %%" % percent, end='\r')


def local_architecture():
    try:
        import apt_pkg
        apt_pkg.init()
        return apt_pkg.config.find("APT::Architecture").lower()
    except:
        arch_tables = {'x86_64': "amd64",
                       'i686': "i386",
                       'armv7l': "armhf",
                       'aarch64': "arm64",
                       'ppc': "powerpc",
                       'ppc64le': "ppc64el"}

        kernel_arch = os.uname().machine

        return arch_tables[kernel_arch]


class UnixHTTPConnection(http.client.HTTPConnection):
    def __init__(self, path):
        http.client.HTTPConnection.__init__(self, 'localhost')
        self.path = path

    def connect(self):
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(self.path)
        self.sock = sock


class LXD(object):
    def __init__(self, path):
        self.lxd = UnixHTTPConnection(path)

    def rest_call(self, path, data=None, method="GET", headers={}):
        if method == "GET" and data:
            self.lxd.request(
                method,
                "%s?%s" % "&".join(["%s=%s" % (key, value)
                                    for key, value in data.items()]), headers)
        else:
            self.lxd.request(method, path, data, headers)

        r = self.lxd.getresponse()
        d = json.loads(r.read().decode("utf-8"))
        return r.status, d

    def aliases_create(self, name, target):
        data = json.dumps({"target": target,
                           "name": name})

        status, data = self.rest_call("/1.0/images/aliases", data, "POST")

        if status != 200:
            raise Exception("Failed to create alias: %s" % name)

    def aliases_remove(self, name):
        status, data = self.rest_call("/1.0/images/aliases/%s" % name,
                                      method="DELETE")

        if status != 200:
            raise Exception("Failed to remove alias: %s" % name)

    def aliases_list(self):
        status, data = self.rest_call("/1.0/images/aliases")

        return [alias.split("/1.0/images/aliases/")[-1]
                for alias in data['metadata']]

    def images_list(self):
        status, data = self.rest_call("/1.0/images")

        return [image.split("/1.0/images/")[-1] for image in data['metadata']]

    def images_upload(self, path, filename, public):
        headers = {}
        if public:
            headers['X-LXD-public'] = "1"

        if isinstance(path, str):
            headers['Content-Type'] = "application/octet-stream"

            status, data = self.rest_call("/1.0/images", open(path, "rb"),
                                          "POST", headers)
        else:
            meta_path, rootfs_path = path
            boundary = str(uuid.uuid1())

            form = []
            for name, path in [("metadata", meta_path),
                               ("rootfs", rootfs_path)]:
                filename = os.path.basename(path)
                form.append("--%s" % boundary)
                form.append("Content-Disposition: form-data; "
                            "name=%s; filename=%s" % (name, filename))
                form.append("Content-Type: application/octet-stream")
                form.append("")
                with open(path, "rb") as fd:
                    form.append(fd.read())

            form.append("--%s--" % boundary)
            form.append("")

            body = b""
            for entry in form:
                if isinstance(entry, bytes):
                    body += entry + b"\r\n"
                else:
                    body += entry.encode() + b"\r\n"

            headers['Content-Type'] = "multipart/form-data; boundary=%s" \
                % boundary

            status, data = self.rest_call("/1.0/images", body, "POST", headers)

        if status != 200:
            raise Exception("Failed to upload the image: %s" % status)

        return data['metadata']


class Image(object):
    """ Class for reuse of various functionality """

    def gpg_update(self):
        print(_("Downloading the GPG key for %s" % self.server))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--keyserver", "hkp://p80.pool.sks-keyservers.net:80",
                 "--recv-keys", self.gpgkey],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("Failed to retrieve the GPG key")

    def gpg_verify(self, path):
        print(_("Validating the GPG signature of %s" % path))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--verify", path],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("GPG signature verification failed for: %s" % path)

    def grab_and_validate(self, url, dest, gpg_suffix=".asc"):
        try:
            # Main file
            urllib.request.urlretrieve(url, dest, report_download)

            # Signature
            urllib.request.urlretrieve(url + gpg_suffix,
                                       dest + ".asc", report_download)
        except socket.timeout or IOError as e:
            raise Exception("Failed to download \"%s\": %s" % (url, e))

        # Verify the signature
        self.gpg_verify(dest + ".asc")


class Busybox(object):
    workdir = None

    def __init__(self):
        # Create our workdir
        self.workdir = tempfile.mkdtemp()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def create_tarball(self, split=False):
        xz = "pxz" if find_on_path("pxz") else "xz"

        destination_tar = os.path.join(self.workdir, "busybox.tar")
        target_tarball = tarfile.open(destination_tar, "w:")

        if split:
            destination_tar_rootfs = os.path.join(self.workdir,
                                                  "busybox.rootfs.tar")
            target_tarball_rootfs = tarfile.open(destination_tar_rootfs, "w:")

        metadata = {'architecture': os.uname()[4],
                    'creation_date': int(os.stat("/bin/busybox").st_ctime),
                    'properties': {
                        'os': "Busybox",
                        'architecture': os.uname()[4],
                        'description': "Busybox %s" % os.uname()[4],
                        'name': "busybox-%s" % os.uname()[4]
                        },
                    }

        # Add busybox
        with open("/bin/busybox", "rb") as fd:
            busybox_file = tarfile.TarInfo()
            busybox_file.size = os.stat("/bin/busybox").st_size
            busybox_file.mode = 0o755
            if split:
                busybox_file.name = "bin/busybox"
                target_tarball_rootfs.addfile(busybox_file, fd)
            else:
                busybox_file.name = "rootfs/bin/busybox"
                target_tarball.addfile(busybox_file, fd)

        # Add symlinks
        busybox = subprocess.Popen(["/bin/busybox", "--list-full"],
                                   stdout=subprocess.PIPE,
                                   universal_newlines=True)
        busybox.wait()

        for path in busybox.stdout.read().split("\n"):
            if not path.strip():
                continue

            symlink_file = tarfile.TarInfo()
            symlink_file.type = tarfile.SYMTYPE
            symlink_file.linkname = "/bin/busybox"
            if split:
                symlink_file.name = "%s" % path.strip()
                target_tarball_rootfs.addfile(symlink_file)
            else:
                symlink_file.name = "rootfs/%s" % path.strip()
                target_tarball.addfile(symlink_file)

        # Add directories
        for path in ("dev", "mnt", "proc", "root", "sys", "tmp"):
            directory_file = tarfile.TarInfo()
            directory_file.type = tarfile.DIRTYPE
            if split:
                directory_file.name = "%s" % path
                target_tarball_rootfs.addfile(directory_file)
            else:
                directory_file.name = "rootfs/%s" % path
                target_tarball.addfile(directory_file)

        # Add the metadata file
        metadata_yaml = json.dumps(metadata, sort_keys=True,
                                   indent=4, separators=(',', ': '),
                                   ensure_ascii=False).encode('utf-8') + b"\n"

        metadata_file = tarfile.TarInfo()
        metadata_file.size = len(metadata_yaml)
        metadata_file.name = "metadata.yaml"
        target_tarball.addfile(metadata_file,
                               io.BytesIO(metadata_yaml))

        target_tarball.close()
        if split:
            target_tarball_rootfs.close()

        # Compress the tarball
        r = subprocess.call([xz, "-9", destination_tar])
        if r:
            raise Exception("Failed to compress: %s" % destination_tar)

        if split:
            r = subprocess.call([xz, "-9", destination_tar_rootfs])
            if r:
                raise Exception("Failed to compress: %s" %
                                destination_tar_rootfs)
            return destination_tar + ".xz", destination_tar_rootfs + ".xz"
        else:
            return destination_tar + ".xz"


class LXCImages(Image):
    workdir = None
    images = []

    def __init__(self,
                 server="https://images.linuxcontainers.org",
                 gpgkey="0xBAEFF88C22F6E216"):

        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        self.gpgdir = "%s/gpg" % self.workdir
        os.mkdir(self.gpgdir, 0o700)

        # Set variables
        self.server = server
        self.gpgkey = gpgkey

        # Get ready to work with this server
        self.gpg_update()
        self.index_update()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def index_update(self):
        print(_("Downloading the image list for %s" % self.server))
        index = "%s/meta/1.0/index-user" % self.server
        index_path = os.path.join(self.workdir, "index.json")

        self.grab_and_validate(index, index_path)

        images = []
        with open(index_path, "r") as fd:
            for line in fd:
                fields = line.strip().split(";")

                images.append({"distribution": fields[0],
                               "release": fields[1],
                               "architecture": fields[2],
                               "variant": fields[3],
                               "version": fields[4],
                               "path": fields[5]})

        self.images = images

    def image_exists(self, distribution, release,
                     architecture, variant="default"):
        for image in self.images:
            if image['distribution'] == distribution and \
                    image['release'] == release and \
                    image['architecture'] == architecture and \
                    image['variant'] == variant:
                return image['version']

        return False

    def image_download(self, distribution, release,
                       architecture, variant="default"):
        for image in self.images:
            if image['distribution'] == distribution and \
                    image['release'] == release and \
                    image['architecture'] == architecture and \
                    image['variant'] == variant:
                break
        else:
            raise Exception("Requested image doesn't exist")

        path = os.path.join(self.workdir,
                            "%s-%s-%s-%s-%s.tar.xz" %
                            (image['distribution'], image['release'],
                             image['architecture'], image['variant'],
                             image['version']))
        url = "%s%slxd.tar.xz" % (self.server, image['path'])

        print(_("Downloading the image: %s" % url))
        self.grab_and_validate(url, path)
        return path


class Ubuntu(Image):
    workdir = None
    server = None
    stream = None

    # FIXME: We should default to the released stream once available
    def __init__(self, server="https://cloud-images.ubuntu.com",
                 stream="daily",
                 gpgkey="7FF3F408476CF100"):

        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        self.gpgdir = "%s/gpg" % self.workdir
        os.mkdir(self.gpgdir, 0o700)

        # Set variables
        self.server = server
        self.stream = stream
        self.gpgkey = gpgkey

        # Get ready to work with this server
        self.gpg_update()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def image_download(self, release=None, architecture=None, version=None):
        if not release:
            release = "trusty"

        if not architecture:
            architecture = local_architecture()

        # Download and verify GPG signature of index file.
        download_path = os.path.join(self.workdir, "download.json")
        url = "%s/%s/streams/v1/com.ubuntu.cloud:%s:download.json" % \
            (self.server, self.stream, self.stream)
        index = self.grab_and_validate(url, download_path, gpg_suffix=".gpg")

        try:
            # Parse JSON data
            with open(download_path, "rb") as download_file:
                index = json.loads(download_file.read().decode("utf-8"))
        except:
            raise Exception("Unable to parse the image index.")

        image = None
        for product_name, product in index['products'].items():
            if product['release'] != release and \
                    product['version'] != release:
                continue

            if product['arch'] != architecture:
                continue

            candidates = {}
            for version_number, version_entry in product['versions'].items():
                if "lxd.tar.xz" not in version_entry['items']:
                    continue

                if "root.tar.xz" not in version_entry['items']:
                    continue

                candidates[version_number] = version_entry

            if not candidates:
                raise Exception("The requested image doesn't exist.")

            if version:
                if version not in candidates:
                    raise Exception("The requested image doesn't exist.")

                image = candidates[version]
                break
            else:
                image = candidates[sorted(candidates.keys())[-1]]

        if not image:
            raise Exception("The requested image doesn't exist.")

        print(_("Downloading the image."))

        def download_and_verify(file_to_download, prefix, pubname):
            """
                This function downloads and verifies files in the image.
            """
            path = os.path.join(self.workdir,
                                "%s%s.tar.xz" % (prefix, pubname))
            # Download file
            urllib.request.urlretrieve(
                "%s/%s" % (self.server, file_to_download['path']),
                path,
                report_download)

            # Verify SHA256 checksum of the downloaded file
            with open(path, 'rb') as fd:
                checksum = hashlib.sha256(fd.read()).hexdigest()
                if checksum != file_to_download['sha256']:
                    raise Exception("Checksum of file does not validate")

            return path

        # Download and verify the lxd.tar.xz file
        meta_path = download_and_verify(
            image['items']['lxd.tar.xz'],
            "meta-",
            image['pubname'])

        # Download and verify the root.tar.xz file
        rootfs_path = download_and_verify(
            image['items']['root.tar.xz'],
            "",
            image['pubname'])

        return meta_path, rootfs_path


if __name__ == "__main__":
    if "LXD_DIR" in os.environ:
        lxd_socket = os.path.join(os.environ['LXD_DIR'], "unix.socket")
    else:
        lxd_socket = "/var/lib/lxd/unix.socket"

    if not os.path.exists(lxd_socket):
        print(_("LXD isn't running."))
        sys.exit(1)
    lxd = LXD(lxd_socket)

    def setup_alias(parser, args, fingerprint):
        existing = lxd.aliases_list()

        for alias in args.alias:
            if alias in existing:
                lxd.aliases_remove(alias)
            lxd.aliases_create(alias, fingerprint)
            print(_("Setup alias: %s" % alias))

    def import_busybox(parser, args):
        busybox = Busybox()

        if args.split:
            meta_path, rootfs_path = busybox.create_tarball(split=True)

            with open(meta_path, "rb") as meta_fd:
                with open(rootfs_path, "rb") as rootfs_fd:
                    fingerprint = hashlib.sha256(meta_fd.read() +
                                                 rootfs_fd.read()).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, _("This image is already in the store.\n"))

            r = lxd.images_upload((meta_path, rootfs_path),
                                  meta_path.split("/")[-1], args.public)
            print(_("Image imported as: %s" % r['fingerprint']))
        else:
            path = busybox.create_tarball()

            with open(path, "rb") as fd:
                fingerprint = hashlib.sha256(fd.read()).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, _("This image is already in the store.\n"))

            r = lxd.images_upload(path, path.split("/")[-1], args.public)
            print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    def import_lxc(parser, args):
        lxc = LXCImages()

        if not lxc.image_exists(args.distribution, args.release,
                                args.architecture, args.variant):
            parser.exit(1, _("Requested image doesn't exist.\n"))

        path = lxc.image_download(args.distribution, args.release,
                                  args.architecture, args.variant)

        with open(path, "rb") as fd:
            fingerprint = hashlib.sha256(fd.read()).hexdigest()

        if fingerprint in lxd.images_list():
            parser.exit(1, _("This image is already in the store.\n"))

        r = lxd.images_upload(path, path.split("/")[-1], args.public)
        print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    def import_ubuntu(parser, args):
        ubuntu = Ubuntu()

        meta_path, rootfs_path = ubuntu.image_download(args.release,
                                                       args.architecture,
                                                       args.version)

        with open(meta_path, "rb") as meta_fd:
            with open(rootfs_path, "rb") as rootfs_fd:
                fingerprint = hashlib.sha256(meta_fd.read() +
                                             rootfs_fd.read()).hexdigest()

        if fingerprint in lxd.images_list():
            parser.exit(1, _("This image is already in the store.\n"))

        r = lxd.images_upload((meta_path, rootfs_path),
                              meta_path.split("/")[-1], args.public)
        print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    parser = FriendlyParser(
        description=_("LXD: image store helper"),
        formatter_class=argparse.RawTextHelpFormatter,
        epilog=_("""Examples:
 To import the latest Ubuntu Cloud image with an alias:
    %s import ubuntu --alias ubuntu

 To import the latest Ubuntu 14.04 LTS 64bit image with some aliases:
    %s import lxc ubuntu trusty amd64 --alias ubuntu --alias ubuntu/trusty

 To import a basic busybox image:
    %s import busybox --alias busybox
""" % (sys.argv[0], sys.argv[0], sys.argv[0])))

    parser_subparsers = parser.add_subparsers(dest="action")
    parser_subparsers.required = True

    # Image import
    parser_import = parser_subparsers.add_parser(
        "import", help=_("Import images"))
    parser_import_subparsers = parser_import.add_subparsers(dest="source")
    parser_import_subparsers.required = True

    # # Busybox
    parser_import_busybox = parser_import_subparsers.add_parser(
        "busybox", help=_("Busybox image"))
    parser_import_busybox.add_argument("--alias", action="append", default=[],
                                       help=_("Aliases for the image"))
    parser_import_busybox.add_argument("--public", action="store_true",
                                       default=False,
                                       help=_("Make the image public"))
    parser_import_busybox.add_argument(
        "--split", action="store_true", default=False,
        help=_("Whether to create a split image"))
    parser_import_busybox.set_defaults(func=import_busybox)

    # # LXC
    parser_import_lxc = parser_import_subparsers.add_parser(
        "lxc", help=_("LXC images"))
    parser_import_lxc.add_argument("distribution", help=_("Distribution"))
    parser_import_lxc.add_argument("release", help=_("Release"))
    parser_import_lxc.add_argument("architecture", help=_("Architecture"))
    parser_import_lxc.add_argument("variant", help=_("Variant"),
                                   default="default", nargs="?")
    parser_import_lxc.add_argument("--alias", action="append", default=[],
                                   help=_("Aliases for the image"))
    parser_import_lxc.add_argument("--public", action="store_true",
                                   default=False,
                                   help=_("Make the image public"))
    parser_import_lxc.set_defaults(func=import_lxc)

    # # Ubuntu
    parser_import_ubuntu = parser_import_subparsers.add_parser(
        "ubuntu", help=_("Ubuntu images"))
    parser_import_ubuntu.add_argument("release", help=_("Release"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("architecture", help=_("Architecture"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("version", help=_("Version"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("--alias", action="append", default=[],
                                      help=_("Aliases for the image"))
    parser_import_ubuntu.add_argument("--public", action="store_true",
                                      default=False,
                                      help=_("Make the image public"))
    parser_import_ubuntu.set_defaults(func=import_ubuntu)

    # Call the function
    args = parser.parse_args()
    try:
        args.func(parser, args)
    except Exception as e:
        parser.error(e)
