#!/usr/bin/env python3
"""
Extract smooth hotspot contours from heatmap GeoTIFFs using continuous raster analysis.

Instead of merging rectilinear grid cells (which creates blocky boundaries),
this script treats the raster as continuous and uses skimage.measure.find_contours
to extract smooth contour polygons at a threshold that captures the target area.

Outputs:
 - PNG overlay showing contours on the heatmap
 - GeoJSON with smooth polygon boundaries (not blocky grid cells)
"""

import argparse
import math
import os
from pathlib import Path

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.features import shapes
from shapely.geometry import shape, MultiPolygon, Polygon
from shapely.ops import unary_union
from skimage import measure

TARGET_AREA_ENV_VAR = "HOTSPOT_TARGET_AREA_KM2"
TARGET_PERCENTILE_ENV_VAR = "HOTSPOT_TARGET_PERCENTILE"


def resolve_target_criteria(area_arg, percentile_arg):
    """Return the requested hotspot criteria (area or percentile)."""
    # Check for explicit arguments first
    if area_arg is not None and percentile_arg is not None:
        raise ValueError("Cannot specify both --target-area and --target-percentile.")
    
    if area_arg is not None:
        try:
            area = float(area_arg)
            if area <= 0:
                raise ValueError("Target area must be greater than zero.")
            return {"type": "area", "value": area}
        except ValueError as exc:
            raise ValueError(f"Invalid target area '{area_arg}'.") from exc
    
    if percentile_arg is not None:
        try:
            percentile = float(percentile_arg)
            if not 0 < percentile < 100:
                raise ValueError("Percentile must be between 0 and 100.")
            return {"type": "percentile", "value": percentile}
        except ValueError as exc:
            raise ValueError(f"Invalid percentile '{percentile_arg}'.") from exc
    
    # Check environment variables
    area_env = os.getenv(TARGET_AREA_ENV_VAR)
    percentile_env = os.getenv(TARGET_PERCENTILE_ENV_VAR)
    
    if area_env is not None and percentile_env is not None:
        raise ValueError(f"Cannot set both {TARGET_AREA_ENV_VAR} and {TARGET_PERCENTILE_ENV_VAR}.")
    
    if area_env is not None:
        try:
            area = float(area_env)
            if area <= 0:
                raise ValueError("Target area must be greater than zero.")
            return {"type": "area", "value": area}
        except ValueError as exc:
            raise ValueError(f"Invalid area in {TARGET_AREA_ENV_VAR}: '{area_env}'.") from exc
    
    if percentile_env is not None:
        try:
            percentile = float(percentile_env)
            if not 0 < percentile < 100:
                raise ValueError("Percentile must be between 0 and 100.")
            return {"type": "percentile", "value": percentile}
        except ValueError as exc:
            raise ValueError(f"Invalid percentile in {TARGET_PERCENTILE_ENV_VAR}: '{percentile_env}'.") from exc
    
    # Default fallback
    raise ValueError(
        f"Target criteria must be provided via --target-area, --target-percentile, "
        f"{TARGET_AREA_ENV_VAR}, or {TARGET_PERCENTILE_ENV_VAR}."
    )


def find_heatmap_rasters(location_dir, extra_patterns=None):
    """Find GeoTIFF heatmaps in the given location directory."""
    patterns = ["*road_density_*km.tif"]
    if extra_patterns:
        patterns.extend(extra_patterns)
    files = []
    for pattern in patterns:
        files.extend(location_dir.glob(pattern))
    unique_files = []
    seen = set()
    for file_path in sorted(files):
        if file_path not in seen and file_path.is_file():
            unique_files.append(file_path)
            seen.add(file_path)
    return unique_files


def parse_size_label(raster_path: Path):
    """Extract the grid size label (e.g., '2km') and corresponding cell size in meters."""
    parts = raster_path.stem.split("_")
    if not parts:
        raise ValueError(f"Cannot parse grid size label from {raster_path.name}")
    size_label = parts[-1]
    if not size_label.endswith("km"):
        raise ValueError(f"Grid label '{size_label}' does not end with 'km'.")
    try:
        cell_km = float(size_label[:-2])
    except ValueError as exc:
        raise ValueError(f"Invalid grid label '{size_label}'.") from exc
    return size_label, cell_km * 1000.0


def calculate_threshold_for_target_area(data, transform, target_area_km2):
    """Calculate density threshold that captures approximately the target area."""
    # Get pixel size in meters
    pixel_width = abs(transform[0])
    pixel_height = abs(transform[4])
    pixel_area_m2 = pixel_width * pixel_height
    pixel_area_km2 = pixel_area_m2 / 1_000_000
    
    # Target number of pixels
    target_pixels = target_area_km2 / pixel_area_km2
    
    # Get valid (non-NaN) values
    valid_data = data[np.isfinite(data)]
    if len(valid_data) == 0:
        raise ValueError("No valid data in raster")
    
    # Sort values descending
    sorted_values = np.sort(valid_data)[::-1]
    
    # Find threshold at target pixel count
    threshold_idx = min(int(target_pixels), len(sorted_values) - 1)
    threshold = sorted_values[threshold_idx]
    
    print(f"   📊 Threshold: {threshold:.1f} (captures ~{target_area_km2:.1f} km²)")
    return threshold


def calculate_threshold_for_percentile(data, percentile):
    """Calculate density threshold for the top N% of values, excluding zeros and nulls."""
    # Get valid (non-NaN, non-zero) values
    valid_data = data[np.isfinite(data) & (data > 0)]
    if len(valid_data) == 0:
        raise ValueError("No valid non-zero data in raster")
    
    # For "top N%", we want the (100-N)th percentile as the threshold
    # E.g., for top 10%, we want the 90th percentile threshold
    # But we need to be careful about the direction
    threshold = np.percentile(valid_data, percentile)
    
    # Calculate actual area captured at this threshold  
    pixel_count = np.sum((np.isfinite(data) & (data > 0)) & (data >= threshold))
    total_nonzero = len(valid_data)
    actual_percent = (pixel_count / total_nonzero) * 100
    
    print(f"   📊 Threshold: {threshold:.1f} (captures {actual_percent:.1f}% = {pixel_count:,}/{total_nonzero:,} non-zero pixels)")
    return threshold


def calculate_threshold(data, transform, criteria):
    """Calculate threshold based on area or percentile criteria."""
    if criteria["type"] == "area":
        return calculate_threshold_for_target_area(data, transform, criteria["value"])
    elif criteria["type"] == "percentile":
        return calculate_threshold_for_percentile(data, criteria["value"])
    else:
        raise ValueError(f"Unknown criteria type: {criteria['type']}")


def extract_contours_from_raster(data, transform, threshold, crs):
    """Extract smooth contours from raster using skimage.measure.find_contours."""
    # Create binary mask where values exceed threshold (exclude zeros/nulls consistently)
    binary_mask = np.where(np.isfinite(data) & (data > 0) & (data >= threshold), 1, 0).astype(np.uint8)
    
    # Find contours at 0.5 level (between 0 and 1)
    contours = measure.find_contours(binary_mask, level=0.5)
    
    if not contours:
        return gpd.GeoDataFrame(columns=["cluster_id", "area_km2", "perimeter_km"], geometry=[], crs=crs)
    
    # Convert contours to georeferenced polygons
    polygons = []
    for contour in contours:
        # Contour coordinates are in (row, col) format
        # Convert to (x, y) in raster CRS
        coords = []
        for row, col in contour:
            x = transform[2] + col * transform[0]
            y = transform[5] + row * transform[4]
            coords.append((x, y))
        
        # Create polygon (need at least 3 points)
        if len(coords) >= 3:
            try:
                # Close the polygon if not already closed
                if coords[0] != coords[-1]:
                    coords.append(coords[0])
                    
                poly = Polygon(coords)
                if poly.is_valid and not poly.is_empty and poly.area > 0:
                    polygons.append(poly)
            except Exception:
                continue
    
    if not polygons:
        return gpd.GeoDataFrame(columns=["cluster_id", "area_km2", "perimeter_km"], geometry=[], crs=crs)
    
    # Dissolve to ensure contiguous regions merge (helps nesting across percentiles)
    dissolved = unary_union(polygons)
    dissolved_parts = []
    if isinstance(dissolved, Polygon):
        dissolved_parts = [dissolved]
    elif isinstance(dissolved, MultiPolygon):
        dissolved_parts = list(dissolved.geoms)
    elif dissolved is not None:
        # Fallback to original polygons
        dissolved_parts = final_polygons
    else:
        dissolved_parts = final_polygons

    # Create GeoDataFrame with additional metrics
    clusters = []
    for idx, geom in enumerate(dissolved_parts, start=1):
        area_km2 = geom.area / 1_000_000
        perimeter_km = geom.length / 1_000
        clusters.append({
            "cluster_id": idx,
            "area_km2": area_km2,
            "perimeter_km": perimeter_km,
            "geometry": geom,
        })
    
    return gpd.GeoDataFrame(clusters, geometry="geometry", crs=crs)


def load_raster(path: Path):
    with rasterio.open(path) as src:
        data = src.read(1).astype("float32")
        nodata = src.nodata
        if nodata is not None:
            data = np.where(data == nodata, np.nan, data)
        else:
            data = np.where(np.isfinite(data), data, np.nan)
        bounds = src.bounds
        crs = src.crs
        transform = src.transform
    return data, bounds, crs, transform


def plot_hotspots(raster_path, data, bounds, raster_crs, hotspots_gdf, criteria):
    """Plot smooth contour hotspots overlay on the heatmap."""
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot raster
    extent = (bounds.left, bounds.right, bounds.bottom, bounds.top)
    ax.imshow(data, cmap="hot", extent=extent, origin="upper", aspect="auto")
    
    # Plot smooth contours
    if not hotspots_gdf.empty:
        if hotspots_gdf.crs != raster_crs:
            hotspots_plot = hotspots_gdf.to_crs(raster_crs)
        else:
            hotspots_plot = hotspots_gdf
        # Plot filled polygons with transparency
        hotspots_plot.plot(ax=ax, facecolor="cyan", edgecolor="white", 
                          alpha=0.3, linewidth=2)
    
    # Create title based on criteria type
    if criteria["type"] == "area":
        title_suffix = f"Top {criteria['value']:g} km²"
    else:  # percentile
        title_suffix = f"≥{criteria['value']:.0f}th percentile"
    
    ax.set_title(
        f"{raster_path.stem} — {title_suffix} (smooth contours)",
        fontsize=12,
        pad=10,
    )
    ax.set_xlim(bounds.left, bounds.right)
    ax.set_ylim(bounds.bottom, bounds.top)
    ax.axis("off")
    plt.tight_layout()
    out_path = raster_path.with_name(f"{raster_path.stem}_hotspots.png")
    fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return out_path


def save_outputs(hotspots_gdf, raster_path, criteria):
    """Save hotspots as both GeoJSON and Shapefile formats."""
    if criteria["type"] == "area":
        suffix = f"hotspots_{int(criteria['value'])}km2"
    else:  # percentile
        suffix = f"hotspots_top{100-criteria['value']:.0f}pct"
    
    base_name = f"{raster_path.stem}_{suffix}"
    
    # Save GeoJSON (WGS84)
    geojson_path = raster_path.with_name(f"{base_name}.geojson")
    hotspots_wgs84 = hotspots_gdf.to_crs(epsg=4326) if hotspots_gdf.crs and hotspots_gdf.crs.to_epsg() != 4326 else hotspots_gdf
    hotspots_wgs84.to_file(geojson_path, driver="GeoJSON")
    
    # Save Shapefile (keep original projection for accuracy)
    shp_path = raster_path.with_name(f"{base_name}.shp")
    hotspots_gdf.to_file(shp_path, driver="ESRI Shapefile")
    
    return geojson_path, shp_path


def parse_args():
    parser = argparse.ArgumentParser(
        description="Extract smooth hotspot contours from density GeoTIFFs using continuous raster analysis."
    )
    parser.add_argument("location", help="Location name (e.g., alabama, egypt, thailand).")
    parser.add_argument(
        "--target-area",
        dest="target_area",
        help=(
            "Target hotspot area in square kilometers. "
            f"Defaults to env var {TARGET_AREA_ENV_VAR}."
        ),
    )
    # Alias to accept env-style flag usage (--HOTSPOT_TARGET_AREA_KM2=500)
    parser.add_argument(
        "--HOTSPOT_TARGET_AREA_KM2",
        dest="target_area_alias",
        help=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--target-percentile",
        dest="target_percentile",
        help=(
            "Target percentile threshold (e.g., 90 for top 10%%). "
            f"Defaults to env var {TARGET_PERCENTILE_ENV_VAR}."
        ),
    )
    # Alias to accept env-style flag usage (--HOTSPOT_TARGET_PERCENTILE=90)
    parser.add_argument(
        "--HOTSPOT_TARGET_PERCENTILE",
        dest="target_percentile_alias",
        help=argparse.SUPPRESS,
    )
    parser.add_argument(
        "--extra-pattern",
        action="append",
        dest="extra_patterns",
        help="Additional glob pattern(s) for locating GeoTIFFs.",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    location = args.location.lower()
    # Allow env-style alias flags if provided
    area_arg = args.target_area if args.target_area is not None else args.target_area_alias
    pct_arg = args.target_percentile if args.target_percentile is not None else args.target_percentile_alias

    try:
        criteria = resolve_target_criteria(area_arg, pct_arg)
    except ValueError as err:
        raise SystemExit(f"❌ {err}") from err

    # Check if location is in germany subdirectory first
    germany_dir = Path("output") / "germany" / location
    if germany_dir.exists():
        location_dir = germany_dir
    else:
        location_dir = Path("output") / location
    
    if not location_dir.exists():
        raise SystemExit(f"❌ Location directory not found: {location_dir}")

    heatmap_rasters = find_heatmap_rasters(location_dir, args.extra_patterns)
    if not heatmap_rasters:
        raise SystemExit(f"⚠️ No heatmap GeoTIFFs found in {location_dir}")

    print(f"🎯 Processing {len(heatmap_rasters)} heatmaps for {location.capitalize()}...")
    print(f"🔬 Using continuous raster analysis with smooth contours")
    
    for raster_path in heatmap_rasters:
        print(f"\n🗺️ {raster_path.name}")
        
        # Load raster data
        try:
            data, bounds, raster_crs, transform = load_raster(raster_path)
        except Exception as err:
            print(f"   ⚠️ Could not load raster: {err}")
            continue
        
        # Calculate threshold
        try:
            threshold = calculate_threshold(data, transform, criteria)
        except ValueError as err:
            print(f"   ⚠️ {err}")
            continue
        
        # Extract smooth contours
        try:
            hotspots = extract_contours_from_raster(data, transform, threshold, raster_crs)
        except Exception as err:
            print(f"   ⚠️ Could not extract contours: {err}")
            continue
        
        if hotspots.empty:
            print("   ⚠️ No hotspots identified.")
            continue
        
        print(f"   ✅ Found {len(hotspots)} hotspot region(s)")
        total_area = hotspots['area_km2'].sum()
        total_perimeter = hotspots['perimeter_km'].sum() if 'perimeter_km' in hotspots.columns else 0
        print(f"   📍 Total area: {total_area:.1f} km²")
        if total_perimeter > 0:
            print(f"   📏 Total perimeter: {total_perimeter:.1f} km")
        
        # Save outputs
        png_path = plot_hotspots(raster_path, data, bounds, raster_crs, hotspots, criteria)
        geojson_path, shp_path = save_outputs(hotspots, raster_path, criteria)
        print(f"   💾 Saved overlay: {png_path.name}")
        print(f"   💾 Saved GeoJSON: {geojson_path.name}")
        print(f"   💾 Saved Shapefile: {shp_path.name}")

    print("\n✅ Smooth contour extraction complete.")


if __name__ == "__main__":
    main()
