#!/usr/bin/env python3
"""
Generate rasterized density heatmaps (GeoTIFF + PNG) for roads with specific attribute values.

Instead of aggregating all roads per cell, this creates SEPARATE density heatmaps for each
distinct attribute value, showing where roads with that specific characteristic are concentrated.

Outputs:
 - {location}_{attribute}_{value}_density_{size}.tif - Raster density map
 - {location}_{attribute}_{value}_density_{size}.png - Visualization
 - {location}_{attribute}_{value}_density_{size}_hotspots_{area}km2.geojson - Contours

Examples:
 - egypt_highway_motorway_density_2km.tif (density of motorway roads)
 - egypt_highway_primary_density_2km.tif (density of primary roads)
 - egypt_bridge_yes_density_2km.tif (density of bridge roads)
 - egypt_lanes_2_density_2km.tif (density of 2-lane roads)
"""

import argparse
import math
import os
import sys
from pathlib import Path
import geopandas as gpd
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.features import rasterize
from rasterio.transform import from_origin
from matplotlib import colormaps
from skimage import measure
from shapely.geometry import Polygon

CELL_SIZE_ENV_VAR = "GRID_CELL_SIZE_METERS"


def resolve_cell_size(explicit_value):
    """Resolve cell size from CLI args or environment variable."""
    raw_value = explicit_value or os.getenv(CELL_SIZE_ENV_VAR)
    if raw_value is None:
        raise ValueError(
            f"Cell size must be passed as an argument or provided via {CELL_SIZE_ENV_VAR}."
        )
    try:
        return int(raw_value)
    except ValueError as exc:
        raise ValueError(
            f"Invalid cell size '{raw_value}'. Provide an integer or set {CELL_SIZE_ENV_VAR}."
        ) from exc


# --------------------------------
# Helper: Calculate density for roads with specific attribute value
# --------------------------------
def calculate_density_for_value(roads, grid, attr, value):
    """Calculate road length density per grid cell for roads with a specific attribute value."""
    print(f"📊 Computing density for {attr}={value}...")

    # Filter roads to only those with this specific value
    if pd.isna(value):
        filtered_roads = roads[roads[attr].isna()]
    else:
        filtered_roads = roads[roads[attr] == value]
    
    if filtered_roads.empty:
        print(f"   ⚠️ No roads found with {attr}={value}")
        return None, None, None
    
    print(f"   ✅ Found {len(filtered_roads):,} roads with {attr}={value}")
    
    # Create a copy of the grid with cell_id
    grid_local = grid.reset_index().rename(columns={"index": "cell_id"})

    # Intersect filtered roads with grid
    try:
        intersections = gpd.overlay(
            filtered_roads[["geometry"]],
            grid_local[["cell_id", "geometry"]],
            how="intersection",
        )
    except Exception as exc:
        print(f"   ⚠️ Failed to compute intersections: {exc}")
        return None, None, None

    # Calculate length of road segments within each cell
    intersections["length_m"] = intersections.geometry.length

    # Sum lengths per cell
    grouped = intersections.groupby("cell_id")["length_m"].sum()

    # Create result grid with density values
    grid_local["length_m"] = grid_local["cell_id"].map(grouped).fillna(0)
    grid_result = grid_local.drop(columns=["cell_id"]).set_index(grid.index)
    
    # Filter grid to only cells that have this road type (crop to extent)
    grid_with_roads = grid_result[grid_result["length_m"] > 0].copy()
    
    if grid_with_roads.empty:
        print(f"   ⚠️ No grid cells contain {attr}={value} roads")
        return None, None, None
    
    print(f"   📍 Cropped to {len(grid_with_roads):,} cells (out of {len(grid_result):,} total)")

    return grid_with_roads, "length_m", grid_result


# --------------------------------
# Helper: Save attribute as GeoTIFF
# --------------------------------
def save_attribute_geotiff(grid, attr_col, cell_size, out_path):
    """Rasterize the attribute grid and save as GeoTIFF."""
    if attr_col not in grid.columns:
        raise ValueError(f"Column '{attr_col}' not found in grid.")
    if grid.crs is None:
        raise ValueError("Grid CRS is undefined; cannot export GeoTIFF.")

    minx, miny, maxx, maxy = grid.total_bounds
    width = int(math.ceil((maxx - minx) / cell_size))
    height = int(math.ceil((maxy - miny) / cell_size))
    if width <= 0 or height <= 0:
        raise ValueError("Invalid grid bounds; cannot derive raster dimensions.")

    transform = from_origin(minx, maxy, cell_size, cell_size)

    shapes = []
    for geom, value in zip(grid.geometry, grid[attr_col]):
        if value is None or (isinstance(value, float) and math.isnan(value)):
            continue
        shapes.append((geom, float(value)))

    if not shapes:
        print(f"⚠️ No valid data to rasterize for {attr_col}")
        return

    raster = rasterize(
        shapes=shapes,
        out_shape=(height, width),
        transform=transform,
        fill=np.nan,
        dtype="float32",
    )

    with rasterio.open(
        out_path,
        "w",
        driver="GTiff",
        height=height,
        width=width,
        count=1,
        dtype="float32",
        crs=grid.crs.to_wkt(),
        transform=transform,
        nodata=np.nan,
    ) as dst:
        dst.write(raster.astype("float32"), 1)


# --------------------------------
# Helper: Extract contours from GeoTIFF
# --------------------------------
def extract_contours_from_geotiff(tif_path, target_area_km2):
    """Extract smooth contours from a GeoTIFF at a threshold that captures target area."""
    with rasterio.open(tif_path) as src:
        data = src.read(1).astype("float32")
        nodata = src.nodata
        if nodata is not None:
            data = np.where(data == nodata, np.nan, data)
        else:
            data = np.where(np.isfinite(data), data, np.nan)
        transform = src.transform
        crs = src.crs
    
    # Calculate threshold for target area
    pixel_width = abs(transform[0])
    pixel_height = abs(transform[4])
    pixel_area_m2 = pixel_width * pixel_height
    pixel_area_km2 = pixel_area_m2 / 1_000_000
    target_pixels = target_area_km2 / pixel_area_km2
    
    valid_data = data[np.isfinite(data)]
    if len(valid_data) == 0:
        return None
    
    sorted_values = np.sort(valid_data)[::-1]
    threshold_idx = min(int(target_pixels), len(sorted_values) - 1)
    threshold = sorted_values[threshold_idx]
    
    # Create binary mask and extract contours
    binary_mask = np.where(np.isfinite(data) & (data >= threshold), 1, 0).astype(np.uint8)
    contours = measure.find_contours(binary_mask, level=0.5)
    
    if not contours:
        return None
    
    # Convert contours to georeferenced polygons
    polygons = []
    for contour in contours:
        coords = []
        for row, col in contour:
            x = transform[2] + col * transform[0]
            y = transform[5] + row * transform[4]
            coords.append((x, y))
        
        if len(coords) >= 3:
            try:
                poly = Polygon(coords)
                if poly.is_valid and not poly.is_empty:
                    polygons.append(poly)
            except Exception:
                continue
    
    if not polygons:
        return None
    
    # Create GeoDataFrame
    clusters = []
    for idx, geom in enumerate(polygons, start=1):
        area_km2 = geom.area / 1_000_000
        clusters.append({
            "cluster_id": idx,
            "area_km2": area_km2,
            "geometry": geom,
        })
    
    gdf = gpd.GeoDataFrame(clusters, geometry="geometry", crs=crs)
    
    # Reproject to WGS84 for GeoJSON
    if gdf.crs and gdf.crs.to_epsg() != 4326:
        gdf = gdf.to_crs(epsg=4326)
    
    return gdf


# --------------------------------
# Helper: Plot and save density heatmap for specific attribute value
# --------------------------------
def plot_and_save_density_heatmap(
    grid, attr, value, location, save_dir, size_label, cell_size
):
    """Plot and save heatmap for roads with specific attribute value, cropped to extent."""
    # Sanitize value for filename
    value_str = str(value).replace("/", "_").replace(" ", "_").replace(":", "_")
    base_name = f"{location}_{attr}_{value_str}_density_{size_label}"
    
    fig, ax = plt.subplots(figsize=(12, 10))
    fig.patch.set_facecolor("white")
    ax.set_facecolor("white")
    
    # Create outline from actual data extent (cropped)
    extent_outline = grid.dissolve().boundary
    
    title = f"{location.capitalize()} — {attr}={value} Road Density ({size_label})"
    
    # Plot density heatmap
    vmin, vmax = np.nanpercentile(grid["length_m"].dropna(), [5, 95]) if not grid["length_m"].dropna().empty else (0, 1)
    cmap_obj = colormaps.get_cmap("hot").copy()
    cmap_obj.set_bad(color="white")
    
    grid.plot(
        column="length_m",
        cmap=cmap_obj,
        linewidth=0,
        ax=ax,
        legend=True,
        vmin=vmin,
        vmax=vmax,
        legend_kwds={"label": "Road Length (m)", "shrink": 0.6, "pad": 0.02},
        missing_kwds={"color": "white"}
    )
    
    # Plot outline of the data extent
    extent_outline.plot(ax=ax, color='black', linewidth=1.0)
    
    # Crop axes to data bounds (with small margin)
    bounds = grid.total_bounds
    margin = 0.05 * max(bounds[2] - bounds[0], bounds[3] - bounds[1])
    ax.set_xlim(bounds[0] - margin, bounds[2] + margin)
    ax.set_ylim(bounds[1] - margin, bounds[3] + margin)
    
    ax.set_title(title, fontsize=14, pad=10)
    ax.axis("off")
    plt.tight_layout()
    
    # Save PNG
    out_png = save_dir / f"{base_name}.png"
    plt.savefig(out_png, dpi=300, bbox_inches="tight", facecolor=fig.get_facecolor())
    plt.close(fig)
    print(f"   💾 Saved PNG: {out_png.name}")
    
    # Save GeoTIFF
    out_tif = save_dir / f"{base_name}.tif"
    try:
        save_attribute_geotiff(grid, "length_m", cell_size, out_tif)
        print(f"   🗺️ Saved GeoTIFF: {out_tif.name}")
        
        # Extract contours and save as GeoJSON
        target_area = float(os.getenv("HOTSPOT_TARGET_AREA_KM2", "500"))
        contours_gdf = extract_contours_from_geotiff(out_tif, target_area)
        if contours_gdf is not None and not contours_gdf.empty:
            out_geojson = save_dir / f"{base_name}_hotspots_{int(target_area)}km2.geojson"
            contours_gdf.to_file(out_geojson, driver="GeoJSON")
            print(f"   📍 Saved GeoJSON: {out_geojson.name} ({len(contours_gdf)} regions)")
    except Exception as e:
        print(f"   ⚠️ Could not save GeoTIFF/GeoJSON: {e}")


# --------------------------------
# Main execution
# --------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate grid-based attribute heatmaps (GeoTIFF + PNG) for a location."
    )
    parser.add_argument("location", help="Location name (e.g., alabama, egypt, thailand).")
    parser.add_argument(
        "cell_size",
        nargs="?",
        help=("Grid cell size in meters. "
              f"Defaults to env var {CELL_SIZE_ENV_VAR} if unset."),
    )
    args = parser.parse_args()

    location = args.location.lower()
    try:
        cell_size = resolve_cell_size(args.cell_size)
    except ValueError as err:
        print(f"❌ {err}")
        sys.exit(1)
    size_label = f"{cell_size//1000}km"

    # Check if location is in germany subdirectory first
    germany_dir = Path("output") / "germany" / location
    if (germany_dir / f"{location}_roads_projected.gpkg").exists():
        base_dir = germany_dir
    else:
        base_dir = Path("output") / location
    
    gdf_path = base_dir / f"{location}_roads_projected.gpkg"
    grid_path = base_dir / f"{location}_grid_{size_label}.gpkg"
    save_dir = gdf_path.parent

    if not gdf_path.exists():
        print(f"❌ Roads file not found: {gdf_path}")
        sys.exit(1)

    if not grid_path.exists():
        print(f"❌ Grid file not found: {grid_path}")
        sys.exit(1)

    print(f"📦 Loading road and grid data for {location.capitalize()}...")
    roads = gpd.read_file(gdf_path, layer="roads_projected")
    grid = gpd.read_file(grid_path, layer=f"grid_{size_label}")

    # Define surface categories based on OSM wiki
    # https://wiki.openstreetmap.org/wiki/Key:surface
    def categorize_surface(surface_val):
        """Categorize surface into paved/unpaved based on string matching."""
        if pd.isna(surface_val):
            return None
        
        surface_lower = str(surface_val).lower().strip()
        
        # Paved indicators (hard, sealed surfaces)
        paved_keywords = [
            'asphalt', 'concrete', 'cement', 'paved', 'paving_stones', 'sett',
            'cobblestone', 'brick', 'metal', 'wood', 'chipseal', 'tartan',
            'boardwalk', 'tile', 'plastic', 'rubber', 'acrylic', 'composite'
        ]
        
        # Unpaved indicators (natural, unsealed surfaces)
        unpaved_keywords = [
            'unpaved', 'dirt', 'gravel', 'sand', 'ground', 'earth', 'soil',
            'compacted', 'rock', 'stone', 'grass', 'mud', 'clay', 'scree',
            'pebble', 'shale', 'woodchip', 'mulch', 'shells', 'cinder',
            'decomposed_granite', 'crushed_stone', 'fine_gravel', 'trail'
        ]
        
        # Check for paved keywords first
        for keyword in paved_keywords:
            if keyword in surface_lower:
                return 'paved'
        
        # Then check for unpaved keywords
        for keyword in unpaved_keywords:
            if keyword in surface_lower:
                return 'unpaved'
        
        # Default to unknown for unrecognized values
        return 'unknown'
    
    # Add derived surface_category column
    if 'surface' in roads.columns:
        roads['surface_category'] = roads['surface'].apply(categorize_surface)
        print(f"📊 Surface categories: {roads['surface_category'].value_counts().to_dict()}")
    
    # Add derived highway_category column for major network types
    # Based on OSM highway classification: https://wiki.openstreetmap.org/wiki/Key:highway
    def categorize_highway(highway_val):
        """Categorize highway into major network types."""
        if pd.isna(highway_val):
            return None
        
        highway_lower = str(highway_val).lower().strip()
        
        # Drive: motorized vehicle roads
        drive_types = {
            'motorway', 'motorway_link', 'trunk', 'trunk_link',
            'primary', 'primary_link', 'secondary', 'secondary_link',
            'tertiary', 'tertiary_link', 'unclassified', 'residential',
            'service', 'road'
        }
        
        # Bike: cycling infrastructure (may allow pedestrians too)
        bike_types = {
            'cycleway', 'bicycle_road'
        }
        
        # Walk: pedestrian paths (paved, urban)
        walk_types = {
            'pedestrian', 'footway', 'steps', 'corridor', 'crossing',
            'sidewalk', 'bridleway'
        }
        
        # Hike: trails and unpaved paths (rural, recreational)
        hike_types = {
            'path', 'track', 'trail'
        }
        
        if highway_lower in drive_types:
            return 'drive'
        elif highway_lower in bike_types:
            return 'bike'
        elif highway_lower in walk_types:
            return 'walk'
        elif highway_lower in hike_types:
            return 'hike'
        else:
            return 'unknown'
    
    # Add derived highway_category column
    if 'highway' in roads.columns:
        roads['highway_category'] = roads['highway'].apply(categorize_highway)
        print(f"📊 Highway categories: {roads['highway_category'].value_counts().to_dict()}")

    # ONLY generate heatmaps for major categorical groupings
    # This keeps output manageable and focuses on useful high-level patterns
    attributes_to_analyze = [
        "highway_category",  # drive/bike/walk/hike/unknown
        "surface_category",  # paved/unpaved/unknown
    ]

    print(f"🎯 Generating density heatmaps for major categories:")
    print(f"   - highway_category: network type (drive/bike/walk/hike)")
    print(f"   - surface_category: surface type (paved/unpaved)")
    
    # Process each categorical attribute
    for attr in attributes_to_analyze:
        if attr not in roads.columns:
            print(f"\n⚠️ Skipping {attr} - not found in data")
            continue
            
        print(f"\n{'='*60}")
        print(f"🔍 Processing attribute: {attr}")
        print(f"{'='*60}")
        
        # Get unique values for this category (excluding NaN)
        unique_values = roads[attr].dropna().unique()
        value_counts = roads[attr].value_counts()
        
        print(f"   📋 Categories found ({len(unique_values)}):")
        for val in sorted(unique_values):
            count = value_counts.get(val, 0)
            print(f"      - {val}: {count:,} roads")
        
        # Generate density heatmap for each category value
        for value in unique_values:
            print(f"\n   🎯 Processing: {attr}={value}")
            
            # Calculate density for roads with this specific value
            # Returns cropped grid (only cells with roads), column name, and full grid
            grid_cropped, col, grid_full = calculate_density_for_value(roads, grid.copy(), attr, value)
            
            if grid_cropped is None:
                continue
            
            # Plot and save outputs using cropped extent
            plot_and_save_density_heatmap(
                grid_cropped, attr, value, location, save_dir, size_label, cell_size
            )

    print("\n" + "="*60)
    print("✅ All attribute-specific density heatmaps generated successfully!")
    print("="*60)
