#!/usr/bin/python

# Copyright (C) 2015 Linaro Limited
#
# Author: Remi Duraffort <remi.duraffort@linaro.org>
#
# This file is part of LAVA Dispatcher.
#
# LAVA Coordinator is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# LAVA Coordinator is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses>.

"""
Start the lava dispatcher and the zmq messager.

Slaves are allowed to connect over ZMQ, but devices can only
be assigned to known slaves by the admin of the instance
(by selecting the worker_host for each pipeline device.
Initially, the details of the workers will be configured
via the current dispatcher support.
"""

import argparse
import atexit
import errno
import fcntl
import logging
import os
import re
import signal
import socket
import subprocess
import sys
import tempfile
import time
import traceback
import yaml
import zmq
import zmq.auth

# pylint: disable=no-member
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
# pylint: disable=too-many-branches
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements

# Default values for:
# timeouts (in seconds)
# zmq socket send high water mark
TIMEOUT = 5
SEND_QUEUE = 10

# FIXME: This is a temporary fix until the overlay is sent to the master
# The job.yaml and device.yaml are retained so that lava-dispatch can be re-run manually
# (at least until the slave is rebooted).
TMP_DIR = os.path.join(tempfile.gettempdir(), "lava-dispatcher/slave/")

# Create the logger that will be configured after arguments parsing
FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
LOG = logging.getLogger("dispatcher-slave")


def mkdir(path):
    """Create a directory only if needed."""
    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


class Master(object):
    """Store information about the master status."""
    def __init__(self):
        self.last_msg = 0
        self.last_ping = 0
        self.online = False

    def received_msg(self):
        """We received a valid message from the master."""
        self.last_msg = time.time()
        if not self.online:
            LOG.info("Master is ONLINE")
        self.online = True


class Job(object):
    """Wrapper around a job process."""
    def __init__(self, job_id, definition, device_definition, env,
                 log_socket, master_cert, slave_cert, env_dut=None):
        self.job_id = job_id
        self.log_socket = log_socket
        self.master_cert = master_cert
        self.slave_cert = slave_cert
        self.env = env
        self.env_dut = env_dut
        self.proc = None
        self.is_running = False
        self.base_dir = os.path.join(TMP_DIR, "%s/" % self.job_id)
        mkdir(self.base_dir)

        # Write back the job and device configuration
        with open(os.path.join(self.base_dir, "job.yaml"), "w") as f_job:
            f_job.write(definition)
        with open(os.path.join(self.base_dir, "device.yaml"), "w") as f_device:
            if device_definition:  # an empty file for secondary connections
                f_device.write(device_definition)

    def create_environ(self):
        """Generate the env variables for the job."""
        conf = yaml.load(self.env)
        if conf.get("purge", False):
            environ = {}
        else:
            environ = dict(os.environ)

        # Remove some variables (that might not exist)
        for var in conf.get("removes", {}):
            try:
                del environ[var]
            except KeyError:
                pass

        # Override
        environ.update(conf.get("overrides", {}))
        return environ

    def log_errors(self):
        err_file = os.path.join(self.base_dir, "err")
        msg = None
        if os.stat(err_file).st_size != 0:
            with open(err_file, 'r') as errlog:
                msg = errlog.read()
            LOG.exception(msg)
        return msg

    def start(self):
        """Start the process."""
        out_file = os.path.join(self.base_dir, "out")
        err_file = os.path.join(self.base_dir, "err")
        env_dut_tmp_path = None

        # Dump the environment variables in the tmp file.
        if self.env_dut:
            env_dut_file_handle, env_dut_tmp_path = tempfile.mkstemp()
            with os.fdopen(env_dut_file_handle, 'wb') as f:
                f.write(self.env_dut)

        try:
            LOG.debug("[%d] START", self.job_id)
            env = self.create_environ()
            args = [
                "lava-dispatch",
                "--target",
                os.path.join(self.base_dir, "device.yaml"),
                os.path.join(self.base_dir, "job.yaml"),
                "--output-dir=%s" % os.path.join(self.base_dir, "logs/"),
                "--job-id=%s" % self.job_id,
                "--socket-addr=%s" % self.log_socket
            ]
            # Use certificates if defined
            if self.master_cert is not None and self.slave_cert is not None:
                args.extend(["--master-cert", self.master_cert,
                             "--slave-cert", self.slave_cert])

            if self.env_dut and env_dut_tmp_path:
                args.append("--env-dut-path=%s" % env_dut_tmp_path)

            self.proc = subprocess.Popen(
                args,
                stdout=open(out_file, "w"),
                stderr=open(err_file, "w"), env=env)
            self.is_running = True
        except Exception as exc:  # pylint: disable=broad-except
            # daemon must always continue running even if the job crashes
            if hasattr(exc, "child_traceback"):
                LOG.exception(
                    {exc.strerror: exc.child_traceback.split("\n")})
            else:
                LOG.exception(exc)
            with open(err_file, "a") as errlog:
                errlog.write("%s\n%s\n" % (exc, traceback.format_exc()))
            self.cancel()

    def cancel(self):
        """Cancel the job and kill the process."""
        if self.proc is not None:
            self.proc.terminate()
            # TODO: be sure not to block here
            self.proc.wait()
            self.proc = None
        self.is_running = False


def get_fqdn():
    """Return the fully qualified domain name."""
    host = socket.getfqdn()
    try:
        if bool(re.match("[-_a-zA-Z0-9.]+$", host)):
            return host
        else:
            raise ValueError("Your FQDN contains invalid characters")
    except ValueError as exc:
        raise exc


def create_zmq_context(master_uri, hostname, send_queue, encrypt,
                       master_cert, slave_cert):
    """Create the ZMQ context and necessary accessories.

    :param master_uri: The URI where the sokect should be connected.
    :type master_uri: string
    :param hostname: The name of this host.
    :type hostname: string
    :param send_queue: How many object should be in the send queue.
    :type send_queue: int
    :param encrypt: encrypt (or not) the zmq messages
    :param master_cert: the master certificate file
    :param slave_cert: the slave certificate file
    :return A tuple with: the zmq context, the zmq socket, the zmq poller, a
    read pipe and a write pipe.
    """
    LOG.info("Creating ZMQ context and socket connections")
    # Connect to the master dispatcher.
    context = zmq.Context()
    sock = context.socket(zmq.DEALER)
    sock.setsockopt(zmq.IDENTITY, hostname)
    sock.setsockopt(zmq.SNDHWM, send_queue)

    # If needed, load certificates
    if encrypt:
        LOG.info("Starting encryption")
        try:
            LOG.debug("Opening slave certificate: %s", slave_cert)
            (client_public, client_private) = zmq.auth.load_certificate(slave_cert)
            sock.curve_publickey = client_public
            sock.curve_secretkey = client_private
            LOG.debug("Opening master certificate: %s", master_cert)
            (server_public, _) = zmq.auth.load_certificate(master_cert)
            sock.curve_serverkey = server_public
        except IOError as err:
            LOG.error(err)
            sock.close()
            context.term()
            sys.exit(1)

    sock.connect(master_uri)

    # Poll on the socket and the pipe (signal).
    poller = zmq.Poller()
    poller.register(sock, zmq.POLLIN)

    # Mask signals and create a pipe that will receive a bit for each signal
    # received. Poll the pipe along with the zmq socket so that we can only be
    # interrupted while reading data.
    (read_pipe, write_pipe) = os.pipe()
    flags = fcntl.fcntl(write_pipe, fcntl.F_GETFL, 0) | os.O_NONBLOCK
    fcntl.fcntl(write_pipe, fcntl.F_SETFL, flags)
    signal.set_wakeup_fd(write_pipe)
    signal.signal(signal.SIGINT, lambda x, y: None)
    signal.signal(signal.SIGTERM, lambda x, y: None)
    signal.signal(signal.SIGQUIT, lambda x, y: None)
    poller.register(read_pipe, zmq.POLLIN)

    return context, sock, poller, read_pipe, write_pipe


def destroy_zmq_context(context, sock, read_pipe, write_pipe):
    """Clean up function to close ZMQ and related objects.

    :param context: The zmq context to terminate.
    :param sock: The zmq socket to close.
    :param read_pipe: The read pipe to close.
    :param write_pipe: The write pipe to close.
    """
    LOG.info("Closing sock and pipes, dropping messages")
    try:
        os.close(read_pipe)
        os.close(write_pipe)
    except OSError:
        # Silently ignore possible errors.
        pass
    sock.close(linger=0)
    context.term()


def configure_logger(log_file, level):
    """Configure the logger

    :param log_file: the log_file or "-" for sys.stdout
    :param level: the log level
    """
    # Configure the log handler
    if log_file == "-":
        handler = logging.StreamHandler(sys.stdout)
    else:
        handler = logging.FileHandler(log_file, "a")
    handler.setFormatter(logging.Formatter(FORMAT))
    LOG.addHandler(handler)

    # Set-up the LOG level
    if level == "ERROR":
        LOG.setLevel(logging.ERROR)
    elif level == "WARN":
        LOG.setLevel(logging.WARN)
    elif level == "INFO":
        LOG.setLevel(logging.INFO)
    else:
        LOG.setLevel(logging.DEBUG)


def connect_to_master(master, poller, pipe_r, sock, timeout):
    """ Connect to master and wait for the answer

    :param master: the master structure to fill when done
    :param  poller: the structure to poll for messages
    :param pipe_r: the read pipe for signals
    :param sock: the zmq socket
    :param timeout: the poll timeout
    :return: True if the master had answered
    """
    retry_msg = "HELLO_RETRY"
    try:
        LOG.info("Waiting for the master to reply")
        sockets = dict(poller.poll(timeout * 1000))
    except zmq.error.ZMQError:
        # TODO: tests needed to understand cases where ZMQError is raised.
        LOG.error("Received an error, interrupted")
        sys.exit(1)

    if sockets.get(pipe_r) == zmq.POLLIN:
        LOG.info("Received a signal, leaving")
        sys.exit(0)
    elif sockets.get(sock) == zmq.POLLIN:
        msg = sock.recv_multipart()

        try:
            message = msg[0]
            LOG.debug("The master replied: %s", msg)
        except (IndexError, TypeError):
            LOG.error("Invalid message from the master: %s", msg)
        else:
            if message == "HELLO_OK":
                LOG.info("Connection with the master established")
                # Mark the master as alive.
                master.received_msg()
                return True
            else:
                LOG.info("Unexpected message from the master: %s", message)

    LOG.debug("Sending new %s message to the master", retry_msg)
    sock.send_multipart([retry_msg])
    return False


def listen_to_master(master, jobs, poller, pipe_r, socket_addr, master_cert,
                     slave_cert, sock, timeout):
    """Listen for master orders

    :param master: the master structure
    :param jobs: the list of jobs
    :param pipe_r: the read pipe for signals
    :param socket_addr: address of the logging socket
    :param master_cert: the master certificate
    :param slave_cert: the slave certificate
    :param sock: the zmq socket
    :param timeout: the poll timeout
    """
    try:
        sockets = dict(poller.poll(timeout * 1000))
    except zmq.error.ZMQError:
        # TODO: tests needed to understand cases where ZMQError is raised.
        return

    if sockets.get(pipe_r) == zmq.POLLIN:
        LOG.info("Received a signal, leaving")
        sys.exit(0)

    if sockets.get(sock) == zmq.POLLIN:
        msg = sock.recv_multipart()

        # 1: the action
        try:
            action = msg[0]
        except (IndexError, TypeError):
            LOG.error("Invalid message from the master: %s", msg)
            return
        LOG.debug("Received action=%s", action)

        # Parse the action
        if action == "HELLO_OK":
            LOG.debug(
                "Received HELLO_OK from the master - nothing do to")

        elif action == "PONG":
            LOG.debug("Connection to master OK")

            # Mark the master as alive
            master.received_msg()

        elif action == "START":
            try:
                job_id = int(msg[1])
                job_definition = msg[2]
                device_definition = msg[3]
                env = msg[4]
                env_dut = msg[5] if len(msg) == 6 else None
            except (IndexError, ValueError) as exc:
                LOG.error("Invalid message '%s'. length=%d. %s", msg, len(msg), exc)
                return

            LOG.info("[%d] Starting job", job_id)
            LOG.debug("[%d]        : %s", job_id, job_definition)
            LOG.debug("[%d] device : %s", job_id, device_definition)
            LOG.debug("[%d] env    : %s", job_id, env)
            LOG.debug("[%d] env-dut: %s", job_id, env_dut)

            # Check if the job is known and started. In this case, send
            # back the right signal (ignoring the duplication or signaling
            # the end of the job).
            if job_id in jobs:
                if jobs[job_id].is_running:
                    LOG.info(
                        "[%d] Job has already been started", job_id)
                    sock.send_multipart(["START_OK", str(job_id)])
                else:
                    LOG.warning("[%d] Job has already ended", job_id)
                    sock.send_multipart(["END", str(job_id), "0"])
            else:
                jobs[job_id] = Job(job_id, job_definition, device_definition,
                                   env, socket_addr, master_cert, slave_cert,
                                   env_dut=env_dut)
                jobs[job_id].start()
                sock.send_multipart(["START_OK", str(job_id)])

            # Mark the master as alive
            master.received_msg()

        elif action == "CANCEL":
            try:
                job_id = int(msg[1])
            except (IndexError, ValueError):
                LOG.error("Invalid message '%s'", msg)
                return
            LOG.info("[%d] Canceling", job_id)

            # Check if the job is known and started. In this case, send
            # back the right signal (ignoring the duplication or signaling
            # the end of the job).
            if job_id in jobs:
                if jobs[job_id].is_running:
                    jobs[job_id].cancel()
                else:
                    LOG.info(
                        "[%d] Job has already been canceled", job_id)
            else:
                LOG.debug("[%d] Unknown job, sending END", job_id)
                jobs[job_id] = Job(job_id, "", "", None, None, None, None)
                jobs[job_id].is_running = False
            # Send the END message anyway
            sock.send_multipart(["END", str(job_id), "0"])

            # Mark the master as alive
            master.received_msg()

        elif action == "END_OK":
            try:
                job_id = int(msg[1])
            except (IndexError, ValueError):
                LOG.error("Invalid message '%s'", msg)
                return
            if job_id in jobs:
                LOG.debug("[%d] Job END acked", job_id)
                del jobs[job_id]
            else:
                LOG.debug("[%d] Unknown job END acked", job_id)

            # Do not mark the master as alive. In fact we are not sending
            # back any data so the master will not be able to mark the
            # slave as alive.

        elif action == "STATUS":
            try:
                job_id = int(msg[1])
            except (IndexError, ValueError):
                LOG.error("Invalid message '%s'", msg)
                return
            if job_id in jobs:
                if jobs[job_id].is_running:
                    # The job is still running
                    sock.send_multipart(["START_OK", str(job_id)])
                else:
                    # The job has already ended
                    sock.send_multipart(["END", str(job_id), "0"])
            else:
                # Unknown job: return END anyway
                LOG.debug(
                    "[%d] Unknown job, sending END after STATUS", job_id)
                jobs[job_id] = Job(job_id, "", "", None, None, None, None)
                jobs[job_id].is_running = False
                sock.send_multipart(["END", str(job_id), "0"])

            # Mark the master as alive
            master.received_msg()

        else:
            LOG.error(
                "Unknown action: '%s', args=(%s)", action, msg[1:])
            # Do not tag the master as alive as the message does not mean
            # anything.


def check_job_status(jobs, sock):
    """Look for finished jobs

    :param jobs: the list of jobs
    :param sock: the zmq socket
    """
    # Loop on all running jobs
    for job_id in [i for i in jobs.keys() if jobs[i].is_running]:
        ret = jobs[job_id].proc.poll()
        # Job has finished
        if ret is not None:
            LOG.info("[%d] Job END", job_id)
            job_status = jobs[job_id].proc.returncode
            if job_status:
                LOG.info("[%d] Job returned non-zero", job_id)
                errs = jobs[job_id].log_errors()
                if errs:
                    sock.send_multipart(["ERROR", str(job_id), str(errs)])

            jobs[job_id].is_running = False
            sock.send_multipart(["END", str(job_id), str(job_status)])


def ping_master(master, sock, timeout):
    """PING the master whnever needed
    Send a PING only if we haven't received a message from the master nor sent
    a PING for a long time.

    :param master: the master structure
    :param sock: the zmq socket
    :param timeout: the time to wait to flag the master as offline
    """
    now = time.time()
    if now - max(master.last_msg, master.last_ping) > timeout:
        # Is the master offline ?
        if master.online and now - master.last_msg > 4 * timeout:
            LOG.warning("Master goes OFFLINE")
            master.online = False

        LOG.debug(
            "Sending PING to the master (last message %ss ago)",
            int(now - master.last_msg))

        sock.send_multipart(["PING"])
        master.last_ping = now


def main():
    """Set up and start the dispatcher slave."""
    parser = argparse.ArgumentParser(description="LAVA Dispatcher Slave")
    parser.add_argument(
        "--hostname", default=get_fqdn(), type=str, help="Name of the slave")
    parser.add_argument(
        "--master", type=str, help="Main master socket", required=True)
    parser.add_argument(
        "--socket-addr", type=str, help="Log socket", required=True)
    parser.add_argument(
        "--log-file", type=str, help="Log file for the slave logs",
        default="/var/log/lava-dispatcher/lava-slave.log")
    parser.add_argument(
        "--level", "-l",
        type=str,
        default="INFO",
        choices=["DEBUG", "ERROR", "INFO", "WARN"],
        help="Log level (DEBUG, ERROR, INFO, WARN); default to INFO"
    )
    parser.add_argument(
        "--timeout", "-t",
        type=int,
        default=TIMEOUT,
        help="Socket connection timeout in seconds; default to %d" % TIMEOUT,
    )

    parser.add_argument(
        "--encrypt", default=False, action="store_true",
        help="Encrypt messages"
    )
    parser.add_argument(
        "--master-cert", type=str,
        default="/etc/lava-dispatcher/certificates.d/master.key",
        help="Master certificate file",
    )
    parser.add_argument(
        "--slave-cert", type=str,
        default="/etc/lava-dispatcher/certificates.d/slave.key_secret",
        help="Slave certificate file",
    )
    args = parser.parse_args()

    # Parse the command line
    timeout = args.timeout
    host_name = args.hostname
    master_uri = args.master

    # configure logger
    configure_logger(args.log_file, args.level)

    # Create the zmq context
    context, sock, poller, pipe_r, pipe_w = create_zmq_context(
        master_uri, host_name, SEND_QUEUE, args.encrypt,
        args.master_cert, args.slave_cert)

    # Register cleanup function to be run at exit.
    atexit.register(destroy_zmq_context, context, sock, pipe_r, pipe_w)

    # Collect every server data and list of jobs
    master = Master()
    jobs = {}

    # Connect to the master and wait for the reply
    LOG.info("Connecting to master as <%s>", host_name)
    if args.encrypt:
        LOG.info("Connection is encrypted using %s", args.slave_cert)
    else:
        # Set to None to disable encryption in the logger
        args.slave_cert = None
        args.master_cert = None
        LOG.info("Connection is not encrypted")

    hello_msg = "HELLO"
    LOG.debug("Greeting the master => '%s'", hello_msg)
    sock.send_multipart([hello_msg])

    while not connect_to_master(master, poller, pipe_r, sock, timeout):
        pass

    # Loop for server instructions
    LOG.info("Waiting for master instructions")
    while True:
        listen_to_master(master, jobs, poller, pipe_r, args.socket_addr,
                         args.master_cert, args.slave_cert, sock, timeout)
        check_job_status(jobs, sock)
        ping_master(master, sock, timeout)


if __name__ == "__main__":
    main()
