# -------------------------------------------------------------------------
#     Copyright (C) 2005-2012 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program.
# -------------------------------------------------------------------------

# load libs
import re
import numpy

# load stopper
from mod_stopper import CHECK_FORCE_QUIT

# load blocks
import blocks

# load objects
import obj_compound
import obj_peaklist

# load modules
import mod_signal


# BASIC CONSTANTS
# ---------------

ELECTRON_MASS = 0.00054857990924

FORMULA_PATTERN = re.compile(r'''
    ^(
        ([\(])* # start parenthesis
        (
            ([A-Z][a-z]{0,2}) # atom symbol
            (\{[\d]+\})? # isotope
            (([\-][\d]+)|[\d]*) # atom count
        )+
        ([\)][\d]*)* # end parenthesis
    )*$
''', re.X)

ELEMENT_PATTERN = re.compile(r'''
            ([A-Z][a-z]{0,2}) # atom symbol
            (?:\{([\d]+)\})? # isotope
            ([\-]?[\d]*) # atom count
''', re.X)


# BASIC FUNCTIONS
# ---------------

def delta(measuredMass, countedMass, units='ppm'):
    """Calculate error between measured Mass and counted Mass in specified units.
        measuredMass (float) - measured mass
        countedMass (float) - counted mass
        units (Da, ppm or %) - error units
    """
    
    if units == 'ppm':
        return (measuredMass - countedMass) / countedMass*1000000
    elif units == 'Da':
        return (measuredMass - countedMass)
    elif units == '%':
        return (measuredMass - countedMass) / countedMass*100
    else:
        raise ValueError, 'Unknown units for delta! -->' + units
# ----


def mz(mass, charge, currentCharge=0, agentFormula='H', agentCharge=1, massType=0):
    """Calculate m/z value for given mass and charge.
        mass (tuple of (Mo, Av) or float) - current mass
        charge (int) - final charge of ion
        currentCharge (int) - current mass charge
        agentFormula (str or mspy.compound) - charging agent formula
        agentCharge (int) - charging agent unit charge
        massType (0 or 1) - used mass type if mass value is float, 0 = monoisotopic, 1 = average
    """
    
    # check agent formula
    if agentFormula != 'e' and not isinstance(agentFormula, obj_compound.compound):
        agentFormula = obj_compound.compound(agentFormula)
    
    # get agent mass
    if agentFormula == 'e':
        agentMass = [ELECTRON_MASS, ELECTRON_MASS]
    else:
        agentMass = agentFormula.mass()
        agentMass = (agentMass[0]-agentCharge*ELECTRON_MASS, agentMass[1]-agentCharge*ELECTRON_MASS)
    
    # recalculate zero charge
    agentCount = currentCharge/agentCharge
    if currentCharge != 0:
        if type(mass) in (tuple, list):
            massMo = mass[0]*abs(currentCharge) - agentMass[0]*agentCount
            massAv = mass[1]*abs(currentCharge) - agentMass[1]*agentCount
            mass = (massMo, massAv)
        else:
            mass = mass*abs(currentCharge) - agentMass[massType]*agentCount
    if charge == 0:
        return mass
    
    # calculate final charge
    agentCount = charge/agentCharge
    if type(mass) in (tuple, list):
        massMo = (mass[0] + agentMass[0]*agentCount)/abs(charge)
        massAv = (mass[1] + agentMass[1]*agentCount)/abs(charge)
        return (massMo, massAv)
    else:
        return (mass + agentMass[massType]*agentCount)/abs(charge)
# ----



# FORMULA FUNCTIONS
# -----------------

def rdbe(compound):
    """Get RDBE (Range or Double Bonds Equivalents) of a given compound.
        compound (str or mspy.compound) - compound
    """
    
    # check compound
    if not isinstance(compound, obj_compound.compound):
        compound = obj_compound.compound(compound)
    
    # get composition
    comp = compound.composition()
    
    # get atoms from composition
    atoms = []
    for item in comp:
        match = ELEMENT_PATTERN.match(item)
        if match and not match.group(1) in atoms:
            atoms.append(match.group(1))
    
    # get rdbe
    rdbeValue = 0.
    for a in atoms:
        valence = blocks.elements[a].valence
        if valence:
            rdbeValue += (valence - 2) * compound.count(a, groupIsotopes=True)
    rdbeValue /= 2.
    rdbeValue += 1.
    
    return rdbeValue
# ----


def frules(compound, rules=['HC','NOPSC','NOPS','RDBE','RDBEInt'], HC=(0.1, 3.0), NOPSC=(4,3,2,3), RDBE=(-1,40)):
    """Check formula rules for a given compound.
        compound (str or mspy.compound) - compound
        rules (list of str) - rules to be checked
        HC (tuple) - H/C limits
        NOPSC (tuple) - NOPS/C max values
        RDBE (tuple) - RDBE limits
    """
    
    # check compound
    if not isinstance(compound, obj_compound.compound):
        compound = obj_compound.compound(compound)
    
    # get element counts
    countC = float(compound.count('C', groupIsotopes=True))
    countH = float(compound.count('H', groupIsotopes=True))
    countN = float(compound.count('N', groupIsotopes=True))
    countO = float(compound.count('O', groupIsotopes=True))
    countP = float(compound.count('P', groupIsotopes=True))
    countS = float(compound.count('S', groupIsotopes=True))
    
    # get carbon ratios
    if countC:
        ratioHC = countH / countC
        ratioNC = countN / countC
        ratioOC = countO / countC
        ratioPC = countP / countC
        ratioSC = countS / countC
    
    # get RDBE
    rdbeValue = rdbe(compound)
    
    # check HC rule
    if 'HC' in rules and countC:
        if (ratioHC < HC[0] or ratioHC > HC[1]):
            return False
    
    # check NOPS rule
    if 'NOPSC' in rules and countC:
        if (ratioNC > NOPSC[0] or ratioOC > NOPSC[1] or ratioPC > NOPSC[2] or ratioSC > NOPSC[3]):
            return False
    
    # check NOPS all > 1 rule
    if 'NOPS' in rules and (countN > 1 and countO > 1 and countP > 1 and countS > 1):
        if (countN >= 10 or countO >= 20 or countP >= 4 or countS >= 3):
            return False
    
    # check NOP all > 3 rule
    if 'NOPS' in rules and (countN > 3 and countO > 3 and countP > 3):
        if (countN >= 11 or countO >= 22 or countP >= 6):
            return False
    
    # check NOS all > 1 rule
    if 'NOPS' in rules and (countN > 1 and countO > 1 and countS > 1):
        if (countN >= 19 or countO >= 14 or countS >= 8):
            return False
    
    # check NPS all > 1 rule
    if 'NOPS' in rules and (countN > 1 and countP > 1 and countS > 1):
        if (countN >= 3 or countP >= 3 or countS >= 3):
            return False
    
    # check OPS all > 1 rule
    if 'NOPS' in rules and (countO > 1 and countP > 1 and countS > 1):
        if (countO >= 14 or countP >= 3 or countS >= 3):
            return False
    
    # check RDBE range
    if 'RDBE' in rules:
        if rdbeValue < RDBE[0] or rdbeValue > RDBE[1]:
            return False
    
    # check integer RDBE
    if 'RDBEInt' in rules:
        if rdbeValue % 1:
            return False
    
    # all ok
    return True
# ----



# ISOTOPIC PATTERN
# ----------------

def pattern(compound, fwhm=0.1, threshold=0.01, charge=0, agentFormula='H', agentCharge=1):
    """Calculate isotopic pattern for given compound.
        compound (str or mspy.compound) - compound
        fwhm (float) - gaussian peak width
        threshold (float) - relative intensity threshold for isotopes (in %/100)
        charge (int) - charge to be calculated
        agentFormula (str or mspy.compound) - charging agent formula
        agentCharge (int) - charging agent unit charge
    """
    
    # check compound
    if not isinstance(compound, obj_compound.compound):
        compound = obj_compound.compound(compound)
    
    # check agent formula
    if agentFormula != 'e' and not isinstance(agentFormula, obj_compound.compound):
        agentFormula = obj_compound.compound(agentFormula)
    
    # add charging agent to compound
    if charge and agentFormula != 'e':
        formula = compound.formula()
        for atom, count in agentFormula.composition().items():
            formula += '%s%d' % (atom, count*(charge/agentCharge))
        compound = obj_compound.compound(formula)
    
    # get composition and check for negative atom counts
    composition = compound.composition()
    for atom in composition:
        if composition[atom] < 0:
            raise ValueError, 'Pattern cannot be calculated for this formula! --> ' + compound.formula()
    
    # set internal thresholds
    internalThreshold = threshold/100.
    groupingWindow = fwhm/4.
    
    # calculate pattern
    finalPattern = []
    for atom in composition:
        
        # get isotopic profile for current atom or specified isotope only
        atomCount = composition[atom]
        atomPattern = []
        match = ELEMENT_PATTERN.match(atom)
        symbol, massNumber, tmp = match.groups()
        if massNumber:
            isotope = blocks.elements[symbol].isotopes[int(massNumber)]
            atomPattern.append([isotope[0], 1.]) # [mass, abundance]
        else:
            for massNumber, isotope in blocks.elements[atom].isotopes.items():
                if isotope[1] > 0.:
                    atomPattern.append(list(isotope)) # [mass, abundance]
        
        # add atoms
        for i in range(atomCount):
            
            CHECK_FORCE_QUIT()
            
            # if pattern is empty (first atom) add current atom pattern
            if len(finalPattern) == 0:
                finalPattern = _normalize(atomPattern)
                continue
            
            # add atom to each peak of final pattern
            currentPattern = []
            for patternIsotope in finalPattern:
                
                # skip peak under relevant abundance threshold
                if patternIsotope[1] < internalThreshold:
                    continue
                
                # add each isotope of current atom to peak
                for atomIsotope in atomPattern:
                    mass = patternIsotope[0] + atomIsotope[0]
                    abundance = patternIsotope[1] * atomIsotope[1]
                    currentPattern.append([mass, abundance])
            
            # group isotopes and normalize pattern
            finalPattern = _groupIsotopes(currentPattern, groupingWindow)
            finalPattern = _normalize(finalPattern)
    
    # correct charge
    if charge:
        for i in range(len(finalPattern)):
            finalPattern[i][0] = (finalPattern[i][0] - ELECTRON_MASS*charge) / abs(charge)
    
    # group isotopes and normalize pattern
    finalPattern = _groupIsotopes(finalPattern, groupingWindow)
    finalPattern = _normalize(finalPattern)
    
    # discard peaks below threshold
    filteredPeaks = []
    for peak in finalPattern:
        if peak[1] >= threshold:
            filteredPeaks.append(peak)
    finalPattern = filteredPeaks
    
    return finalPattern
# ----


def profile(peaklist, fwhm=0.1, points=10, noise=None, raster=None, forceFwhm=False):
    """Make profile spectrum for given peaklist.
        peaklist (mspy.peaklist) - peaklist
        fwhm (float) - default peak fwhm
        points (int) - default number of points per peak width (not used if raster is given)
        noise (float) - random noise width
        raster (1D numpy array) - m/z raster
        forceFwhm (bool) - use default fwhm for all peaks
    """
    
    # check peaklist type
    if not isinstance(peaklist, obj_peaklist.peaklist):
        peaklist = obj_peaklist.peaklist(peaklist)
    
    # get fwhm range
    minFwhm = None
    maxFwhm = None
    if not forceFwhm:
        for peak in peaklist:
            if not peak.fwhm:
                continue
            if not minFwhm or peak.fwhm < minFwhm:
                minFwhm = peak.fwhm
            if not maxFwhm or (peak.fwhm > maxFwhm and peak.fwhm < 2*maxFwhm):
                maxFwhm = peak.fwhm
    
    # use default fwhm range if not set
    if not minFwhm or not maxFwhm:
        minFwhm = fwhm
        maxFwhm = fwhm
    
    # get m/z raster
    if raster == None:
        mzRange = (peaklist[0].mz - 5*maxFwhm , peaklist[-1].mz + 5*maxFwhm)
        rasterRange = (minFwhm/points, maxFwhm/points)
        raster = _xraster(mzRange, rasterRange)
    
    # get intensity raster
    intensities = numpy.zeros(raster.size, float)
    
    # calulate gaussian peak for each isotope
    for peak in peaklist:
        
        CHECK_FORCE_QUIT()
        
        # get peak fwhm
        if peak.fwhm and not forceFwhm:
            peakFwhm = peak.fwhm
            peakWidth = peak.fwhm/1.66
        else:
            peakFwhm = fwhm
            peakWidth = fwhm/1.66
        
        # calulate peak
        i1 = mod_signal.locate(raster, (peak.mz-5.0*peakFwhm))
        i2 = mod_signal.locate(raster, (peak.mz+5.0*peakFwhm))
        for i in range(i1, i2):
            intensities[i] += peak.intensity*numpy.exp(-1*(pow(raster[i]-peak.mz,2))/pow(peakWidth,2))
    
    # add random noise
    if noise:
        intensities += numpy.random.uniform(-noise/2., noise/2., raster.size)
    
    # make final profile
    raster.shape = (-1,1)
    intensities.shape = (-1,1)
    data = numpy.concatenate((raster, intensities), axis=1)
    data = data.copy()
    
    # make baseline
    baseline = [[peaklist[0].mz, -peaklist[0].base]]
    for peak in peaklist[1:]:
        if baseline[-1][0] != peak.mz:
            baseline.append([peak.mz, -peak.base])
    baseline = numpy.array(baseline)
    
    # shift data baseline
    data = mod_signal.subbase(data, baseline)
    
    return data
# ----


def _groupIsotopes(isotopes, window):
    """Group peaks within specified window.
        isotopes: (list of [mass, abundance]) isotopes list
        window: (float) grouping window
    """
    
    isotopes.sort()
    
    buff = []
    buff.append(isotopes[0])
    
    for current in isotopes[1:]:
        previous = buff[-1]
        if (previous[0] + window) > current[0]:
            abundance = previous[1] + current[1]
            mass = (previous[0]*previous[1] + current[0]*current[1]) / abundance
            buff[-1] = [mass, abundance]
        else:
            buff.append(current)
    
    return buff
# ----


def _xraster(mzRange, rasterRange):
    """Make m/z raster as linear gradient for given m/z range and edge points differences."""
    
    m = (rasterRange[1] - rasterRange[0]) / (mzRange[1] - mzRange[0])
    b = rasterRange[0] - m * mzRange[0]
    
    size = ((mzRange[1] - mzRange[0]) / rasterRange[0]) + 2
    raster = numpy.zeros(int(size), float)
    
    i = 0
    x = mzRange[0]
    while x <= mzRange[1]:
        raster[i] = x
        x += m*x + b
        i += 1
    
    return raster[:i].copy()
# ----


def _normalize(data):
    """Normalize data."""
    
    # get maximum Y
    maximum = data[0][1]
    for item in data:
        if item[1] > maximum:
            maximum = item[1]
    
    # normalize data data
    for x in range(len(data)):
        data[x][1] /= maximum
    
    return data
# ----


