#!/usr/bin/env python
# Copyright 2012 Canonical Ltd.  This software is licensed under the
# GNU Lesser General Public License version 2.1 (see the file COPYING).

"""Display the ip address of a container."""

import argparse
from contextlib import closing, contextmanager
import ctypes
import fcntl
from functools import wraps
import os
import socket
import struct
import subprocess
import sys


# The namespace type to use with setns(2): in this case we want to
# reassociate this thread with the network namespace.
CLONE_NEWNET = 0x40000000
ERRORS = {
    'not_connected': 'unable to find the container ip address',
    'not_found': 'the container does not exist or is not running',
    'not_installed': 'lxc does not seem to be installed',
    'not_root': 'you must be root',
    }
# The ioctl command to retrieve the interface address.
SIOCGIFADDR = 0x8915


def _parse_args():
    """Parse the command line arguments."""
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '-n', '--name', required=True,
        help='The name of the container. ')
    parser.add_argument(
        '-i', '--interface',
        help='Display the ip address of the specified network interface.')
    namespace = parser.parse_args()
    return namespace.name, namespace.interface


def _error(code):
    """Return an OSError containing given `msg`."""
    return OSError(
        '{}: error: {}'.format(os.path.basename(sys.argv[0]), ERRORS[code]))


def _output(interfaces, ip_addresses, short):
    """Format the output displaying the ip addresses of the container."""
    if short:
        return ip_addresses[0]
    interface_ip_map = zip(interfaces, ip_addresses)
    return '\n'.join('{}: {}'.format(*i) for i in interface_ip_map)


def _load_library(name, loader=None):
    """Load a shared library into the process and return it.

    Search the library `name` inside `/usr/lib` and, if *$DEB_HOST_MULTIARCH*
    is retrievable, inside `/usr/lib/$DEB_HOST_MULTIARCH/`.

    The optional argument `loader` is the callable used to load the library:
    if None, `ctypes.cdll.LoadLibrary` is used.

    Raise OSError if the library is not found.
    """
    if loader is None:
        loader = ctypes.cdll.LoadLibrary
    try:
        return loader(os.path.join('/usr/lib/', name))
    except OSError:
        # Search the library in `/usr/lib/$DEB_HOST_MULTIARCH/`:
        # see https://wiki.ubuntu.com/MultiarchSpec.
        process = subprocess.Popen(
            ['dpkg-architecture', '-qDEB_HOST_MULTIARCH'],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if not process.returncode:
            output, _ = process.communicate()
            return loader(os.path.join('/usr/lib/', output.strip(), name))
        raise


def _wrap(function, error_code):
    """Add error handling to the given C `function`.

    If the function returns an error, the wrapped function raises an
    OSError using a message corresponding to the given `error_code`.
    """
    def errcheck(result, func, arguments):
        if result < 0:
            raise _error(error_code)
        return result
    function.errcheck = errcheck
    return function


@contextmanager
def redirect_stderr(path):
    """Redirect system stderr to `path`.

    This context manager does not use normal sys.std* Python redirection
    because we also want to intercept and redirect stderr written by
    underlying C functions called using ctypes.
    """
    fd = sys.stderr.fileno()
    backup = os.dup(fd)
    new_fd = os.open(path, os.O_WRONLY)
    sys.stderr.flush()
    os.dup2(new_fd, fd)
    os.close(new_fd)
    try:
        yield
    finally:
        sys.stderr.flush()
        os.dup2(backup, fd)


def root_required(func):
    """A decorator checking for current user effective id.

    The decorated function is only executed if the current user is root.
    Otherwise, an OSError is raised.
    """
    @wraps(func)
    def decorated(*args, **kwargs):
        if os.geteuid():
            raise _error('not_root')
        return func(*args, **kwargs)
    return decorated


class SetNamespace(object):
    """A context manager to switch the network namespace for this thread.

    A namespace is one of the entries in /proc/[pid]/ns/.
    """
    def __init__(self, pid, nstype=CLONE_NEWNET):
        libc = ctypes.cdll.LoadLibrary('libc.so.6')
        self._pid = pid
        self._nstype = nstype
        self._setns = _wrap(libc.setns, 'not_connected')

    def __enter__(self):
        """Switch the namespace."""
        self.set(self._pid)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Restore normal namespace."""
        # To restore the namespace we use the file descriptor associated
        # with the hosts's init process. In Linux the init pid is always 1.
        self.set(1)

    def set(self, pid):
        try:
            fd = os.open('/proc/{}/ns/net'.format(pid), os.O_RDONLY)
        except OSError:
            raise _error('not_found')
        self._setns(fd, self._nstype)
        os.close(fd)


@root_required
def get_pid(name):
    """Return the pid of an LXC, given its `name`.

    Raise OSError if LXC is not installed or the container is not found.
    """
    try:
        liblxc = _load_library('lxc/liblxc.so.0')
    except OSError:
        raise _error('not_installed')
    get_init_pid = _wrap(liblxc.get_init_pid, 'not_found')
    # Redirect the system stderr in order to get rid of the error raised by
    # the underlying C function call if the container is not found.
    with redirect_stderr('/dev/null'):
        return get_init_pid(name)


@root_required
def get_interfaces(pid, exclude=()):
    """Return a list of active net interfaces, given the container's `pid`.

    Raise OSError if the container does not exist, is not running, or if
    no interface is found.
    """
    path = '/proc/{}/root/sys/class/net/'.format(pid)
    try:
        interfaces = [
            i for i in os.listdir(path)
            if i not in exclude and
            os.path.isdir(os.path.join(path, i))
            ]
    except OSError:
        raise _error('not_found')
    if not interfaces:
        raise _error('not_connected')
    return interfaces


@root_required
def get_ip_addresses(pid, interfaces):
    """Return ip addresses of LXC `interfaces`, given the container's `pid`.

    Raise OSError if the container is not found or one the ip addresses
    is not retrievable.

    Note that `socket.gethostbyname` is not usable in this context: it uses
    the system's dns resolver that by default does not resolve lxc names.
    """
    ip_addresses = []
    with SetNamespace(pid):
        # Retrieve the ip address for the given network interface.
        # Original from http://code.activestate.com/recipes/
        #     439094-get-the-ip-address-associated-with-a-network-inter/
        with closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as s:
            for interface in interfaces:
                # Slice the interface because the buffer size used to hold an
                # interface name, including its terminating zero byte,
                # is 16 in Linux (see /usr/include/linux/if.h).
                packed = struct.pack('256s', interface[:15])
                try:
                    # Use the ioctl Unix routine to request SIOCGIFADDR.
                    binary_ip = fcntl.ioctl(
                        s.fileno(), SIOCGIFADDR, packed)[20:24]
                    # Convert the packet ipv4 address to its standard
                    # dotted-quad string representation.
                    ip = socket.inet_ntoa(binary_ip)
                except (IOError, socket.error):
                    raise _error('not_connected')
                ip_addresses.append(ip)
    return ip_addresses


def main():
    name, interface = _parse_args()
    try:
        pid = get_pid(name)
        if interface is None:
            interfaces = get_interfaces(pid, exclude=['lo'])
        else:
            interfaces = [interface]
        ip_addresses = get_ip_addresses(pid, interfaces)
    except (KeyboardInterrupt, OSError) as err:
        return err
    print _output(interfaces, ip_addresses, interface is not None)


if __name__ == '__main__':
    sys.exit(main())
