#!/usr/bin/python

# Copyright (C) 2015 Linaro Limited
#
# Author: Remi Duraffort <remi.duraffort@linaro.org>
#
# This file is part of LAVA Dispatcher.
#
# LAVA Coordinator is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# LAVA Coordinator is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <http://www.gnu.org/licenses>.

"""Start the lava dispatcher and the zmq messager."""

import argparse
import atexit
import errno
import fcntl
import logging
import os
import re
import signal
import socket
import subprocess
import sys
import tempfile
import time
import traceback
import yaml
import zmq

# pylint: disable=no-member
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-arguments
# pylint: disable=too-many-branches
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements

# Default values for:
# timeouts (in seconds)
# zmq socket send high water mark
TIMEOUT = 5
SEND_QUEUE = 10

# FIXME: This is a temporary fix and the dir contents need to be assessed
# whether they are useful (logs should go into /tmp or tmpfs) or not.
TMP_DIR = os.path.join(tempfile.gettempdir(), "lava-dispatcher/slave/")

# Setup the log.
FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
logging.basicConfig(format=FORMAT)
LOG = logging.getLogger("dispatcher-slave")


def mkdir(path):
    """Create a directory only if needed."""
    try:
        os.makedirs(path)
    except OSError as exc:
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


class Master(object):
    """Store information about the master status."""
    def __init__(self):
        self.last_msg = 0
        self.last_ping = 0
        self.online = False

    def received_msg(self):
        """We received a valid message from the master."""
        self.last_msg = time.time()
        if not self.online:
            LOG.info("Master is ONLINE")
        self.online = True


class Job(object):
    """Wrapper around a job process."""
    def __init__(self, job_id, definition, device_definition, env,
                 log_socket, env_dut=None):
        self.job_id = job_id
        self.log_socket = log_socket
        self.env = env
        self.env_dut = env_dut
        self.proc = None
        self.running = False
        self.base_dir = os.path.join(TMP_DIR, "%s/" % self.job_id)
        mkdir(self.base_dir)

        # Write back the job and device cofniguration
        with open(os.path.join(self.base_dir, "job.yaml"), "w") as f_job:
            f_job.write(definition)
        with open(os.path.join(self.base_dir, "device.yaml"), "w") as f_device:
            f_device.write(device_definition)

    def create_environ(self):
        """Generate the env variables for the job."""
        conf = yaml.load(self.env)
        if conf.get("purge", False):
            environ = {}
        else:
            environ = dict(os.environ)

        # Remove some variables (that might not exist)
        for var in conf.get("removes", {}):
            try:
                del environ[var]
            except KeyError:
                pass

        # Override
        environ.update(conf.get("overrides", {}))
        return environ

    def start(self):
        """Start the process."""
        out_file = os.path.join(self.base_dir, "out")
        err_file = os.path.join(self.base_dir, "err")
        env_dut_tmp_path = None

        # Dump the environment variables in the tmp file.
        if self.env_dut:
            env_dut_file_handle, env_dut_tmp_path = tempfile.mkstemp()
            with os.fdopen(env_dut_file_handle, 'wb') as f:
                f.write(self.env_dut)

        try:
            LOG.debug("[%d] START", self.job_id)
            env = self.create_environ()
            args = [
                "lava-dispatch",
                "--target",
                os.path.join(self.base_dir, "device.yaml"),
                os.path.join(self.base_dir, "job.yaml"),
                "--output-dir=%s" % os.path.join(self.base_dir, "logs/"),
                "--job-id=%s" % self.job_id,
                "--socket-addr=%s" % self.log_socket
            ]

            if self.env_dut and env_dut_tmp_path:
                args.append("--env-dut-path=%s" % env_dut_tmp_path)

            self.proc = subprocess.Popen(
                args,
                stdout=open(out_file, "w"),
                stderr=open(err_file, "w"), env=env)
            self.running = True
        except Exception as exc:
            # daemon must always continue running even if the job crashes
            if hasattr(exc, "child_traceback"):
                LOG.exception(
                    {exc.strerror: exc.child_traceback.split("\n")})
            else:
                LOG.exception(exc)
            with open(err_file, "a") as errlog:
                # TODO: send something to the zmq LOG
                errlog.write("%s\n%s\n" % (exc, traceback.format_exc()))
            self.cancel()

    def cancel(self):
        """Cancel the job and kill the process."""
        if self.proc is not None:
            self.proc.terminate()
            # TODO: be sure not to block here
            self.proc.wait()
            self.proc = None
        self.running = False


def get_fqdn():
    """Return the fully qualified domain name."""
    host = socket.getfqdn()
    try:
        if bool(re.match("[-_a-zA-Z0-9.]+$", host)):
            return host
        else:
            raise ValueError("Your FQDN contains invalid characters")
    except ValueError as exc:
        raise exc


def create_zmq_context(master_uri, hostname, send_queue=10):
    """Create the ZMQ context and necessary accessories.

    :param master_uri: The URI where the sokect should be connected.
    :type master_uri: string
    :param hostname: The name of this host.
    :type hostname: string
    :param send_queue: How many object should be in the send queue.
    :type send_queue: int
    :return A tuple with: the zmq context, the zmq socket, the zmq poller, a
    read pipe and a write pipe.
    """
    LOG.info("Creating ZMQ context and socket connections")
    # Connect to the master dispatcher.
    context = zmq.Context()
    sock = context.socket(zmq.DEALER)
    sock.setsockopt(zmq.IDENTITY, hostname)
    sock.setsockopt(zmq.SNDHWM, send_queue)
    sock.connect(master_uri)

    # Poll on the socket and the pipe (signal).
    poller = zmq.Poller()
    poller.register(sock, zmq.POLLIN)

    # Mask signals and create a pipe that will receive a bit for each signal
    # received. Poll the pipe along with the zmq socket so that we can only be
    # interrupted while reading data.
    (read_pipe, write_pipe) = os.pipe()
    flags = fcntl.fcntl(write_pipe, fcntl.F_GETFL, 0) | os.O_NONBLOCK
    fcntl.fcntl(write_pipe, fcntl.F_SETFL, flags)
    signal.set_wakeup_fd(write_pipe)
    signal.signal(signal.SIGINT, lambda x, y: None)
    signal.signal(signal.SIGTERM, lambda x, y: None)
    signal.signal(signal.SIGQUIT, lambda x, y: None)
    poller.register(read_pipe, zmq.POLLIN)

    return context, sock, poller, read_pipe, write_pipe


def destroy_zmq_context(context, sock, read_pipe, write_pipe):
    """Clean up function to close ZMQ and related objects.

    :param context: The zmq context to terminate.
    :param sock: The zmq socket to close.
    :param read_pipe: The read pipe to close.
    :param write_pipe: The write pipe to close.
    """
    LOG.info("Closing sock and pipes, dropping messages")
    try:
        os.close(read_pipe)
        os.close(write_pipe)
    except OSError:
        # Silently ignore possible errors.
        pass
    sock.close(linger=0)
    context.term()


def main():
    """Set up and start the dispatcher slave."""
    parser = argparse.ArgumentParser(description="LAVA Dispatcher Slave")
    parser.add_argument(
        "--hostname", default=get_fqdn(), type=str, help="Name of the slave")
    parser.add_argument(
        "--master", type=str, help="Main master socket", required=True)
    parser.add_argument(
        "--socket-addr", type=str, help="Log socket", required=True)
    parser.add_argument(
        "--level", "-l",
        type=str,
        default="INFO",
        choices=["DEBUG", "ERROR", "INFO", "WARN"],
        help="Log level (DEBUG, ERROR, INFO, WARN); default to INFO"
    )
    parser.add_argument(
        "--timeout", "-t",
        type=int,
        default=TIMEOUT,
        help="Socket connection timeout in seconds; default to %d" % TIMEOUT,
    )
    args = parser.parse_args()

    log_level = args.level
    timeout = args.timeout
    host_name = args.hostname
    master_uri = args.master

    # Set-up the LOG level
    if log_level == "ERROR":
        LOG.setLevel(logging.ERROR)
    elif log_level == "WARN":
        LOG.setLevel(logging.WARN)
    elif log_level == "INFO":
        LOG.setLevel(logging.INFO)
    else:
        LOG.setLevel(logging.DEBUG)

    context, sock, poller, pipe_r, pipe_w = create_zmq_context(
        master_uri, host_name, send_queue=SEND_QUEUE)

    # Register cleanup function to be run at exit.
    atexit.register(destroy_zmq_context, context, sock, pipe_r, pipe_w)

    # Collect every server data and list of jobs
    master = Master()
    jobs = {}

    LOG.info("Connecting to master as <%s>", host_name)
    LOG.debug("Greeting the master => 'HELLO'")
    sock.send_multipart(["HELLO"])

    while True:
        try:
            LOG.info("Waiting for the master to reply")
            sockets = dict(poller.poll(timeout * 1000))
        except zmq.error.ZMQError:
            # TODO: tests needed to understand cases where ZMQError is raised.
            LOG.error("Received an error, interrupted")
            sys.exit(1)

        if sockets.get(pipe_r) == zmq.POLLIN:
            LOG.info("Received a signal, leaving")
            sys.exit(0)
        elif sockets.get(sock) == zmq.POLLIN:
            msg = sock.recv_multipart()

            try:
                message = msg[0]
                LOG.debug("The master replied: %s", msg)
            except (IndexError, TypeError):
                LOG.error("Invalid message from the master: %s", msg)
            else:
                if message == "HELLO_OK":
                    LOG.info("Connection with the master established")
                    # Mark the master as alive.
                    master.received_msg()
                    break
                else:
                    LOG.info("Unexpected message from the master: %s", message)

        LOG.debug("Sending new HELLO_RETRY message to the master")
        sock.send_multipart(["HELLO_RETRY"])

    # Loop for server instructions
    LOG.info("Waiting for master instructions")
    while True:
        try:
            sockets = dict(poller.poll(timeout * 1000))
        except zmq.error.ZMQError:
            # TODO: tests needed to understand cases where ZMQError is raised.
            continue

        if sockets.get(pipe_r) == zmq.POLLIN:
            LOG.info("Received a signal, leaving")
            break

        if sockets.get(sock) == zmq.POLLIN:
            msg = sock.recv_multipart()

            # 1: the action
            try:
                action = msg[0]
            except (IndexError, TypeError):
                LOG.error("Invalid message from the master: %s", msg)
                continue
            LOG.debug("Received action=%s, args=(%s)", action, msg[1:])

            # Parse the action
            if action == "HELLO_OK":
                LOG.debug(
                    "Received HELLO_OK from the master - nothing do to")

            elif action == "PONG":
                LOG.debug("Connection to master OK")

                # Mark the master as alive
                master.received_msg()

            elif action == "START":
                try:
                    job_id = int(msg[1])
                    job_definition = msg[2]
                    device_definition = msg[3]
                    env = msg[4]
                    env_dut = msg[5] if len(msg) == 6 else None
                except (IndexError, ValueError) as exc:
                    LOG.error("Invalid message '%s'. length=%d. %s", (msg, len(msg), exc))
                    continue

                LOG.info("[%d] Starting job", job_id)
                LOG.debug("[%d]        : %s", job_id, job_definition)
                LOG.debug("[%d] device : %s", job_id, device_definition)
                LOG.debug("[%d] env    : %s", job_id, env)
                LOG.debug("[%d] env-dut: %s", job_id, env_dut)

                # Check if the job is known and started. In this case, send
                # back the right signal (ignoring the duplication or signaling
                # the end of the job).
                if job_id in jobs:
                    if jobs[job_id].running:
                        LOG.info(
                            "[%d] Job has already been started", job_id)
                        sock.send_multipart(["START_OK", str(job_id)])
                    else:
                        LOG.warning("[%d] Job has already ended", job_id)
                        sock.send_multipart(["END", str(job_id), "0"])
                else:
                    jobs[job_id] = Job(job_id, job_definition,
                                       device_definition, env,
                                       args.socket_addr, env_dut=env_dut)
                    jobs[job_id].start()
                    sock.send_multipart(["START_OK", str(job_id)])

                # Mark the master as alive
                master.received_msg()

            elif action == "CANCEL":
                try:
                    job_id = int(msg[1])
                except (IndexError, ValueError):
                    LOG.error("Invalid message '%s'", msg)
                    continue
                LOG.info("[%d] Canceling", job_id)

                # Check if the job is known and started. In this case, send
                # back the right signal (ignoring the duplication or signaling
                # the end of the job).
                if job_id in jobs:
                    if jobs[job_id].running:
                        jobs[job_id].cancel()
                    else:
                        LOG.info(
                            "[%d] Job has already been canceled", job_id)
                else:
                    LOG.debug("[%d] Unknown job, sending END", job_id)
                    jobs[job_id] = Job(job_id, "", "", None, None)
                    jobs[job_id].running = False
                # Send the END message anyway
                sock.send_multipart(["END", str(job_id), "0"])

                # Mark the master as alive
                master.received_msg()

            elif action == "END_OK":
                try:
                    job_id = int(msg[1])
                except (IndexError, ValueError):
                    LOG.error("Invalid message '%s'", msg)
                    continue
                if job_id in jobs:
                    LOG.debug("[%d] Job END acked", job_id)
                    del jobs[job_id]
                else:
                    LOG.debug("[%d] Unknown job END acked", job_id)

                # Do not mark the master as alive. In fact we are not sending
                # back any data so the master will not be able to mark the
                # slave as alive.

            elif action == "STATUS":
                try:
                    job_id = int(msg[1])
                except (IndexError, ValueError):
                    LOG.error("Invalid message '%s'", msg)
                    continue
                if job_id in jobs:
                    if jobs[job_id].running:
                        # The job is still running
                        sock.send_multipart(["START_OK", str(job_id)])
                    else:
                        # The job has already ended
                        sock.send_multipart(["END", str(job_id), "0"])
                else:
                    # Unknown job: return END anyway
                    LOG.debug(
                        "[%d] Unknown job, sending END after STATUS", job_id)
                    jobs[job_id] = Job(job_id, "", "", None, None)
                    jobs[job_id].running = False
                    sock.send_multipart(["END", str(job_id), "0"])

                # Mark the master as alive
                master.received_msg()

            else:
                LOG.error(
                    "Unknown action: '%s', args=(%s)", action, msg[1:])
                # Do not write the master as alive as the message does not mean
                # anything.

        # Check job status
        for job_id in jobs.keys():
            if jobs[job_id].running:
                ret = jobs[job_id].proc.poll()
                # Job has finished
                if ret is not None:
                    LOG.info("[%d] Job END", job_id)
                    job_status = jobs[job_id].proc.returncode
                    if job_status:
                        LOG.info("[%d] Job returned non-zero", job_id)

                    jobs[job_id].running = False
                    sock.send_multipart(["END", str(job_id), str(job_status)])

        # Send a PING only if we haven't received a message from the master nor
        # sent a PING for a long time.
        now = time.time()
        if now - max(master.last_msg, master.last_ping) > timeout:
            # Is the master offline ?
            if master.online and now - master.last_msg > 4 * timeout:
                LOG.warning("Master goes OFFLINE")
                master.online = False

            LOG.debug(
                "Sending PING to the master (last message %ss ago)",
                int(now - master.last_msg))

            sock.send_multipart(["PING"])
            master.last_ping = now


if __name__ == "__main__":
    main()
