#!/usr/bin/env python
import iam_sympy_example

from sympy import Basic,exp,Symbol,sin,Rational,I,Mul, Matrix, \
    ones, sqrt, pprint, simplify, trim, Eq, sympify

from sympy.physics import msigma, mgamma

#gamma^mu
gamma0=mgamma(0)
gamma1=mgamma(1)
gamma2=mgamma(2)
gamma3=mgamma(3)
gamma5=mgamma(5)

#sigma_i
sigma1=msigma(1)
sigma2=msigma(2)
sigma3=msigma(3)

E = Symbol("E", real=True)
m = Symbol("m", real=True)

def u(p,r):
    """ p = (p1, p2, p3); r = 0,1 """
    assert r in [1,2]
    p1,p2,p3 = p
    if r == 1:
        ksi = Matrix([ [1],[0] ])
    else:
        ksi = Matrix([ [0],[1] ])
    a = (sigma1*p1 + sigma2*p2 + sigma3*p3) / (E+m) * ksi
    if a ==0:
        a = zeros((2, 1))
    return sqrt(E+m) * Matrix([ [ksi[0,0]], [ksi[1,0]], [a[0,0]], [a[1,0]] ])

def v(p,r):
    """ p = (p1, p2, p3); r = 0,1 """
    assert r in [1,2]
    p1,p2,p3 = p
    if r == 1:
        ksi = Matrix([ [1],[0] ])
    else:
        ksi = -Matrix([ [0],[1] ])
    a = (sigma1*p1 + sigma2*p2 + sigma3*p3) / (E+m) * ksi
    if a ==0:
        a = zeros((2,1))
    return sqrt(E+m) * Matrix([ [a[0,0]], [a[1,0]], [ksi[0,0]], [ksi[1,0]] ])

def pslash(p):
    p1,p2,p3 = p
    p0 = sqrt(m**2+p1**2+p2**2+p3**2)
    return gamma0*p0-gamma1*p1-gamma2*p2-gamma3*p3

def Tr(M):
    assert M.lines == M.cols
    t = 0
    for i in range(M.lines):
        t+=M[i,i]
    return t

a=Symbol("a", real=True)
b=Symbol("b", real=True)
c=Symbol("c", real=True)

p = (a,b,c)

assert u(p, 1).D * u(p, 2) == Matrix(1, 1, [0])
assert u(p, 2).D * u(p, 1) == Matrix(1, 1, [0])

p1,p2,p3 =[Symbol(x, real=True) for x in ["p1","p2","p3"]]
pp1,pp2,pp3 =[Symbol(x, real=True) for x in ["pp1","pp2","pp3"]]
k1,k2,k3 =[Symbol(x, real=True) for x in ["k1","k2","k3"]]
kp1,kp2,kp3 =[Symbol(x, real=True) for x in ["kp1","kp2","kp3"]]

p = (p1,p2,p3)
pp = (pp1,pp2,pp3)

k = (k1,k2,k3)
kp = (kp1,kp2,kp3)

mu = Symbol("mu")

e = (pslash(p)+m*ones(4))*(pslash(k)-m*ones(4))
f = pslash(p)+m*ones(4)
g = pslash(p)-m*ones(4)


def xprint(lhs, rhs):
    pprint( Eq(sympify(lhs), rhs ) )

#pprint(e)
xprint( 'Tr(f*g)', Tr(f*g) )
#print Tr(pslash(p) * pslash(k)).expand()

M0 = [ ( v(pp, 1).D * mgamma(mu) * u(p, 1) ) * ( u(k, 1).D * mgamma(mu,True) * \
        v(kp, 1) ) for mu in range(4)]
M = M0[0]+M0[1]+M0[2]+M0[3]
M = M[0]
assert isinstance(M, Basic)
#print M
#print trim(M)

d=Symbol("d", real=True) #d=E+m

xprint('M', M)
print "-"*40
M = ((M.subs(E,d-m)).expand() * d**2 ).expand()
xprint('M2', 1/(E+m)**2 * M)
print "-"*40
x,y= M.as_real_imag()
xprint('Re(M)', x)
xprint('Im(M)', y)
e = x**2+y**2
xprint('abs(M)**2', e)
print "-"*40
xprint('Expand(abs(M)**2)', e.expand())

#print Pauli(1)*Pauli(1)
#print Pauli(1)**2
#print Pauli(1)*2*Pauli(1)
