#!/usr/bin/python3

# Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net>
# Copyright: February 2019
# License: GPLv3+ (see https://www.gnu.org/licenses/gpl.html)

import os
import sys
import logging
import ipaddress
import socket
import select

logging.getLogger().name = sys.argv[0]

# global array of sockets that need to avoid being destroyed before the exec()...
keepers = []

def connectinet(options, address, socktype):
    if '/' in address:
        (addr, port) = address.split('/', maxsplit=1)
        port = int(port)
        addr = ipaddress.ip_address(addr)
        if addr.version == 4:
            family = socket.AF_INET
        elif addr.version == 6:
            family = socket.AF_INET6
        else:
            logging.warning(f'could not determine address family for {address}')
            sys.exit(100)
    else:
        port = int(address)
        family = socket.AF_INET
        addr = ''

    sock = socket.socket(family=family, type=socktype)
    sock.bind((str(addr), port))
    return (sock, options)

def connectunix(options, address, socktype):
    sock = socket.socket(family=socket.AF_UNIX, type=socktype)
    if 'mode' in options:
        mode = int(options['mode'], base=8)
        oldumask = os.umask(0o0777 & mode)
    sock.bind(address)
    if 'mode' in options:
        os.umask(oldumask)
        os.chmod(sock.fileno(), mode)

    if 'user' in options or 'group' in options:
        user = options.get('user', None)
        if user is None:
            user = os.getuid()
        else:
            user = int(user)
            
        group = options.get('group', None)
        if group is None:
            group = os.getgid()
        else:
            group = int(user)
        os.chown(sock.fileno(), user, group)
    
    # drop unix-domain-specific options before returning them
    for o in ['user', 'group', 'mode']:
        if o in options:
            del(options[o])
    return (sock, options)

def addlistener(fd, family, options, address):
    logging.debug(f'trying to listen for FD {fd} on {family}, address: {address}, options: {options}')

    m = {'udp': (connectinet, socket.SOCK_DGRAM),
         'tcp': (connectinet, socket.SOCK_STREAM),
         'unix': (connectunix, socket.SOCK_STREAM),
         'unix-dgram': (connectunix, socket.SOCK_DGRAM),
    }
    if family not in m:
        logging.warning(f'unknown family: {family}')
        sys.exit(100)
    socktype = m[family][1]
    (newsocket, options) = m[family][0](options, address, socktype)
    if socktype == socket.SOCK_STREAM:
        backlog = options.get('backlog', socket.SOMAXCONN)
        if 'backlog' in options:
            del(options['backlog'])
        newsocket.listen(int(backlog))

    unknown_opts = set(options.keys()).difference(set(['label']))
    if unknown_opts:
        oname = 'option'
        if len(unknown_opts) > 1:
            oname = 'options'
        logging.warning(f'unknown {oname} for family {family}: {unknown_opts}')
        sys.exit(100)

    if newsocket.fileno() == fd:
        # avoid having the socket object get closed by python's
        # garbage collection:
        keepers.append(newsocket)
        os.set_inheritable(newsocket.fileno(), True)
    else:
        logging.debug(f'moving new FD {newsocket.fileno()} to {fd}')
        os.dup2(newsocket.fileno(), fd)
    
    return options.get('label', '')

def parse_args(args):
    listeners = []
    nextfd = 3
    while args:
        arg = args.pop(0)
        logging.debug(f'parsing arg {arg}')
        if arg in ['-v', '--verbose']:
            logging.getLogger().setLevel(logging.DEBUG)
        elif arg == '--':
            break
        elif arg[0] != '-':
            args.insert(0, arg)
            break
        elif arg.startswith('--'):
            (family, opts, address) = arg[2:].split(':', maxsplit=2)
            options = {}
            for o in filter((lambda x: x != ''), opts.split(',')):
                try:
                    (k, v) = o.split('=', maxsplit=1)
                except ValueError:
                    logging.warning(f'option has no value: {o}')
                    sys.exit(100)
                if k in options:
                    logging.warning(f'repeated option: {k}')
                    sys.exit(100)
                options[k] = v
            listeners.append(addlistener(nextfd, family, options, address))
            nextfd += 1
    return (listeners, args)

def main():
    (listeners, args) = parse_args(sys.argv[1:])
    os.environ['LISTEN_PID'] = str(os.getpid())
    os.environ['LISTEN_FDS'] = str(len(listeners))
    os.environ['LISTEN_FDNAMES'] = ':'.join(listeners)
    # wait until something happens to do the exec:
    dname = 'descriptor'
    if len(listeners) > 1:
        dname = 'descriptors'
    logging.debug(f'select()ing for some inbound activity on {len(listeners)} file {dname} before exec()ing...')
    fds = list(range(3, 3+len(listeners)))
    (r,w,x) = select.select(fds, [], fds)
    logging.debug(f'saw activity on r={r}, w={w}, x={x}')
    logging.debug(f'executing {args}')
    os.execvp(args[0], args)

main()
