#!/usr/bin/env python

# Copyright (c) 2016- The University of Notre Dame.
# This software is distributed under the GNU General Public License.
# See the file COPYING for details. 

# This script defined a customized mesos executor. Compare to the default executor, 
# it can send the url of sandbox directory of each executor to the scheduler. And
# then the Mesos scheduler is able to retrieve the outputs. 

import os
import sys
import json
import urllib2
import shutil
import logging
import threading
import subprocess

sys.path.insert(0, sys.argv[4])

from mesos.native import MesosExecutorDriver
from mesos.interface import Executor
from mesos.interface import mesos_pb2

logging.basicConfig(filename=('{0}.log'.format(sys.argv[2])), 
                    level=logging.INFO,
                    format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                    datefmt='%m-%d %H:%M',)

DEFAULT_PORT = 5051

class MakeflowMesosExecutor(Executor):

    def __init__(self, cmd, executor_id, framework_id, hostname):
        self.hostname = hostname
        self.port = DEFAULT_PORT
        self.cmd = cmd
        self.executor_id = executor_id
        self.framework_id = framework_id

    def disconnected(self, driver):
        driver.sendFrameworkMessage("[EXECUTOR_STATE] {0} {1} disconnected".format(self.executor_id, self.task_id))

    def get_sandbox_dir(self, hostname):
        slave_state_uri = "http://{0}:{1}/state.json".format(hostname, self.port)

        slave_state = json.load(urllib2.urlopen(slave_state_uri))
        executors_data = slave_state['frameworks'][0]['executors']

        for executor_data in executors_data:

            if executor_data['id'] == self.executor_id:
                # The task is in the completed_tasks lists
                completed_tasks = executor_data['completed_tasks']
                for completed_task in completed_tasks:
                    if completed_task['id'] == self.task_id:
                        return executor_data['directory']
               
                # due to the network delay, the task is in the tasks lists
                tasks = executor_data['tasks']
                for task in tasks:
                    if task['id'] == self.task_id:
                        return executor_data['directory']
                
                logging.error("Task {0} does not appear in the tasks list\
                        of executor {1}.".format(self.task_id, executor_id))

                return None

    def launchTask(self, driver, task):
        def run_task():
            print "Running task %s" % task.task_id.value
            update = mesos_pb2.TaskStatus()
            update.task_id.value = task.task_id.value
            update.state = mesos_pb2.TASK_RUNNING
            self.task_id = task.task_id.value
            driver.sendStatusUpdate(update)
           
            tmp_io = sys.stdout 
            sys.stdout = sys.stderr
            for uri in task.executor.command.uris:
                inp_fn = os.path.basename(uri.value)
                inp_fn_size = os.path.getsize(inp_fn)
                print "task {0} input: {1} {2}".format(task.task_id.value, uri.value, inp_fn_size)
            sys.stdout = tmp_io

            # Launch the makeflow task

            print "Sending status update..."
            update = mesos_pb2.TaskStatus()
            update.task_id.value = task.task_id.value

            try:  
                subprocess.check_call(self.cmd, shell=True)
                
                with open ("stderr", "r") as stderr_fd:
                    stderr_msg = stderr_fd.read()   

                data_msg = "stderr of task {0} \n{1}".format(self.task_id, stderr_msg)
                update.data = data_msg; 
                # send the sandbox URI to the scheduler
                sandbox_dir = self.get_sandbox_dir(self.hostname)
                get_dir_addr = "[EXECUTOR_OUTPUT] http://{0}:{1}/files/download?path={2}".format(\
                    self.hostname, self.port, sandbox_dir)
                print "{0}".format(get_dir_addr)
                task_id_msg = "task_id {0}".format(task.task_id.value)
                message = "{0} {1}".format(get_dir_addr, task_id_msg) 
                logging.info("Sending message: {0}".format(message))
                print "Sent output file URI" 
                update.state = mesos_pb2.TASK_FINISHED
                update.message = message
                update.executor_id.value = self.executor_id

            except subprocess.CalledProcessError as e:
                returncode = e.returncode
                update.state = mesos_pb2.TASK_FAILED
                update.message = str(returncode) 

            driver.sendStatusUpdate(update)
            print "Sent status update"
            
            driver.sendFrameworkMessage("[EXECUTOR_STATE] {0} stopped".format(\
                    self.executor_id))
            # driver.stop()

        thread = threading.Thread(target=run_task)
        thread.start()
        thread.join()

    def frameworkMessage(self, driver, message):
        print "receive message {0}".format(message)
        message_list = message.split()
        if message_list[1].strip(' \t\n\r') == "abort":
            logging.info("task {0} aborted".format(self.task_id))
            driver.sendFrameworkMessage("[EXECUTOR_STATE] {0} aborted {1}".format(\
                    self.executor_id, self.task_id))
            driver.stop()
        if message_list[1].strip(' \t\n\r') == "retrieve":
            print "The output of task {0} has been retrieved by master".format(self.task_id)
            print "Removing the sandbox directory....."
            shutil.rmtree(self.get_sandbox_dir(self.hostname))
            driver.stop()
            
if __name__ == '__main__':
    print "starting makeflow mesos executor!"
    driver = MesosExecutorDriver(MakeflowMesosExecutor(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[5]))

    status = 0 if driver.run() == mesos_pb2.DRIVER_STOPPED else 1

    sys.exit(status)

# vim: set noexpandtab tabstop=4:
