# -*- coding: utf-8 -*-
#   WillowNG - Content Filtering Web Proxy 
#   Copyright (C) 2006  Travis Watkins
#
#   This library is free software; you can redistribute it and/or
#   modify it under the terms of the GNU Library General Public
#   License as published by the Free Software Foundation; either
#   version 2 of the License, or (at your option) any later version.
#
#   This library is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
#   Library General Public License for more details.
#
#   You should have received a copy of the GNU Library General Public
#   License along with this library; if not, write to the Free Software
#   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#   Based on Reverend, Copyright (C) 2003 Amir Bakhtiar

import operator
import string, re
import math
import urllib2
import gobject

class BayesData:
    timer = None

    def __init__(self, db_list, name, cache_version=False):
        self.db_list = db_list
        #awful hack, will work it out eventually
        if cache_version:
            self.name = name + '__cache'
        else:
            self.name = name

    #don't want to commit right away for performance reasons
    #don't want to make the app have to commit either
    def __requestCommit(self):
        if self.timer:
            gobject.source_remove(self.timer)
            self.timer = None
        self.timer = gobject.timeout_add(3, self.__doCommit)

    def __doCommit(self):
        #we only commit things to the user DB, not the builtin one
        self.db_list[1].commit()

    def get(self, key, default=None):
        count = 0
        for db in self.db_list:
            cur = db.cursor()
            cur.execute('SELECT count FROM %s WHERE token = ?' % self.name, (key,))
            row = cur.fetchone()
            if row == None and default != None:
                count += int(default)
            else:
                count += row[0]
        return count

    def has_key(self, word):
        if self.get(word) == 0:
            return False
        return True

    def __getitem__(self, key):
        return self.get(key)

    def __setitem__(self, key, val):
        cur = self.db_list[1].cursor()
        if self.get(key) == 0:
            cur.execute('INSERT INTO %s (token, count) VALUES(?, ?)' % self.name, (key, val))
        else:
            cur.execute('UPDATE %s SET count = ? WHERE token = ?' % self.name, (val, key))
        self.__requestCommit()

    def items(self):
        all = {}
        for db in self.db_list:
            cur = db.cursor()
            cur.execute('SELECT token, count FROM %s' % self.name)
            temp = cur.fetchall()
            for key, value in temp:
                if all.has_key(key):
                    all[key] += value
                else:
                    all[key] = value
        return all.items()

    def clear(self):
        cur = self.db_list[1].cursor()
        cur.execute('DELETE FROM %s' % self.name)
        self.__requestCommit()

    def __repr__(self):
        return '<BayesDict: %s, %s tokens>' % (self.name, self.tokenCount)

    def get_token_count(self):
        count = 0
        for db in self.db_list:
            cur = db.cursor()
            cur.execute('SELECT COUNT(token) FROM %s' % self.name)
            count += cur.fetchone()[0]
        return count

    tokenCount = property(get_token_count)

class Bayes(object):
    cache = {}

    def __init__(self, db):
        self.db = db
        self.corpus = BayesData(self.db, '__Corpus__')
        self.pools = {}
        self.pools['__Corpus__'] = self.corpus
        self.pools.setdefault('good', BayesData(self.db, 'good'))
        self.pools.setdefault('bad', BayesData(self.db, 'bad'))
        self.dirty = True

    def cleanup(self):
        #clear out caches
        BayesData(self.db, 'bad__cache').clear()
        BayesData(self.db, 'good__cache').clear()

    def buildCache(self):
        """ merges corpora and computes probabilities
        """
        self.cache = {}
        for pname, pool in self.pools.items():
            # skip our special pool
            if pname == '__Corpus__':
                continue
            
            poolCount = pool.tokenCount
            themCount = max(self.corpus.tokenCount - poolCount, 1)
            cacheDict = self.cache.setdefault(pname, BayesData(self.db, pname, True))

            for word, totCount in self.corpus.items():
                # for every word in the corpus
                # check to see if this pool contains this word
                thisCount = float(pool.get(word, 0.0))
                if (thisCount == 0.0):
                    continue
                otherCount = float(totCount) - thisCount

                if not poolCount:
                    goodMetric = 1.0
                else:
                    goodMetric = min(1.0, otherCount/poolCount)
                badMetric = min(1.0, thisCount/themCount)
                f = badMetric / (goodMetric + badMetric)
                
                # PROBABILITY_THRESHOLD
                if abs(f-0.5) >= 0.1 :
                    # GOOD_PROB, BAD_PROB
                    cacheDict[word] = max(0.0001, min(0.9999, f))
                    
    def poolProbs(self):
        if self.dirty:
            self.buildCache()
            self.dirty = False
        return self.cache

    def getTokens(self, data):
        temp = data.split('>')

        tokens = []
        for item in temp:
            if len(item) == 0:
                continue
            if item[0] == '<':
                continue
            if '<' in item:
                item = item.split('<')[0]
            words = item.strip().split(' ')
            for word in words:
                if len(word) and word[0] in ('.', ',', ';', ':'):
                    word = word[1:]
                if len(word) and word[-1] in ('.', ',', ';', ':'):
                    word = word[:-1]
                if len(word) > 4:
                    tokens.append(word.lower())
        return tokens

    def getProbs(self, pool, words):
        """ extracts the probabilities of tokens in a message
        """
        probs = []
        for word in words:
            if pool.has_key(word):
                pool[word]
                probs.append((word, pool[word]))
        probs.sort(lambda x,y: cmp(y[1],x[1]))
        return probs[:2048]

    def train(self, pool, url):
        req = urllib2.urlopen(url)
        data = req.read()
        tokens = self.getTokens(data)
        pool = self.pools.setdefault(pool, BayesData(self.db, pool))
        for token in tokens:
            count = pool.get(token, 0)
            pool[token] =  count + 1
            count = self.corpus.get(token, 0)
            self.corpus[token] =  count + 1
        self.dirty = True

    def isBad(self, msg):
        results = self.guess(msg)
        if len(results) and results[0][0] == 'bad':
            return True

    def guess(self, msg):
        tokens = self.getTokens(msg)
        pools = self.poolProbs()
        res = {}
        for pname, pprobs in pools.items():
            p = self.getProbs(pprobs, tokens)
            if len(p) != 0:
                res[pname] = self.robinsonFisher(p, pname)
        res = res.items()
        res.sort(lambda x,y: cmp(y[1], x[1]))
        return res        

    def robinsonFisher(self, probs, ignore):
        """ computes the probability of a message being spam (Robinson-Fisher method)
            See http://www.linuxjournal.com/article/6467
        """
        n = len(probs)
        try:
            H = chi2P(-2.0 * math.log(reduce(operator.mul, (p[1] for p in probs), 1.0)), 2*n)
        except (OverflowError, ValueError):
            H = 0.0
        try:
            S = chi2P(-2.0 * math.log(reduce(operator.mul, (1.0-p[1] for p in probs), 1.0)), 2*n)
        except (OverflowError, ValueError):
            S = 0.0
        return (1 + H - S) / 2

def chi2P(chi, df):
    """ return P(chisq >= chi, with df degree of freedom)

    df must be even
    """
    assert df & 1 == 0
    m = chi / 2.0
    sum = term = math.exp(-m)
    for i in range(1, df/2):
        term *= m/i
        sum += term
    return min(sum, 1.0)

