#!/usr/bin/env python3
"""
Calculates total road length per grid cell for a given state.

✅ Features:
 - Reads {state}_roads_projected.gpkg and {state}_grid_<size>.gpkg
 - Computes total road length per tile (in meters)
 - Writes {state}_road_density_<size>.gpkg
"""

import argparse
import os
import sys
from pathlib import Path
import geopandas as gpd

CELL_SIZE_ENV_VAR = "GRID_CELL_SIZE_METERS"


def resolve_cell_size(explicit_value):
    """Resolve cell size from CLI args or environment variable."""
    raw_value = explicit_value or os.getenv(CELL_SIZE_ENV_VAR)
    if raw_value is None:
        raise ValueError(
            f"Cell size must be passed as an argument or provided via {CELL_SIZE_ENV_VAR}."
        )
    try:
        return int(raw_value)
    except ValueError as exc:
        raise ValueError(
            f"Invalid cell size '{raw_value}'. Provide an integer or set {CELL_SIZE_ENV_VAR}."
        ) from exc


def calculate_density(roads, grid):
    """Calculate total road length per grid cell using in-cell lengths."""
    print("📊 Computing road segments within each grid cell...")

    grid_local = grid.reset_index().rename(columns={"index": "cell_id"})

    try:
        intersections = gpd.overlay(
            roads[["geometry"]],
            grid_local[["cell_id", "geometry"]],
            how="intersection",
        )
    except Exception as exc:
        raise RuntimeError(f"Failed to compute road/grid intersections: {exc}") from exc

    print(f"✅ Created {len(intersections):,} road pieces clipped to cells.")

    intersections["length_m"] = intersections.geometry.length

    print("🧮 Summing clipped lengths per grid cell...")
    grouped = intersections.groupby("cell_id")["length_m"].sum()

    grid_local["length_m"] = grid_local["cell_id"].map(grouped).fillna(0)
    grid_result = grid_local.drop(columns=["cell_id"]).set_index(grid.index)

    print("✅ Density calculated with in-cell lengths.")
    return grid_result


def parse_args():
    parser = argparse.ArgumentParser(
        description="Calculate road length per grid cell for a given state."
    )
    parser.add_argument("state", help="State name (e.g., alabama).")
    parser.add_argument(
        "cell_size",
        nargs="?",
        help=("Grid cell size in meters. "
              f"Defaults to env var {CELL_SIZE_ENV_VAR} if unset."),
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    state = args.state.lower()
    try:
        cell_size = resolve_cell_size(args.cell_size)
    except ValueError as err:
        print(f"❌ {err}")
        sys.exit(1)
    grid_size_label = f"{cell_size//1000}km"

    # Check if state is in germany subdirectory first
    germany_dir = Path("output") / "germany" / state
    if (germany_dir / f"{state}_roads_projected.gpkg").exists():
        base_dir = germany_dir
    else:
        base_dir = Path("output") / state
    
    roads_path = base_dir / f"{state}_roads_projected.gpkg"
    grid_path = base_dir / f"{state}_grid_{grid_size_label}.gpkg"
    output_path = base_dir / f"{state}_road_density_{grid_size_label}.gpkg"

    # Check files exist
    for path in [roads_path, grid_path]:
        if not path.exists():
            print(f"❌ Missing input file: {path}")
            sys.exit(1)

    print(f"📂 Working on {state.capitalize()} ({cell_size/1000:.1f} km grid)")
    print(f"📥 Roads: {roads_path.name}")
    print(f"📥 Grid:  {grid_path.name}")

    roads = gpd.read_file(roads_path, layer="roads_projected")
    grid = gpd.read_file(grid_path, layer=f"grid_{grid_size_label}")

    if roads.empty or grid.empty:
        print("⚠️ One of the input files is empty. Aborting.")
        sys.exit(1)

    result = calculate_density(roads, grid)

    # Save output
    result.to_file(output_path, layer=f"road_density_{grid_size_label}", driver="GPKG")
    print(f"💾 Saved density results to: {output_path.resolve()}")
