from __future__ import with_statement
import sys

try:
    import numpy
    import os
    import re
    import reflex
    from pipeline_product import PipelineProduct
    import pipeline_display
    import reflex_plot_widgets
    import matplotlib.gridspec as gridspec
    from matplotlib.text import Text
    from pylab import *

    import_success = True
except ImportError:
    import_success = False
    print "Error importing modules pyfits, wx, matplotlib, numpy"

def paragraph(text, width=None):
    """ wrap text string into paragraph
       text:  text to format, removes leading space and newlines
       width: if not None, wraps text, not recommended for tooltips as
              they are wrapped by wxWidgets by default
    """
    import textwrap
    if width is None:
        return textwrap.dedent(text).replace('\n', ' ').strip()
    else:
        return textwrap.fill(textwrap.dedent(text), width=width)

class DataPlotterManager(object):
    # static members
    recipe_name = "kmos_combine"
    combined_cat = "COMBINE_SCI_RECONSTRUCTED"
    oh_spec_cat = "OH_SPEC"
    
    def setWindowTitle(self):
        return self.recipe_name+"_interactive"

    def setInteractiveParameters(self):
        return [

            reflex.RecipeParameter(recipe=self.recipe_name, displayName="fmethod",
                    group="Comb.", description="Fitting Method (gauss, moffat)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="flux",
                    group="Comb.", description="Apply flux conservation"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="edge_nan",
                    group="Comb.", description="Set borders of cubes to NaN before combining them"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmethod",
                    group="Comb.", description="Combination Method (average, median, sum, min_max, ksigma)"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cpos_rej",
                    group="Comb.", description="The positive rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cneg_rej",
                    group="Comb.", description="The negative rejection threshold"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="citer",
                    group="Comb.", description="The number of iterations"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmax",
                    group="Comb.", description="The number of maximum pixel values"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="cmin",
                    group="Comb.", description="The number of minimum pixel values"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="skipped_frames",
                    group="Comb.", description="Comma-separated list of indexes of the frames to skip. n means the nth RAW frames will be ignored. Empty[default] skips none"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="ifus",
                    group="Other", description="The indices of the IFUs to combine"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="method",
                    group="Other", description="The shifting method"),
            reflex.RecipeParameter(recipe=self.recipe_name, displayName="filename",
                    group="Other", description="The path to the file with the shift vectors"),
        ]

    def readFitsData(self, fitsFiles):
        # Initialise
        self.files = dict()
        self.oh_spec = dict()
        self.oh_spec["Found"] = False

        # Loop on all FITS files 
        for f in fitsFiles:
            # Use OH_SPEC if found
            if f.category == self.oh_spec_cat :
                oh_spec_file = PipelineProduct(f)
                self.oh_spec["Found"] = True
                self.oh_spec["CRVAL1"] = oh_spec_file.all_hdu[1].header["CRVAL1"]
                self.oh_spec["CDELT1"] = oh_spec_file.all_hdu[1].header["CDELT1"]
                self.oh_spec["Spectrum"] = oh_spec_file.all_hdu[1].data

            # For each reconstructed image
            if f.category[:25] == self.combined_cat :
                recons_file = PipelineProduct(f)
                filename = os.path.basename(f.name)
                # Create a Dictionary per file
                self.files[filename] = dict()
                # Loop on extensions
                recons_ext = recons_file.all_hdu[1]            

                # EXTNAME is missing in the primary header - Skip it anyway
                try:
                    extname = recons_ext.header['EXTNAME']
                except KeyError:
                    continue
               
                # Create Entry for the extension
                self.files[filename][extname]=dict()
                    
                # Get the IFU number from extname to get the IFU status
                m = re.search(r"\d+", extname)
                ifu_number = m.group()
                self.files[filename][extname]["IFU_NUMBER"] = int(ifu_number)

                naxis = recons_ext.header['NAXIS']

                if (naxis == 3):
                    # Get Keyword infos  
                    self.files[filename][extname]["CRPIX3"] = recons_ext.header['CRPIX3']
                    self.files[filename][extname]["CRVAL3"] = recons_ext.header['CRVAL3']
                    self.files[filename][extname]["CDELT3"] = recons_ext.header['CDELT3']
                    self.files[filename][extname]["UNIT"] = recons_ext.header['ESO QC CUBE_UNIT']
                    
                    # Fill Spectrum
                    self.files[filename][extname]["Spectrum"] = []
                    for cube_plane in recons_ext.data:
                        cube_plane_nan_free = cube_plane[~numpy.isnan(cube_plane)]
                        if (len(cube_plane_nan_free) > 0):
                            mean = cube_plane_nan_free.mean()
                        else:
                            mean = numpy.nan
                        self.files[filename][extname]["Spectrum"].append(mean)

                    # Fill Collapsed Image
                    collapsed_frame_ext = self._get_collapsed_ext(fitsFiles, f, extname)
                    if collapsed_frame_ext:
                        self.files[filename][extname]["Collapsed"] = collapsed_frame_ext.data

        # If proper files are there...
        if (len(self.files.keys()) > 0):
            # Set the plotting functions
            self._add_subplots = self._add_subplots
            self._plot = self._data_plot
        else:
            self._add_subplots = self._add_nodata_subplots
            self._plot = self._nodata_plot

    # Inputs : all files, reconstructed file, extension
    # Returns the corresponding extension in the collapsed corresponding file 
    # The corresponding file must have the same file name as ref_file, prefixed with make_image_
    # The corresponding extention must have the same EXTNAME keyword
    def _get_collapsed_ext(self, fitsFiles, ref_file, extname):
        ref_file_name = os.path.basename(ref_file.name)
        # Loop on all FITS files 
        for f in fitsFiles:
            filename = os.path.basename(f.name)
            if (filename == "make_image_"+ref_file_name):
                collapsed_frame = PipelineProduct(f)
                for collapsed_frame_ext in collapsed_frame.all_hdu:
                    # EXTNAME is missing in the primary header - Skip it anyway
                    try:
                        coll_extname = collapsed_frame_ext.header['EXTNAME']
                    except KeyError:
                        continue
                    if coll_extname == extname:
                        return collapsed_frame_ext

    def addSubplots(self, figure):
        self._add_subplots(figure)

    def plotProductsGraphics(self):
        self._plot()

    def plotWidgets(self) :
        widgets = list()

        # Files Selector radiobutton

	print(self.files.keys())

        self.radiobutton = reflex_plot_widgets.InteractiveRadioButtons(self.files_selector, self.setFSCallback, self.files.keys(), 0, 
                title='Files Selection (Left mouse button)')
        widgets.append(self.radiobutton)
        
        return widgets

    def extension_has_spectrum(self, filename, extname):
        if ("Spectrum" in self.files[self.selected_file][extname].keys()):
            return True
        else:
            return False

    def setIFUSCallback(self, point) :
        if (1 < point.ydata < 3) :
            extname = "IFU."+str(int((point.xdata/2)+0.5))+".DATA"
            if (self.extension_has_spectrum(self.selected_file, extname)):
                # Update selected extension
                self.selected_extension = extname
                # Redraw spectrum
                self._plot_spectrum()

    def setFSCallback(self, filename) :
        # Keep track of the selected file
        self.selected_file = filename
	
	#debug print(self.selected_file)

        # Check that the new file currently selected extension is valid
        if (not self.extension_has_spectrum(self.selected_file, self.selected_extension)):
            self.selected_extension = self._get_first_valid_extname(self.selected_file)
        # Redraw spectrum
        self._plot_spectrum()
        
    def _add_subplots(self, figure):
        gs = gridspec.GridSpec(3, 2)
        self.files_selector = figure.add_subplot(gs[0,:])
        self.spec_plot = figure.add_subplot(gs[2,0])

    def _data_plot_get_tooltip(self):
        return self.selected_file+" ["+self.selected_extension+"]"

    def _plot_spectrum(self):
        extension_dict = self.files[self.selected_file][self.selected_extension] 
        
        # Define wave
        pix = numpy.arange(len(extension_dict["Spectrum"]))
        wave = extension_dict["CRVAL3"] + pix * extension_dict["CDELT3"]

        # Plot Spectrum
        self.spec_plot.clear()
        specdisp = pipeline_display.SpectrumDisplay()
        specdisp.setLabels("Wavelength (microns)", extension_dict["UNIT"])
        specdisp.display(self.spec_plot, "Spectrum", self._data_plot_get_tooltip(), wave, extension_dict["Spectrum"])

        if (self.oh_spec["Found"]):
            # Overplot the OH spectrum
            pix = numpy.arange(len(self.oh_spec["Spectrum"]))
            wave = self.oh_spec["CRVAL1"] + pix * self.oh_spec["CDELT1"]
            specdisp.overplot(self.spec_plot, wave, self.oh_spec["Spectrum"], 'red')
            self.spec_plot.legend(('Observed', 'OH'))

    def _get_first_valid_extname(self, filename):
        for extname in sorted(self.files[filename].keys()):
            if (self.extension_has_spectrum(filename, extname)) :
                return extname
        return ""

    def _data_plot(self):
        # Initial file is the first one
        self.selected_file = self.files.keys()[0]
        self.selected_extension = self._get_first_valid_extname(self.selected_file)
        
        # Draw Spectrum
        self._plot_spectrum()

    def _add_nodata_subplots(self, figure):
        self.img_plot = figure.add_subplot(1,1,1)

    def _nodata_plot(self):
        # could be moved to reflex library?
        self.img_plot.set_axis_off()
        text_nodata = "Data not found. Input files should contain this" \
                       " type:\n%s" % self.combined_cat
        self.img_plot.text(0.1, 0.6, text_nodata, color='#11557c',
                      fontsize=18, ha='left', va='center', alpha=1.0)
        self.img_plot.tooltip = 'No data found'

#This is the 'main' function
if __name__ == '__main__':
    from reflex_interactive_app import PipelineInteractiveApp

    # Create interactive application
    interactive_app = PipelineInteractiveApp(enable_init_sop=True)

    # get inputs from the command line
    interactive_app.parse_args()

    #Check if import failed or not
    if not import_success:
        interactive_app.setEnableGUI(False)

    #Open the interactive window if enabled
#    if interactive_app.isGUIEnabled():
    if False:
        #Get the specific functions for this window
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    #Print outputs. This is parsed by the Reflex python actor to
    #get the results. Do not remove
    interactive_app.print_outputs()
    sys.exit()


