import tkinter as tk
from tkinter import ttk
from tkinter import messagebox
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from itertools import count
import numpy as np

class CausalProcess:
    def __init__(self, rules):
        self.rules = []
        for rule in rules:
            if isinstance(rule, str) and '->' in rule:
                lhs, rhs = rule.split('->')      # If rule is valid it splits the rule into a left hand side and right hand side
                lhs = lhs.strip()
                rhs = rhs.strip()
                self.rules.append((tuple(lhs), rhs, '?'))
            elif isinstance(rule, tuple) and len(rule) == 3:
                lhs, rhs, name = rule
                self.rules.append((tuple(lhs), rhs, name))     # Converts lhs into a tuple
            else:
                raise ValueError(f"Invalid rule format: {rule}")
        self.id_counter = count()       # Creates a unique id counter so it can generate a new id for each new object

    class Element:
        # Represents a single object in the causal rewriting process
        def __init__(self, char, marked, gen, gid):
            self.char = char        # The actual character (e.g., 'A', 'B')
            self.marked = marked    # True if marked for tranformation
            self.gen = gen          # Generation it was created in
            self.gid = gid          # Unique ID

        def __repr__(self):
            return f"{self.char}{'*' if self.marked else ''}(g{self.gen}, id{self.gid})"


    def parse_string(self, labeled_str):
        labeled_str = labeled_str.replace(" ", "")  # Removes spaces
        result = []
        i = 0
        while i < len(labeled_str):
            if i + 1 < len(labeled_str) and labeled_str[i + 1] == "'":
                result.append((labeled_str[i], True))  # Character is marked
                i += 2
            else:
                result.append((labeled_str[i], False))  # Character is unmarked
                i += 1
        return result


    def mark_step(self, prev_result, user_marked_str):
        user_seq = self.parse_string(user_marked_str)  
        result_copy = prev_result[:]                   # Copies current working list
        used = [False] * len(result_copy)              # Tracks which characters have already been matched

        for idx, (char, marked) in enumerate(user_seq):
            matched = False
            for i in range(len(result_copy)):
                if not used[i] and result_copy[i].char == char:
                    result_copy[i].marked = marked     
                    used[i] = True                     
                    matched = True
                    break
            if not matched:
                raise ValueError(f"Could not match input character '{char}' at position {idx}")
        return result_copy


    def apply_greedy_rules(self, labeled_seq, steps, intermediate):
        i = 0
        while i < len(labeled_seq):
            matched = False  # Tracks if any rules match
            for pattern, replacement, name in self.rules:
                pat_len = len(pattern)
                if i + pat_len > len(labeled_seq):
                    continue  

                candidate_segment = labeled_seq[i:i+pat_len]  
                lhs_segment = ''.join(el.char for el in candidate_segment)
                all_true = all(el.marked for el in candidate_segment)  # All must be marked true for transformation
                max_gen = max(el.gen for el in candidate_segment)      # Checks the oldest gen from any of the objects

                if lhs_segment == ''.join(pattern) and all_true:
                    target_gen = max_gen + 1  

                    for el in candidate_segment:
                        el.marked = False

                    # Creates new elements with updated generation and unique IDs
                    new_items = [
                        self.Element(c, False, target_gen, next(self.id_counter))
                        for c in replacement
                    ]

                    # Replaces the matched segment in the sequence
                    labeled_seq = labeled_seq[:i] + new_items + labeled_seq[i+pat_len:]
                    while len(steps) <= target_gen:
                        steps.append([])
                    ids_to_replace = [el.gid for el in candidate_segment]
                    if not steps[target_gen]:
                        steps[target_gen] = labeled_seq[:]
                    else:
                        # Updates the step by replacing the old segment with the new one
                        old_step = steps[target_gen]
                        new_step = []
                        j = 0
                        while j < len(old_step):
                            el = old_step[j]
                            if el.gid in ids_to_replace:
                                if ids_to_replace.index(el.gid) == 0:
                                    new_step.extend(new_items)  # Adds new items
                                j += len(ids_to_replace)
                            else:
                                new_step.append(el)
                                j += 1
                        steps[target_gen] = new_step

                    
                    for future_gen in range(target_gen + 1, len(steps)):
                        old_future = steps[future_gen]
                        updated_future = []
                        j = 0
                        while j < len(old_future):
                            el = old_future[j]
                            if el.gid in ids_to_replace:
                                if ids_to_replace.index(el.gid) == 0:
                                    updated_future.extend(new_items)
                                j += len(ids_to_replace)
                            else:
                                updated_future.append(el)
                                j += 1
                        steps[future_gen] = updated_future

                    # Updates the intermediate visualisation
                    while len(intermediate) <= target_gen - 1:
                        intermediate.append([el for el in steps[len(intermediate)]])

                    inter_step = intermediate[target_gen - 1]
                    new_inter = []
                    j = 0
                    while j < len(inter_step):
                        if inter_step[j].gid == candidate_segment[0].gid:
                            new_inter.append(self.Element(name, False, target_gen, -1))  # Inserts rule label
                            j += len(candidate_segment)
                        else:
                            new_inter.append(inter_step[j])
                            j += 1
                    intermediate[target_gen - 1] = new_inter

                    matched = True
                    break  # Restarts loop to continue applying rules 

            if not matched:
                i += 1  # Moves forward if no rule applied at this position

        return labeled_seq  # Final rewritten sequence for this stage

    def run_greedy_process(self, input_strings, process_id="Process"):
        steps = []
        intermediate = []
        initial_labeled = [self.Element(c, marked, 1, next(self.id_counter)) 
                        for c, marked in self.parse_string(input_strings[0])]

        steps.append([])               
        steps.append(initial_labeled[:])  # Adds initial state
        intermediate.append(initial_labeled[:]) 
        current = initial_labeled[:]
        # Applies greedy transformations to the first step
        current = self.apply_greedy_rules(current, steps, intermediate)

        # Processes each step
        for step_index in range(1, len(input_strings)):
            actual_current_string = self.labelled_to_string(current)  # Gets current string as plain characters

            # Expands symbolic rule names to actual strings
            raw_input = input_strings[step_index]
            rule_lhs = {name: "'".join(lhs) + "'" for lhs, _, name in self.rules}
            for name, expansion in rule_lhs.items():
                raw_input = raw_input.replace(name, expansion)

            expected_start = ''.join(c for c, _ in self.parse_string(raw_input))

            # If user wrote a step that doesn't match current state, raise error
            if expected_start != actual_current_string:
                raise ValueError(
                    f"{process_id} - Step {step_index} does not logically follow from previous steps.\n"
                    f"You marked this step as starting from: {expected_start}\n"
                    f"But the actual current state is: {actual_current_string}"
                )

            current = self.mark_step(current, input_strings[step_index])
            before = self.labelled_to_string(current)
            current = self.apply_greedy_rules(current, steps, intermediate)            # Applies greedy rules again

            # Checks if anything changed
            after = self.labelled_to_string(current)
            if after != before:
                current_gen = max(el.gen for el in current)
                while len(steps) <= current_gen:
                    steps.append([])

                if not steps[current_gen]:
                    steps[current_gen] = current[:]

            while len(intermediate) < len(steps) - 1:
                intermediate.append([el for el in steps[len(intermediate)]])

        return steps, intermediate


    def labelled_to_string(self, labeled_seq):
        return ''.join(el.char for el in labeled_seq)


    def expand_named_rules(self, raw_inputs, rule_lhs, rules):
        states = []

        # Expands all rule names into raw LHS string
        for i, step in enumerate(raw_inputs):
            for name, expansion in rule_lhs.items():
                step = step.replace(name, expansion)
            cleaned = step.replace("'", "")  # Removes markings
            states.append(cleaned)

        last_intermediate = raw_inputs[-1]
        for lhs, _, name in rules:
            marked = "'".join(lhs) + "'"
            last_intermediate = last_intermediate.replace(name, marked)

        # Parses the marked form to get a labeled version for greedy process
        parsed = self.parse_string(last_intermediate)
        labeled = [self.Element(c, marked, 1, i) for i, (c, marked) in enumerate(parsed)]

        # Runs greedy logic
        temp_steps = [[], labeled[:]]
        temp_inter = [labeled[:]]
        result = self.apply_greedy_rules(labeled[:], temp_steps, temp_inter)
        final_string = ''.join(e.char for e in result)
        if final_string != states[-1]:
            states.append(final_string)

        return states


    def compare_processes(self, input1, input2, raw_input1, raw_input2, equivalence_message=None):
        # Runs both processes through greedy process
        try:
            steps1, inter1 = self.run_greedy_process(input1, process_id="Process 1")
        except ValueError as e:
            messagebox.showerror("Marking Error", f" Error in Process 1:\n\n{e}")
            return
        try:
            steps2, inter2 = self.run_greedy_process(input2, process_id="Process 2")
        except ValueError as e:
            messagebox.showerror("Marking Error", f" Error in Process 2:\n\n{e}")
            return

        # Recreates the original state with rule names for visual comparison
        rule_lhs = {name: "'".join(lhs) + "'" for lhs, _, name in self.rules}
        original_states1 = self.expand_named_rules(raw_input1, rule_lhs, self.rules)
        original_states2 = self.expand_named_rules(raw_input2, rule_lhs, self.rules)

        # Gets the final greedy step of both processes
        final_step1 = steps1[-1]
        final_step2 = steps2[-1]

        # Compares the final strings
        final_string1 = ''.join(e.char for e in final_step1)
        final_string2 = ''.join(e.char for e in final_step2)
        strings_match = final_string1 == final_string2

        # Compares final structure
        structure_match = len(final_step1) == len(final_step2) and sorted(
            (e.char, e.gen) for e in final_step1
        ) == sorted(
            (e.char, e.gen) for e in final_step2
        )

        if strings_match and structure_match:
            result_text = "The processes ARE causally equivalent."
        elif strings_match:
            result_text = "The final strings match, but the causal history is ordered differently."
        else:
            result_text = "The processes are not causally equivalent."

        # Generates visualisation
        self.visualise_dual(
            steps1, inter1, steps2, inter2,
            original_states1, original_states2,
            raw_input1, raw_input2,
            equivalence_message=result_text
        )

    def visualise_dual(self, steps1, inter1, steps2, inter2, states1, states2, raw1, raw2, equivalence_message=None):
        def build_grid(steps, inter):
            all_steps, labels = [], []
            for i in range(1, len(steps)):
                all_steps.append(steps[i])  # Adds actual step
                labels.append(''.join(e.char for e in steps[i]))
                if i < len(inter):
                    all_steps.append(inter[i])  # Adds intermediate step
                    labels.append(''.join(e.char for e in inter[i]))
            max_len = max(len(s) for s in all_steps if s)
            grid = [[1 if e.char == 'A' else 2 if e.char == 'B' else 3 for e in row] + [0] * (max_len - len(row)) for row in all_steps]
            return grid, labels

        def build_raw(states, intermediates):
            steps, labels = [], []
            steps.append([1 if c == 'A' else 2 if c == 'B' else 3 for c in states[0]])
            labels.append(states[0])
            for i in range(len(intermediates)):
                inter = intermediates[i].replace(' ', '')
                steps.append([1 if c == 'A' else 2 if c == 'B' else 3 for c in inter])
                labels.append(inter)
                if i + 1 < len(states):
                    steps.append([1 if c == 'A' else 2 if c == 'B' else 3 for c in states[i+1]])
                    labels.append(states[i+1])
            max_len = max(len(row) for row in steps)
            padded = [row + [0]*(max_len - len(row)) for row in steps]
            return padded, labels


        def plot(ax, grid, labels, title, sublabel):
            ax.imshow(grid, cmap=ListedColormap(['white', 'red', 'blue', 'gray']), aspect='auto')
            ax.set_yticks(np.arange(len(labels)))
            ax.set_yticklabels(labels)
            ax.set_xticks([])
            ax.set_title(f"{title}\n{sublabel}", fontsize=10)
                                                                            # Plots visualisation
        fig, axs = plt.subplots(2, 2, figsize=(14, 10))
        raw_grid1, raw_labels1 = build_raw(states1, raw1)
        raw_grid2, raw_labels2 = build_raw(states2, raw2)
        plot(axs[0, 0], raw_grid1, raw_labels1, "Original Input 1", '')
        plot(axs[0, 1], raw_grid2, raw_labels2, "Original Input 2", '')
        grid1, labels1 = build_grid(steps1, inter1)
        grid2, labels2 = build_grid(steps2, inter2)
        plot(axs[1, 0], grid1, labels1, "Greedy Output 1", '')
        plot(axs[1, 1], grid2, labels2, "Greedy Output 2", '')
        rule_text = '\n'.join([f"{name}: {''.join(lhs)} -> {rhs}" for lhs, rhs, name in self.rules])
        fig.text(0.91, 0.95, "Rules:\n" + rule_text, fontsize=10, ha='left', va='top', bbox=dict(facecolor='white', edgecolor='black'))
        if equivalence_message:
            fig.text(0.5, 0.01, equivalence_message, ha='center', va='center', fontsize=14,
                    color='green' if 'ARE' in equivalence_message else 'red',
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.5'))

        plt.tight_layout()
        plt.show()

class CausalApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Causal Process Equivalence Checker")  
        self.rules = []            # Stores the rules inputted by the user
        self.rule_entries = []     
        self.rule_count = 0        # Stores the of rules that the user wants to define
        self.welcome_page()        

    def welcome_page(self):
        self.clear_window()  
        label = ttk.Label(self.root, text="How many rewrite rules do you want to define?")
        label.pack(pady=10)
        self.rule_count_var = tk.IntVar()  
        rule_entry = ttk.Entry(self.root, textvariable=self.rule_count_var)
        rule_entry.pack(pady=5)
        submit_button = ttk.Button(self.root, text="Submit", command=self.rule_input_page)
        submit_button.pack(pady=10)

    def rule_input_page(self):
        try:
            self.rule_count = self.rule_count_var.get()
            if self.rule_count <= 0 or self.rule_count > 10:            # Ensures rule count is between 1 and 10
                raise ValueError
        except:
            messagebox.showerror("Error", "Please enter a number between 1 and 10 for the number of rules.")
            return

        self.clear_window()
        self.rule_entries = []

        for i in range(self.rule_count):
            frame = ttk.Frame(self.root)
            frame.pack(pady=5)
            ttk.Label(frame, text=f"Rule {i+1}:").grid(row=0, column=0)
            lhs_entry = ttk.Entry(frame)  
            lhs_entry.grid(row=0, column=1)
            ttk.Label(frame, text="->").grid(row=0, column=2)
            rhs_entry = ttk.Entry(frame)  
            rhs_entry.grid(row=0, column=3)
            ttk.Label(frame, text="Name:").grid(row=0, column=4)
            name_entry = ttk.Entry(frame)  
            name_entry.grid(row=0, column=5)
            self.rule_entries.append((lhs_entry, rhs_entry, name_entry))

        proceed_button = ttk.Button(self.root, text="Proceed to Input Sequences", command=self.process_input_page)
        proceed_button.pack(pady=10)

    def process_input_page(self):
        self.rules.clear()
        for lhs_entry, rhs_entry, name_entry in self.rule_entries:
            lhs = lhs_entry.get().strip()
            rhs = rhs_entry.get().strip()
            name = name_entry.get().strip()
            if not lhs or not rhs or not name:
                messagebox.showerror("Error", "Please fill in all fields for all rules.")
                return
            self.rules.append((lhs, rhs, name))

        self.clear_window()

        ttk.Label(self.root, text="Enter process 1 steps (one per line):").pack()
        self.process1_text = tk.Text(self.root, height=10, width=50)
        self.process1_text.pack(pady=5)
        ttk.Label(self.root, text="Enter process 2 steps (one per line):").pack()
        self.process2_text = tk.Text(self.root, height=10, width=50)
        self.process2_text.pack(pady=5)
        run_button = ttk.Button(self.root, text="Check Equivalence", command=self.run_equivalence_check)
        run_button.pack(pady=10)

    def run_equivalence_check(self):
        raw1 = self.process1_text.get("1.0", tk.END).strip().splitlines()
        raw2 = self.process2_text.get("1.0", tk.END).strip().splitlines()
        if not raw1 or not raw2:
            messagebox.showerror("Error", "Both processes must contain at least one step.")
            return

        named_rules = {name: "'".join(lhs) + "'" for lhs, _, name in self.rules}
        rules = [(lhs, rhs, name) for lhs, rhs, name in self.rules]
        expanded1 = [step for step in raw1]
        for name, rep in named_rules.items():
            expanded1 = [step.replace(name, rep) for step in expanded1]

        expanded2 = [step for step in raw2]
        for name, rep in named_rules.items():
            expanded2 = [step.replace(name, rep) for step in expanded2]

        cp = CausalProcess(rules)
        try:
            steps1, _ = cp.run_greedy_process(expanded1, process_id="Process 1")
        except ValueError as e:
            messagebox.showerror("Marking Error", f" Error in Process 1:\n\n{e}")
            return

        try:
            steps2, _ = cp.run_greedy_process(expanded2, process_id="Process 2")
        except ValueError as e:
            messagebox.showerror("Marking Error", f" Error in Process 2:\n\n{e}")
            return

        cp.compare_processes(expanded1, expanded2, raw1, raw2)

    def clear_window(self):
        for widget in self.root.winfo_children():
            widget.destroy()

if __name__ == "__main__":
    root = tk.Tk()
    app = CausalApp(root)
    root.mainloop()
