#!/usr/bin/env python3
"""
Generate rasterized road density heatmap (GeoTIFF + PNG) for a location.

✅ Outputs:
 - GeoTIFF raster ({location}_road_density_{size}.tif) for GIS analysis
 - PNG visualization ({location}_road_density_{size}.png) for reporting

✅ Features:
 - Loads {location}_road_density_<size>.gpkg
 - Rasterizes density grid to proper heatmap format
 - Supports multiple grid resolutions (2km, 5km, etc.)
"""

import argparse
import math
import os
import sys
from pathlib import Path

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.features import rasterize
from rasterio.transform import from_origin

CELL_SIZE_ENV_VAR = "GRID_CELL_SIZE_METERS"


def resolve_cell_size(explicit_value):
    """Resolve cell size from CLI args or environment variable."""
    raw_value = explicit_value or os.getenv(CELL_SIZE_ENV_VAR)
    if raw_value is None:
        raise ValueError(
            f"Cell size must be passed as an argument or provided via {CELL_SIZE_ENV_VAR}."
        )
    try:
        return int(raw_value)
    except ValueError as exc:
        raise ValueError(
            f"Invalid cell size '{raw_value}'. Provide an integer or set {CELL_SIZE_ENV_VAR}."
        ) from exc


def plot_heatmap(grid, column="length_m", cmap="hot", title=None, save_path=None):
    """Plot density heatmap from grid data."""
    fig, ax = plt.subplots(figsize=(12, 8))
    grid_plot = grid.plot(
        column=column,
        cmap=cmap,
        ax=ax,
        legend=True,
        linewidth=0,
        legend_kwds={
            "label": "Total Road Length (meters)",
            "shrink": 0.6,      # make the bar shorter
            "pad": 0.02,        # small gap between map and bar
            "orientation": "vertical"
        }
    )

    ax.set_title(title or f"Road Density ({column})", fontsize=14, pad=12)
    ax.axis("off")
    plt.tight_layout()  # ensures full use of figure space

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"💾 Saved heatmap to: {save_path}")
    else:
        plt.show()


def save_density_geotiff(grid, cell_size, out_path, column="length_m"):
    """Rasterize the density grid and save as GeoTIFF."""
    if column not in grid.columns:
        raise ValueError(f"Column '{column}' not found in grid.")
    if grid.crs is None:
        raise ValueError("Grid CRS is undefined; cannot export GeoTIFF.")

    minx, miny, maxx, maxy = grid.total_bounds
    width = int(math.ceil((maxx - minx) / cell_size))
    height = int(math.ceil((maxy - miny) / cell_size))
    if width <= 0 or height <= 0:
        raise ValueError("Invalid grid bounds; cannot derive raster dimensions.")

    transform = from_origin(minx, maxy, cell_size, cell_size)

    shapes = []
    for geom, value in zip(grid.geometry, grid[column]):
        if value is None or (isinstance(value, float) and math.isnan(value)):
            continue
        shapes.append((geom, float(value)))

    raster = rasterize(
        shapes=shapes,
        out_shape=(height, width),
        transform=transform,
        fill=np.nan,
        dtype="float32",
    )

    with rasterio.open(
        out_path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=1,
        dtype="float32",
        crs=grid.crs.to_wkt(),
        transform=transform,
        nodata=np.nan,
    ) as dst:
        dst.write(raster.astype("float32"), 1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate rasterized road density heatmap (GeoTIFF + PNG)."
    )
    parser.add_argument("location", help="Location name (e.g., alabama, egypt, thailand).")
    parser.add_argument(
        "cell_size",
        nargs="?",
        help=("Grid cell size in meters. "
              f"Defaults to env var {CELL_SIZE_ENV_VAR} if unset."),
    )
    args = parser.parse_args()

    location = args.location.lower()
    try:
        cell_size = resolve_cell_size(args.cell_size)
    except ValueError as err:
        print(f"❌ {err}")
        sys.exit(1)
    grid_size_label = f"{cell_size//1000}km"

    # Check if location is in germany subdirectory first
    germany_dir = Path("output") / "germany" / location
    if (germany_dir / f"{location}_road_density_{grid_size_label}.gpkg").exists() or (germany_dir / f"{location}_roads_projected.gpkg").exists():
        base_dir = germany_dir
    else:
        base_dir = Path("output") / location
    
    input_path = base_dir / f"{location}_road_density_{grid_size_label}.gpkg"

    if not input_path.exists():
        print(f"❌ File not found: {input_path}")
        sys.exit(1)

    print(f"📦 Loading density data from {input_path.name}")
    grid = gpd.read_file(input_path, layer=f"road_density_{grid_size_label}")

    if grid.empty:
        print("⚠️ No features found to plot.")
        sys.exit(1)

    title = f"{location.capitalize()} Road Density ({grid_size_label} grid)"
    output_png = base_dir / f"{location}_road_density_{grid_size_label}.png"
    output_tif = base_dir / f"{location}_road_density_{grid_size_label}.tif"

    plot_heatmap(grid, column="length_m", cmap="hot", title=title, save_path=output_png)
    try:
        save_density_geotiff(grid, cell_size, output_tif, column="length_m")
        print(f"🗺️ Saved GeoTIFF heatmap to: {output_tif}")
    except ValueError as err:
        print(f"⚠️ Could not save GeoTIFF: {err}")
