import sys
import argparse
import tkinter as tk
import tkinter.ttk
sys.path.append('./pyuiutils/')
import pyuiutils.uiutils as uiutils

import argparse
import numpy as np
import cv2
import filtering

class LaplacianUIFrame(tk.Frame):
    """ Frame to contain the laplacian image editing GUI.
    Contains two sub-frames - an ImageFrame and a SliderFrame. """

    def __init__(self, parent, root, num_levels):
        tk.Frame.__init__(self, parent)
        self.num_levels = num_levels

        self.grid_columnconfigure(0, weight=2)
        self.grid_columnconfigure(1, weight=1)
        self.grid_rowconfigure(0, weight=1)

        self.image_frame = ImageFrame(self, root)
        self.image_frame.grid(column=0, row=0, sticky=tk.NSEW)

        self.slider_frame = SliderFrame(self, root)
        self.slider_frame.grid(row=0, column=1, sticky=tk.NSEW)


    def update_img(self, *args):
        sigma = self.slider_frame.sliders["sigma"].get()
        alpha = self.slider_frame.sliders["alpha"].get()
        beta  = self.slider_frame.sliders["beta"].get()

        filtered_luminance = filtering.fast_llf(self.image_frame.gpyr, sigma, alpha, beta) * 100

        self.image_frame.img_lab[:,:,0]  = filtered_luminance

        self.image_frame.img_rgb = cv2.cvtColor(self.image_frame.img_lab, cv2.COLOR_LAB2BGR)

        img = (self.image_frame.img_rgb.clip(0,1)*255).astype(np.uint8)
        self.image_frame.image_widget.draw_cv_image(img)


class ImageFrame(uiutils.BaseFrame):
    """ Frame to load, display, and save out a reconstructed image. """
    def __init__(self, parent, root):
        uiutils.BaseFrame.__init__(self, parent, root, 3, 1)
        self.config(highlightthickness=1, highlightbackground="black")


        tk.Button(self, text='Load Image', command=self.load_img).grid(
                row=0, column=0, sticky=tk.N)

        self.image_widget = uiutils.ImageWidget(self)
        self.image_widget.grid(row=1, column=0, sticky=tk.NSEW)
        self.rowconfigure(0, weight=0)
        self.rowconfigure(1, weight=1)

        tk.Button(self, text='Save Image', 
                command=self.save_image).grid(row=2, column=0, sticky=tk.W + tk.E)

    def load_img(self, img_name=None):
        img_name, img = self.ask_for_image(img_name)
        if img is not None:
            self.image_widget.draw_cv_image(img)
            #self.parent.make_pyr(img.astype(np.float32)/255)
            self.img_rgb = img.astype(np.float32)/255
            self.img_lab = cv2.cvtColor(self.img_rgb, cv2.COLOR_BGR2LAB)
            
            # precompute the Gaussian pyramid; 
            # luma (the L* channel) runs from 0 to 100, so normalize to [0,1]
            self.gpyr = filtering.construct_gaussian(self.img_lab[:,:,0]/100, self.parent.num_levels)
            self.image_name = img_name

    def save_image(self):
        f = uiutils.ask_for_image_path_to_save(self)
        if f is not None:
            self.image_widget.write_to_file(f, False)

class SliderFrame(uiutils.BaseFrame):
    """ Frame to provide sliders that set the weight of each pyramid level
    in the reconstruction. """
    def __init__(self, parent, root):
        uiutils.BaseFrame.__init__(self, parent, root, 3, 2)
        self.config(highlightthickness=1, highlightbackground="black")

        slider_data = {
                "sigma": {"from_": 0.0, "to": 1.0, "resolution": 0.01},
                "alpha": {"from_": 0.0, "to": 4.0, "resolution": 0.01},
                "beta":  {"from_": 0.0, "to": 2.0, "resolution": 0.01}
                }

        self.sliders = {}
        i = 0
        for param, slider_info in slider_data.items():
            tk.Label(self, text=param).grid(row=i, column=0, sticky=tk.E)
            self.sliders[param] = tk.Scale(self, orient=tk.HORIZONTAL, **slider_info)
            self.sliders[param].grid(row=i, column=1)
            self.sliders[param].set(1.0)
            self.sliders[param].bind('<ButtonRelease-1>', self.parent.update_img)
            i += 1
            
if __name__ == "__main__":
    root = tk.Tk()

    parser = argparse.ArgumentParser('Run the Laplacian Image Editing GUI.')

    parser.add_argument('--image', '-i', help='An image to load.', default=None)
    parser.add_argument('--levels', '-l', help='Levels of laplacian pyramid', type=int, default=5)
    args = parser.parse_args()

    root.title('CSCI 476 G1 - Local Laplacian Filtering')
    w, h = root.winfo_screenwidth(), root.winfo_screenheight() - 50
    root.geometry('{}x{}+0+0'.format(w, h))
    root.grid_columnconfigure(0, weight=1)
    root.grid_rowconfigure(0, weight=1)
    app = LaplacianUIFrame(root, root, args.levels)
    app.grid(row=0, sticky=tk.NSEW)

    if args.image:
        app.image_frame.load_img(args.image)

    root.mainloop()


