# Copyright (C) 2009-2010 Raul Jimenez
# Released under GNU LGPL 2.1
# See LICENSE.txt for more information

import sys
import threading
import logging
from operator import attrgetter

import os, sys
this_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(this_dir, '..')
sys.path.append(root_dir)

import core.ptime as time
import core.identifier as identifier
from core.node import Node

sys.path.pop()

logger = logging.getLogger('dht')

NUM_OVERLAY_BOOTSTRAP_NODES = 40

MARK_INDEX = 2

ANNOUNCE_REDUNDANCY = 3


class _QueuedNode(object):

    def __init__(self, node_, distance, token):
        self.node = node_
        self.distance = distance
        self.token = token

    def __cmp__(self, other):
        # nodes without log_distance (bootstrap) go first
        if self.distance is None:
            return -1
        elif other.distance is None:
            return 1
        return self.distance.__cmp__(other.distance)


class _LookupQueue(object):

    def __init__(self, info_hash, queue_size):
        self.info_hash = info_hash
        self.queue_size = queue_size
        self.queue = [_QueuedNode(None, identifier.ID_SIZE_BITS+1, None)]
        # *_ips is used to prevent that many Ids are
        # claimed from a single IP address.
        self.queued_ips = set()
        self.queried_ips = set()
        self.queued_qnodes = []
        self.responded_qnodes = []

#        self.max_queued_qnodes = 16
        self.max_responded_qnodes = 16

        self.last_query_ts = time.time()

    def bootstrap(self, rnodes, max_nodes, overlay_bootstrap):
        # Assume that the ips are not duplicated.
        qnodes = []
        for n in rnodes:
            if n.id:
                distance = n.id.distance(self.info_hash)
            else:
                distance = None
            qnode = _QueuedNode(n, distance, None)
            qnodes.append(qnode)
        self._add_queued_qnodes(qnodes, do_sort=not overlay_bootstrap)
        return self._pop_nodes_to_query(max_nodes)

    def on_response(self, src_node, nodes, token, max_nodes):
        ''' Nodes must not be duplicated'''
        qnode = _QueuedNode(src_node,
                            src_node.id.distance(self.info_hash),
                            token)
        self._add_responded_qnode(qnode)
        qnodes = [_QueuedNode(n, n.id.distance(self.info_hash), None)
                  for n in nodes]
        self._add_queued_qnodes(qnodes)
        return self._pop_nodes_to_query(max_nodes)

    def on_timeout(self, max_nodes):
        return self._pop_nodes_to_query(max_nodes)
    # TODO: use STABLE nodes if all UNSTABLE nodes unreachable!!!
    # FIXME: ^^^^^^^^^^^

    on_error = on_timeout

    def get_closest_responded_qnodes(self,
                                     num_nodes=ANNOUNCE_REDUNDANCY):
        closest_responded_qnodes = []
        for qnode in self.responded_qnodes:
            if qnode.token:
                closest_responded_qnodes.append(qnode)
                if len(closest_responded_qnodes) == num_nodes:
                    break
        return closest_responded_qnodes

    def _add_queried_ip(self, ip):
        if ip not in self.queried_ips:
            self.queried_ips.add(ip)
            return True

    def _add_responded_qnode(self, qnode):
        self.responded_qnodes.append(qnode)
        self.responded_qnodes.sort(key=attrgetter('distance'))
        del self.responded_qnodes[self.max_responded_qnodes:]

    def _add_queued_qnodes(self, qnodes, do_sort=True):
        for qnode in qnodes:
            if qnode.node.ip not in self.queued_ips \
                    and qnode.node.ip not in self.queried_ips:
                self.queued_qnodes.append(qnode)
                self.queued_ips.add(qnode.node.ip)
        if do_sort:
            # We do not want to sort nodes coming from bootstrapper.
            # Bootstrapper relies on nodes being contacted in the same order as
            # the given list. See bootstrapper.report_unreachable.
            self.queued_qnodes.sort()

    def _pop_nodes_to_query(self, max_nodes):
        if len(self.responded_qnodes) > MARK_INDEX:
            mark = self.responded_qnodes[MARK_INDEX].distance.log
        else:
            mark = identifier.ID_SIZE_BITS
        nodes_to_query = []
        for _ in range(max_nodes):
            try:
                qnode = self.queued_qnodes[0]
            except (IndexError):
                break  # no more queued nodes left
            if qnode.distance is None or qnode.distance.log < mark:
                self.queried_ips.add(qnode.node.ip)
                nodes_to_query.append(qnode.node)
                del self.queued_qnodes[0]
                self.queued_ips.remove(qnode.node.ip)
        self.last_query_ts = time.time()
        return nodes_to_query


class GetPeersLookup(object):
    """DO NOT use underscored variables, they are thread-unsafe.
    Variables without leading underscore are thread-safe.

    All nodes in bootstrap_nodes MUST have ID.
    """

    def __init__(self, msg_f, my_id,
                 lookup_id, info_hash,
                 callback_f, bt_port=0):
        self.msg_f = msg_f

        self.bootstrap_alpha = 4
        self.normal_alpha = 4
        self.normal_m = 1
        self.slowdown_alpha = 4
        self.slowdown_m = 1

        self.start_ts = time.time()
        logger.debug('New lookup (info_hash: %r) %d' % (info_hash, bt_port))
        self._my_id = my_id
        self.lookup_id = lookup_id
        self.callback_f = callback_f
        self._lookup_queue = _LookupQueue(info_hash, 20)

        self.info_hash = info_hash
        self._bt_port = bt_port
        self._lock = threading.RLock()

        self._num_parallel_queries = 0

        self.num_queries = 0
        self.num_responses = 0
        self.num_timeouts = 0
        self.num_errors = 0

        self._running = False
        self._slow_down = False
        self._msg_factory = msg_f.outgoing_get_peers_query

        self.bootstrapper = None  # do overlay bootstrap if not None

    def _get_max_nodes_to_query(self):
        if self._slow_down:
            return min(self.slowdown_alpha - self._num_parallel_queries,
                       self.slowdown_m)
        return min(self.normal_alpha - self._num_parallel_queries,
                   self.normal_m)

    def start(self, bootstrap_rnodes, bootstrapper=None):
        assert not self._running
        self._running = True
        if bootstrap_rnodes:
            # Normal lookup
            self.bootstrapper = None
        else:
            self.bootstrapper = bootstrapper
            # OVERLAY BOOTSTRAP (using nodes from bootstrapper)
            addrs = self.bootstrapper.get_sample_unstable_addrs(
                NUM_OVERLAY_BOOTSTRAP_NODES)
            addrs.extend(self.bootstrapper.get_shuffled_stable_addrs())
            bootstrap_rnodes = [Node(addr) for addr in addrs]

        overlay_bootstrap = bool(bootstrapper)
        nodes_to_query = self._lookup_queue.bootstrap(bootstrap_rnodes,
                                                      self.bootstrap_alpha,
                                                      overlay_bootstrap)
        queries_to_send = self._get_lookup_queries(nodes_to_query)
        return queries_to_send

    def on_response_received(self, response_msg, node_):
        logger.debug('response from %r\n%r' % (node_, response_msg))
        if self.bootstrapper:
            self.bootstrapper.report_reachable(node_.addr, 0)
        self._num_parallel_queries -= 1
        self.num_responses += 1
        token = getattr(response_msg, 'token', None)
        peers = getattr(response_msg, 'peers', None)
        if peers:
            self._slow_down = True

        max_nodes = self._get_max_nodes_to_query()
        nodes_to_query = self._lookup_queue.on_response(node_,
                                                        response_msg.all_nodes,
                                                        token, max_nodes)
        queries_to_send = self._get_lookup_queries(nodes_to_query)
        lookup_done = not self._num_parallel_queries
        return (queries_to_send, peers, self._num_parallel_queries,
                lookup_done)

    def on_timeout(self, node_):
        logger.debug('TIMEOUT node: %r' % node_)
        if self.bootstrapper:
            self.bootstrapper.report_unreachable(node_.addr)
        self._num_parallel_queries -= 1
        self.num_timeouts += 1
        self._slow_down = True

        max_nodes = self._get_max_nodes_to_query()
        nodes_to_query = self._lookup_queue.on_timeout(max_nodes)
        queries_to_send = self._get_lookup_queries(nodes_to_query)
        lookup_done = not self._num_parallel_queries
        return (queries_to_send, self._num_parallel_queries,
                lookup_done)

    def on_error_received(self, error_msg, node_addr):
        logger.debug('Got error from node addr: %s:%s', node_addr[0], node_addr[1])
        self._num_parallel_queries -= 1
        self.num_errors += 1

        max_nodes = self._get_max_nodes_to_query()
        nodes_to_query = self._lookup_queue.on_error(max_nodes)
        queries_to_send = self._get_lookup_queries(nodes_to_query)
        lookup_done = not self._num_parallel_queries
        return (queries_to_send, self._num_parallel_queries,
                lookup_done)

    def _get_lookup_queries(self, nodes):
        queries = []
        for node_ in nodes:
            if node_.id and node_.id == self._my_id:
                # Don't send to myself
                continue
            self._num_parallel_queries += 1
            self.num_queries += len(nodes)
            queries.append(self._msg_factory(node_, self.info_hash, self))
        return queries

    def announce(self):
        if not self._bt_port:
            return [], False
        nodes_to_announce = self._lookup_queue.get_closest_responded_qnodes()
        announce_to_myself = False
        # TODO: is is worth it to announce to self? The problem is that I don't
        # know my own IP number. Maybe if 127.0.0.1 translates into "I (the
        # node returning 127.0.0.1) am in the swarm".
        '''
        if len(nodes_to_announce) < ANNOUNCE_REDUNDANCY:
            announce_to_myself = True
        elif (self._my_id.log_distance(self.info_hash) <
              nodes_to_announce[ANNOUNCE_REDUNDANCY-1].id.log_distance(
                self.info_hash)):
            nodes_to_announce = nodes_to_announce[:-1]
            announce_to_myself = True
        '''
        queries_to_send = []
        for qnode in nodes_to_announce:
            logger.debug('announcing to %r' % qnode.node)
            query = self.msg_f.outgoing_announce_peer_query(
                qnode.node, self.info_hash, self._bt_port, qnode.token)
            queries_to_send.append(query)
        return queries_to_send, announce_to_myself

    def get_closest_responded_hexids(self):
        return ['%r' % qnode.node.id for
                qnode in self._lookup_queue.get_closest_responded_qnodes()]


class MaintenanceLookup(GetPeersLookup):

    def __init__(self, msg_f, my_id, target):
        GetPeersLookup.__init__(self, msg_f, my_id,
                                None, target, None, 0)
        self._target = target
        self.bootstrap_alpha = 4
        self.normal_alpha = 4
        self.normal_m = 1
        self.slowdown_alpha = 4
        self.slowdown_m = 1
        self._msg_factory = msg_f.outgoing_find_node_query


class LookupManager(object):

    def __init__(self, my_id, msg_f, bootstrapper):
        self.my_id = my_id
        self.msg_f = msg_f
        self.bootstrapper = bootstrapper

    def get_peers(self, lookup_id, info_hash, callback_f, bt_port=0):
        lookup_q = GetPeersLookup(self.msg_f, self.my_id,
                                  lookup_id, info_hash,
                                  callback_f, bt_port)
        return lookup_q

    def maintenance_lookup(self, target=None):
        target = target or self.my_id
        lookup_q = MaintenanceLookup(self.msg_f, self.my_id, target)
        return lookup_q
