#!/usr/bin/env python3
# Let's stick to core python3 modules
import argparse
import gettext
import hashlib
import http.client
import io
import json
import os
import shutil
import socket
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
import uuid

# External dependencies:
# - gnupg
# - xz (or pxz)

_ = gettext.gettext
gettext.textdomain("lxd")


class FriendlyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write('error: %s\n' % message)
        self.print_help()
        sys.exit(2)


def find_on_path(command):
    """Is command on the executable search path?"""

    if 'PATH' not in os.environ:
        return False
    path = os.environ['PATH']
    for element in path.split(os.pathsep):
        if not element:
            continue
        filename = os.path.join(element, command)
        if os.path.isfile(filename) and os.access(filename, os.X_OK):
            return True
    return False


def report_download(blocks_read, block_size, total_size):
    size_read = blocks_read * block_size
    percent = size_read/total_size*100
    if percent > 100:
        return

    print(_("Progress: %.0f %%") % percent, end='\r')


def local_architecture():
    try:
        import apt_pkg
        apt_pkg.init()
        return apt_pkg.config.find("APT::Architecture").lower()
    except:
        arch_tables = {'x86_64': "amd64",
                       'i686': "i386",
                       'armv7l': "armhf",
                       'aarch64': "arm64",
                       'ppc': "powerpc",
                       'ppc64le': "ppc64el"}

        kernel_arch = os.uname().machine

        return arch_tables[kernel_arch]


class UnixHTTPConnection(http.client.HTTPConnection):
    def __init__(self, path):
        http.client.HTTPConnection.__init__(self, 'localhost')
        self.path = path

    def connect(self):
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(self.path)
        self.sock = sock


class LXD(object):
    def __init__(self, path):
        self.lxd = UnixHTTPConnection(path)

    def rest_call(self, path, data=None, method="GET", headers={}):
        if method == "GET" and data:
            self.lxd.request(
                method,
                "%s?%s" % "&".join(["%s=%s" % (key, value)
                                    for key, value in data.items()]), headers)
        else:
            self.lxd.request(method, path, data, headers)

        r = self.lxd.getresponse()
        d = json.loads(r.read().decode("utf-8"))
        return r.status, d

    def aliases_create(self, name, target):
        data = json.dumps({"target": target,
                           "name": name})

        status, data = self.rest_call("/1.0/images/aliases", data, "POST")

        if status != 200:
            raise Exception("Failed to create alias: %s" % name)

    def aliases_remove(self, name):
        status, data = self.rest_call("/1.0/images/aliases/%s" % name,
                                      method="DELETE")

        if status != 200:
            raise Exception("Failed to remove alias: %s" % name)

    def aliases_list(self):
        status, data = self.rest_call("/1.0/images/aliases")

        return [alias.split("/1.0/images/aliases/")[-1]
                for alias in data['metadata']]

    def images_list(self):
        status, data = self.rest_call("/1.0/images")

        return [image.split("/1.0/images/")[-1] for image in data['metadata']]

    def images_upload(self, path, filename, public):
        headers = {}
        if public:
            headers['X-LXD-public'] = "1"

        if isinstance(path, str):
            headers['Content-Type'] = "application/octet-stream"

            status, data = self.rest_call("/1.0/images", open(path, "rb"),
                                          "POST", headers)
        else:
            meta_path, rootfs_path = path
            boundary = str(uuid.uuid1())

            form = []
            for name, path in [("metadata", meta_path),
                               ("rootfs", rootfs_path)]:
                filename = os.path.basename(path)
                form.append("--%s" % boundary)
                form.append("Content-Disposition: form-data; "
                            "name=%s; filename=%s" % (name, filename))
                form.append("Content-Type: application/octet-stream")
                form.append("")
                with open(path, "rb") as fd:
                    form.append(fd.read())

            form.append("--%s--" % boundary)
            form.append("")

            body = b""
            for entry in form:
                if isinstance(entry, bytes):
                    body += entry + b"\r\n"
                else:
                    body += entry.encode() + b"\r\n"

            headers['Content-Type'] = "multipart/form-data; boundary=%s" \
                % boundary

            status, data = self.rest_call("/1.0/images", body, "POST", headers)

        if status != 200:
            raise Exception("Failed to upload the image: %s" % status)

        return data['metadata']


class Image(object):
    """ Class for reuse of various functionality """

    def gpg_update(self):
        print(_("Downloading the GPG key for %s" % self.server))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--keyserver", "hkp://p80.pool.sks-keyservers.net:80",
                 "--recv-keys", self.gpgkey],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("Failed to retrieve the GPG key")

    def gpg_verify(self, path):
        print(_("Validating the GPG signature of %s" % path))
        gpg_environ = dict(os.environ)
        gpg_environ["GNUPGHOME"] = self.gpgdir

        with open(os.devnull, "w") as devnull:
            r = subprocess.call(
                ["gpg",
                 "--verify", path],
                env=gpg_environ,
                stdout=devnull, stderr=devnull)

        if r:
            raise Exception("GPG signature verification failed for: %s" % path)

    def grab_and_validate(self, url, dest, gpg_suffix=".asc"):
        try:
            # Main file
            urllib.request.urlretrieve(url, dest, report_download)

            # Signature
            urllib.request.urlretrieve(url + gpg_suffix,
                                       dest + ".asc", report_download)
        except socket.timeout or IOError as e:
            raise Exception("Failed to download \"%s\": %s" % (url, e))

        # Verify the signature
        self.gpg_verify(dest + ".asc")


class Busybox(object):
    workdir = None

    def __init__(self):
        # Create our workdir
        self.workdir = tempfile.mkdtemp()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def create_tarball(self, split=False):
        xz = "pxz" if find_on_path("pxz") else "xz"

        destination_tar = os.path.join(self.workdir, "busybox.tar")
        target_tarball = tarfile.open(destination_tar, "w:")

        if split:
            destination_tar_rootfs = os.path.join(self.workdir,
                                                  "busybox.rootfs.tar")
            target_tarball_rootfs = tarfile.open(destination_tar_rootfs, "w:")

        metadata = {'architecture': os.uname()[4],
                    'creation_date': int(os.stat("/bin/busybox").st_ctime),
                    'properties': {
                        'os': "Busybox",
                        'architecture': os.uname()[4],
                        'description': "Busybox %s" % os.uname()[4],
                        'name': "busybox-%s" % os.uname()[4]
                        },
                    }

        # Add busybox
        with open("/bin/busybox", "rb") as fd:
            busybox_file = tarfile.TarInfo()
            busybox_file.size = os.stat("/bin/busybox").st_size
            busybox_file.mode = 0o755
            if split:
                busybox_file.name = "bin/busybox"
                target_tarball_rootfs.addfile(busybox_file, fd)
            else:
                busybox_file.name = "rootfs/bin/busybox"
                target_tarball.addfile(busybox_file, fd)

        # Add symlinks
        busybox = subprocess.Popen(["/bin/busybox", "--list-full"],
                                   stdout=subprocess.PIPE,
                                   universal_newlines=True)
        busybox.wait()

        for path in busybox.stdout.read().split("\n"):
            if not path.strip():
                continue

            symlink_file = tarfile.TarInfo()
            symlink_file.type = tarfile.SYMTYPE
            symlink_file.linkname = "/bin/busybox"
            if split:
                symlink_file.name = "%s" % path.strip()
                target_tarball_rootfs.addfile(symlink_file)
            else:
                symlink_file.name = "rootfs/%s" % path.strip()
                target_tarball.addfile(symlink_file)

        # Add directories
        for path in ("dev", "mnt", "proc", "root", "sys", "tmp"):
            directory_file = tarfile.TarInfo()
            directory_file.type = tarfile.DIRTYPE
            if split:
                directory_file.name = "%s" % path
                target_tarball_rootfs.addfile(directory_file)
            else:
                directory_file.name = "rootfs/%s" % path
                target_tarball.addfile(directory_file)

        # Add the metadata file
        metadata_yaml = json.dumps(metadata, sort_keys=True,
                                   indent=4, separators=(',', ': '),
                                   ensure_ascii=False).encode('utf-8') + b"\n"

        metadata_file = tarfile.TarInfo()
        metadata_file.size = len(metadata_yaml)
        metadata_file.name = "metadata.yaml"
        target_tarball.addfile(metadata_file,
                               io.BytesIO(metadata_yaml))

        target_tarball.close()
        if split:
            target_tarball_rootfs.close()

        # Compress the tarball
        r = subprocess.call([xz, "-9", destination_tar])
        if r:
            raise Exception("Failed to compress: %s" % destination_tar)

        if split:
            r = subprocess.call([xz, "-9", destination_tar_rootfs])
            if r:
                raise Exception("Failed to compress: %s" %
                                destination_tar_rootfs)
            return destination_tar + ".xz", destination_tar_rootfs + ".xz"
        else:
            return destination_tar + ".xz"


class Ubuntu(Image):
    workdir = None
    server = None
    stream = None

    # FIXME: We should default to the released stream once available
    def __init__(self, server="https://cloud-images.ubuntu.com",
                 stream="releases",
                 gpgkey="7FF3F408476CF100"):

        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        self.gpgdir = "%s/gpg" % self.workdir
        os.mkdir(self.gpgdir, 0o700)

        # Set variables
        self.server = server
        self.stream = stream
        self.gpgkey = gpgkey

        # Get ready to work with this server
        self.gpg_update()

    def __del__(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def image_download(self, release=None, architecture=None, version=None):
        if not release:
            release = "trusty"

        if not architecture:
            architecture = local_architecture()

        # Download and verify GPG signature of index file.
        download_path = os.path.join(self.workdir, "download.json")
        url = "%s/%s/streams/v1/com.ubuntu.cloud:%s:download.json" % \
            (self.server,
             self.stream,
             "released" if self.stream == "releases" else self.stream)
        index = self.grab_and_validate(url, download_path, gpg_suffix=".gpg")

        try:
            # Parse JSON data
            with open(download_path, "rb") as download_file:
                index = json.loads(download_file.read().decode("utf-8"))
        except:
            raise Exception("Unable to parse the image index.")

        image = None
        for product_name, product in index.get("products", {}).items():
            if product.get("release", None) != release and \
                    product.get("version") != release:
                continue

            if product.get("arch") != architecture:
                continue

            candidates = {}
            for version_number, version_entry in \
                    product.get("versions", {}).items():
                if "lxd.tar.xz" not in version_entry.get("items", {}):
                    continue

                if "root.tar.xz" not in version_entry.get("items", {}):
                    continue

                candidates[version_number] = version_entry

            if not candidates:
                raise Exception("The requested image doesn't exist.")

            if version:
                if version not in candidates:
                    raise Exception("The requested image doesn't exist.")

                image = candidates[version]
                break
            else:
                image = candidates[sorted(candidates.keys())[-1]]

        if not image:
            raise Exception("The requested image doesn't exist.")

        print(_("Downloading the image."))
        try:
            print(_("Image manifest: %s") % "%s/%s" %
                  (self.server, image['items']['manifest']['path']))
        except KeyError:
            print(_("No image manifest provided."))

        def download_and_verify(file_to_download, prefix, pubname):
            """
                This function downloads and verifies files in the image.
            """
            path = os.path.join(self.workdir,
                                "%s%s.tar.xz" % (prefix, pubname))
            # Download file
            urllib.request.urlretrieve(
                "%s/%s" % (self.server, file_to_download['path']),
                path,
                report_download)

            # Verify SHA256 checksum of the downloaded file
            with open(path, 'rb') as fd:
                checksum = hashlib.sha256(fd.read()).hexdigest()
                if checksum != file_to_download['sha256']:
                    raise Exception("Checksum of file does not validate")

            return path

        # Download and verify the lxd.tar.xz file
        meta_path = download_and_verify(
            image['items']['lxd.tar.xz'],
            "meta-",
            image['pubname'])

        # Download and verify the root.tar.xz file
        rootfs_path = download_and_verify(
            image['items']['root.tar.xz'],
            "",
            image['pubname'])

        return meta_path, rootfs_path


if __name__ == "__main__":
    if "LXD_DIR" in os.environ:
        lxd_socket = os.path.join(os.environ['LXD_DIR'], "unix.socket")
    else:
        lxd_socket = "/var/lib/lxd/unix.socket"

    if not os.path.exists(lxd_socket):
        print(_("LXD isn't running."))
        sys.exit(1)
    lxd = LXD(lxd_socket)

    def setup_alias(parser, args, fingerprint):
        existing = lxd.aliases_list()

        for alias in args.alias:
            if alias in existing:
                lxd.aliases_remove(alias)
            lxd.aliases_create(alias, fingerprint)
            print(_("Setup alias: %s" % alias))

    def import_busybox(parser, args):
        busybox = Busybox()

        if args.split:
            meta_path, rootfs_path = busybox.create_tarball(split=True)

            with open(meta_path, "rb") as meta_fd:
                with open(rootfs_path, "rb") as rootfs_fd:
                    fingerprint = hashlib.sha256(meta_fd.read() +
                                                 rootfs_fd.read()).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, _("This image is already in the store.\n"))

            r = lxd.images_upload((meta_path, rootfs_path),
                                  meta_path.split("/")[-1], args.public)
            print(_("Image imported as: %s" % r['fingerprint']))
        else:
            path = busybox.create_tarball()

            with open(path, "rb") as fd:
                fingerprint = hashlib.sha256(fd.read()).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, _("This image is already in the store.\n"))

            r = lxd.images_upload(path, path.split("/")[-1], args.public)
            print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    def import_ubuntu(parser, args):
        ubuntu = Ubuntu(stream=args.stream)

        meta_path, rootfs_path = ubuntu.image_download(args.release,
                                                       args.architecture,
                                                       args.version)

        with open(meta_path, "rb") as meta_fd:
            with open(rootfs_path, "rb") as rootfs_fd:
                fingerprint = hashlib.sha256(meta_fd.read() +
                                             rootfs_fd.read()).hexdigest()

        if fingerprint in lxd.images_list():
            parser.exit(1, _("This image is already in the store.\n"))

        r = lxd.images_upload((meta_path, rootfs_path),
                              meta_path.split("/")[-1], args.public)
        print(_("Image imported as: %s" % r['fingerprint']))

        setup_alias(parser, args, fingerprint)

    parser = FriendlyParser(
        description=_("LXD: image store helper"),
        formatter_class=argparse.RawTextHelpFormatter,
        epilog=_("""Examples:
 To import the latest Ubuntu Cloud image with an alias:
    %s import ubuntu --alias ubuntu

 To import a basic busybox image:
    %s import busybox --alias busybox
""" % (sys.argv[0], sys.argv[0])))

    parser_subparsers = parser.add_subparsers(dest="action")
    parser_subparsers.required = True

    # Image import
    parser_import = parser_subparsers.add_parser(
        "import", help=_("Import images"))
    parser_import_subparsers = parser_import.add_subparsers(dest="source")
    parser_import_subparsers.required = True

    # # Busybox
    parser_import_busybox = parser_import_subparsers.add_parser(
        "busybox", help=_("Busybox image"))
    parser_import_busybox.add_argument("--alias", action="append", default=[],
                                       help=_("Aliases for the image"))
    parser_import_busybox.add_argument("--public", action="store_true",
                                       default=False,
                                       help=_("Make the image public"))
    parser_import_busybox.add_argument(
        "--split", action="store_true", default=False,
        help=_("Whether to create a split image"))
    parser_import_busybox.set_defaults(func=import_busybox)

    # # Ubuntu
    parser_import_ubuntu = parser_import_subparsers.add_parser(
        "ubuntu", help=_("Ubuntu images"))
    parser_import_ubuntu.add_argument("release", help=_("Release"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("architecture", help=_("Architecture"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("version", help=_("Version"),
                                      default=None, nargs="?")
    parser_import_ubuntu.add_argument("--stream", default="releases",
                                      choices=("releases", "daily"),
                                      help=_("The simplestream stream to use"))
    parser_import_ubuntu.add_argument("--alias", action="append", default=[],
                                      help=_("Aliases for the image"))
    parser_import_ubuntu.add_argument("--public", action="store_true",
                                      default=False,
                                      help=_("Make the image public"))
    parser_import_ubuntu.set_defaults(func=import_ubuntu)

    # Call the function
    args = parser.parse_args()
    try:
        args.func(parser, args)
    except Exception as e:
        parser.error(e)
