# -------------------------------------------------------------------------
#     Copyright (C) 2005-2011 Martin Strohalm <www.mmass.org>

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 3 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE.TXT in the
#     main directory of the program
# -------------------------------------------------------------------------

#load libs
import re
import numpy
from copy import copy, deepcopy

# load configuration
import config

# register essential objects
import blocks

# register essential modules
import common
import processing


# compile basic patterns
formulaPattern = re.compile(r'''
    ^(
        ([\(])* # start parenthesis
        (
            ([A-Z][a-z]{0,2}) # atom symbol
            (\{[\d]+\})? # isotope
            (([\-][\d]+)|[\d]*) # atom count
        )+
        ([\)][\d]*)* # end parenthesis
    )*$
''', re.X)

elementPattern = re.compile(r'''
            ([A-Z][a-z]{0,2}) # atom symbol
            (?:\{([\d]+)\})? # isotope
            ([\-]?[\d]*) # atom count
''', re.X)


# BASIC OBJECTS DEFINITION
# ------------------------

class compound:
    """Compound object definition."""
    
    def __init__(self, rawFormula):
        
        self._mass = None
        self._composition = None
        
        self.rawFormula = rawFormula
        
        # check formula
        if not formulaPattern.match(self.rawFormula):
            raise ValueError, 'Wrong formula! --> ' + self.rawFormula
        
        # check elements and isotopes
        for atom in elementPattern.findall(self.rawFormula):
            if not atom[0] in blocks.elements:
                raise ValueError, 'Unknown element in formula! --> ' + atom[0] + ' in ' + self.rawFormula
            elif atom[1] and not atom[1] in blocks.elements[atom[0]].isotopes:
                raise ValueError, 'Unknown isotope in formula! --> ' + atom[0] + atom[1] + ' in ' + self.rawFormula
        
        # check brackets
        if self.rawFormula.count(')') != self.rawFormula.count('('):
            raise ValueError, 'Wrong number of brackets in formula! --> ' + self.rawFormula
        
        # user defined params
        self.userParams = {}
    # ----
    
    
    def formula(self):
        """Get formula."""
        
        # get composition
        comp = self.composition()
        
        # format composition
        buff = ''
        for el in sorted(comp.keys()):
            if comp[el] == 1:
                buff += el
            else:
                buff += '%s%d' % (el, comp[el])
        
        return buff
    # ----
    
    
    def composition(self):
        """Get elemental composition as dict."""
        
        # check mass buffer
        if self._composition != None:
            return self._composition
        
        # unfold brackets
        unfoldedFormula = self._unfoldBrackets(self.rawFormula)
        
        # group same elements
        self._composition = {}
        for symbol, isotop, count in elementPattern.findall(unfoldedFormula):
            
            # make atom
            if isotop:
                atom = '%s{%s}' % (symbol, isotop)
            else:
                atom = symbol
            
            # convert counting
            if count:
                count = int(count)
            else:
                count = 1
            
            # add atom
            if atom in self._composition:
                self._composition[atom] += count
            else:
                self._composition[atom] = count
        
        # remove zeros
        for atom in self._composition.keys():
            if self._composition[atom] == 0:
                del self._composition[atom]
        
        return self._composition
    # ----
    
    
    def mass(self, massType=None):
        """Get mass.
            massType: monoisotopic (0 or 'mo') or average (1, 'av')
        """
        
        # get mass
        if self._mass == None:
            massMo = 0
            massAv = 0
            
            # get composition
            comp = self.composition()
            
            # get mass for each atom
            for atom in comp:
                count = comp[atom]
                
                # check specified isotope and mass
                match = elementPattern.match(atom)
                symbol, massNumber, tmp = match.groups()
                if massNumber:
                    isotope = blocks.elements[symbol].isotopes[massNumber]
                    atomMass = (isotope[0], isotope[0])
                else:
                    atomMass = blocks.elements[symbol].mass
                
                # multiply atom
                massMo += atomMass[0]*count
                massAv += atomMass[1]*count
            
            # store mass in buffer
            self._mass = (massMo, massAv)
        
        # return mass
        if massType == 0 or massType == 'mo':
            return self._mass[0]
        elif massType == 1 or massType == 'av':
            return self._mass[1]
        else:
            return (self._mass[0], self._mass[1])
    # ----
    
    
    def mz(self, charge, agentFormula='H', agentCharge=1):
        """Get ion m/z."""
        
        # get current mass and calculate mz
        return common.mz(self.mass(), charge, agentFormula=agentFormula, agentCharge=agentCharge)
    # ----
    
    
    def pattern(self, fwhm=0.1, relIntThreshold=0.01, charge=0, agentFormula='H', agentCharge=1, rounding=7):
        """Get isotopic pattern."""
        
        return common.pattern(self, \
            fwhm=fwhm, \
            relIntThreshold=relIntThreshold, \
            charge=charge, \
            agentFormula=agentFormula, \
            agentCharge=agentCharge, \
            rounding=rounding)
    # ----
    
    
    def validate(self, charge=0, agentFormula='H', agentCharge=1):
        """Utility to check ion composition."""
        
        # check agent formula
        if agentFormula != 'e' and not isinstance(agentFormula, compound):
            agentFormula = compound(agentFormula)
        
        # make ion compound
        if charge and agentFormula != 'e':
            ionFormula = self.rawFormula
            for atom, count in agentFormula.composition().items():
                ionFormula += '%s%d' % (atom, count*(charge/agentCharge))
            ion = compound(ionFormula)
        else:
            ion = compound(self.rawFormula)
        
        # get composition
        for atom, count in ion.composition().items():
            if count < 0:
                return False
        
        return True
    # ----
    
    
    def _unfoldBrackets(self, string):
        """Unfold formula and count each atom."""
        
        unfoldedFormula = ''
        brackets = [0,0]
        enclosedFormula = ''
        
        i = 0
        while i < len(string):
            
            # handle brackets
            if string[i] == '(':
                brackets[0] += 1
            elif string[i] == ')':
                brackets[1] += 1
            
            # part outside brackets
            if brackets == [0,0]:
                unfoldedFormula += string[i]
            
            # part within brackets
            else:
                enclosedFormula += string[i]
                
                # unfold part within brackets
                if brackets[0] == brackets[1]:
                    enclosedFormula = self._unfoldBrackets(enclosedFormula[1:-1])
                    
                    # multiply part within brackets
                    count = ''
                    while len(string)>(i+1) and string[i+1].isdigit():
                        count += string[i+1]
                        i += 1
                    if count:
                        enclosedFormula = enclosedFormula * int(count)
                    
                    # add and clear
                    unfoldedFormula += enclosedFormula
                    enclosedFormula = ''
                    brackets = [0,0]
            
            i += 1
        return unfoldedFormula
    # ----
    


class sequence:
    """Sequence object definition."""
    
    def __init__(self, chain, title='', nTermFormula='H', cTermFormula='OH', lossFormula=''):
        
        self._mass = None # (monoisotopic, average)
        self._composition = None
        
        # get chain
        chain = chain.upper()
        for char in ('\t','\n','\r','\f','\v', ' ', '-', '*', '.'):
                chain = chain.replace(char, '')
        self.chain = chain.upper()
        
        self.nTermFormula = nTermFormula
        self.cTermFormula = cTermFormula
        self.modifications = [] # [[name, position=[#|symbol], state[f|v]], ] (f-fixed, v-variable)
        self.labels = [] # [[name, position=[#|symbol], state[f|v]], ] (f-fixed, v-variable)
        
        # for proteins
        self.title = title
        self.accession = None
        self.orgName = None
        self.pi = None
        self.score = None
        
        # for peptides
        self.userRange = []
        self.aaBefore = ''
        self.aaAfter = ''
        self.miscleavages = 0
        
        # for fragments
        self.fragmentSerie = None
        self.fragmentIndex = None
        self.fragmentFiltered = False
        self.lossFormula = lossFormula
        
        # user defined params
        self.userParams = {}
        
        # check amino acids
        for amino in self.chain:
            if not amino in blocks.aminoacids:
                raise ValueError, 'Unknown amino acid in sequence! --> ' + amino
    # ----
    
    
    def __len__(self):
        """Get sequence length."""
        return len(self.chain)
    # ----
    
    
    def __getitem__(self, i):
        return self.chain[i]
    # ----
    
    
    def __getslice__(self, start, stop):
        """Get slice of the sequence."""
        
        # check slice
        if stop < start:
            raise ValueError, 'Invalid sequence slice definition!'
        
        # break the links
        parent = deepcopy(self)
        
        # check slice
        start = max(start, 0)
        stop = min(stop, len(parent.chain))
        
        # make new sequence object
        seq = parent.chain[start:stop]
        peptide = sequence(seq)
        
        # add modifications
        for mod in parent.modifications:
            if type(mod[1]) == int and mod[1] >= start and mod[1] < stop:
                mod[1] -= start
                peptide.modifications.append(mod)
            elif type(mod[1]) in (str, unicode) and mod[1] in peptide.chain:
                peptide.modifications.append(mod)
        
        # add labels
        for mod in parent.labels:
            if type(mod[1]) == int and mod[1] >= start and mod[1] < stop:
                mod[1] -= start
                peptide.labels.append(mod)
            elif type(mod[1]) in (str, unicode) and mod[1] in peptide.chain:
                peptide.labels.append(mod)
        
        # set range in user coordinates
        peptide.userRange = [start+1, stop]
        
        # set adjacent amino acids
        if start > 0:
            peptide.aaBefore = parent.chain[start-1]
        if stop < len(parent.chain):
            peptide.aaAfter = parent.chain[stop]
        
        # add terminal modifications
        if start == 0:
            peptide.nTermFormula = parent.nTermFormula
        if stop >= len(parent.chain):
            peptide.cTermFormula = parent.cTermFormula
        
        return peptide
    # ----
    
    
    def __setslice__(self, start, stop, value):
        """Insert sequence object."""
        
        # check slice
        if stop < start:
            raise ValueError, 'Invalid slice!'
        
        # check value
        if not isinstance(value, sequence):
            raise TypeError, 'Invalid object to instert!'
        
        # brake the links
        value = deepcopy(value)
        
        # delete slice
        if stop != start:
            del(self[start:stop])
        
        # insert sequence
        self.chain = self.chain[:start] + value.chain + self.chain[start:]
        
        # shift modifications
        for x, mod in enumerate(self.modifications):
            if type(mod[1]) == int and mod[1] >= start:
                self.modifications[x][1] += (len(value))
        
        # shift labels
        for x, mod in enumerate(self.labels):
            if type(mod[1]) == int and mod[1] >= start:
                self.labels[x][1] += (len(value))
        
        # insert modifications
        for mod in value.modifications:
            if type(mod[1]) == int:
                mod[1] += start
            self.modifications.append(mod)
        
        # insert labels
        for mod in value.labels:
            if type(mod[1]) == int:
                mod[1] += start
            self.labels.append(mod)
        
        # clear
        self._mass = None
        self._composition = None
        self.userRange = []
        self.aaBefore = ''
        self.aaAfter = ''
        self.miscleavages = 0
    # ----
    
    
    def __delslice__(self, start, stop):
        """Delete slice of the sequence"""
        
        # check slice
        if stop < start:
            raise ValueError, 'Invalid slice!'
        
        # remove sequence
        self.chain = self.chain[:start] + self.chain[stop:]
        
        # remove modifications
        keep = []
        for mod in self.modifications:
            if type(mod[1]) == int and (mod[1] < start or mod[1] >= stop):
                if mod[1] >= stop:
                    mod[1] -= (stop - start)
                keep.append(mod)
            elif type(mod[1]) in (str, unicode) and mod[1] in self.chain:
                keep.append(mod)
        self.modifications = keep
        
        # remove labels
        keep = []
        for mod in self.labels:
            if type(mod[1]) == int and (mod[1] < start or mod[1] >= stop):
                if mod[1] >= stop:
                    mod[1] -= (stop - start)
                keep.append(mod)
            elif type(mod[1]) in (str, unicode) and mod[1] in self.chain:
                keep.append(mod)
        self.labels = keep
        
        # clear
        self._mass = None
        self._composition = None
        self.userRange = []
        self.aaBefore = ''
        self.aaAfter = ''
        self.miscleavages = 0
    #----
    
    
    def __add__(self, value):
        """Join sequences and return result."""
        
        # check value
        if not isinstance(value, sequence):
            raise TypeError, 'Invalid object to join with sequence!'
        
        # join sequences
        result = self[:]
        result[len(result):] = value
        
        # set C terminus
        result.cTermFormula = value.cTermFormula
        
        # set neutral loss
        result.lossFormula = value.lossFormula
        
        # clear
        result._mass = None
        result._composition = None
        result.userRange = []
        result.aaBefore = ''
        result.aaAfter = ''
        result.miscleavages = 0
        
        return result
    # ----
    
    
    def __iter__(self):
        self._index = 0
        return self
    # ----
    
    
    def next(self):
        if self._index < len(self.chain):
            self._index += 1
            return self.chain[self._index-1]
        else:
            raise StopIteration
    # ----
    
    
    def count(self, item):
        """Count item in the chain."""
        return self.chain.count(item)
    # ----
    
    
    def formula(self):
        """Get formula."""
        
        # get composition
        comp = self.composition()
        
        # format composition
        buff = ''
        for el in sorted(comp.keys()):
            if comp[el] == 1:
                buff += el
            else:
                buff += '%s%d' % (el, comp[el])
        
        return buff
    # ----
    
    
    def composition(self):
        """Get elemental composition as dict."""
        
        # check mass buffer
        if self._composition != None:
            return self._composition
        
        plusFormula = ''
        minusFormula = ''
        
        # add amino acids to formula
        for amino in self.chain:
            plusFormula += blocks.aminoacids[amino].formula
        
        # add modifications and labels
        mods = self.modifications + self.labels
        for name, position, state in mods:
            count = 1
            if type(position) in (str, unicode) and position !='':
                count = self.chain.count(position)
            plusFormula += count*blocks.modifications[name].gainFormula
            minusFormula += count*blocks.modifications[name].lossFormula
        
        # add terminal modifications
        plusFormula += self.nTermFormula + self.cTermFormula
        
        # subtract neutral loss for fragments
        minusFormula += self.lossFormula
        
        # get compositions
        self._composition = compound(plusFormula).composition()
        minusComposition = compound(minusFormula).composition()
        
        # subtract minus formula
        for atom in minusComposition:
            if atom in self._composition:
                self._composition[atom] -= minusComposition[atom]
            else:
                self._composition[atom] = minusComposition[atom]
        
        # remove zeros
        for atom in self._composition.keys():
            if self._composition[atom] == 0:
                del self._composition[atom]
        
        return self._composition
    # ----
    
    
    def mass(self, massType=None):
        """Get mass.
            massType: monoisotopic (0 or 'mo') or average (1, 'av')
        """
        
        # get mass
        if self._mass == None:
            comp = self.formula()
            m = compound(comp).mass()
            self._mass = (m[0], m[1])
        
        # return mass
        if massType == 0 or massType == 'mo':
            return self._mass[0]
        elif massType == 1 or massType == 'av':
            return self._mass[1]
        else:
            return (self._mass[0], self._mass[1])
    # ----
    
    
    def mz(self, charge, agentFormula='H', agentCharge=1):
        """Get ion m/z"""
        
        # get current mass and calculate mz
        return common.mz(self.mass(), charge, agentFormula=agentFormula, agentCharge=agentCharge)
    # ----
    
    
    def pattern(self, fwhm=0.1, relIntThreshold=0.01, charge=0, agentFormula='H', agentCharge=1, rounding=7):
        """Get isotopic pattern."""
        
        return common.pattern(self, \
            fwhm=fwhm, \
            relIntThreshold=relIntThreshold, \
            charge=charge, \
            agentFormula=agentFormula, \
            agentCharge=agentCharge, \
            rounding=rounding)
    # ----
    
    
    def format(self, template='S [m]'):
        """Get formated sequence."""
        
        # make keys
        keys = {}
        keys['s'] = self.chain.lower()
        keys['S'] = self.chain
        keys['N'] = self.nTermFormula
        keys['C'] = self.cTermFormula
        keys['b'] = self.aaBefore.lower()
        keys['B'] = self.aaBefore
        keys['a'] = self.aaAfter.lower()
        keys['A'] = self.aaAfter
        keys['m'] = self._formatModifications()
        keys['M'] = self._formatModifications('all')
        keys['l'] = self._formatModifications('labels')
        keys['p'] = self.miscleavages
        
        if self.userRange:
            keys['r'] = '%s-%s' % tuple(self.userRange)
        
        if self.fragmentSerie != None and self.fragmentIndex != None:
            keys['f'] = '%s %s' % (self.fragmentSerie, self.fragmentIndex)
        elif self.fragmentSerie != None:
            keys['f'] =  self.fragmentSerie
        
        # format
        buff = ''
        for char in template:
            if char in keys:
                buff += keys[char]
            else:
                buff += char
        
        # clear format
        buff = buff.replace('[]', '')
        buff = buff.replace('()', '')
        buff = buff.strip()
        
        return buff
    # ----
    
    
    def modify(self, name, position, state='f'):
        """Apply modification to sequence."""
        
        # check modification
        if not name in blocks.modifications:
            raise KeyError, 'Unknown modification! --> ' + name
        
        # check position
        try: position = int(position)
        except: position = str(position)
        if type(position) == str and self.chain.count(position) == 0:
            return False
        
        # add modification
        self.modifications.append([name, position, str(state)])
        self._mass = None
        self._composition = None
        
        return True
    # ----
    
    
    def unmodify(self, name=None, position=None, state='f'):
        """Remove modification from sequence."""
        
        # remove all modifications
        if name == None:
            del self.modifications[:]
        
        # remove modification
        else:
            try: mod = [name, int(position), str(state)]
            except: mod = [name, str(position), str(state)]
            while mod in self.modifications:
                i = self.modifications.index(mod)
                del self.modifications[i]
        
        self._mass = None
        self._composition = None
    # ----
    
    
    def label(self, name, position, state='f'):
        """Apply label modification to sequence."""
        
        # check modification
        if not name in blocks.modifications:
            raise KeyError, 'Unknown modification! --> ' + name
        
        # check position
        try: position = int(position)
        except: position = str(position)
        if type(position) == str and self.chain.count(position) == 0:
            return False
        
        # add label
        self.labels.append([name, position, state])
        self._mass = None
        self._composition = None
        
        return True
    # ----
    
    
    def digest(self, enzyme, miscleavage=0, allowMods=False, strict=True):
        """Digest seuence by specified enzyme.
            enzyme: (str) enzyme name - must be defined in mspy.enzymes
            miscleavage: (int) number of allowed misscleavages
            allowMods: (bool) do not care about modifications in cleavage site
            strict: (bool) do not cleave even if variable modification is in cleavage site
        """
        
        # get enzyme
        if enzyme in blocks.enzymes:
            enzyme = blocks.enzymes[enzyme]
            expression = re.compile(enzyme.expression+'$')
        else:
            raise KeyError, 'Unknown enzyme! -> ' + enzyme
        
        # get digest indices
        slices = [] # from | to | miscl
        lastIndex = 0
        peptide = ''
        for x, aa in enumerate(self.chain):
            peptide += aa
            if expression.search(peptide):
                
                # skip not allowed modifications
                if not allowMods and self.isModified(x-1, strict) and not enzyme.modsBefore:
                    continue
                elif not allowMods and self.isModified(x, strict) and not enzyme.modsAfter:
                    continue
                else:
                    slices.append((lastIndex, x, 0))
                    lastIndex = x
        
        # add last peptide
        slices.append((lastIndex, x+1, 0))
        
        # add indices for partials
        indices = len(slices)
        for x in range(indices):
            for y in range(1, miscleavage+1):
                if x+y < indices:
                    slices.append((slices[x][0], slices[x+y][1], y))
                else:
                    break
        
        # get peptides slices from protein
        peptides = []
        for indices in slices:
            peptide = self[indices[0]:indices[1]]
            peptide.miscleavages = indices[2]
            
            # add terminal groups
            if indices[0] != 0:
                peptide.nTermFormula = enzyme.nTermFormula
            if indices[1] != len(self.chain):
                peptide.cTermFormula = enzyme.cTermFormula
            
            peptides.append(peptide)
        
        return peptides
    # ----
    
    
    def fragment(self, serie, index=None):
        """Generate list of neutral peptide fragments from given peptide.
            serie: (str) fragment serie name - must be defined in mspy.fragments
            index: (int) fragment index
        """
        
        frags = []
        length = len(self.chain)
        
        # get serie definition
        if serie in blocks.fragments:
            serie = blocks.fragments[serie]
        else:
            raise KeyError, 'Unknown fragment type! -> ' + serie
        
        # N-terminal fragments
        if serie.terminus == 'N':
            if index != None:
                frag = self[:index]
                frag.cTermFormula = serie.cTermFormula
                frag.lossFormula = serie.lossFormula
                frag.fragmentSerie = serie.name
                frag.fragmentIndex = index
                frags.append(frag)
            else:
                for x in range(length):
                    frag = self[:x+1]
                    frag.cTermFormula = serie.cTermFormula
                    frag.lossFormula = serie.lossFormula
                    frag.fragmentSerie = serie.name
                    frag.fragmentIndex = (x+1)
                    frags.append(frag)
        
        # C-terminal fragments
        elif serie.terminus == 'C':
            if index != None:
                frag = self[length-index:]
                frag.nTermFormula = serie.nTermFormula
                frag.lossFormula = serie.lossFormula
                frag.fragmentSerie = serie.name
                frag.fragmentIndex = index
                frags.append(frag)
            else:
                for x in range(length):
                    frag = self[length-(x+1):]
                    frag.nTermFormula = serie.nTermFormula
                    frag.lossFormula = serie.lossFormula
                    frag.fragmentSerie = serie.name
                    frag.fragmentIndex = (x+1)
                    frags.append(frag)
        
        # singlet fragments
        elif serie.terminus == 'S':
            if index != None:
                frag = self[index-1:index]
                frag.nTermFormula = serie.nTermFormula
                frag.cTermFormula = serie.cTermFormula
                frag.lossFormula = serie.lossFormula
                frag.fragmentSerie = serie.name
                frag.fragmentIndex = index
                frags.append(frag)
            else:
                for x in range(length):
                    frag = self[x:x+1]
                    frag.nTermFormula = serie.nTermFormula
                    frag.cTermFormula = serie.cTermFormula
                    frag.lossFormula = serie.lossFormula
                    frag.fragmentSerie = serie.name
                    frag.fragmentIndex = (x+1)
                    frags.append(frag)
        
        # internal fragments
        elif serie.terminus == 'I':
            if index != None:
                raise ValueError, 'No index allowed for this serie! -> ' + serie.name
            else:
                for x in range(1,length-1):
                    for y in range(2,length-x):
                        frag = self[x:x+y]
                        frag.nTermFormula = serie.nTermFormula
                        frag.cTermFormula = serie.cTermFormula
                        frag.lossFormula = serie.lossFormula
                        frag.fragmentSerie = serie.name
                        frags.append(frag)
        else:
            raise ValueError, 'Unknown fragment terminus! -> ' + serie.terminus
        
        # check fragment specifity
        for x in range(len(frags)):
            frags[x].fragmentFiltered = True
            for aa in serie.specifity:
                if aa in frags[x].chain:
                    frags[x].fragmentFiltered = False
                    break
        
        # filter nonsense fragments
        if frags and serie.terminus == 'N':
            if serie.termFilter[0] and (index==None or index==1):
                frags[0].fragmentFiltered = True
            if serie.termFilter[1] and (index==None or index==len(self.chain)):
                frags[-1].fragmentFiltered = True
        elif frags and serie.terminus == 'C':
            if serie.termFilter[0] and (index==None or index==len(self.chain)):
                frags[-1].fragmentFiltered = True
            if serie.termFilter[1] and (index==None or index==1):
                frags[0].fragmentFiltered = True
        elif frags and serie.terminus == 'S':
            if serie.termFilter[0] and (index==None or index==1):
                frags[0].fragmentFiltered = True
            if serie.termFilter[1] and (index==None or index==len(self.chain)):
                frags[-1].fragmentFiltered = True
        
        return frags
    # ----
    
    
    def search(self, mass, charge, tolerance, enzyme=None, tolUnits='Da', massType='mo', maxMods=1, position=False):
        """Search sequence for specified ion.
            mass: (float) m/z value to search for
            charge: (int) charge of the m/z value
            tolerance: (float) mass tolerance
            tolUnits: ('Da', 'ppm') tolerance units
            enzyme: (str) enzyme used for peptides endings, if None H/OH is used
            massType: ('mo' or 'av') mass type of the mass value
            maxMods: (int) maximum number of modifications at one residue
            position: (bool) retain position for variable modifications (much slower)
        """
        
        matches = []
        
        # set mass type
        if massType == 'mo':
            massType = 0
        elif massType == 'av':
            massType = 1
        
        # set terminal modifications
        if enzyme:
            enzyme = blocks.enzymes[enzyme]
            nTerm = enzyme.nTermFormula
            cTerm = enzyme.cTermFormula
        else:
            nTerm = 'H'
            cTerm = 'OH'
        
        # set mass limits
        if tolUnits == 'ppm':
            lowMass = mass - (tolerance * mass/1000000)
            highMass = mass + (tolerance * mass/1000000)
        else:
            lowMass = mass - tolerance
            highMass = mass + tolerance
        
        # search sequence
        length = len(self)
        for i in range(length):
            for j in range(i+1, length+1):
                
                # get peptide
                peptide = self[i:j]
                if i != 0:
                    peptide.nTerminalFormula = nTerm
                if j != length:
                    peptide.cTerminalFormula = cTerm
                
                # variate modifications
                variants = peptide.variations(maxMods=maxMods, position=position)
                
                # check mass limits
                peptides = []
                masses = []
                for pep in variants:
                    pepMZ = pep.mz(charge)[massType]
                    peptides.append((pepMZ, pep))
                    masses.append(pepMZ)
                if max(masses) < lowMass:
                    continue
                elif min(masses) > highMass:
                    break
                
                # search for matches
                for pep in peptides:
                    if lowMass <= pep[0] <= highMass:
                        matches.append(pep[1])
        
        return matches
    # ----
    
    
    def variations(self, maxMods=1, position=True, enzyme=None):
        """Calculate all possible combinations of variable modifications.
            maxMods: (int) maximum modifications allowed per one residue
            position: (bool) retain modifications positions (much slower)
            enzyme: (str) enzyme name to ensure that modifications are not presented in cleavage site
        """
        
        variablePeptides = []
        
        # get modifications
        fixedMods = []
        variableMods = []
        for mod in self.modifications:
            if mod[2] == 'f':
                fixedMods.append(mod)
            elif type(mod[1]) == int:
                variableMods.append(mod)
            else:
                if not position:
                    variableMods += [mod] * self.chain.count(mod[1])
                else:
                    for x, amino in enumerate(self.chain):
                        if amino == mod[1]:
                            variableMods.append([mod[0], x, 'v'])
        
        # make combinations of variable modifications
        variableMods = self._countUniqueModifications(variableMods)
        combinations = []
        for x in self._uniqueCombinations(variableMods):
            combinations.append(x)
        
        # disable positions occupied by fixed modifications
        occupied = []
        for mod in fixedMods:
            count = max(1, self.chain.count(str(mod[1])))
            occupied += [mod[1]]*count
        
        # disable modifications at cleavage sites
        if enzyme:
            enz = blocks.enzymes[enzyme]
            if not enz.modsBefore and self.aaAfter:
                occupied += [len(self.chain)-1]*maxMods
            if not enz.modsAfter and self.aaBefore:
                occupied += [0]*maxMods
        
        # filter modifications
        buff = []
        for combination in combinations:
            positions = occupied[:]
            for mod in combination:
                positions += [mod[0][1]]*mod[1]
            if self._checkModifications(positions, self.chain, maxMods):
                buff.append(combination)
        combinations = buff
        
        # format modifications and filter the same
        buff = []
        for combination in combinations:
            mods = []
            for mod in combination:
                if position:
                    mods += [[mod[0][0], mod[0][1], 'f']]*mod[1]
                else:
                    mods += [[mod[0][0],'','f']]*mod[1]
            mods.sort()
            if not mods in buff:
                buff.append(mods)
        combinations = buff
        
        # make new peptides
        for combination in combinations:
            variablePeptide = deepcopy(self)
            variablePeptide.modifications = fixedMods+combination
            variablePeptide.modifications.sort()
            variablePeptides.append(variablePeptide)
        
        return variablePeptides
    # ----
    
    
    def isModified(self, position=None, strict=False):
        """Check if selected amino acid or whole sequence has any modification.
            position: (int) amino acid index
            strict: (bool) check variable modifications
        """
        
        # check specified position only
        if position != None:
            for mod in self.modifications:
                if (strict or mod[2]=='f') and (mod[1] == position or mod[1] == self.chain[position]):
                    return True
        
        # check whole sequence
        else:
            for mod in self.modifications:
                if strict or mod[2]=='f':
                    return True
        
        # not modified
        return False
    # ----
    
    
    def validate(self, charge=0, agentFormula='H', agentCharge=1):
        """Utility to check ion composition."""
        
        formula = compound(self.formula())
        return formula.validate(charge=charge, agentFormula=agentFormula, agentCharge=agentCharge)
    # ----
    
    
    def _formatModifications(self, modType='modifications'):
        """Format modifications."""
        
        # get modifications
        modifications = []
        if modType in ('modifications', 'all'):
            modifications += self.modifications
        if modType in ('labels', 'all'):
            modifications += self.labels
        
        # get modifications
        modifs = {}
        for mod in modifications:
            
            # count modification
            if mod[1] == '' or type(mod[1]) == int:
                count = 1
            elif type(mod[1]) in (str, unicode):
                count = self.chain.count(mod[1])
            
            # add modification to dic
            if count and mod[0] in modifs:
                modifs[mod[0]] += count
            elif count:
                modifs[mod[0]] = count
        
        # format modifications
        if modifs:
            mods = ''
            for mod in sorted(modifs.keys()):
                mods += '%sx%s; ' % (modifs[mod], mod)
            return '%s' % mods[:-2]
        else:
            return ''
    # ----
    
    
    def _uniqueCombinations(self, items):
        """Generate unique combinations of items."""
        
        for i in range(len(items)):
            for cc in self._uniqueCombinations(items[i+1:]):
                for j in range(items[i][1]):
                    yield [[items[i][0],items[i][1]-j]] + cc
        yield []
    # ----
    
    
    def _countUniqueModifications(self, mods):
        """Get list of unique modifications with counter."""
        
        uniqueMods = []
        modsCount = []
        for mod in mods:
            if mod in uniqueMods:
                modsCount[uniqueMods.index(mod)] +=1
            else:
                uniqueMods.append(mod)
                modsCount.append(1)
        
        modsList = []
        for x, mod in enumerate(uniqueMods):
            modsList.append([mod, modsCount[x]])
        
        return modsList
    # ----
    
    
    def _checkModifications(self, positions, chain, maxMods):
        """Check if current modification set is applicable."""
        
        for x in positions:
            count = positions.count(x)
            if type(x) == int:
                if count>maxMods:
                    return False
            elif type(x) in (str, unicode):
                available = chain.count(x)
                for y in positions:
                    if type(y) == int and chain[y] == x:
                        available -= 1
                if count>(available*maxMods):
                    return False
        
        return True
    # ----
    
    


class peak:
    """Peak object definition"""
    
    def __init__(self, mz, intensity=0., baseline=0., sn=None, charge=None, isotope=None, massType='mo', fwhm=None):
        self.mz = float(mz)
        self.intensity = float(intensity)
        self.baseline = float(baseline)
        self.sn = sn
        self.fwhm = fwhm
        self.charge = charge
        self.massType = massType
        self.isotope = isotope
        self.relIntensity = 1.
        self.childScanNumber = None
        
        # user defined params
        self.userParams = {}
    # ----
    
    
    def mass(self, agentFormula='H', agentCharge=1):
        """Get neutral peak mass."""
        
        # check charge
        if self.charge == None:
            return None
        
        # get neutral mass
        return common.mz(self.mz, 0, self.charge, agentFormula, agentCharge, self.massType)
    # ----
    
    
    def realIntensity(self):
        """Get baseline-corrected intensity."""
        return self.intensity - self.baseline
    # ----
    
    
    def resolution(self):
        """Get peak resolution."""
        
        if self.fwhm:
            return self.mz/self.fwhm
        else:
            return None
    # ----
    


class peaklist:
    """Peaklist object definition."""
    
    def __init__(self, peaks=[]):
        
        self.basePeak = None
        
        # check data
        self.peaks = []
        for item in peaks:
            self._check(item)
            self.peaks.append(item)
        
        # add data
        self._sort()
        self._setBasePeak()
        self.refresh()
    # ----
    
    
    def __len__(self):
        return len(self.peaks)
    # ----
    
    
    def __setitem__(self, i, item):
        
        # check item
        self._check(item)
        
        # check relint and add
        if self.peaks[i] is self.basePeak:
            self.peaks[i] = item
            self._sort()
            self._setBasePeak()
            self.refresh()
        elif item.intensity - item.baseline > self.basePeak.intensity - self.basePeak.baseline:
            self.peaks[i] = item
            self._sort()
            self.basePeak = item
            self.refresh()
        else:
            item.relIntensity = (item.intensity - item.baseline)/(self.basePeak.intensity - self.basePeak.baseline)
            self.peaks[i] = item
            self._sort()
    # ----
    
    
    def __getitem__(self, i):
        return self.peaks[i]
    # ----
    
    
    def __delitem__(self, i):
        
        # recalculate relative intensity
        if self.peaks[i] is self.basePeak:
            del self.peaks[i]
            self._setBasePeak()
            self.refresh()
        else:
            del self.peaks[i]
    # ----
    
    
    def __iter__(self):
        self._index = 0
        return self
    # ----
    
    
    def __add__(self, peaksB):
        """Return A+X."""
        
        # get peaklists
        peaksA = deepcopy(self.peaks)
        peaksB = deepcopy(peaksB)
        
        # merge peaklists
        for peak in peaksB:
            peaksA.append(peak)
        
        # make peaklist object
        newPeaklist = peaklist(peaksA)
        
        return newPeaklist
    # ----
    
    
    def __mul__(self, x):
        """Return A*X."""
        
        peaks = deepcopy(self.peaks)
        
        # recalculate intensity and baseline
        for peak in peaks:
            peak.intensity *= x
            peak.baseline *= x
        
        return peaks
    # ----
    
    
    def next(self):
        if self._index < len(self.peaks):
            self._index += 1
            return self.peaks[self._index-1]
        else:
            raise StopIteration
    # ----
    
    
    def append(self, item):
        
        # check peak
        self._check(item)
        
        # add peak and sort peaklist
        if self.peaks and self.peaks[-1].mz > item.mz:
            self.peaks.append(item)
            self._sort()
        else:
            self.peaks.append(item)
        
        # recalc relative intensity
        if not self.basePeak:
            self._setBasePeak()
        if item.intensity - item.baseline > self.basePeak.intensity - self.basePeak.baseline:
            self.basePeak = item
            self.refresh()
        else:
            item.relIntensity = (item.intensity - item.baseline)/(self.basePeak.intensity - self.basePeak.baseline)
    # ----
    
    
    def refresh(self):
        """Recalculate relative intensities."""
        
        if not self.basePeak:
            self._setBasePeak()
            if not self.basePeak:
                return
        
        maxInt = self.basePeak.intensity - self.basePeak.baseline
        for item in self.peaks:
            if maxInt:
                item.relIntensity = (item.intensity - item.baseline)/maxInt
            else:
                item.relIntensity = 1.
    # ----
    
    
    def delete(self, indexes=None):
        """Delete selected peaks."""
        
        # delete all
        if indexes==None:
            del self.peaks[:]
            self.basePeak = None
        
        # delete by indexes
        else:
            indexes.sort()
            indexes.reverse()
            relint = False
            for i in indexes:
                if self.peaks[i] is self.basePeak:
                    relint = True
                del self.peaks[i]
            if relint:
                self._setBasePeak()
                self.refresh()
    # ----
    
    
    def crop(self, minX, maxX):
        """Crop data points and peaklist.
            minX: (float) lower m/z limit
            maxX: (float) upper m/z limit
        """
        
        # get indexes to delete
        indexes = []
        for x, peak in enumerate(self.peaks):
            if peak.mz < minX or peak.mz > maxX:
                indexes.append(x)
        
        # delete peaks
        self.delete(indexes)
    # ----
    
    
    def _check(self, item):
        """Check each item to be a peak"""
        
        if not isinstance(item, peak):
            raise TypeError, 'Item must be a peak object!'
    # ----
    
    
    def _sort(self):
        """Sort data according to mass."""
        
        buff = []
        for item in self.peaks:
            buff.append((item.mz, item))
        buff.sort()
        
        self.peaks = []
        for item in buff:
            self.peaks.append(item[1])
    # ----
    
    
    def _setBasePeak(self):
        """Get most intensive peak."""
        
        if not self.peaks:
            self.basePeak = None
            return
        
        self.basePeak = self.peaks[0]
        maxInt = self.basePeak.intensity - self.basePeak.baseline
        
        for item in self.peaks[1:]:
            intensity = item.intensity-item.baseline
            if intensity > maxInt:
                self.basePeak = item
                maxInt = intensity
    # ----
    
    


class scan:
    """Scan object definition."""
    
    def __init__(self, points=[], peaks=[]):
        
        self.scanNumber = None
        self.parentScanNumber = None
        self.polarity = None
        self.msLevel = None
        self.retentionTime = None
        self.totIonCurrent = None
        self.basePeakMZ = None
        self.basePeakIntensity = None
        self.precursorMZ = None
        self.precursorIntensity = None
        self.precursorCharge = None
        
        # user defined params
        self.userParams = {}
        
        # convert points to numPy array
        self.points = numpy.array(points)
        
        # convert peaks to peaklist
        if isinstance(peaks, peaklist):
            self.peaklist = peaks
        else:
            self.peaklist = peaklist(peaks)
    # ----
    
    
    def __len__(self):
        return len(self.points)
    # ----
    
    
    def __add__(self, scanB):
        """Return A+B (points or peaklists only)."""
        
        newPoints = []
        newPeaklist = []
        
        # use spectra only
        if len(self.points) or len(scanB.points):
            
            # unify raster
            pointsA, pointsB = self._unifyRaster(self.points, scanB.points)
            
            # convert back to arrays
            pointsA = numpy.array(pointsA)
            pointsB = numpy.array(pointsB)
            
            # math operation
            pointsB[:,0] = 0
            newPoints = pointsA + pointsB
        
        # use peaklists only
        elif len(self.peaklist) or len(scanB.peaklist):
            newPeaklist = self.peaklist + scanB.peaklist
        
        # make new scan object
        newScan = scan(newPoints, newPeaklist)
        
        return newScan
    # ----
    
    
    def __sub__(self, scanB):
        """Return A-B (points only)."""
        
        # unify raster
        pointsA, pointsB = self._unifyRaster(self.points, scanB.points)
        
        # convert back to arrays
        pointsA = numpy.array(pointsA)
        pointsB = numpy.array(pointsB)
        
        # math operation
        pointsB[:,0] = 0
        newPoints = pointsA - pointsB
        
        # make new scan object
        newScan = scan(newPoints)
        
        return newScan
    # ----
    
    
    def __mul__(self, x):
        """Return A*X (points and peaklist)."""
        
        newPoints = []
        newPeaklist = []
        
        # get data
        points = deepcopy(self.points)
        peaklist = deepcopy(self.peaklist)
        
        # math operation
        if len(points):
            newPoints = points * numpy.array((1.0, x))
        newPeaklist = peaklist * x
        
        # make new scan object
        newScan = scan(newPoints, newPeaklist)
        
        return newScan
    # ----
    
    
    def select(self, minX, maxX):
        """Get points for selected mz range."""
        
        # check slice
        if maxX < minX:
            raise ValueError, 'Invalid mz slice definition!'
        
        # crop points
        i1 = self._getIndex(self.points, minX)
        i2 = self._getIndex(self.points, maxX)
        
        return self.points[i1:i2]
    # ----
    
    
    def intensity(self, mz):
        """Get interpolated intensity for given m/z.
            mz: (float) m/z value
        """
        
        # check data
        if len(self.points) == 0:
            return None
        
        # get mz index
        index = self._getIndex(self.points, mz)
        if not index or index == len(self.points):
            return None
        
        # get intensity
        intens = self._interpolateLine(self.points[index-1], self.points[index], x=mz)
        
        return intens
    # ----
    
    
    def width(self, mz, intensity):
        """Get peak width for given m/z and height.
            mz: (float) peak m/z value
            intensity: (float) intensity of measurement
        """
        
        # check data
        if len(self.points) == 0:
            return None
        
        # get indexes
        index = self._getIndex(self.points, mz)
        if index < 1:
            return None
        
        leftIndx = index-1
        while leftIndx >= 0:
            if self.points[leftIndx][1] <= intensity:
                break
            leftIndx -= 1
        
        rightIndx = index
        while rightIndx < len(self.points):
            if self.points[rightIndx][1] <= intensity:
                break
            rightIndx += 1
        
        # get mz
        leftMZ = self._interpolateLine(self.points[leftIndx], self.points[leftIndx+1], y=intensity)
        rightMZ = self._interpolateLine(self.points[rightIndx-1], self.points[rightIndx], y=intensity)
        
        return rightMZ - leftMZ
    # ----
    
    
    def noise(self, minX=None, maxX=None, mz=None, window=0.1):
        """Get noise for specified mz range or mz value.
            minX: (float) lower m/z limit
            maxX: (float) upper m/z limit
            mz: (float) m/z value
            window: (float) percentage around specified m/z value to use for noise calculation
        """
        
        # check data
        if len(self.points) == 0:
            return None, None
        
        return processing.noise(self.points, minX=minX, maxX=maxX, mz=mz, window=window)
    # ----
    
    
    def baseline(self, segments, offset=0., smooth=True):
        """Calculate spectrum baseline.
            segments: (int) number of baseline segments
            offset: (float) intensity offset in %/100
            smooth: (bool) smooth final baseline
        """
        
        # check data
        if len(self.points) == 0:
            return None
        
        return processing.baseline(self.points, segments, offset=offset, smooth=smooth)
    # ----
    
    
    def crop(self, minX, maxX):
        """Crop data points and peaklist.
            minX: (float) lower m/z limit
            maxX: (float) upper m/z limit
        """
        
        # crop spectrum data
        i1 = self._getIndex(self.points, minX)
        i2 = self._getIndex(self.points, maxX)
        self.points = self.points[i1:i2]
        
        # crop peaklist data
        self.peaklist.crop(minX, maxX)
    # ----
    
    
    def calibrate(self, fce, params):
        """Re-calibrate data points and peaklist.
            fce: (function) calibration function
            params: (list or tuple) function parameters
        """
        
        # calibrate spectrum data
        for x, point in enumerate(self.points):
            self.points[x][0] = fce(params, point[0])
        
        # calibrate peaklist data
        for x, peak in enumerate(self.peaklist):
            self.peaklist[x].mz = fce(params, peak.mz)
    # ----
    
    
    def normalize(self):
        """Normalize spectrum points and peaklist."""
        
        # get normalization params
        normalization = self.normalization()
        if normalization == False:
            return
        
        # normalize spectrum points
        if len(self.points) > 0:
            self.points = self.points - numpy.array((0, normalization[1]))
            self.points = self.points / numpy.array((1, normalization[0]))
        
        # normalize peaklist
        if len(self.peaklist) > 0:
            for peak in self.peaklist:
                peak.intensity = (peak.intensity - normalization[1]) / normalization[0]
                peak.baseline = (peak.baseline - normalization[1]) / normalization[0]
            
            self.peaklist.refresh()
    # ----
    
    
    def normalization(self):
        """Get normalization params."""
        
        # calculate range for spectrum and peaklist
        if len(self.points) > 0 and len(self.peaklist) > 0:
            spectrumMax = numpy.maximum.reduce(self.points)[1]
            spectrumMin = numpy.minimum.reduce(self.points)[1]
            peaklistMax = max([peak.intensity for peak in self.peaklist])
            peaklistMin = min([peak.baseline for peak in self.peaklist])
            shift = min(spectrumMin, peaklistMin)
            scale = (max(spectrumMax, peaklistMax)-shift)/100
        
        # calculate range for spectrum only
        elif len(self.points) > 0:
            spectrumMax = numpy.maximum.reduce(self.points)[1]
            shift = numpy.minimum.reduce(self.points)[1]
            scale = (spectrumMax-shift)/100
        
        # calculate range for peaklist only
        elif len(self.peaklist) > 0:
            peaklistMax = max([peak.intensity for peak in self.peaklist])
            shift = min([peak.baseline for peak in self.peaklist])
            scale = (peaklistMax-shift)/100
        
        # no data
        else:
            return False
        
        return scale, shift
    # ----
    
    
    def _interpolateLine(self, p1, p2, x=None, y=None):
        """Get line interpolated X or Y value."""
        
        # check points
        if p1[0] == p2[0] and x!=None:
            return max(p1[1], p2[1])
        elif p1[0] == p2[0] and y!=None:
            return p1[0]
        
        # get equation
        m = (p2[1] - p1[1])/(p2[0] - p1[0])
        b = p1[1] - m * p1[0]
        
        # get point
        if x != None:
            return m * x + b
        elif y != None:
            return (y - b) / m
    # ----
    
    
    def _getIndex(self, points, x):
        """Get nearest higher index for selected point."""
        
        lo = 0
        hi = len(points)
        while lo < hi:
            mid = (lo + hi) / 2
            if x < points[mid][0]:
                hi = mid
            else:
                lo = mid + 1
        
        return lo
    # ----
    
    
    def _unifyRaster(self, pointsA, pointsB):
        """Unify x-raster of two scans."""
        
        # convert arrays
        pointsA = list(pointsA)
        pointsB = list(pointsB)
        
        # merge left
        i = 0
        while i<len(pointsA) and i<len(pointsB) and pointsA[i][0] < pointsB[i][0]:
            pointsB.insert(i, [pointsA[i][0], 0.0])
            i += 1
        
        i = 0
        while i<len(pointsA) and i<len(pointsB) and pointsA[i][0] > pointsB[i][0]:
            pointsA.insert(i, [pointsB[i][0], 0.0])
            i += 1
        
        # merge middle
        for i, x in enumerate(pointsA):
            
            if i == len(pointsB):
                break
            
            if pointsA[i][0] < pointsB[i][0]:
                intens = self._interpolateLine(pointsB[i-1], pointsB[i], x=pointsA[i][0])
                pointsB.insert(i, [pointsA[i][0], intens])
                
            elif pointsA[i][0] > pointsB[i][0]:
                intens = self._interpolateLine(pointsA[i-1], pointsA[i], x=pointsB[i][0])
                pointsA.insert(i, [pointsB[i][0], intens])
        
        # merge right
        if len(pointsA) < len(pointsB):
            for x in pointsB[len(pointsA):]:
                pointsA.append([x[0],0.0])
        
        elif len(pointsA) > len(pointsB):
            for x in pointsA[len(pointsB):]:
                pointsB.append([x[0],0.0])
        
        return pointsA, pointsB
    # ----
    
    

     