# -------------------------------------------------------------------------
#     Copyright (C) 2005-2011 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program
# -------------------------------------------------------------------------

# load libs
import numpy
from copy import deepcopy

# load configuration
import config

# register essential objects
import objects

# register essential modules
import common
import processing


# AVERAGINE DISTRIBUTION
# ----------------------

averageAmino = {'C':4.9384, 'H':7.7583, 'N':1.3577, 'O':1.4773, 'S':0.0417}
averageBase = {'C':9.75, 'H':12.25, 'N':3.75, 'O':6, 'P':1}

def averagePattern(mass, composition=averageAmino):
    """Calculate isotopic distribution for given mass and building block.
        mass: (float) neutral mass to be modeled
        composition: (dict) building block composition
    """
    
    # get number of possible blocks for given mz value
    blockFormula = ''
    for element in composition:
        blockFormula += '%s%0.f' % (element, composition[element]*100)
    blockMass = objects.compound(blockFormula).mass()
    count = int(max(1, (round(mass/blockMass[0]*100))))
    
    # make formula
    formula = ''
    for element in composition:
        formula += '%s%0.f' % (element, composition[element]*count)
    formula = objects.compound(formula)
    
    # calc pattern
    distribution = []
    pattern = formula.pattern(relIntThreshold=0.005, fwhm=0.2)
    for peak in pattern:
        distribution.append(peak[1])
    
    return distribution
# ----



# PEAK PICKING FUNCTIONS
# ----------------------

class peakPicking():
    """Peak picking module."""
    
    def __init__(self):
        
        # common threshold parameters
        self.pickingHeight = 0.75 # (float) peak intensity height where the mass is picked in %/100
        self.snThreshold = 2.5 # (float) signal to noise threshold for picked peaks
        self.relIntThreshold = 0.0 # (float) relative basepeak intensity threshold for picked peaks in %/100
        self.absIntThreshold = 0.0 # (float) absolute intensity threshold for picked peaks
        
        # adaptive noise parameters
        self.adaptiveNoise = False # (bool) calculate noise level separately for each peak
        self.noiseWindow = 0.1 # (float) relative m/z window to use for noise calculations
        self.noiseWindowTolerance = 0.01 # (float) relative m/z interval within a same noise is used
        
        # deisotoping parameters
        self.maxCharge = 1 # (int) maximum charge searched
        self.isotopeMassTolerance = 0.15 # (float) m/z tolerance for the next isotope
        self.isotopeIntTolerance = 0.5 # (float) intensity tolerance for the next isotope in %/100
        
        # shoulder peaks removal
        self.shoulderRelThreshold = 0.05 # (float) relative intensity threshold of shoulder/parent peak in %/100
        self.shoulderWindow = 2.5 # (float) fwhm window multiplier for shoulder peaks
        
        # internal parameters
        self._buildingBlock = averageAmino
        self._patterns = {}
    # ----
    
    
    def setBuildingBlock(self, composition):
        """Set building block composition for isotopic distribution modelling."""
        
        self._buildingBlock = deepcopy(composition)
        self._patterns = {}
    # ----
    
    
    def labelScan(self, points, calculateCharge=True, removeIsotopes=True, removeUnknown=True, removeShoulders=False):
        """Find peaks in given spectrum points.
            points: (numpy.array) spectrum points
            calculateCharge: (bool) calculate peak charges
            removeIsotopes: (bool) remove isotopes
            removeUnknown: (bool) remove unknown peaks
        """
        
        # get possible peaks by local maxima
        peaklist = self._getLocalMax(points)
        
        # calculate masses by centroid model
        peaklist = self._getCentroids(points, peaklist)
        
        # filter peaklist
        peaklist = self.filterPeaklist(peaklist)
        
        # filter shoulder peaks
        if removeShoulders:
            peaklist = self.removeShoulderPeaks(points, peaklist)
        
        # calculate charge
        if calculateCharge:
            peaklist = self.calculateCharge(peaklist)
            peaklist = self.removeIsotopes(peaklist, removeIsotopes, removeUnknown)
        
        # calculate resolution
        for peak in peaklist:
            height = peak.baseline + (peak.intensity - peak.baseline) * 0.5
            peak.fwhm = self._getPeakWidth(points, peak.mz, height)
        
        return peaklist
    # ----
    
    
    def labelPeak(self, points, startMZ, stopMZ):
        """Find one peak in given m/z range.
            points: (numpy.array) spectrum points
            startMZ: (float) starting mz value
            stopMZ: (float) ending mz value
        """
        
        # get current selection
        i1 = self._getIndex(points, startMZ)
        i2 = self._getIndex(points, stopMZ)
        selection = points[i1:i2]
        
        # get local max for current selection
        localMax = [0,0]
        for point in selection:
            if point[1] > localMax[1]:
                localMax = point
        
        # get noise
        if self.adaptiveNoise:
            noiseLevel, noiseWidth = processing.noise(points, mz=localMax[0], window=self.noiseWindow)
        else:
            noiseLevel, noiseWidth = processing.noise(points)
        
        # get baseline
        baseline = 0.
        if noiseLevel:
            baseline = noiseLevel
        
        # make peak
        peak = objects.peak(mz=localMax[0], intensity=localMax[1], baseline=baseline)
        
        # get centroid
        peaklist = self._getCentroids(points, [peak])
        if peaklist:
            peak = peaklist[0]
        
        # check peak
        if peak.intensity <= peak.baseline:
            return None
        
        # get resolution
        height = peak.baseline + (peak.intensity - peak.baseline) * 0.5
        peak.fwhm = self._getPeakWidth(points, peak.mz, height)
        
        return peak
    # ----
    
    
    def labelPoint(self, points, mz):
        """Find peak for given m/z value.
            points: (numpy.array) spectrum points
            mz: (float) m/z value
        """
        
        # get intensity
        i = self._getIndex(points, mz)
        if i == 0:
            return None
        intensity = self._interpolateLine(points[i], points[i-1], mz)
        
        # get noise
        if self.adaptiveNoise:
            noiseLevel, noiseWidth = processing.noise(points, mz=mz, window=self.noiseWindow)
        else:
            noiseLevel, noiseWidth = processing.noise(points)
        
        # get baseline
        baseline = 0.
        if noiseLevel:
            baseline = noiseLevel
        
        # get S/N
        sn = None
        if noiseWidth:
            sn = (intensity-noiseLevel) / noiseWidth
            sn = round(sn,3)
        
        # check peak
        if intensity <= baseline:
            return None
        
        # get resolution
        height = baseline + (intensity - baseline) * 0.5
        fwhm = self._getPeakWidth(points, mz, height)
        
        # make peak
        peak = objects.peak(mz=mz, intensity=intensity, baseline=baseline, sn=sn, fwhm=fwhm)
        
        return peak
    # ---
    
    
    def calculateCharge(self, peaklist):
        """Calculate charges for peaks of given peaklist by isotopes differences and averagine distribution.
            peaklist: (mspy.peaklist) peaklist
        """
        
        # check peaklist and clear previous results
        if not isinstance(peaklist, objects.peaklist):
            peaklist = objects.peaklist(peaklist)
        peaklist = deepcopy(peaklist)
        for x in range(len(peaklist)):
            peaklist[x].charge = None
            peaklist[x].isotope = None
        
        # get charges
        if self.maxCharge < 0:
            charges = [-x for x in range(1, abs(self.maxCharge)+1)]
            charges.reverse()
        else:
            charges = [x for x in range(1, abs(self.maxCharge)+1)]
            charges.reverse()
        
        # walk in peaklist
        for x in range(len(peaklist)):
            
            # skip identified peaks
            if peaklist[x].isotope != None:
                continue
            
            # try all charges
            cluster = [peaklist[x]]
            for z in charges:
                isotope = 0
                
                # get isotopic pattern
                mass = common.mz(peaklist[x].mz, 0, z)
                if round(mass,-2) in self._patterns:
                    pattern = self._patterns[round(mass,-2)]
                else:
                    pattern = averagePattern(mass, self._buildingBlock)
                    self._patterns[round(mass,-2)] = pattern
                
                # search for next isotope within m/z and intensity tolerance
                y = 1
                while x+y < len(peaklist) and ((peaklist[x+y].mz - cluster[-1].mz) - (1.00287/abs(z))) <= self.isotopeMassTolerance:
                    if abs((peaklist[x+y].mz - cluster[-1].mz) - (1.00287/abs(z))) <= self.isotopeMassTolerance:
                        isotope += 1
                        
                        if not pattern or len(pattern) <= isotope:
                            break
                        
                        calcIntens = ((cluster[-1].intensity-cluster[-1].baseline) / pattern[isotope-1]) * pattern[isotope]
                        error = peaklist[x+y].intensity - peaklist[x+y].baseline - calcIntens
                        
                        if abs(error) <= (calcIntens * self.isotopeIntTolerance):
                            peaklist[x+y].isotope = isotope
                            peaklist[x+y].charge = z
                            cluster.append(peaklist[x+y])
                        elif error > 0 or (error < 0 and isotope != 1):
                            cluster.append(peaklist[x+y])
                        
                    y += 1
                
                # skip other charges if one isotope at least was found
                if len(cluster) > 1:
                    peaklist[x].charge = z
                    peaklist[x].isotope = 0
                    break
        
        # prvent patterns overload
        if len(self._patterns) > 100:
            self._patterns = {}
        
        return peaklist
    # ----
    
    
    def filterPeaklist(self, peaklist):
        """Remove peaks below threshold.
            peaklist: (mspy.peaklist) peaklist to be filtered
        """
        
        # check peaklist
        if len(peaklist) == 0:
            return peaklist
        if not isinstance(peaklist, objects.peaklist):
            peaklist = objects.peaklist(peaklist)
        
        # get absolute threshold
        threshold = (peaklist.basePeak.intensity - peaklist.basePeak.baseline) * self.relIntThreshold
        threshold = max(threshold, self.absIntThreshold)
        
        # check peaks
        buff = []
        for peak in peaklist:
            if (peak.intensity - peak.baseline) >= threshold \
                and (peak.sn==None or peak.sn >= self.snThreshold):
                buff.append(peak)
        
        return objects.peaklist(buff)
    # ----
    
    
    def removeShoulderPeaks(self, points, peaklist):
        """Remove shoulder peaks from FTMS data.
            points: (numpy.array) spectrum points
            peaklist: (mspy.peaklist) peaklist to be filtered
        """
        
        # check points
        if len(points) == 0:
            return peaklist
        
        # check peaklist
        if len(peaklist) == 0:
            return peaklist
        if not isinstance(peaklist, objects.peaklist):
            peaklist = objects.peaklist(peaklist)
        
        # get possible parent peaks
        candidates = []
        for peak in peaklist:
            if not peak.sn or peak.sn*self.shoulderRelThreshold > 3:
                candidates.append(peak)
            
        # filter shoulder peaks
        remove = []
        for parent in candidates:
            
            # get fwhm
            if not parent.fwhm:
                height = parent.baseline + (parent.intensity - parent.baseline) * 0.5
                parent.fwhm = self._getPeakWidth(points, parent.mz, height)
            if not parent.fwhm:
                continue
            
            # get shoulder range
            lowMZ = parent.mz - parent.fwhm * self.shoulderWindow
            highMZ = parent.mz + parent.fwhm * self.shoulderWindow
            intThreshold = (parent.intensity - parent.baseline) * self.shoulderRelThreshold
            
            # remember shoulder peaks to be removed
            for x, peak in enumerate(peaklist):
                if (lowMZ < peak.mz < highMZ) and ((peak.intensity - peak.baseline) < intThreshold) and (not x in remove):
                    remove.append(x)
                if peak.mz > highMZ:
                    break
        
        # remove peaks
        peaklist.delete(indexes=remove)
        return peaklist
    # ----
    
    
    def removeIsotopes(self, peaklist, isotopes=True, unknown=True):
        """Remove isotopes and unknown peaks."""
        
        # do nothing
        if not isotopes and not unknown:
            return peaklist
        
        # remove isotopes
        if isotopes:
            buff = []
            for peak in peaklist:
                if peak.isotope == 0 or peak.charge == None:
                    buff.append(peak)
            peaklist = buff
        
        # remove unknown peaks
        if unknown:
            buff = []
            for peak in peaklist:
                if peak.charge != None:
                    buff.append(peak)
            peaklist = buff
        
        peaklist = objects.peaklist(buff)
        return peaklist
    # ----
    
    
    def _getPeakWidth(self, points, mz, intensity):
        """Get peak width for given m/z and height.
            mz: (float) peak m/z value
            intensity: (float) intensity of measurement
        """
        
        # check data
        if len(points) == 0:
            return None
        
        # get mz index
        index = self._getIndex(points, mz)
        if not index or index == len(points):
            return None
        
        # get nearest lower
        leftIndx = index
        while leftIndx > 0:
            if points[leftIndx][1] <= intensity:
                break
            leftIndx -= 1
        
        rightIndx = index
        while rightIndx < (len(points)-1):
            if points[rightIndx][1] <= intensity:
                break
            rightIndx += 1
        
        # get mz
        leftMZ = self._interpolateLine(points[leftIndx], points[leftIndx+1], y=intensity)
        rightMZ = self._interpolateLine(points[rightIndx-1], points[rightIndx], y=intensity)
        
        return rightMZ - leftMZ
    # ----
    
    
    def _getLocalMax(self, points):
        """Get possible peaks from the points using local maxima."""
        
        buff = []
        localMax = points[0]
        growing = True
        noiseLevel, noiseWidth = processing.noise(points)
        previous = None
        basePeak = 0
        
        # get local maxima
        for point in points[1:]:
            
            # remember last maximum
            if localMax[1] <= point[1]:
                localMax = point
                growing = True
            
            # store local maximum
            elif growing and localMax[1] > point[1]:
                peak = [localMax[0], localMax[1], 0., None]
                
                # use adaptive baseline
                if self.adaptiveNoise and (not previous or ((peak[0] - previous) > (peak[0] * self.noiseWindowTolerance))):
                    noiseLevel, noiseWidth = processing.noise(points, mz=peak[0], window=self.noiseWindow)
                    previous = peak[0]
                
                # set baseline and sn
                if noiseLevel != None:
                    peak[2] = noiseLevel
                    if noiseWidth:
                        peak[3] = round(((localMax[1]-noiseLevel) / noiseWidth), 3)
                
                # store peak
                buff.append(peak)
                basePeak = max(basePeak, peak[1]-peak[2])
                localMax = point
                growing = False
                
            else:
                localMax = point
        
        # filter peaklist
        threshold = max(basePeak * self.relIntThreshold, self.absIntThreshold)
        peaklist = []
        for peak in buff:
            if (peak[1] - peak[2]) >= threshold and (peak[3]==None or peak[3] >= self.snThreshold):
                peaklist.append(objects.peak(peak[0], peak[1], baseline=peak[2], sn=peak[3]))
        
        return objects.peaklist(peaklist)
    # ----
    
    
    def _getCentroids(self, points, peaklist):
        """Make centroided peaks for given peaklist."""
        
        # check peaklist
        if not isinstance(peaklist, objects.peaklist):
            peaklist = objects.peaklist(peaklist)
        if len(peaklist) == 0:
            return peaklist
        
        # walk in peaklist
        buff = []
        previous = None
        for peak in peaklist:
            height = ((peak.intensity-peak.baseline) * self.pickingHeight) + peak.baseline
            
            # get indexes
            index = self._getIndex(points, peak.mz)
            if not (0 < index < len(points)):
                continue
            
            leftIndx = index-1
            while leftIndx > 0:
                if points[leftIndx][1] <= height:
                    break
                leftIndx -= 1
            
            rightIndx = index
            while rightIndx < (len(points)-1):
                if points[rightIndx][1] <= height:
                    break
                rightIndx += 1
            
            # get mz
            leftMZ = self._interpolateLine(points[leftIndx], points[leftIndx+1], y=height)
            rightMZ = self._interpolateLine(points[rightIndx-1], points[rightIndx], y=height)
            peak.mz = (leftMZ + rightMZ)/2
            
            # get intensity
            index = self._getIndex(points, peak.mz)
            if not (0 < index < len(points)):
                continue
            
            intensity = self._interpolateLine(points[index-1], points[index], x=peak.mz)
            if intensity <= peak.intensity:
                peak.intensity = intensity
            else:
                continue
            
            # try to group with previous peak
            if previous != None and leftMZ < previous:
                if peak.intensity > buff[-1].intensity:
                    buff[-1] = peak
                    previous = rightMZ
            else:
                buff.append(peak)
                previous = rightMZ
        
        # calculate baselines and s/n
        previous = None
        noiseLevel, noiseWidth = processing.noise(points)
        for peak in buff:
            if self.adaptiveNoise and (not previous or ((peak.mz - previous) > (peak.mz * self.noiseWindowTolerance))):
                noiseLevel, noiseWidth = processing.noise(points, mz=peak.mz, window=self.noiseWindow)
                previous = peak.mz
            if noiseLevel != None:
                peak.baseline = noiseLevel
                if noiseWidth:
                    peak.sn = round(((peak.intensity-noiseLevel) / noiseWidth),3)
        
        return objects.peaklist(buff)
    # ----
    
    
    def _getIndex(self, points, x):
        """Get nearest index for selected point."""
        
        lo = 0
        hi = len(points)
        while lo < hi:
            mid = (lo + hi) / 2
            if x < points[mid][0]:
                hi = mid
            else:
                lo = mid + 1
        
        return lo
    # ----
    
    
    def _interpolateLine(self, p1, p2, x=None, y=None):
        """Get line interpolated X or Y value."""
        
        # check points
        if p1[0] == p2[0] and x!=None:
            return max(p1[1], p2[1])
        elif p1[0] == p2[0] and y!=None:
            return p1[0]
        
        # get equation
        m = (p2[1] - p1[1])/(p2[0] - p1[0])
        b = p1[1] - m * p1[0]
        
        # get point
        if x != None:
            return m * x + b
        elif y != None:
            return (y - b) / m
    # ----
    
    

