#!/usr/bin/env python3
"""
Visualize OpenStreetMap road features for a state,
colored directly by their 'layer' attribute value.

✅ Features:
 - Colors each road by its layer value (-2 .. +3)
 - Uses diverging color map (blue→green→red)
 - Handles missing values gracefully (layer=0 default)
 - Saves PNG to state folder
"""

import argparse
import sys
from pathlib import Path
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt

def plot_layer_map(roads, title, save_path):
    """Plot roads colored by their layer value."""
    fig, ax = plt.subplots(figsize=(12, 10))

    # Plot by layer value
    roads.plot(
        column="layer",
        cmap="RdYlBu_r",  # red-yellow-blue reversed (blue=low, red=high)
        linewidth=0.3,
        ax=ax,
        legend=True,
        legend_kwds={
            "label": "OSM Layer Value",
            "orientation": "vertical",
            "shrink": 0.7,
            "pad": 0.02,
        },
    )

    ax.set_title(title, fontsize=14, pad=10)
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"💾 Saved colored layer map to: {save_path}")
    plt.close(fig)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Visualize roads colored by their OSM 'layer' attribute."
    )
    parser.add_argument("state", help="State name (e.g., alabama).")
    args = parser.parse_args()

    state = args.state.lower()
    base_dir = Path("output") / state
    roads_path = base_dir / f"{state}_roads_projected.gpkg"

    if not roads_path.exists():
        print(f"❌ File not found: {roads_path}")
        sys.exit(1)

    print(f"📦 Loading roads for {state.capitalize()}...")
    roads = gpd.read_file(roads_path, layer="roads_projected")

    # Make sure 'layer' exists and is numeric
    if "layer" not in roads.columns:
        print("⚠️ No 'layer' column found — defaulting all to 0.")
        roads["layer"] = 0
    else:
        roads["layer"] = pd.to_numeric(roads["layer"], errors="coerce").fillna(0)

    title = f"{state.capitalize()} Road Network by OSM Layer"
    output_png = base_dir / f"{state}_road_layers.png"
    plot_layer_map(roads, title, output_png)
