#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (C) 2009 The Tegaki project contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

# Contributors to this file:
# - Christoph Burgmer, cburgmer ÂT ira DÔT uka DÔT de (main author)
# - Mathieu Blondel

"""
Tries to bootstrap a character collection using similar collections and
knowledge about the characters shape data.

Basically all characters whose components have data, can be interpolated by
scaling and shifting the components data and merging them together.

Example:
黴 is made up from the component structure ⿲彳⿳山一黑攵 or graphically
　山
彳一攵
　黑
If we have handwriting data of the five components, we can easily integrate that
into the character in question.

To bootstrap a Traditional Chinese collection run:
$ tegaki-bootstrap --domain=BIG5 --locale=T --max-samples=1 \
    handwriting-zh_TW.xml -d ~/xml/ \
    -t tegaki-zinnia-simplified-chinese/handwriting-zh_CN.xml \
    -t tegaki-zinnia-japanese/handwriting-ja.xml

TODO:
    - Use bounding box when merging of components. For example above 彳 probably
      has a bounding box of witdth=0.5, which only needs to be scaled by 3/5,
      and 一 has a bounding box of height=0.2, which allows to scale the upper
      and lower components less aggressivly.
    - Map collection sources to a character locale, so we can exclude data that
      does not fit to the target locale. Furthermore instead of doing exact
      transformations for source characters from the wrong locale, try using
      component data.
    - Try similar characters if no exact character can be found?
    - Offer easy way to provide additional variations, e.g. stroke order
"""
import sys
import locale
from optparse import OptionParser

from tegaki.character import CharacterCollection, Writing, Character

from tegakitools.tomoe import tomoe_dict_to_character_collection
from tegakitools.kuchibue import kuchibue_to_character_collection
from tegakitools.charcol import *

try:
    from cjklib.characterlookup import CharacterLookup
    from cjklib.exception import UnsupportedError
except ImportError:
    print "Please install cjklib (http://code.google.com/p/cjklib)"
    sys.exit(1)

VERSION = '0.3'

class TegakiBootstrapError(Exception):
    pass

class TegakiBootstrap(object):
    # xrate, yrate, dx, dy
    COMPONENT_TRANSFORMATION = {
        u'⿰': [(0.5, 1, 0, 0), (0.5, 1, 0.5, 0)],
        u'⿱': [(1, 0.5, 0, 0), (1, 0.5, 0, 0.5)],
        #u'⿴': [(1, 1, 0, 0), (0.5, 0.5, 0.25, 0.25)],
        u'⿵': [(1, 1, 0, 0), (0.5, 0.75, 0.25, 0.25)],
        u'⿶': [(1, 1, 0, 0), (0.5, 0.75, 0.25, 0)],
        #u'⿷': [(1, 1, 0, 0), (0.75, 0.5, 0.25, 0.25)],
        u'⿸': [(1, 1, 0, 0), (0.75, 0.75, 0.25, 0.25)],
        u'⿹': [(1, 1, 0, 0), (0.75, 0.75, 0, 0.25)],
        u'⿺': [(1, 1, 0, 0), (0.75, 0.75, 0.25, 0)],
        u'⿲': [(0.33, 1, 0, 0), (0.33, 1, 0.33, 0), (0.33, 1, 0.66, 0)],
        u'⿳': [(1, 0.33, 0, 0), (1, 0.33, 0, 0.33), (1, 0.33, 0, 0.66)],
        }

    RADICALS_NON_VISUAL_EQUIVALENCE = set([u'⺄', u'⺆', u'⺇', u'⺈', u'⺊',
        u'⺌', u'⺍', u'⺎', u'⺑', u'⺗', u'⺜', u'⺥', u'⺧', u'⺪', u'⺫', u'⺮',
        u'⺳', u'⺴', u'⺶', u'⺷', u'⺻', u'⺼', u'⻏', u'⻕'])
    """
    Radical forms that have a equivalent character form with no visual
    resemblance.
    """

    MIN_COMPONENT_PRODUCTIVITY = 2
    """
    Min productivity when reporting out-domain components that could help boost
    the in-domain set.
    """

    def __init__(self, options, args):
        self._directories = options.directories
        self._charcols = options.charcols
        self._tomoe = options.tomoe
        self._kuchibue = options.kuchibue
        self._include = options.include
        self._exclude = options.exclude
        self._max_samples = options.max_samples
        self._locale = options.locale
        self._character_domain = options.character_domain
        self._quiet = options.quiet

        try:
            self._output_path = args[0]
        except:
            self._output_path = None

        self._cjk = CharacterLookup(self._locale, self._character_domain)

    def _get_charcol(self):
        if not hasattr(self, '_charcol'):
            self._charcol = get_aggregated_charcol(
                                  ((TYPE_CHARCOL, self._charcols),
                                   (TYPE_DIRECTORY, self._directories),
                                   (TYPE_TOMOE, self._tomoe),
                                   (TYPE_KUCHIBUE, self._kuchibue)))

            self._charcol.include_characters_from_files(self._include)
            self._charcol.exclude_characters_from_files(self._exclude)

            # max samples
            if self._max_samples:
                self._charcol.remove_samples(keep_at_most=self._max_samples)

        return self._charcol

    def run(self):
        charcol = self._get_charcol()

        if charcol.get_total_n_characters() == 0:
            raise TegakiBootstrapError("Empty input collection provided")

        # do the bootstrapping
        to_charcol = self.bootstrap(charcol)

        # max samples
        if self._max_samples:
            to_charcol.remove_samples(keep_at_most=self._max_samples)

        # output
        if not self._output_path:
            # outputs to stdout if no output path specified
            print to_charcol.to_xml()
        else:
            gzip = False; bz2 = False
            if self._output_path.endswith(".gz"): gzip = True
            if self._output_path.endswith(".bz2"): bz2 = True
            to_charcol.write(self._output_path, gzip=gzip, bz2=bz2)

    def bootstrap(self, charcol):
        exact_transformations = 0
        decomposition_transformations = 0
        decomposition_fertilities = []
        missing_transformations = 0

        to_charcol = CharacterCollection()

        missing_char_dict = {}
        missing_single_characters = []
        # iterate through all characters of the target character set
        count = 0
        for target_char in self._cjk.getDomainCharacterIterator():
        #for target_char in iter([u'亄', u'乿', u'仜', u'伳']): # DEBUG
            count += 1
            if count % 100 == 0:
                sys.stdout.write('.')
                sys.stdout.flush()
            charSet = target_char.encode('utf8')
            source_character_lookup = self._get_source_character_lookup()
            if target_char in source_character_lookup:
                to_charcol.add_set(charSet)
                for character in source_character_lookup[target_char]:
                    to_charcol.append_character(charSet, character)

                exact_transformations += 1
            else:
                writing_objects, missing_chars \
                    = self.get_writings_from_decomposition(target_char)
                if writing_objects:
                    for writing in writing_objects:
                        character = Character()
                        character.set_writing(writing)
                        character.set_unicode(target_char)
                        to_charcol.append_character(charSet, character)

                    decomposition_transformations += 1
                    decomposition_fertilities.append(len(writing_objects))
                else:
                    if missing_chars:
                        for missing in missing_chars:
                            if missing not in missing_char_dict:
                                missing_char_dict[missing] = []
                            missing_char_dict[missing].append(target_char)
                    else:
                        missing_single_characters.append(target_char)

                    missing_transformations += 1

        sys.stdout.write('\n')

        if not self._quiet:
            _, default_encoding = locale.getdefaultlocale()
            total = exact_transformations + decomposition_transformations \
                + missing_transformations
            print 'Exact transformation count: %d (%d%%)' \
                % (exact_transformations, 100 * exact_transformations / total)
            print 'Decomposition transformation count: %d (%d%%)' \
                % (decomposition_transformations,
                    100 * decomposition_transformations / total)
            if decomposition_fertilities:
                decomposition_fertility = (sum(decomposition_fertilities) \
                    / len(decomposition_fertilities))
            else:
                decomposition_fertility = 1
            print 'Decomposition fertility: %d' % decomposition_fertility
            print 'Missing transformations: %d (%d%%)' \
                % (missing_transformations,
                    100 * missing_transformations / total)

            # missing single characters
            # Extend by those with components, that have a component with low
            #   productivity.
            in_domain_components = set(
                self._cjk.filterDomainCharacters(missing_char_dict.keys()))

            low_productivity_component_chars = []
            for component, chars in missing_char_dict.items():
                if component not in in_domain_components \
                    and len(chars) < self.MIN_COMPONENT_PRODUCTIVITY:
                    low_productivity_component_chars.extend(chars)
                    del missing_char_dict[component]
            missing_single_characters.extend(low_productivity_component_chars)

            print 'Missing single characters:',
            print ''.join(missing_single_characters).encode(default_encoding)

            # remove characters that we already placed in "single"
            _missing_single_characters = set(missing_single_characters)
            for component, chars in missing_char_dict.items():
                missing_char_dict[component] = list(
                    set(chars) - _missing_single_characters)
                if not missing_char_dict[component]:
                    del missing_char_dict[component]

            # missing components

            missingComponents = sorted(missing_char_dict.items(),
                key=lambda (x,y): len(y))
            missingComponents.reverse()

            in_domain_component_list = [(component, chars) \
                for component, chars in missingComponents \
                if component in in_domain_components]
            # only show "out-domain" components if they have productivity > 1
            out_domain_component_list = [(component, chars) \
                for component, chars in missingComponents \
                if component not in in_domain_components and len(chars) > 1]

            print 'Missing components: %d' % (len(in_domain_component_list) \
                + len(out_domain_component_list))
            print 'Missing in-domain components:',
            print ', '.join(['%s (%s)' % (component, ''.join(chars)) \
                for component, chars in in_domain_component_list])\
                .encode(default_encoding)
            print 'Missing out-domain components:',
            print ', '.join(['%s (%s)' % (component, ''.join(chars)) \
                for component, chars in out_domain_component_list])\
                .encode(default_encoding)

        return to_charcol

    def _get_source_character_lookup(self):
        if not hasattr(self, '_source_character_lookup'):
            self._source_character_lookup = {}
            for character in self._get_charcol().get_all_characters():
                char = character.get_utf8().decode('utf8')
                if char not in self._source_character_lookup:
                    self._source_character_lookup[char] = []
                self._source_character_lookup[char].append(character)

        return self._source_character_lookup

    def get_writings_from_decomposition(self, char):
        writing_objects = []
        if char in self._get_source_character_lookup():
            writing_objects = [character.get_writing() \
                for character in self._get_source_character_lookup()[char]]
        elif (CharacterLookup.isRadicalChar(char)
            and char not in self.RADICALS_NON_VISUAL_EQUIVALENCE):
            try:
                equivChar = self._cjk.getRadicalFormEquivalentCharacter(char)
                if equivChar in self._get_source_character_lookup():
                    writing_objects = [character.get_writing() for character \
                        in self._get_source_character_lookup()[equivChar]]
            except UnsupportedError:
                pass

        missing_chars = []
        if not writing_objects:
            decompositions = self._cjk.getDecompositionEntries(char)
            for decomposition in decompositions:
                writing_objs, _, missing = self._get_writing_from_entry(
                    decomposition)
                if not writing_objs:
                    missing_chars.extend(missing)
                writing_objects.extend(writing_objs)

        if writing_objects:
            missing_chars = []

        return writing_objects, missing_chars

    def _get_writing_from_entry(self, decomposition, index=0):
        """Goes through a decomposition"""
        if type(decomposition[index]) != type(()):
            # IDS operator
            character = decomposition[index]
            writing_objects_list = []
            missing_chars = []
            if CharacterLookup.isBinaryIDSOperator(character):
                # check for IDS operators we can't make any order
                # assumption about
                if character in [u'⿴', u'⿻', u'⿷']:
                    return [], index, []
                else:
                    # Get stroke order for both components
                    for _ in range(0, 2):
                        writing_objs, index, missing \
                            = self._get_writing_from_entry(decomposition,
                                index+1)
                        if not writing_objs:
                            missing_chars.extend(missing)
                            #return [], index
                        writing_objects_list.append(writing_objs)

            elif CharacterLookup.isTrinaryIDSOperator(character):
                # Get stroke order for three components
                for _ in range(0, 3):
                    writing_objs, index, missing = self._get_writing_from_entry(
                        decomposition, index+1)
                    if not writing_objs:
                        missing_chars.extend(missing)
                        #return [], index
                    writing_objects_list.append(writing_objs)
            else:
                assert False, 'not an IDS character'

            # merge
            writing_objects = []
            if not missing_chars:
                for writing_objs in TegakiBootstrap.cross(
                    *writing_objects_list):
                    writing = self.merge_writing_objects(character,
                        writing_objs)
                    writing_objects.append(writing)
            return writing_objects, index, missing_chars
        else:
            # no IDS operator but character
            char, _ = decomposition[index]
            # if the character is unknown or there is none raise
            if char == u'？':
                return [], index, []
            else:
                # recursion
                writing_objs, missing_chars \
                    = self.get_writings_from_decomposition(char)
                if not writing_objs and not missing_chars:
                    if (CharacterLookup.isRadicalChar(char)
                        and char not in self.RADICALS_NON_VISUAL_EQUIVALENCE):
                        try:
                            char = self._cjk.getRadicalFormEquivalentCharacter(
                                char)
                        except UnsupportedError:
                            pass
                    missing_chars = [char]
                return writing_objs, index, missing_chars

        assert False

    @classmethod
    def merge_writing_objects(cls, idsChar, writing_objects):
        if idsChar not in cls.COMPONENT_TRANSFORMATION:
            raise ValueError("Not supported") # [u'⿴', u'⿻', u'⿷']

        assert (CharacterLookup.isBinaryIDSOperator(idsChar) \
            and len(writing_objects) == 2) \
            or (CharacterLookup.isTrinaryIDSOperator(idsChar) \
            and len(writing_objects) == 3)
        assert len(cls.COMPONENT_TRANSFORMATION[idsChar]) \
            == len(writing_objects)

        transformations = cls.COMPONENT_TRANSFORMATION[idsChar]
        # reverse transformations where inner part is written first
        if idsChar in [u'⿺', u'⿶']:
            writing_objects.reverse()
            transformations = transformations[:]
            transformations.reverse()

        resultingWriting = Writing()
        for idx, transformation in enumerate(transformations):
            xrate, yrate, dx, dy = transformation

            obj = writing_objects[idx].copy()
            obj.resize(xrate, yrate)
            obj.move_rel(dx * obj.get_width(), dy* obj.get_height())
            obj.resize(resultingWriting.get_width() / obj.get_width(),
                resultingWriting.get_height() / obj.get_height())
            for s in obj.get_strokes(True):
                resultingWriting.append_stroke(s)

        return resultingWriting

    @staticmethod
    def cross(*args):
        ans = [[]]
        for arg in args:
            ans = [x+[y] for x in ans for y in arg]
        return ans


parser = OptionParser(usage="usage: %prog [options] [output-path]",
                      version="%prog " + VERSION)

parser.add_option("-d", "--directory",
                  action="append", type="string", dest="directories",
                  default=[],
                  help="Directory containing individual XML character files")
parser.add_option("-c", "--charcol",
                  action="append", type="string", dest="charcols",
                  default=[],
                  help="character collection XML files")
parser.add_option("-t", "--tomoe-dict",
                  action="append", type="string", dest="tomoe",
                  default=[],
                  help="Tomoe XML dictionary files")
parser.add_option("-k", "--kuchibue",
                  action="append", type="string", dest="kuchibue",
                  default=[],
                  help="Kuchibue unipen database")


parser.add_option("-i", "--include",
                  action="append", type="string", dest="include",
                  default=[],
                  help="File containing characters to include")
parser.add_option("-e", "--exclude",
                  action="append", type="string", dest="exclude",
                  default=[],
                  help="File containing characters to exclude")
parser.add_option("-m", "--max-samples",
                  type="int", dest="max_samples",
                  help="Maximum number of samples per character")


parser.add_option("-l", "--locale",
                  type="string", dest="locale",
                  default='T',
                  help="Character locale of target characters")
parser.add_option("--domain",
                  type="string", dest="character_domain",
                  default='Unicode',
                  help="Character domain of target characters")


parser.add_option("-q", "--quiet", dest="quiet",
                  action="store_true",
                  help="Don't print any statistics")

(options, args) = parser.parse_args()

try:
    TegakiBootstrap(options, args).run()
except TegakiBootstrapError, e:
    sys.stderr.write(str(e) + "\n\n")
    parser.print_help()
    sys.exit(1)
