from __future__ import with_statement
from __future__ import print_function
import sys

try:
    import os
    import numpy as np
    import reflex
    import pipeline_display
    from pipeline_product import PipelineProduct
    from reflex_plot_widgets import *
    import_success = True

except ImportError:
    import_success = False
    print("Error importing modules pyfits, wx, matplotlib, numpy")



class DataPlotterManager(object):

    recipeName       = "muse_exp_align"
    tagInputData     = "IMAGE_FOV"
    tagAuxiliaryData = "SOURCE_LIST"

    _srcListColumnIds = ["Id", "X", "Y", "RA", "DEC", "RA_CORR", "DEC_CORR", "Flux"]

    _btnLabelAll     = "All detections (uncorrected)"
    _btnLabelAllCorr = "All detections (corrected)"

    _obsId = "DATE-OBS"

    _markerSymbols = ["o", "s", "D", "p", "*", "+", "x"]
    _markerColors  = ["#0000FF", "#00FF00", "#FF0000", "#00FFFF",
                      "#FF00FF", "#FFFF00", "#FFA500", "#800080"]

    def __init__(self):
        self._plt     = dict()
        self._widgets = dict()
        self._imageSelectorActive = 0

        super(DataPlotterManager, self).__init__()
        return


    def setWindowTitle(self):
        return self.recipeName + " GUI"


    def setWindowHelp(self):
        return "Help for " + self.setWindowTitle()


    def setInteractiveParameters(self):
         self.parameters = [
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="rsearch", group="Offset Calculation",
                description="Search radius (in arcsec)"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="nbins", group="Offset Calculation",
                description="Number of bins of the 2D histogram"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="weight", group="Offset Calculation",
                description="Use weighting"),

            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="fwhm", group="Source Detection",
                description="FWHM of the convolution filter"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="threshold", group="Source Detection",
                description="Initial threshold for detecting point sources"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="bkgignore", group="Source Detection",
                description="Fraction of the image to be ignored"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="bkgfraction", group="Source Detection",
                description="Fraction of the image (without the ignored part) considered as background"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="step", group="Source Detection",
                description="Increment/decrement of the threshold value"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="iterations", group="Source Detection",
                description="Maximum number of iterations"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="srcmin", group="Source Detection",
                description="Minimum number of sources"),
            reflex.RecipeParameter(recipe=self.recipeName,
                displayName="srcmax", group="Source Detection",
                description="Maximum number of sources")
            ]
         return self.parameters


    def setCurrentParameterHelper(self, parameterValueGetter):
        self._getParameterValue = parameterValueGetter
        return


    def readFitsData(self, fitsFiles):
        self.frameset = dict()
        for fitsFile in fitsFiles:
            if fitsFile.category not in self.frameset.keys():
                self.frameset[fitsFile.category] = []

            datasets = PipelineProduct(fitsFile)

            frame = dict()
            frame["label"] = os.path.basename(datasets.fits_file.name)
            frame["datasets"] = datasets
            self.frameset[fitsFile.category].append(frame)

        if self.tagInputData in self.frameset.keys():
            self._pltCreator = self._dataPlotCreate
            self._pltPlotter = self._dataPlotDraw
        else:
            self._pltCreator = self._dummyPlotCreate
            self._pltPlotter = self._dummyPlotDraw
        return


    def addSubplots(self, figure):
        self._figure = figure
        self._figure.clear()
        self._pltCreator()
        return


    def plotWidgets(self):

        if "dummy" not in self._plt.keys():
            if "buttons" not in self._widgets.keys():
                buttonLabels = []
                for frame in self.frameset[self.tagInputData]:
                    buttonLabels.append(frame["label"])
                buttonLabels.append(DataPlotterManager._btnLabelAll)
                buttonLabels.append(DataPlotterManager._btnLabelAllCorr)

                self._widgets["buttons"] = InteractiveRadioButtons(self._plt["buttons"],
                    self.onImageSelected, buttonLabels, self._imageSelectorActive,
                    title="Image Selector")
                self._resizeCID = self._figure.canvas.mpl_connect("resize_event",
                    self.onResize)

        widgets = list()
        for key in self._widgets.keys():
            widgets.append(self._widgets[key])

        return widgets


    def plotProductsGraphics(self):
        self._pltPlotter()
        return


    def readTable(self, frame, columns, dataset=1):
        table = dict()
        for name in columns:
            try:
                table[name] = frame.all_hdu[dataset].data.field(name)
            except KeyError:
                table = None
                print("Warning: column '" + name + "' not found in " +
                      frame.fits_file.name)
            if not table:
                break
        return table


    # Event handler

    def onImageSelected(self, buttonLabel):
        selectionChanged = False
        if buttonLabel == DataPlotterManager._btnLabelAll:
            if self._imageSelectorActive != -1:
                self._imageSelectorActive = -1
                selectionChanged = True
        elif buttonLabel == DataPlotterManager._btnLabelAllCorr:
            if self._imageSelectorActive != -2:
                self._imageSelectorActive = -2
                selectionChanged = True
        else:
            for idx, frame in enumerate(self.frameset[self.tagInputData]):
                if frame["label"] == buttonLabel and idx != self._imageSelectorActive:
                    self._imageSelectorActive = idx
                    selectionChanged = True

        # Update image view only if the selection changed
        if selectionChanged:
            self._dataPlotDraw()

        return


    def onResize(self, event):
        if hasattr(self, '_widgets') and "buttons" in self._widgets.keys():
            bbox = self._plt["buttons"].get_position().get_points()
            btns = self._widgets["buttons"].rbuttons

            w = bbox[:, 0].ptp()
            h = bbox[:, 1].ptp()
            fh = self._figure.get_figheight()
            fw = self._figure.get_figwidth()
            vscale = (w * fw) / (h * fh)

            width = btns.labels[0].get_fontsize()
            width /= (self._figure.get_dpi() * w * fw)

            for circle in btns.circles:
                circle.width  = width
                circle.height = width * vscale
        return


    # Utility functions
    def _cycleMarkerProperties(self, index):
        ncolors  = len(DataPlotterManager._markerColors)
        nsymbols = len(DataPlotterManager._markerSymbols)

        color  = DataPlotterManager._markerColors[index % ncolors]
        symbol = DataPlotterManager._markerSymbols[(index // ncolors) % nsymbols]
        return (symbol, color)


    # Implementation of the plot creator and plotter delegates to be used
    # if the required data is actually available.

    def _dataPlotCreate(self):
        self._plt["buttons"]   = self._figure.add_subplot(1, 2, 1)
        self._plt["imageview"] = self._figure.add_subplot(1, 2, 2)
        return


    def _dataPlotDraw(self):
        if self._imageSelectorActive >= 0:
            imgFrame = self.frameset[self.tagInputData][self._imageSelectorActive]
            imgData = imgFrame["datasets"]
            imgData.readImage(1)
            imgData.read2DLinearWCS(1)

            tooltip = imgFrame["label"]

            # Find matching source list for the current FOV image imgData
            tblData = None
            numSources = 0
            for frame in self.frameset[self.tagAuxiliaryData]:
                timestamp = frame["datasets"].readKeyword(DataPlotterManager._obsId)
                if timestamp == imgData.readKeyword(DataPlotterManager._obsId):
                    print("Reading Source list '" + frame["label"] +
                        "' for FOV image '" + imgData.fits_file.name + "'")
                    tblData = self.readTable(frame["datasets"],
                                            self._srcListColumnIds)
                    if tblData is not None:
                        numSources = len(tblData["Id"])        

            tooltip += ": " + str(numSources) + " sources detected"

            xlabel = "X [pixel]"
            ylabel = "Y [pixel]"
            xpos = tblData["X"]        
            ypos = tblData["Y"]        
            if imgData.type1.strip() == "RA---TAN":
                xlabel = "Right Ascension [deg]"
                xpos = tblData["RA"]
            if imgData.type2.strip() == "DEC--TAN":
                ylabel = "Declination [deg]"
                ypos = tblData["DEC"]

            markerSize = np.zeros(len(tblData["Id"]))
            markerSize[:] = 8

            self._plt["imageview"].clear()
            
            imageview = pipeline_display.ImageDisplay()
            imageview.setAspect("equal")
            imageview.setCmap("hot")
            imageview.setZAutoLimits(imgData.image, None)
            imageview.setLabels(xlabel, ylabel)
            imageview.setXLinearWCSAxis(imgData.crval1, imgData.cdelt1, imgData.crpix1)
            imageview.setYLinearWCSAxis(imgData.crval2, imgData.cdelt2, imgData.crpix2)
            imageview.display(self._plt["imageview"], "Exposure FOV", tooltip, imgData.image)

            marker, color = self._cycleMarkerProperties(self._imageSelectorActive)

            self._plt["imageview"].autoscale(enable=False)
            self._plt["imageview"].scatter(xpos, ypos,
                                           facecolors="none", edgecolors=color,
                                           marker=marker, s=markerSize**2)
            
            for idx, srcId in enumerate(tblData["Id"]):
                txtOffset = 0.5 * np.sqrt(2.) * markerSize[idx]
                xydata = (xpos[idx], ypos[idx])
                xytext = (txtOffset, -txtOffset)
                self._plt["imageview"].annotate(str(srcId),
                                                xy=xydata, xycoords="data",
                                                xytext=xytext, textcoords="offset points",
                                                ha="left", va="top", fontsize="x-small",
                                                color=color, clip_on=True)
        else:
            # Use the first image to setup the image display and get the WCS. 
            imgFrame = self.frameset[self.tagInputData][0]
            imgData = imgFrame["datasets"]
            imgData.readImage(1)
            imgData.read2DLinearWCS(1)

            # Set the pixel data of the image to 0 to clear it so that no
            # background image is displayed.
            bkgImage = np.array(imgData.image, copy=True)
            bkgImage[:] = 0

            srcLists = list()
            for imgFrame in self.frameset[self.tagInputData]:
                timestamp = imgFrame["datasets"].readKeyword(DataPlotterManager._obsId)
                for tblFrame in self.frameset[self.tagAuxiliaryData]:
                    if timestamp == tblFrame["datasets"].readKeyword(DataPlotterManager._obsId):
                        print("Reading Source list '" + tblFrame["label"] +
                              "' for FOV image '" + imgFrame["label"] + "'")
                        tblData = self.readTable(tblFrame["datasets"],
                                                 self._srcListColumnIds)
                        srcLists.append(tblData)

            colNameX = "X"
            colNameY = "Y"
            xlabel = "X [pixel]"
            ylabel = "Y [pixel]"
            if imgData.type1.strip() == "RA---TAN":
                xlabel = "Right Ascension [deg]"
                if self._imageSelectorActive == -1:
                    colNameX = "RA"
                else:
                    colNameX = "RA_CORR"
            if imgData.type2.strip() == "DEC--TAN":
                ylabel = "Declination [deg]"
                if self._imageSelectorActive == -1:
                    colNameY = "DEC"
                else:
                    colNameY = "DEC_CORR"
            if self._imageSelectorActive == -1:
                tooltip = "Uncorrected positions"
            else:
                tooltip = "Corrected positions"

            self._plt["imageview"].clear()
            
            imageview = pipeline_display.ImageDisplay()
            imageview.setAspect("equal")
            imageview.setCmap("gray")
            imageview.setZLimits([0., 1.])
            imageview.setLabels(xlabel, ylabel)
            imageview.setXLinearWCSAxis(imgData.crval1, imgData.cdelt1, imgData.crpix1)
            imageview.setYLinearWCSAxis(imgData.crval2, imgData.cdelt2, imgData.crpix2)
            imageview.display(self._plt["imageview"], "Source Positions", tooltip, bkgImage)

            self._plt["imageview"].autoscale(enable=False)

            for idx, srcList in enumerate(srcLists):
                markerSize = np.zeros(len(tblData["Id"]))
                markerSize[:] = 8
                marker, color = self._cycleMarkerProperties(idx)
                self._plt["imageview"].scatter(srcList[colNameX], srcList[colNameY],
                                               facecolors="none", edgecolors=color,
                                               marker=marker, s=markerSize**2)


        return


    # Implementation of the dummy plot creator and plotter delegates
    # follows here.

    def _dummyPlotCreate(self):
        self._plt["dummy"] = self._figure.add_subplot(1, 1, 1)
        return


    def _dummyPlotDraw(self):
        label = "Data not found! Input files should contain these types:\n%s" \
            % self.tagInputData

        self._plt["dummy"].set_axis_off()
        self._plt["dummy"].text(0.1, 0.6, label, color="#11557c", fontsize=18,
            horizontalalignment="left", verticalalignment="center", alpha=0.25)
        self._plt["dummy"].tooltip = "No data found"
        return



if __name__ == "__main__":

    from reflex_interactive_app import PipelineInteractiveApp


    interactive_app = PipelineInteractiveApp(enable_init_sop=True)
    interactive_app.parse_args()

    if not import_success:
        interactive_app.setEnableGUI(False)

    if interactive_app.isGUIEnabled():
        dataPlotManager = DataPlotterManager()

        interactive_app.setPlotManager(dataPlotManager)
        interactive_app.showGUI()
    else:
        interactive_app.set_continue_mode()

    # NOTE: Do not remove this line! This prints the output, which is parsed
    #       by the Reflex PythonActor to get the results!
    interactive_app.print_outputs()

    sys.exit()
