#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains code generation tools for quadrature rules. 
"""

__author__ = "Martin Sandve Alnes"
__date__   = "2008-08-13 -- 2009-03-19"
__copyright__ = "(C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory"
__license__  = "GNU GPL Version 2, or (at your option) any later version"

from sfc.codegeneration.codeformatting import indent

def gen_quadrature_rule_definition(rule, points_name="quad_points", weigths_name="quad_weights"):
    code = ""
    code += "static const double %s[%d][%d] = \n" % (points_name, rule.num_points, rule.nsd)
    code += "  {\n"
    code += ",\n".join( [ "      { %s }" % (", ".join(map(str,p))) for p in rule.points ] )
    code += "\n  };\n"
    code += "\n"
    code += "static const double %s[%d] =\n" % (weigths_name, rule.num_points)
    code += "  {\n"
    code += "    " + ", ".join( map(str, rule.weights) )
    code += "\n  };\n"
    return code

def gen_quadrature_rule_definitions(rules, points_name="facet_quad_points", weigths_name="facet_quad_weights"):
    """Generate static code for quadrature rule definitions for a list of rules."""
    nr  = len(rules)
    nq  = rules[0].num_points
    nsd = rules[0].nsd

    code = ""

    rule_snippets = []
    for r in range(nr):
        rule = rules[r]
        
        point_snippets = []
        for p in rule.points:
            c  = "{ " + ", ".join(map(str,p)) + " }"
            point_snippets.append( indent(c) )

        cc  = "{ // quadrature points for rule %d:\n" % r
        cc += ",\n".join(point_snippets) + "\n"
        cc += "}"
        rule_snippets.append( indent(cc) )
    
    code += "// [number of rules][number of points][space dimension]\n"
    code += "static const double %s[%d][%d][%d] = \n" % (points_name, nr, nq, nsd)
    code += indent("{") + "\n"
    code += indent(",\n".join(rule_snippets)) + "\n"
    code += indent("};") + "\n\n"

    rule_snippets= []
    for r in range(nr):
        rule = rules[r]
        c = "// weights for rule %d:\n" % r
        c += "{ "
        c += ", ".join(map(str, rule.weights))
        c += " }"
        rule_snippets.append( indent(c) )

    code += "// [number of rules][number of points]\n"
    code += "static const double %s[%d][%d] =\n" % (weigths_name, nr, nq)
    code += indent("{") + "\n"
    code += indent(",\n".join(rule_snippets)) + "\n"
    code += indent("};") + "\n\n"
    
    return code


if __name__ == '__main__':
    inner_code = "// Compute A here\n"

    order = 3
    polygon = "triangle"

    rule = find_quadrature_rule(polygon, order, "gauss")
    
    code = gen_quadrature_rule_definition(rule)
    print code

    rules = [rule, rule, rule]
    code = gen_quadrature_rule_definitions(rules)
    print code
    
