#!/usr/bin/env python3
"""
Creates a regular grid (square tiles) covering a state's projected road data.

✅ Features:
 - Reads each state's {state}_roads_projected.gpkg
 - Generates grid tiles (default 2 km)
 - Saves as {state}_grid_2km.gpkg
 - Logs extent and CRS info
"""

import argparse
import os
import sys
from pathlib import Path
import geopandas as gpd
import numpy as np
from shapely.geometry import box

CELL_SIZE_ENV_VAR = "GRID_CELL_SIZE_METERS"


def resolve_cell_size(explicit_value):
    """Resolve cell size from CLI args or environment variable."""
    raw_value = explicit_value or os.getenv(CELL_SIZE_ENV_VAR)
    if raw_value is None:
        raise ValueError(
            f"Cell size must be passed as an argument or provided via {CELL_SIZE_ENV_VAR}."
        )
    try:
        return int(raw_value)
    except ValueError as exc:
        raise ValueError(
            f"Invalid cell size '{raw_value}'. Provide an integer or set {CELL_SIZE_ENV_VAR}."
        ) from exc


def make_grid(gdf, cell_size=2000):
    """Generate a regular grid (square tiles) over the extent of the input GeoDataFrame."""
    xmin, ymin, xmax, ymax = gdf.total_bounds
    cols = np.arange(xmin, xmax + cell_size, cell_size)
    rows = np.arange(ymin, ymax + cell_size, cell_size)

    polygons = [box(x, y, x + cell_size, y + cell_size)
                for x in cols[:-1] for y in rows[:-1]]

    grid = gpd.GeoDataFrame({"geometry": polygons}, crs=gdf.crs)
    print(f"✅ Created {len(grid):,} grid cells.")
    return grid


def parse_args():
    parser = argparse.ArgumentParser(
        description="Create regular grid tiles for a given state's projected roads."
    )
    parser.add_argument("state", help="State name (e.g., alabama).")
    parser.add_argument(
        "cell_size",
        nargs="?",
        help=("Grid cell size in meters. "
              f"Defaults to env var {CELL_SIZE_ENV_VAR} if unset."),
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    state = args.state.lower()
    try:
        cell_size = resolve_cell_size(args.cell_size)
    except ValueError as err:
        print(f"❌ {err}")
        sys.exit(1)

    # Check if state is in germany subdirectory first
    germany_dir = Path("output") / "germany" / state
    if (germany_dir / f"{state}_roads_projected.gpkg").exists():
        base_dir = germany_dir
    else:
        base_dir = Path("output") / state
    
    input_path = base_dir / f"{state}_roads_projected.gpkg"
    output_path = base_dir / f"{state}_grid_{cell_size//1000}km.gpkg"

    # Validate input
    if not input_path.exists():
        print(f"❌ Projected roads file not found: {input_path}")
        sys.exit(1)

    print(f"📦 Loading roads from {input_path}")
    gdf = gpd.read_file(input_path, layer="roads_projected")

    if gdf.empty:
        print(f"⚠️ No features found in {input_path}")
        sys.exit(1)

    print(f"🗺️ CRS: {gdf.crs}")
    print(f"📏 Generating grid with {cell_size/1000:.1f} km cells...")

    grid = make_grid(gdf, cell_size=cell_size)

    grid.to_file(output_path, layer=f"grid_{cell_size//1000}km", driver="GPKG")
    print(f"💾 Saved grid to: {output_path.resolve()}")
