#!/usr/bin/env python3
"""
Generate heatmaps for all key OSM road attributes (numeric and categorical).

Enhancements:
 - Automatically removes unused legend categories
 - Uses adaptive colormaps (Set2 → tab20 if >8 categories)
 - Prints unique values for diagnostics
 - Handles missing data gracefully
"""

import argparse
import sys
from pathlib import Path
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colormaps


# --------------------------------
# Helper: Plot a single attribute
# --------------------------------
def plot_attribute_map(gdf, attr, state, save_dir, cmap="inferno", categorical=False):
    gdf_local = gdf.copy()
    fig, ax = plt.subplots(figsize=(12, 10))

    # Neutral background
    fig.patch.set_facecolor("#f2f2f2")
    ax.set_facecolor("#f2f2f2")

    title = f"{state.capitalize()} — {attr.capitalize()}"
    print(f"\n🔎 Unique {attr} values in {state}:")
    print(gdf_local[attr].value_counts(dropna=False))

    if categorical:
        # Clean categories
        unique_vals = gdf_local[attr].dropna().unique()
        if len(unique_vals) == 0:
            print(f"⚠️ No values found for {attr}, skipping.")
            plt.close(fig)
            return

        # Too many categories: simplify
        if len(unique_vals) > 20:
            print(f"⚠️ Too many categories in {attr}, showing top 20.")
            top_values = gdf_local[attr].value_counts().head(20).index
            gdf_local.loc[~gdf_local[attr].isin(top_values), attr] = "Other"
            unique_vals = gdf_local[attr].dropna().unique()

        # Choose colormap adaptively
        if len(unique_vals) > 8:
            cmap_name = "tab20"
        else:
            cmap_name = "Set2"

        gdf_local.plot(
            column=attr,
            ax=ax,
            legend=True,
            categorical=True,
            cmap=cmap_name,
            linewidth=0.3,
            legend_kwds={
                "bbox_to_anchor": (1.02, 1),
                "loc": "upper left",
                "title": attr,
                "labels": sorted([str(v) for v in unique_vals])
            },
            missing_kwds={
                "color": "white",
                "label": "Missing",
                "edgecolor": "none"
            },
        )

    else:
        # Numeric attributes
        gdf_local[attr] = pd.to_numeric(gdf_local[attr], errors="coerce")
        if gdf_local[attr].dropna().empty:
            print(f"⚠️ No numeric data for {attr}, skipping.")
            plt.close(fig)
            return

        vmin, vmax = np.nanpercentile(gdf_local[attr].dropna(), [5, 95])
        cmap_obj = colormaps.get_cmap(cmap).copy()
        cmap_obj.set_bad(color="white")

        gdf_local.plot(
            column=attr,
            cmap=cmap_obj,
            linewidth=0.3,
            ax=ax,
            legend=True,
            vmin=vmin,
            vmax=vmax,
            legend_kwds={"label": attr, "shrink": 0.6, "pad": 0.02},
            missing_kwds={
                "color": "white",
                "label": "Missing",
                "edgecolor": "none"
            },
        )

    ax.set_title(title, fontsize=14, pad=10)
    ax.axis("off")
    plt.tight_layout()

    out_path = save_dir / f"{state}_{attr}_heatmap.png"
    plt.savefig(out_path, dpi=300, bbox_inches="tight", facecolor=fig.get_facecolor())
    plt.close(fig)
    print(f"💾 Saved heatmap: {out_path.name}")


# --------------------------------
# Main execution
# --------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate per-road attribute heatmaps for a state."
    )
    parser.add_argument("state", help="State name (e.g., alabama).")
    args = parser.parse_args()

    state = args.state.lower()
    gdf_path = Path("output") / state / f"{state}_roads_projected.gpkg"
    save_dir = gdf_path.parent

    if not gdf_path.exists():
        print(f"❌ File not found: {gdf_path}")
        sys.exit(1)

    print(f"📦 Loading road data for {state.capitalize()}...")
    roads = gpd.read_file(gdf_path, layer="roads_projected")

    numeric_attrs = ["lanes", "width"]
    categorical_attrs = ["surface", "material", "bridge", "tunnel", "layer"]

    available_numeric = [a for a in numeric_attrs if a in roads.columns]
    available_categorical = [a for a in categorical_attrs if a in roads.columns]

    print(f"📊 Numeric attributes: {available_numeric}")
    print(f"🎨 Categorical attributes: {available_categorical}")

    for attr in available_numeric:
        print(f"\n📈 Plotting {attr} heatmap...")
        plot_attribute_map(roads, attr, state, save_dir, cmap="inferno", categorical=False)

    for attr in available_categorical:
        print(f"\n🎨 Plotting {attr} map...")
        plot_attribute_map(roads, attr, state, save_dir, cmap="Set2", categorical=True)

    print("\n✅ All attribute heatmaps generated successfully!")
