#!/usr/bin/python -tt
# encoding=UTF-8

# Copyright © 2011, 2012, 2013 Jakub Wilk
#
# This program is free software.  It is distributed under the terms of the GNU
# General Public License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
# You should have received a copy of the GNU General Public License along with
# this program.  If not, you can find it on the World Wide Web at
# http://www.gnu.org/copyleft/gpl.html, or write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.

import ast
import itertools
import os
import re
import sys
import warnings

builtin_exception_types = set()
pil_modules = set()
errno_constants = {}
code_copies = []
code_copies_regex = None

def load_data_file(ident):
    try:
        include_dirs = os.environ['LINTIAN_INCLUDE_DIRS']
    except LookupError:
        include_dirs = os.path.abspath(
            os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)
        )
        warnings.warn(
            'setting LINTIAN_INCLUDE_DIRS={root}'.format(root=include_dirs),
            category=RuntimeWarning,
            stacklevel=3,
        )
    for include_dir in include_dirs.split(':'):
        path = '{root}/vendors/debian/python/data/{ident}'.format(root=include_dir, ident=ident)
        if os.path.exists(path):
            break
    with open(path) as file:
        for line in file:
            line = line.strip()
            if not line:
                continue
            if line.startswith('#'):
                continue
            yield line

def load_data():
    global code_copies_regex
    builtin_exception_types.update(
        load_data_file('python-exceptions')
    )
    pil_modules.update(
        load_data_file('python-imaging-modules')
    )
    for line in load_data_file('errno-constants'):
        n, code = line.split()
        errno_constants[int(n)] = code
    regex = []
    for line in load_data_file('python{n}-embedded-code-copies'.format(n=sys.version_info[0])):
        regex += [None]
        package, regex[-1] = map(str.strip, line.split('||'))
        code_copies.append(package)
    if regex:
        regex = '|'.join('(%s)' % r for r in regex)
        code_copies_regex = re.compile(regex, re.DOTALL)

if sys.version_info >= (3, 2):
    import tokenize
    python_open = tokenize.open
    del tokenize
elif sys.version_info >= (3,):
    import py_compile
    def python_open(filename, read_encoding=py_compile.read_encoding):
        encoding = read_encoding(filename, 'utf-8')
        return open(filename, 'rU', encoding=encoding)
    del py_compile
else:
    def python_open(filename):
        return open(filename, 'rU')

def tag(*s):
    return s

class Visitor(ast.NodeVisitor):

    def __init__(self):
        class state:
            code_copy = False
        self.state = state

    def generic_visit(self, node):
        for child in ast.iter_child_nodes(node):
            for t in self.visit(child):
                yield t

    def visit_Raise(self, node):
        try:
            ex_type = node.type
        except AttributeError:
            ex_type = node.exc
        if ex_type is None:
            yield ('*reraise',)
        while isinstance(ex_type, ast.BinOp):
            ex_type = ex_type.left
        if isinstance(ex_type, ast.Str):
            yield tag('string-exception', node.lineno)
        for t in self.generic_visit(node):
            yield t

    def visit_ExceptHandler(self, node):
        node_name = None
        if isinstance(node.name, ast.Name):
            # Python 2
            node_name = node.name.id
        elif isinstance(node.name, str):
            # Python 3
            node_name = node.name
        if node_name in builtin_exception_types:
            yield tag('except-shadows-builtin', node.lineno, node_name)
        for t in self.generic_visit(node):
            yield t

    def visit_Assert(self, node):
        if isinstance(node.test, ast.Tuple) and len(node.test.elts) > 0:
            yield tag('assertion-always-true', node.lineno)
        for t in self.generic_visit(node):
            yield t

    def visit_Import(self, node):
        imp_modules = frozenset(mod.name for mod in node.names)
        imp_pil_modules = imp_modules & pil_modules
        for mod in sorted(imp_pil_modules):
            yield tag('obsolete-pil-import', node.lineno, mod)
        imp_pil_modules = (
            frozenset(mod[4:] for mod in imp_modules if mod.startswith('PIL.'))
            & pil_modules
        )
        for mod in sorted(imp_pil_modules):
            yield ('*modern-pil-import', node.lineno, mod)
        for t in self.generic_visit(node):
            yield t

    def visit_ImportFrom(self, node):
        if node.level == 0 and node.module in pil_modules:
            yield tag('obsolete-pil-import', node.lineno, node.module)
        elif node.level == 0 and node.module == 'PIL':
            imp_modules = frozenset(mod.name for mod in node.names)
            imp_pil_modules = imp_modules & pil_modules
            for mod in sorted(imp_pil_modules):
                yield ('*modern-pil-import', node.lineno, mod)
        for t in self.generic_visit(node):
            yield t

    def visit_Compare(self, node):
        if len(node.ops) == 1:
            [op] = node.ops
            left, right = [node.left] + node.comparators
            if not isinstance(left, ast.Attribute):
                left, right = right, left
            hardcoded_errno = (
                isinstance(left, ast.Attribute) and
                left.attr == 'errno' and
                isinstance(op, (ast.Eq, ast.NotEq)) and
                isinstance(right, ast.Num) and
                isinstance(right.n, int) and
                right.n in errno_constants
            )
            if hardcoded_errno:
                yield ('*hardcoded-errno-value', node.lineno, right.n)
        for t in self.generic_visit(node):
            yield t

    def visit_TryExcept(self, node):
        body_modern_pil_imp = set()
        except_modern_pil_imp = set()
        pending_body_tags = []
        pending_except_tags = []
        for child in node.body:
            for t in self.visit(child):
                if t[0] == 'obsolete-pil-import':
                    pending_body_tags += [t]
                    continue
                if t[0] == '*modern-pil-import':
                    _, _, mod = t
                    body_modern_pil_imp.add(mod)
                yield t
        for child in node.handlers:
            reraised = False
            for t in self.visit(child):
                if t[0] == 'obsolete-pil-import':
                    pending_except_tags += [t]
                    continue
                if t[0] == '*modern-pil-import':
                    _, _, mod = t
                    except_modern_pil_imp.add(mod)
                if t[0] == '*hardcoded-errno-value':
                    _, lineno, n = t
                    code = errno_constants[n]
                    yield tag('hardcoded-errno-value', lineno, n, '->', 'errno.{code}'.format(code=code))
                if t[0] == '*reraise':
                    reraised = True
                yield t
            if child.type is None and not reraised:
                yield tag('except-without-exception-type', child.lineno)
        for t in pending_body_tags:
            _, _, mod = t
            if not mod in except_modern_pil_imp:
                yield t
        for t in pending_except_tags:
            _, _, mod = t
            if not mod in body_modern_pil_imp:
                yield t
        for child in node.orelse:
            for t in self.visit(child):
                yield t

    visit_Try = visit_TryExcept

    def visit_Subscript(self, node):
        func = None
        if isinstance(node.value, ast.Call):
            call = node.value
            if isinstance(call.func, ast.Name):
                func = call.func.id
            elif isinstance(node.value.func, ast.Attribute):
                func = call.func.attr
        if func == 'mkstemp':
            if isinstance(node.slice, ast.Index):
                index = node.slice.value
                if isinstance(index, ast.Num) and index.n == 1:
                    yield tag('mkstemp-file-descriptor-leak', node.lineno)
        for t in self.generic_visit(node):
            yield t

    def visit_Str(self, node):
        if self.state.code_copy:
            return
        if code_copies_regex is None:
            return
        match = code_copies_regex.search(node.s)
        if match is None:
            return
        for match, info in zip(match.groups(), code_copies):
            if match is not None:
                break
        else:
            info = None
        if info is not None:
            yield tag('embedded-code-copy', '-', info)
            self.state.code_copy = True

    def visit_Module(self, node):
        python_pkg_prefix = 'python3-' if sys.version_info >= (3,) else 'python-'
        tabversion = None
        for child in node.body:
            if not isinstance(child, ast.Assign):
                continue
            if len(child.targets) != 1:
                continue
            [target] = child.targets
            if isinstance(target, ast.Name) and isinstance(target.ctx, ast.Store):
                if target.id == '_tabversion' and isinstance(child.value, ast.Str):
                    tabversion = child.value.s
                elif tabversion is not None:
                    if target.id == '_lr_method' and isinstance(child.value, ast.Str):
                        yield tag('missing-dependency-on-ply-virtual-package', '-', python_pkg_prefix + 'ply-yacc-' + tabversion)
                        break
                    elif target.id == '_lextokens' and isinstance(child.value, ast.Dict):
                        yield tag('missing-dependency-on-ply-virtual-package', '-', python_pkg_prefix + 'ply-lex-' + tabversion)
                        break
        for t in self.generic_visit(node):
            yield t

def check_node(node):
    return Visitor().visit(node)

def check_file(filename):
    with python_open(filename) as file:
        contents = file.read()
    return check_contents(contents)

def check_contents(contents, catch_tab_errors=True):
    if sys.version_info >= (3,):
        # Inconsistent use of tabs and spaces in indentation is always a fatal
        # error in Python 3.X.
        catch_tab_errors = False
    try:
        contents = ast.parse(contents)
    except TabError as exc:
        if catch_tab_errors:
            contents = contents.expandtabs()
            return itertools.chain(
                [tag('inconsistent-use-of-tabs-and-spaces-in-indentation', exc.lineno)],
                check_contents(contents, catch_tab_errors=False)
            )
        else:
            return [tag('syntax-error', exc.lineno, exc.msg)]
    except SyntaxError as exc:
        return [tag('syntax-error', exc.lineno, exc.msg)]
    except Exception as exc:
        return [tag('syntax-error', '-', str(exc))]
    return (
        taginfo
        for taginfo in check_node(contents)
        if not taginfo[0].startswith('*')
    )

def main():
    try:
        if sys.flags.tabcheck < 2:
            raise RuntimeError('tab check disabled')
    except AttributeError:
        pass
    load_data()
    for filename in sys.argv[1:]:
        print('# {0}'.format(filename.replace('\n', '?')))
        for t in check_file(filename):
            print(' '.join(map(str, t)))

if __name__ == '__main__':
    main()

# Local Variables:
# indent-tabs-mode: nil
# End:
# vim: syntax=python sw=4 sts=4 sr et
