#!/usr/bin/env python3
import sys
import os
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import subprocess, threading, time

# -------------------------------
# Global sudo password storage
# -------------------------------
sudo_password = None

# -------------------------------
# Root helper plumbing
# -------------------------------
def run_root_task(args_list, timeout=30):
    """Run privileged command using cached sudo password."""
    global sudo_password
    if not sudo_password:
        return "⚠ No sudo password set."
    try:
        cmd = ["sudo", "-S"] + args_list
        proc = subprocess.run(
            cmd,
            input=sudo_password + "\n",
            capture_output=True,
            text=True,
            timeout=timeout
        )
        out = (proc.stdout or "") + (proc.stderr or "")
        out = out.strip()
        if proc.returncode != 0:
            raise subprocess.CalledProcessError(proc.returncode, cmd, output=out)
        return out or "(no output)"
    except subprocess.TimeoutExpired:
        return "Root task timed out."
    except subprocess.CalledProcessError as e:
        return e.output.strip() or f"Root task failed with code {e.returncode}"

# -------------------------------
# Utility (unprivileged) helpers
# -------------------------------
def run_cmd(cmd):
    try:
        result = subprocess.run(
            cmd, shell=True, capture_output=True, text=True, timeout=8
        )
        return result.stdout.strip() or result.stderr.strip() or ""
    except subprocess.TimeoutExpired:
        return ""

def run_smartctl_bundle(dev):
    """Run smartctl once with -i -A -H and return all output."""
    return run_root_task(["smartctl", "-i", "-A", "-H", dev])

def interpret_smart_summary(smart_text):
    analysis = []

    # --- Reallocated sector count (use RAW_VALUE = last column) ---
    for line in smart_text.splitlines():
        if "Reallocated_Sector_Ct" in line:
            try:
                value = int(line.split()[-1])
                if value == 0:
                    analysis.append("Media health is excellent – no bad sectors remapped.")
                else:
                    analysis.append("Drive has reallocated sectors – monitor closely, risk is rising.")
            except:
                pass
            break

    # --- Hours on drive ---
    hours = None
    for line in smart_text.splitlines():
        if "Power_On_Hours" in line:
            try:
                parts = line.split()
                hours = int(parts[-1])
            except:
                pass
    if hours is not None:
        if hours > 30000:
            analysis.append(f"Drive has {hours} hours – beyond typical design life. Higher failure risk.")
        elif hours > 20000:
            analysis.append(f"Drive has {hours} hours – mid-life, expect eventual wear.")
        else:
            analysis.append(f"Drive has {hours} hours – relatively young, low wear.")

    # --- Temperature check (simple presence hint) ---
    if "Temperature" in smart_text:
        analysis.append("Temperature is within a normal range.")

    # --- SMART overall health ---
    if "SMART overall-health self-assessment test result: PASSED" in smart_text or "PASSED" in smart_text:
        analysis.append("SMART firmware check: PASSED – no internal warnings.")
    elif "FAILED" in smart_text:
        analysis.append("SMART firmware check: FAILED – replace immediately.")

    return "\n".join(analysis) if analysis else "No interpretable SMART fields found."

def analyze_results(write_avg, read_avg):
    msgs = []
    msgs.append("✓ Write speed looks healthy." if write_avg >= 30 else "⚠ Write speed is lower than expected.")
    msgs.append("✓ Read speed looks reasonable." if read_avg >= 10 else "⚠ Read speed is extremely slow.")
    if read_avg < write_avg / 5:
        msgs.append("⚠ Reads disproportionately slower than writes – possible media/controller issue.")
    return "\n".join(msgs)

# -------------------------------
# Main App Class
# -------------------------------
class DiskDigApp(tk.Tk):
    def __init__(self):
        super().__init__()
        try:
            self.call('tk', 'appname', 'diskdig')
        except Exception as e:
            print("appname set failed:", e)

        self.title("DiskDig Prototype")
        self.geometry("850x760")
        self.path_map = {}

        self._build_ui()
        self.list_volumes()

    # -------------------------------
    # UI construction
    # -------------------------------
    def _build_ui(self):
        frame = ttk.Frame(self, padding=10)
        frame.pack(fill="both", expand=True)

        # Options + Auth frame (top row)
        menu_frame = ttk.Frame(frame)
        menu_frame.pack(fill="x", pady=(0, 5))

        # Authorization field (left)
        auth_frame = ttk.Frame(menu_frame)
        auth_frame.pack(side="left", padx=5)

        ttk.Label(auth_frame, text="Authorization Required:").pack(side="left")

        self.pw_entry = ttk.Entry(auth_frame, width=15, show="*")
        self.pw_entry.pack(side="left", padx=2)

        save_btn = ttk.Button(auth_frame, text="Save", command=self.save_password, style="Auth.TButton")
        save_btn.pack(side="left")

        style = ttk.Style()
        style.configure("Auth.TButton", foreground="white", background="blue")

        # NEW: style for action buttons
        style.configure("Action.TButton",
                        background="#444444",  # dark gray
                        foreground="white",
                        font=("TkDefaultFont", 10, "bold"))
        style.map("Action.TButton",
                  background=[("active", "#666666")])

        # Options menubutton (right)
        menubtn = tk.Menubutton(menu_frame, text="Options", relief="raised")
        menu = tk.Menu(menubtn, tearoff=0)
        menu.add_command(label="Clear Results", command=self.clear_results)
        menu.add_command(label="Refresh Drives", command=self.refresh_all)
        menu.add_separator()
        menu.add_command(label="Exit", command=self.quit)
        menubtn.config(menu=menu)
        menubtn.pack(anchor="ne")

        # Mounted drives
        ttk.Label(frame, text="Select A Currently Mounted Drive", font=("TkDefaultFont", 11, "bold")).pack(anchor="w")
        vol_frame = ttk.Frame(frame)
        vol_frame.pack(fill="x", pady=(0, 5))
        self.volumes_box = tk.Listbox(vol_frame, height=4, font=("TkFixedFont", 11))  # shrunk from 6 → 4
        self.volumes_box.pack(side="left", fill="x", expand=True)
        scroll = ttk.Scrollbar(vol_frame, orient="vertical", command=self.volumes_box.yview)
        scroll.pack(side="right", fill="y")
        self.volumes_box.config(yscrollcommand=scroll.set)
        self.volumes_box.bind("<<ListboxSelect>>", self.on_volume_select)

        # SMART box
        ttk.Label(frame, text="SMART Overview:", font=("TkDefaultFont", 11, "bold")).pack(anchor="w")
        smart_frame = ttk.Frame(frame)
        smart_frame.pack(fill="both", expand=True, pady=5)
        smart_scroll = tk.Scrollbar(smart_frame, orient="vertical")
        self.smart_box = tk.Text(smart_frame, wrap="word", height=10, yscrollcommand=smart_scroll.set)
        self.smart_box.pack(side="left", fill="both", expand=True)
        self.smart_box.tag_configure("warning", background="red", foreground="white")
        self.smart_box.tag_configure("good", background="green", foreground="white")
        self.smart_box.tag_configure("success", background="#d6f5d6", foreground="black")
        self.smart_box.tag_configure("fail", background="#f5d6d6", foreground="black")
        self.smart_box.tag_configure("caution", background="#fff5cc", foreground="black")
        self.smart_box.tag_configure("heading", font=("TkDefaultFont", 11, "bold"))
        smart_scroll.config(command=self.smart_box.yview)
        smart_scroll.pack(side="right", fill="y")

        # Prevent typing but allow selection
        def block_typing(event):
            return "break"
        self.smart_box.bind("<Key>", block_typing)

        # Right-click copy menu
        self.smart_menu = tk.Menu(self.smart_box, tearoff=0)
        self.smart_menu.add_command(label="Copy", command=self.copy_smart_selection)
        self.smart_menu.add_command(label="Select All", command=self.select_all_smart)
        self.smart_box.bind("<Button-3>", self.show_smart_menu)  # Right-click (Linux/Win)
        self.smart_box.bind("<Button-2>", self.show_smart_menu)  # Middle-click (macOS)

        ttk.Separator(frame, orient="horizontal").pack(fill="x", pady=8)

        # Controls
        control_frame = ttk.Frame(frame)
        control_frame.pack(anchor="w", pady=(0, 10), fill="x")
        ttk.Label(control_frame, text="Read/Write Test", font=("TkDefaultFont", 10, "bold")).pack(side="left", padx=(0, 15))
        ttk.Label(control_frame, text="Read/write file size (MB):").pack(side="left")
        self.size_entry = ttk.Entry(control_frame, width=8)
        self.size_entry.insert(0, "100")
        self.size_entry.pack(side="left", padx=(5, 20))
        ttk.Label(control_frame, text="Repeat how many times:").pack(side="left")
        self.repeats_entry = ttk.Entry(control_frame, width=5)
        self.repeats_entry.insert(0, "5")
        self.repeats_entry.pack(side="left", padx=5)

        ttk.Label(frame, text="Test path (folder to create dummy file):").pack(anchor="w")
        path_frame = ttk.Frame(frame)
        path_frame.pack(anchor="w", pady=(0, 10), fill="x")
        self.path_entry = ttk.Entry(path_frame, width=40)
        self.path_entry.pack(side="left", fill="x", expand=True)
        browse_btn = ttk.Button(path_frame, text="Browse", command=self.browse_path)
        browse_btn.pack(side="left", padx=5)

        # --- Button Row ---
        btn_row = ttk.Frame(frame)
        btn_row.pack(fill="x", pady=5)

        start_btn = ttk.Button(btn_row, text="Begin Read/Write Test",
                               command=self.start_monitor, style="Action.TButton")
        start_btn.pack(side="left")

        self.unmount_entry = ttk.Entry(btn_row, width=12)
        self.unmount_entry.insert(0, "/dev/sdX")
        self.unmount_entry.pack(side="right", padx=(5, 0))

        unmount_btn = ttk.Button(btn_row, text="Unmount Drive:",
                                 command=self.unmount_drive, style="Action.TButton")
        unmount_btn.pack(side="right", padx=(10, 5))

        # Drive info
        ttk.Label(frame, text="Drive Info / Test Status:", font=("TkDefaultFont", 11, "bold")).pack(anchor="w")
        output_frame = ttk.Frame(frame)
        output_frame.pack(fill="both", pady=5, expand=True)
        yscroll = tk.Scrollbar(output_frame, orient="vertical")
        self.output_box = tk.Text(output_frame, wrap="none", height=6, yscrollcommand=yscroll.set)
        self.output_box.pack(side="left", fill="both", expand=True)
        yscroll.config(command=self.output_box.yview)
        yscroll.pack(side="right", fill="y")

        # Analysis
        ttk.Label(frame, text="Results of Read/Write Test", font=("TkDefaultFont", 10, "bold")).pack(anchor="w")
        self.analysis_box = tk.Text(frame, wrap="word", height=6)
        self.analysis_box.pack(fill="both", expand=False, pady=5)

    # -------------------------------
    # SMART text copy helpers
    # -------------------------------
    def show_smart_menu(self, event):
        try:
            self.smart_menu.tk_popup(event.x_root, event.y_root)
        finally:
            self.smart_menu.grab_release()

    def copy_smart_selection(self):
        try:
            selection = self.smart_box.get("sel.first", "sel.last")
        except tk.TclError:
            return  # nothing selected
        self.clipboard_clear()
        self.clipboard_append(selection)
        self.update()

    def select_all_smart(self):
        self.smart_box.tag_add("sel", "1.0", "end-1c")

    # -------------------------------
    # Authorization
    # -------------------------------
    def save_password(self):
        global sudo_password
        sudo_password = self.pw_entry.get().strip()
        if sudo_password:
            messagebox.showinfo("Authorization", "Password saved for this session.")
        else:
            messagebox.showwarning("Authorization", "No password entered.")

    # -------------------------------
    # Menu actions
    # -------------------------------
    def clear_results(self):
        self.smart_box.delete("1.0", tk.END)
        self.output_box.delete("1.0", tk.END)
        self.analysis_box.delete("1.0", tk.END)

    def refresh_all(self):
        self.clear_results()
        self.path_entry.delete(0, tk.END)
        self.list_volumes()

    # -------------------------------
    # Unmount + Power-off
    # -------------------------------
    def unmount_drive(self):
        dev = self.unmount_entry.get().strip()
        if not dev.startswith("/dev/"):
            messagebox.showerror("No Device", "Please enter a valid device path like /dev/sdX")
            return
        try:
            out1 = run_root_task(["udisksctl", "unmount", "-b", dev])
            out2 = run_root_task(["udisksctl", "power-off", "-b", dev])
            combined = (out1 + "\n" + out2).strip()
            messagebox.showinfo("Drive Unmounted", f"Drive {dev} unmounted and powered off.\n\n{combined}")
        except Exception as e:
            messagebox.showerror("Error", f"Failed to unmount/power-off {dev}\n\n{e}")

    # -------------------------------
    # Volume + SMART
    # -------------------------------
    def list_volumes(self):
        self.volumes_box.delete(0, tk.END)
        self.path_map.clear()
        try:
            df_out = subprocess.check_output(["df", "-hT"], text=True).splitlines()
            for line in df_out[1:]:
                parts = line.split()
                if len(parts) < 7:
                    continue
                device, fstype, size, used, avail, percent, mountpoint = parts[:7]
                if fstype in ("tmpfs", "devtmpfs", "overlay", "cifs", "smbfs"):
                    continue
                if mountpoint.startswith(("/boot", "/proc", "/sys", "/run")):
                    continue
                uuid = run_cmd(f"blkid -s UUID -o value {device}") or "root"
                entry = f"{device:<12} [UUID:{uuid:<12}] {size:>6} total {avail:>6} free"
                self.volumes_box.insert(tk.END, entry)
                self.path_map[entry] = mountpoint
        except Exception as e:
            self.output_box.insert(tk.END, f"Error listing volumes: {e}\n")

    def on_volume_select(self, event):
        sel = self.volumes_box.curselection()
        if not sel:
            return
        entry = self.volumes_box.get(sel[0])
        mountpoint = self.path_map.get(entry)
        if mountpoint:
            self.path_entry.delete(0, tk.END)
            self.path_entry.insert(0, mountpoint)
            self.fetch_drive_info_async(mountpoint)

    def get_drive_info_and_smart(self, path):
        try:
            source = run_cmd(f"findmnt -n -o SOURCE --target '{path}'").strip()
            if not source.startswith("/dev/"):
                return None, "⚠ Could not determine device for selected path.", ""
            base = run_cmd(f"lsblk -no PKNAME {source}").strip() or os.path.basename(source)
            base = f"/dev/{base}"
            lsblk_info = run_cmd(f"lsblk -o NAME,MODEL,SERIAL,VENDOR,SIZE {base}")
            smart_output = run_smartctl_bundle(base)
            details = f"Drive info for {base}:\n\n--- lsblk ---\n{lsblk_info}\n\n--- smartctl (i/A/H) ---\n{smart_output}"
            return base, details, smart_output
        except Exception as e:
            return None, f"⚠ Error getting drive info: {e}", ""

    def fetch_drive_info_async(self, path):
        def worker():
            base, info, smart_output = self.get_drive_info_and_smart(path)
            if base:
                analysis = interpret_smart_summary(smart_output)

                # --- Reallocated sector count (raw) ---
                realloc_count = None
                for line in smart_output.splitlines():
                    if "Reallocated_Sector_Ct" in line:
                        try:
                            realloc_count = int(line.split()[-1])
                        except:
                            pass
                        break

                summary_lines = []
                summary_lines.append("SUMMARY OF ANALYSIS")
                if realloc_count is not None:
                    if realloc_count == 0:
                        summary_lines.append(f"Reallocated_Sector_Ct = {realloc_count}  ✓ No sectors remapped.")
                    else:
                        summary_lines.append(f"Reallocated_Sector_Ct = {realloc_count}  ⚠ Warning: sectors remapped!")
                if analysis:
                    summary_lines.append(analysis)

                summary_block = "\n".join(summary_lines) + "\n" + "-"*40 + "\n\n"
            else:
                summary_block = "SUMMARY OF ANALYSIS\n⚠ SMART overview not available.\n\n"
                smart_output = ""

            def update_ui():
                self.smart_box.delete("1.0", tk.END)

                for line in summary_block.splitlines():
                    if line.strip() == "SUMMARY OF ANALYSIS":
                        self.smart_box.insert(tk.END, line + "\n", "heading")
                    elif "Reallocated_Sector_Ct" in line:
                        if "= 0 " in line:
                            self.smart_box.insert(tk.END, line + "\n", "good")
                        else:
                            self.smart_box.insert(tk.END, line + "\n", "warning")
                    elif "PASSED" in line:
                        self.smart_box.insert(tk.END, line + "\n", "success")
                    elif "FAILED" in line:
                        self.smart_box.insert(tk.END, line + "\n", "fail")
                    elif ("Warning" in line
                          or "monitor closely" in line
                          or "Higher failure risk." in line):
                        self.smart_box.insert(tk.END, line + "\n", "caution")
                    else:
                        self.smart_box.insert(tk.END, line + "\n")

                # Then raw smartctl output
                if smart_output:
                    self.smart_box.insert(tk.END, smart_output + "\n")

                self.output_box.delete("1.0", tk.END)
                self.output_box.insert(tk.END, info + "\n\n")

            self.after(0, update_ui)

        threading.Thread(target=worker, daemon=True).start()

    # -------------------------------
    # IO test + Monitor
    # -------------------------------
    def io_test(self, path, size_mb):
        test_file = os.path.join(path, "diskmon_testfile.tmp")

        t0 = time.time()
        subprocess.run(
            f"dd if=/dev/zero of='{test_file}' bs=1M count={size_mb} conv=fdatasync status=none",
            shell=True, capture_output=True, text=True
        )
        t1 = time.time()
        write_speed = size_mb / (t1 - t0)

        t0 = time.time()
        subprocess.run(
            f"dd if='{test_file}' of=/dev/null bs=1M iflag=direct status=none",
            shell=True, capture_output=True, text=True
        )
        t1 = time.time()
        read_speed = size_mb / (t1 - t0)

        try:
            os.remove(test_file)
        except FileNotFoundError:
            pass
        return write_speed, read_speed

    def monitor(self, size_mb, repeats, path):
        ws_all, rs_all = [], []
        self.output_box.delete("1.0", tk.END)
        self.output_box.insert(tk.END, f"Running {repeats} test(s) with {size_mb} MB file…\n\n")

        for i in range(repeats):
            try:
                ws, rs = self.io_test(path, size_mb)
                ws_all.append(ws)
                rs_all.append(rs)
                self.output_box.insert(
                    tk.END,
                    f"✓ Test {i+1}/{repeats} complete. Write: {ws:.1f} MB/s | Read: {rs:.1f} MB/s\n"
                )
            except Exception as e:
                self.output_box.insert(tk.END, f"⚠ Error during test {i+1}: {e}\n")

        if ws_all and rs_all:
            write_avg = sum(ws_all) / len(ws_all)
            read_avg = sum(rs_all) / len(rs_all)
            result_text = (
                f"\n✓ All tests finished.\n"
                f"Avg Write: {write_avg:.1f} MB/s | Avg Read: {read_avg:.1f} MB/s\n"
            )
            analysis = analyze_results(write_avg, read_avg)
        else:
            result_text = "⚠ No valid results collected.\n"
            analysis = "⚠ Unable to analyze."

        def update_ui():
            self.output_box.insert(tk.END, result_text)
            self.analysis_box.delete("1.0", tk.END)
            self.analysis_box.insert(tk.END, analysis)

        self.after(0, update_ui)

    def start_monitor(self):
        try:
            size_mb = int(self.size_entry.get())
        except:
            size_mb = 20
        try:
            repeats = int(self.repeats_entry.get())
        except:
            repeats = 1
        path = self.path_entry.get().strip()
        if not path:
            self.output_box.insert(tk.END, "⚠ No test path selected!\n")
            return
        threading.Thread(
            target=self.monitor, args=(size_mb, repeats, path), daemon=True
        ).start()

    def get_initial_mount_dir(self):
        user = os.getenv("USER") or os.getenv("LOGNAME") or ""
        candidates = [
            f"/media/{user}",
            f"/run/media/{user}",
            "/mnt",
            "/media",
            "/"
        ]
        for c in candidates:
            if os.path.isdir(c):
                return c
        return "/"

    def browse_path(self):
        initial = self.get_initial_mount_dir()
        chosen = filedialog.askdirectory(initialdir=initial, title="Select test directory")
        if chosen:
            self.path_entry.delete(0, tk.END)
            self.path_entry.insert(0, chosen)
            self.fetch_drive_info_async(chosen)

# -------------------------------
# Main entry
# -------------------------------
if __name__ == "__main__":
    app = DiskDigApp()
    app.mainloop()
