#!/usr/bin/python
########################################################################
### FILE:	greylistd.py
### PURPOSE:	Simple greylisting daemon.  See "greylistd(8)".
###		For an introduction to greylisting, see:
### 		http://projects.puremagic.com/greylisting/
###
### 		This program listens for connections on a UNIX domain
###		socket, presumably from an MTA such as Exim.  Nominally, 
###		it reads an identifier (referred to as a "triplet"),
###		and returns a single word ("white" or "grey") depending
###             on prior knowledge of said identifier.
###
### Copyright (C) 2004, Tor Slettnes <tor@slett.net>
###
### This program is free software; you can redistribute it and/or modify
### it under the terms of the GNU General Public License as published by
### the Free Software Foundation; either version 2 of the License, or
### (at your option) any later version.
###
### This program is distributed in the hope that it will be useful,
### but WITHOUT ANY WARRANTY; without even the implied warranty of
### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
### GNU General Public License for more details.
###
### On Debian GNU/Linux systems, the complete text of the GNU General
### Public License can be found in `/usr/share/common-licenses/GPL'.
### It is also available at: http://www.gnu.org/licenses/gpl.html
########################################################################

from time         import time, ctime, localtime, strftime
from socket       import socket, AF_UNIX, SOCK_STREAM, error as SocketError
from os           import remove, rename, chmod, chown, getuid, getpid, isatty
from pwd          import getpwnam
from grp          import getgrnam
from os.path      import join, exists
from sys          import version as PyVersion, stdout, stderr, exit, exc_info
from signal       import signal, SIGTERM, SIGHUP, SIGUSR1, SIG_IGN, SIG_DFL
from syslog       import openlog, syslog, LOG_NOTICE, LOG_WARNING, LOG_ERR
from select       import select
from inspect      import getargspec


### Ensure that we can run this program
if PyVersion < "2.3":
    stderr.write("This program requires Python 2.3 or newer\n")
    exit(1)


### Configuration file sections, items
(DATA, STATEFILE, TRIPLETFILE, SAVETRIPLETS, UPDATEINTERVAL) = (
    "data", "statefile", "tripletfile", "savetriplets", "update" )

(SOCKET, SOCKPATH, SOCKOWNER, SOCKMODE) = (
    "socket", "path", "owner", "mode")

(TIMEOUTS, RETRYMIN, RETRYMAX, EXPIRE) = (
    "timeouts", "retryMin", "retryMax", "expire")


### Defaults for various configuration items
conffile   = "/etc/greylistd/config"

config     = { DATA       : { STATEFILE      : "/var/lib/greylistd/states",
                              TRIPLETFILE    : "/var/lib/greylistd/triplets",
                              SAVETRIPLETS   : True,
                              UPDATEINTERVAL : 300 },

               SOCKET     : { SOCKPATH       : "/var/run/greylistd/socket",
                              SOCKOWNER      : "greylist:greylist",
                              SOCKMODE       : "0660" },

               TIMEOUTS   : { RETRYMIN       : 60 * 60,
                              RETRYMAX       : 60 * 60 * 8,
                              EXPIRE         : 60 * 60 * 24 * 60 } }


### Lists/states
(WHITE, GREY, BLACK)   = ("white", "grey", "black")

### Additional data file sections/items
(STATS, START, LASTSAVE) = ("statistics", "start", "lastsave")

### Greylist data
data       = { WHITE      : {},
               GREY       : {},
               BLACK      : {},
               STATS      : { WHITE   : 0,
                              GREY    : 0,
                              BLACK   : 0,
                              START   : 0,
                              LASTSAVE: 0 }}

### Type conversions for items in data file
datatypes  = { WHITE : (int, (int, int, int)),
               GREY  : (int, (int, int, int)),
               BLACK : (int, (int, int, int)) }

### Index of elements in data
(IDX_LAST, IDX_FIRST, IDX_COUNT) = range(3)

### Unhashed/original data not yet saved to disk
newTriplets = {}


class RunException (Exception):
    pass

class CommandError (RunException):
    pass


def expireKeys (now):
    for (listKey, timeoutKey) in ((GREY,  RETRYMAX),
                                  (WHITE, EXPIRE),
                                  (BLACK, EXPIRE)):
        for (dataKey, dataValue) in data[listKey].items():
            if dataValue[IDX_LAST] + config[TIMEOUTS][timeoutKey] < now:
                del data[listKey][dataKey]


def listStatus (searchkey):
    for listkey in datatypes:
        if searchkey in data[listkey]:
            return listkey

    return None



def log (message, priority=LOG_NOTICE):
    if isatty(stderr.fileno()):
        stderr.write("%s\n"%(message,))
    else:
        syslog(priority, message)


def duration (secs):
    plural = ("", "s")

    if secs < 60:
        return "%d second%s"%(secs, plural[secs > 1])

    elif secs < 60 * 60:
        (mins, secs) = (secs / 60, secs % 60)
        return "%s%s%s%s%s%s%s" % (mins, " minute", plural[mins != 1],
                                   secs and " and " or "",
                                   secs or "",
                                   secs and " second" or "",
                                   plural[secs > 1] or "")

    elif (secs + 30) < 60 * 60 * 24:
        (hrs, mins) = ((secs + 30) / 3600, ((secs + 30) / 60) % 60)
        return "%s%s%s%s%s%s%s" % (hrs, " hour", plural[hrs != 1],
                                   mins and " and " or "",
                                   mins or "",
                                   mins and " minute" or "",
                                   plural[mins > 1] or "")

    else:
        (days, hrs) = ((secs + 1800) / 86400, ((secs + 1800) / 3600) % 24)
        return "%s%s%s%s%s%s%s" % (days, " day", plural[days != 1],
                                   hrs and " and " or "",
                                   hrs or "",
                                   hrs and " hour" or "",
                                   plural[hrs > 1] or "")



def typeConvert(typeobject, string):
    if type(typeobject) in (tuple, list):
        typelist   = typeobject
        stringlist = string.split(None, len(typeobject) - 1)
        valuelist  = []

        for idx, typeobject in enumerate(typelist):
            word = stringlist[idx]
            valuelist.append(typeConvert(typeobject, word))

        return valuelist

    elif type(typeobject) is not type:
        return typeConvert(type(typeobject), string)

    elif typeobject is bool:
        try:
            if int(string):
                return True
            else:
                return False

        except ValueError:
            if string.lower() in ("yes", "true", "on"):
                return True
            elif string.lower() in ("no", "false", "off"):
                return False
            else:
                raise ValueError, "Not a valid boolean: '%s'"%string

    else:
        return(typeobject(string))




def loadFromFile (datafile, dictionary, typelist=None):
    try:
        fp      = file(datafile)
        section = None

        logPfx  = 'In %s:'%datafile

        for line in fp:
            line = line.strip()

            if line.startswith("#"):
                continue

            elif (line[0:1] == '[') and (']' in line):
                section = line[1:line.find(']')].strip().lower()

                if not section in dictionary:
                    log("%s Invalid or obsolete section: [%s]"%
                        (logPfx, section))
                    section = None

            elif section and ('=' in line):
                key, data = map(lambda s: s.strip(), line.split('=', 1))

                if typelist and (section in typelist):
                    keytype, valuetype = typelist[section]

                elif key in dictionary[section]:
                    keytype      = str
                    valuetype    = dictionary[section][key]
                else:
                    log("%s Invalid or obsolete key: [%s] %s"%
                        (logPfx, section, key))
                    continue

                try:
                    key                      = typeConvert(keytype, key)
                    dictionary[section][key] = typeConvert(valuetype, data)

                except ValueError:
                    log("%s Invalid value for [%s] %s: '%s'"%
                        (logPfx, section, key, data))

                except IndexError:
                    log("%s Too few values for [%s] %s (%d, should be %d)"%
                        (logPfx, section, key, len(stringlist), len(typelist)))

            elif line:
                log("%s Invalid line: '%s'"%(logPfx, line))

        fp.close()

    except IOError, e:
        raise RunException, "Cannot read from '%s': %s"%(datafile, e[1])



def saveToFile (datafile, dictionary, perm=0600):
    try:
        fp = file(datafile, 'w')

        chmod(datafile, perm)
        
        for (section, subdict) in dictionary.items():
            fp.write("[%s]\n"%section)

            for (key, value) in subdict.items():
                if type(value) in (list, tuple):
                    value = " ".join(map(str, value))
                fp.write("%s = %s\n"%(key, value))

            fp.write("\n")

        fp.close()

    except IOError, e:
        raise RunException, "Cannot write to %s: %s"%(datafile, e[1])

    except OSError, e:
        raise RunException, \
              "Cannot set mode 0%o on %s: %s"%(perm, datafile, e[1])



def loadConfigAndData ():
    now = int(time())

    try:
        loadFromFile(conffile, config)
    except RunException, e:
        log(str(e))

    try:
        loadFromFile(config[DATA][STATEFILE], data, datatypes)
    except RunException, e:
        data[STATS][START] = now

    expireKeys(now)
    data[STATS][LASTSAVE] = now
    


def saveData (datafile):
    ### Save data hashes and timestamps
    tempfile = "%s.%s"%(datafile, getpid())

    saveToFile(tempfile, data)
    
    try:
        rename(tempfile, datafile)
    except OSError, e:
        raise RunException, "Cannot rename %s to %s: %s"%(tempfile,
                                                           datafile, e[1])



def syncTriplets (datafile, perm=0600):
    source = datafile
    target = "%s.%s"%(source, getpid())

    try:
        infile  = file(source, "r")
    except IOError, e:
        infile  = None

    try:
        outfile = file(target, "w")

        chmod(target, perm)

        if infile:
            for line in infile:
                try:
                    (key, value) = line.split(" = ", 1)
                    key          = int(key)

                    if listStatus(key) and not key in newTriplets:
                        outfile.write(line)

                except ValueError:
                    continue

        for (key, data) in newTriplets.items():
            if listStatus(key):
                outfile.write("%d = %s\n"%(key, data))

        newTriplets.clear()
        outfile.close()

    except IOError, e:
        raise RunException, "Could not write to %s: %s"%(target, e[1])

    except OSError, e:
        raise RunException, \
              "Cannot set mode 0%o on %s: %s"%(perm, target, e[1])


    if infile:
        infile.close()

    try:
        rename(target, source)
    except OSError, e:
        raise RunException, \
               "Could not rename %s to %s: %s"%(target, source, e[1])



def listTriplets (fp, socket, options):
    parseErrors = False
    firstPass   = True
    listformat  = "  %-20s %5s  %s\n"

    for listkey in (options or datatypes):
        if not listkey in data:
            raise CommandError, "Invalid list: %s"%listkey

        elif not data[listkey]:
            continue

        listdata = data[listkey]
        line     = "%slist data:"%listkey.capitalize()
        dash     = "="*len(line)
        client.send("\n%s\n%s\n"%(line, dash))
        client.send(listformat%("Last Seen", "Count", "Data"))

        fp.seek(0)
        for line in fp:
            try:
                (key, value) = line.split(" = ", 1)
                key          = int(key)

                if key in listdata:
                    last, first, num = listdata[key]
                    ldate = strftime("%Y-%m-%d %H:%M:%S", localtime(last))
                    client.send(listformat%(ldate, num, value.strip()))
            
            except ValueError:
                if not parseErrors:
                    log("While reading triplets from %s:"%
                        (config[DATA][TRIPLETFILE],))
                    parseErrors = True

                if firstPass:
                    log("Invalid line: '%s'"%line)

        firstPass = False



def do_add (options, key):
    if not options:
        state = WHITE
    elif options[0] in datatypes:
        state = options[0]
    else:
        raise CommandError, "No such list: '%s'"%options[0]

    now       = int(time())
    oldstate  = listStatus(key)

    if state == oldstate:
        (lastseen, firstseen, count) = data[state][key]

    else:
        (lastseen, firstseen, count) = (now, now, 0)
        data[STATS][state] += 1
        if oldstate:
            del data[oldstate][key]

    data[state][key] = (now, firstseen, count+1)
    return "Added to %slist"%state



def do_delete (options, key):
    listkey = listStatus(key)
    if listkey:
        del data[listkey][key]
        return "Removed from %slist"%listkey
    else:
        raise CommandError, "Not found"


def do_check (options, key, update=False):
    if options:
        truthtest = options[0]
        if not truthtest in datatypes:
            raise CommandError, "'%s' is not a known state"%truthtest
    else:
        truthtest = None

    now       = int(time())
    expireKeys(now)
    state     = listStatus(key)

    if state is None:
        state = GREY

    elif ((state == GREY) and 
          (data[GREY][key][IDX_FIRST] + config[TIMEOUTS][RETRYMIN] < now)):
        state = WHITE

    if update:
        do_add([ state ], key)

    if truthtest:
        return (state == truthtest) and "true" or "false"
    else:
        return state



def do_update (options, key):
    return do_check(options, key, update=True)



def do_stats ():
    text      = []
    now       = int(time())
    stats     = data[STATS]
    starttime = stats.get(START, None)
    expireKeys(now)


    if starttime:
        title = "Statistics since %s (%s ago)"%(
            ctime(starttime), duration(now - starttime))
    else:
        title = "Statistics"

    text.append(title)
    text.append("-" * len(title))
    hits    = {}
    items   = {}

    for listkey in datatypes:
        items[listkey] = len(data[listkey])
        hits[listkey]  = 0
        for (key, value) in data[listkey].items():
            (lastseen, firstseen, count) = value
            hits[listkey] += count


    for listkey in datatypes:
        hitdigits  = len(str(max(hits.values())))
        itemdigits = len(str(max(items.values())))

        text.append("%s items, matching %s requests, are currently %slisted"%
                    (str(items[listkey]).rjust(itemdigits),
                     str(hits[listkey]).rjust(hitdigits),
                     listkey))


    previousGrey  = stats[GREY] - len(data[GREY])
    expiredGrey   = previousGrey - stats[WHITE]

    if previousGrey:
        digits = len(str(previousGrey))

        text.append("")
        text.append("Of %s items that were initially greylisted:"%
                    str(previousGrey).rjust(digits))

        text.append(" - %s (%5.1f%%) became whitelisted"%
                    (str(stats[WHITE]).rjust(digits),
                     100.0 * stats[WHITE] / previousGrey))

        text.append(" - %s (%5.1f%%) expired from the greylist"%
                    (str(expiredGrey).rjust(digits),
                     100.0 * expiredGrey / previousGrey))

    text.append('')
    return "\n".join(text)



def do_list (options, socket):
    if not config[DATA][TRIPLETFILE] or not config[DATA][SAVETRIPLETS]:
        raise CommandError, "Original triplet data is not retained."

    do_save()

    try:
        infile = file(config[DATA][TRIPLETFILE], "r")
        error = listTriplets(infile, socket, options)
        infile.close()

    except IOError, e:
        raise CommandError, \
              "Cannot read from '%s': %s\n"%(config[DATA][TRIPLETFILE], e[1])
    

def do_clear (options):
    for listkey in (options or datatypes):
        if listkey in data:
            data[listkey].clear()
            data[STATS][listkey] = 0
        else:
            raise CommandError, "Invalid list: '%s'"%listkey

    if not options:
        data[STATS][START] = int(time())

    return "data and statistics cleared"


def do_reload ():
    do_save()
    loadConfigAndData()
    return "configuration and data reloaded"


def do_save ():
    now = int(time())
    expireKeys(now)

    ### Save data hashes and timestamps
    saveData(config[DATA][STATEFILE])

    ### Save unhashed triplets
    if newTriplets:
        syncTriplets(config[DATA][TRIPLETFILE])

    data[STATS][LASTSAVE] = now

    return "greylistd data has been saved"



def nodata ():
    raise CommandError, "No data received"


def runCommand (line, client):
    now     = int(time())
    words   = line.lower().split()
    options = []

    if not words:
        function = nodata
    elif "do_"+words[0] in globals():
        function = globals()["do_"+words.pop(0)]
    else:
        function = do_update

    args, varargs, varkw, defaults =  getargspec(function)

    while words and words[0].startswith("-"):
        options.append(words.pop(0).lstrip("-"))

    arglist = []
    key     = None
    useargs = False

    for arg in args:
        if arg == "options":
            arglist.append(options)

        elif arg == "key":
            key = hash(" ".join(words))
            useargs = True
            arglist.append(key)

        elif arg == "args":
            useargs = True
            arglist.append(words)

        elif arg == "socket":
            arglist.append(client)

        else:
            break

    try:
        if useargs and not words:
            raise CommandError, "Missing argument"

        if words and not useargs:
            raise CommandError, "Too many arguments"

        if key and config[DATA][TRIPLETFILE] and config[DATA][SAVETRIPLETS]:
            newTriplets[key] = " ".join(words).lower()

        return function(*tuple(arglist)) or ""


    except RunException, e:
        if not isinstance(e, CommandError):
            log(str(e))

        return "error: %s"%str(e).replace("\n", "\n       ")



def createSocket (path, owner, mode):
    sock = socket(AF_UNIX, SOCK_STREAM)
    sock.bind(path)

    if getuid() == 0:
        try:
            if ":" in owner:
                user, group = owner.split(":", 1)
            else:
                user, group = owner, None

            (name, passwd, uid, gid, gecos, dir, shell) = getpwnam(user)

            if group:
                (name, passwd, gid, members) = getgrnam(group)

        except KeyError:
            raise RunException, \
                  "Invalid owner specified in configuration file: %s: %s"%owner

        try:
            chown(path, uid, gid)
        except OSError, e:
            raise RunException, \
                  "Could not change ownership of socket %s: %s"%(path, e[1])

    try:
        chmod(path, int(mode, 8))

    except ValueError:
        raise RunException, \
              "Specified socket mode '%s' is not a valid octal number"%mode
    except OSError, e:
        raise RunException, \
              "Could not set mode 0%o on socket %s: %s"%(mode, path, e[1])

    sock.listen(5)
    return sock



def startup ():
    global listener, sockets
    listener = None
    sockets  = []

    signal(SIGTERM, term)
    signal(SIGHUP,  hangup)
    openlog("greylistd")
    loadConfigAndData()

    try:
        listener = createSocket(**config[SOCKET])

    except SocketError, e:
        log("Could not bind/listen to socket %s: %s"%(config[SOCKET][SOCKPATH], str(e)))
        exit(-1)

    except RunException, e:
        log(str(e))
        cleanup(False)
        exit(-1)

    sockets  = [ listener ]




def cleanup (save=True):
    if exists(config[SOCKET][SOCKPATH]):
        remove(config[SOCKET][SOCKPATH])

    if save:
        do_save()


def term (signum=None, frame=None):
    cleanup()
    exit(0)


def hangup (signum=None, frame=None):
    do_reload()


startup()

try:
    while sockets:
        interval = config[DATA][UPDATEINTERVAL]
        lastsave = data[STATS][LASTSAVE]
        
        if interval and (lastsave + interval < time()):
            (inlist, outlist, errlist) = select(sockets, [], [], 0)
        else:
            (inlist, outlist, errlist) = select(sockets, [], [])


        if not inlist:
            try:
                do_save()
            except RunException, e:
                log(str(e))


        elif inlist[0] is listener:
            (client, addr) = listener.accept()
            sockets.append(client)

        else:
            client = inlist[0]

            try:
                line = client.recv(16384)
                reply = runCommand(line, client)
                client.send(reply)

            except SocketError, e:
                log("Socket error: %s"%e[1])

            client.close()
            sockets.remove(client)


except SystemExit:
    pass

except KeyboardInterrupt:
    cleanup()

except Exception, e:
    (type, value, tb) = exc_info()
    while tb.tb_next:
        tb       = tb.tb_next

    frame    = tb.tb_frame
    code     = frame.f_code
    line     = frame.f_lineno
    filename = code.co_filename
            
    log("### Fatal event in %s, line %d:"%(filename, line), LOG_ERR)
    log(">>> %s"%e, LOG_ERR)

    cleanup(save=False)
