#!/usr/bin/env python

# u1trial: Test runner for Python unit tests needing DBus
#
# Author: Rodney Dawes <rodney.dawes@canonical.com>
#
# Copyright 2009-2010 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Test runner that uses a private dbus session and glib main loop."""

import coverage
import gc
import inspect
import os
import re
import sys
import unittest

from twisted.trial.runner import TrialRunner

sys.path.insert(0, os.path.abspath("."))


class TestRunner(TrialRunner):
    """The test runner implementation."""

    def __init__(self):
        # install the glib2reactor before any import of the reactor to avoid
        # using the default SelectReactor and be able to run the dbus tests
        from twisted.internet import glib2reactor
        glib2reactor.install()
        from twisted.trial.reporter import TreeReporter

        # setup a custom XDG_CACHE_HOME and create the logs directory
        xdg_cache = os.path.join(os.getcwd(), "_trial_temp", "xdg_cache")
        os.environ["XDG_CACHE_HOME"] = xdg_cache
        # setup the ROOTDIR env var
        os.environ['ROOTDIR'] = os.getcwd()
        if not os.path.exists(xdg_cache):
            os.makedirs(xdg_cache)

        self.tempdir = os.path.join(os.getcwd(), "_trial_temp")
        working_dir = os.path.join(self.tempdir, 'tmp')
        super(TestRunner, self).__init__(reporterFactory=TreeReporter,
                                         realTimeErrors=True,
                                         workingDirectory=working_dir,
                                         forceGarbageCollection=True)
        self.required_services = []
        self.source_files = []

    def _load_unittest(self, relpath):
        """Load unit tests from a Python module with the given 'relpath'."""
        assert relpath.endswith(".py"), (
            "%s does not appear to be a Python module" % relpath)
        if not os.path.basename(relpath).startswith('test_'):
            return
        modpath = relpath.replace(os.path.sep, ".")[:-3]
        module = __import__(modpath, None, None, [""])

        # If the module specifies required_services, make sure we get them
        members = [x[1] for x in inspect.getmembers(module, inspect.isclass)]
        for member_type in members:
            if hasattr(member_type, 'required_services'):
                member = member_type()
                self.required_services.extend(member.required_services())
                del member
        gc.collect()

        # If the module has a 'suite' or 'test_suite' function, use that
        # to load the tests.
        if hasattr(module, "suite"):
            return module.suite()
        elif hasattr(module, "test_suite"):
            return module.test_suite()
        else:
            return unittest.defaultTestLoader.loadTestsFromModule(module)

    def _collect_tests(self, path, test_pattern):
        """Return the set of unittests."""
        suite = unittest.TestSuite()
        if test_pattern:
            pattern = re.compile('.*%s.*' % test_pattern)
        else:
            pattern = None

        # Disable this lint warning as we need to access _tests in the
        # test suites, to collect the tests
        # pylint: disable=W0212
        if path:
            try:
                module_suite = self._load_unittest(path)
                if pattern:
                    for inner_suite in module_suite._tests:
                        for test in inner_suite._tests:
                            if pattern.match(test.id()):
                                suite.addTest(test)
                else:
                    suite.addTests(module_suite)
                return suite
            except AssertionError:
                pass
        else:
            print 'Path should be defined.'
            exit(1)

        # We don't use the dirs variable, so ignore the warning
        # pylint: disable=W0612
        for root, dirs, files in os.walk(path):
            for test in files:
                filepath = os.path.join(root, test)
                if test.endswith(".py"):
                    self.source_files.append(filepath)
                    if test.startswith("test_"):
                        module_suite = self._load_unittest(filepath)
                        if pattern:
                            for inner_suite in module_suite._tests:
                                for test in inner_suite._tests:
                                    if pattern.match(test.id()):
                                        suite.addTest(test)
                        else:
                            suite.addTests(module_suite)
        return suite

    # pylint: disable=E0202
    def run(self, path, options=None):
        """run the tests."""
        success = 0
        running_services = []
        if options.coverage:
            coverage.erase()
            coverage.start()

        try:
            suite = self._collect_tests(path, options.test)
            if options.loops:
                old_suite = suite
                suite = unittest.TestSuite()
                for _ in xrange(options.loops):
                    suite.addTest(old_suite)

            # Start any required services
            for service in self.required_services:
                runner = service()
                runner.start_service(tempdir=self.tempdir)
                running_services.append(runner)

            result = super(TestRunner, self).run(suite)
            success = result.wasSuccessful()
        finally:
            # Stop all the running services
            for runner in running_services:
                runner.stop_service()

        if options.coverage:
            coverage.stop()
            coverage.report(self.source_files, ignore_errors=True,
                            show_missing=False)

        if not success:
            sys.exit(1)
        else:
            sys.exit(0)


def main():
    """Do the deed."""
    from optparse import OptionParser
    usage = '%prog [options] path'
    parser = OptionParser(usage=usage)
    parser.add_option("-t", "--test", dest="test",
                  help = "run specific tests, e.g: className.methodName")
    parser.add_option("-l", "--loop", dest="loops", type="int", default=1,
                      help = "loop selected tests LOOPS number of times",
                      metavar="LOOPS")
    parser.add_option("-c", "--coverage", action="store_true", dest="coverage",
                      help="print a coverage report when finished")

    (options, args) = parser.parse_args()
    if args:
        testpath = args[0]
        if not os.path.exists(testpath):
            print "the path to test does not exists!"
            sys.exit(1)
    else:
        parser.print_help()
        sys.exit(2)

    TestRunner().run(testpath, options)

if __name__ == '__main__':
    main()
