#!/usr/bin/env python3
"""
Extract polygons for >90th percentile road density by type.

Generates shapefiles (GeoPackage format) that outline areas where road density
exceeds the 90th percentile threshold for specific road categories:

Categories:
 - layer_nonzero: Roads on bridges/tunnels/different layers (layer != 0)
 - footway: Pedestrian footway roads
 - links: On-ramps and off-ramps (motorway_link, trunk_link, primary_link, etc.)
 - rare_types: The rarest highway types in the dataset
 - general: All roads combined (>90th percentile of total density)

Outputs:
 - {location}_90th_percentile_{category}_{size}.gpkg - Vector polygons (GeoPackage)
 - {location}_90th_percentile_{category}_{size}.shp - Vector polygons (Shapefile)
 - {location}_90th_percentile_{category}_{size}.geojson - Vector polygons (GeoJSON)
"""

import argparse
import os
import sys
from pathlib import Path
import geopandas as gpd
import pandas as pd
import numpy as np
from tqdm import tqdm

CELL_SIZE_ENV_VAR = "GRID_CELL_SIZE_METERS"


def resolve_cell_size(explicit_value):
    """Resolve cell size from CLI args or environment variable."""
    raw_value = explicit_value or os.getenv(CELL_SIZE_ENV_VAR)
    if raw_value is None:
        raise ValueError(
            f"Cell size must be passed as an argument or provided via {CELL_SIZE_ENV_VAR}."
        )
    try:
        return int(raw_value)
    except ValueError as exc:
        raise ValueError(
            f"Invalid cell size '{raw_value}'. Provide an integer or set {CELL_SIZE_ENV_VAR}."
        ) from exc


def calculate_density_for_category(roads, grid, category_name, filter_func):
    """Calculate road length density per grid cell for a specific category."""
    print(f"📊 Computing density for category: {category_name}...")

    # Filter roads based on category
    filtered_roads = filter_func(roads)
    
    if filtered_roads.empty:
        print(f"   ⚠️ No roads found for category {category_name}")
        return None
    
    print(f"   ✅ Found {len(filtered_roads):,} roads in category {category_name}")
    
    # Create a copy of the grid with cell_id
    grid_local = grid.reset_index().rename(columns={"index": "cell_id"})

    # Intersect filtered roads with grid
    try:
        intersections = gpd.overlay(
            filtered_roads[["geometry"]],
            grid_local[["cell_id", "geometry"]],
            how="intersection",
        )
    except Exception as exc:
        print(f"   ⚠️ Failed to compute intersections: {exc}")
        return None

    # Calculate length of road segments within each cell
    intersections["length_m"] = intersections.geometry.length

    # Sum lengths per cell
    grouped = intersections.groupby("cell_id")["length_m"].sum()

    # Create result grid with density values
    grid_local["length_m"] = grid_local["cell_id"].map(grouped).fillna(0)
    grid_result = grid_local.drop(columns=["cell_id"])
    
    return grid_result


def extract_90th_percentile_polygons(density_grid, category_name, percentile=90):
    """Extract polygons where density > percentile threshold."""
    print(f"🔍 Extracting >{percentile}th percentile for {category_name}...")
    
    # Remove cells with zero density
    non_zero = density_grid[density_grid["length_m"] > 0].copy()
    
    if non_zero.empty:
        print(f"   ⚠️ No non-zero density cells found")
        return None
    
    # Calculate percentile threshold
    threshold = np.percentile(non_zero["length_m"], percentile)
    print(f"   📈 {percentile}th percentile threshold: {threshold:,.2f} m")
    
    # Filter to cells above threshold
    hotspots = density_grid[density_grid["length_m"] > threshold].copy()
    
    if hotspots.empty:
        print(f"   ⚠️ No cells exceed the {percentile}th percentile threshold")
        return None
    
    print(f"   ✅ Found {len(hotspots):,} cells above threshold")
    
    # Dissolve touching cells into contiguous polygons
    print(f"   🔗 Dissolving into contiguous polygons...")
    try:
        dissolved = hotspots.dissolve()
        
        # If multipolygon, explode into individual polygons
        if not dissolved.empty:
            dissolved = dissolved.explode(index_parts=False)
            dissolved = dissolved.reset_index(drop=True)
            print(f"   ✅ Created {len(dissolved)} polygon(s)")
            
            # Add metadata
            dissolved["category"] = category_name
            dissolved["percentile"] = percentile
            dissolved["threshold_m"] = threshold
            dissolved["area_km2"] = dissolved.geometry.area / 1_000_000
            
            return dissolved
    except Exception as e:
        print(f"   ⚠️ Error during dissolve: {e}")
        return None
    
    return None


def get_rare_types(roads, top_n=5):
    """Get the N rarest highway types (excluding None/NaN)."""
    highway_counts = roads['highway'].value_counts()
    rare_types = highway_counts.tail(top_n).index.tolist()
    print(f"🔬 Rarest {top_n} highway types: {rare_types}")
    print(f"   Counts: {highway_counts.tail(top_n).to_dict()}")
    return rare_types


def main():
    parser = argparse.ArgumentParser(
        description="Extract >90th percentile density polygons by road type category"
    )
    parser.add_argument("location", help="Location name (e.g., 'belgium', 'california')")
    parser.add_argument(
        "--cell-size",
        type=int,
        help=f"Grid cell size in meters (default: env ${CELL_SIZE_ENV_VAR})",
    )
    parser.add_argument(
        "--percentile",
        type=int,
        default=90,
        help="Percentile threshold (default: 90)",
    )
    parser.add_argument(
        "--rare-count",
        type=int,
        default=5,
        help="Number of rarest types to analyze (default: 5)",
    )
    
    args = parser.parse_args()
    location = args.location.lower()
    percentile = args.percentile
    rare_count = args.rare_count
    
    # Resolve cell size
    try:
        cell_size_m = resolve_cell_size(args.cell_size)
    except ValueError as e:
        print(f"❌ {e}")
        sys.exit(1)
    
    cell_size_km = cell_size_m / 1000
    
    # Define paths
    location_dir = Path("output") / location
    roads_path = location_dir / f"{location}_roads_projected.gpkg"
    grid_path = location_dir / f"{location}_grid_{cell_size_km:.0f}km.gpkg"
    
    # Check if files exist
    if not roads_path.exists():
        print(f"❌ Roads file not found: {roads_path}")
        print(f"💡 Run: python -m scripts.extract_roads {location}")
        sys.exit(1)
    
    if not grid_path.exists():
        print(f"❌ Grid file not found: {grid_path}")
        print(f"💡 Run: GRID_CELL_SIZE_METERS={cell_size_m} python -m scripts.create_grid {location}")
        sys.exit(1)
    
    # Load data
    print(f"📂 Loading roads from: {roads_path}")
    roads = gpd.read_file(roads_path)
    print(f"   ✅ Loaded {len(roads):,} road segments")
    
    print(f"📂 Loading grid from: {grid_path}")
    grid = gpd.read_file(grid_path)
    print(f"   ✅ Loaded {len(grid):,} grid cells")
    
    # Ensure both have the same CRS
    if roads.crs != grid.crs:
        print(f"⚠️ Reprojecting roads to match grid CRS")
        roads = roads.to_crs(grid.crs)
    
    # Define categories and their filter functions
    categories = {}
    
    # 1. Layer != 0 (bridges, tunnels, elevated/sunken roads)
    def filter_layer_nonzero(roads):
        """Filter roads with layer != 0.
        If layer column missing, use bridge/tunnel as proxy."""
        if 'layer' in roads.columns:
            return roads[(roads['layer'].notna()) & (roads['layer'] != '0') & (roads['layer'] != 0)]
        else:
            # Fallback: use bridge/tunnel columns as proxy for elevated/underground roads
            print(f"   ℹ️  No 'layer' column - using bridge/tunnel as proxy for layer_nonzero")
            has_bridge = roads['bridge'].notna() & (roads['bridge'] != 'no') if 'bridge' in roads.columns else False
            has_tunnel = roads['tunnel'].notna() & (roads['tunnel'] != 'no') if 'tunnel' in roads.columns else False
            
            if isinstance(has_bridge, bool) and isinstance(has_tunnel, bool):
                print(f"   ❌ Cannot compute layer_nonzero: no layer, bridge, or tunnel columns")
                return roads.iloc[:0]  # Return empty dataframe with same structure
            
            return roads[(has_bridge) | (has_tunnel)]
    
    categories['layer_nonzero'] = ('layer_nonzero', filter_layer_nonzero)
    
    # 2. Footway
    def filter_footway(roads):
        return roads[roads['highway'] == 'footway']
    categories['footway'] = ('footway', filter_footway)
    
    # 3. Links (on-ramps, off-ramps)
    def filter_links(roads):
        return roads[roads['highway'].str.contains('_link', na=False)]
    categories['links'] = ('links', filter_links)
    
    # 4. Rare types
    rare_types = get_rare_types(roads, top_n=rare_count)
    def filter_rare(roads):
        return roads[roads['highway'].isin(rare_types)]
    categories['rare_types'] = ('rare_types', filter_rare)
    
    # 5. General (all roads) - use existing density if available
    categories['general'] = ('general', lambda roads: roads)
    
    # Process each category
    all_results = []
    
    for category_key, (category_name, filter_func) in tqdm(categories.items(), desc="Processing categories"):
        print(f"\n{'='*60}")
        print(f"Processing: {category_name}")
        print(f"{'='*60}")
        
        # Calculate density for this category
        density_grid = calculate_density_for_category(roads, grid, category_name, filter_func)
        
        if density_grid is None:
            print(f"⏭️ Skipping {category_name} (no data)")
            continue
        
        # Extract 90th percentile polygons
        hotspot_polygons = extract_90th_percentile_polygons(density_grid, category_name, percentile)
        
        if hotspot_polygons is None:
            print(f"⏭️ Skipping {category_name} (no hotspots)")
            continue
        
        # Save outputs
        output_base = location_dir / f"{location}_{percentile}th_percentile_{category_name}_{cell_size_km:.0f}km"
        
        # Save as GeoPackage
        gpkg_path = output_base.with_suffix('.gpkg')
        hotspot_polygons.to_file(gpkg_path, driver="GPKG")
        print(f"✅ Saved: {gpkg_path}")
        
        # Save as Shapefile
        shp_path = output_base.with_suffix('.shp')
        hotspot_polygons.to_file(shp_path, driver="ESRI Shapefile")
        print(f"✅ Saved: {shp_path}")
        
        # Save as GeoJSON
        geojson_path = output_base.with_suffix('.geojson')
        hotspot_polygons.to_file(geojson_path, driver="GeoJSON")
        print(f"✅ Saved: {geojson_path}")
        
        all_results.append({
            'category': category_name,
            'polygon_count': len(hotspot_polygons),
            'total_area_km2': hotspot_polygons['area_km2'].sum(),
            'threshold_m': hotspot_polygons['threshold_m'].iloc[0]
        })
    
    # Print summary
    if all_results:
        print(f"\n{'='*60}")
        print(f"✅ SUMMARY")
        print(f"{'='*60}")
        summary_df = pd.DataFrame(all_results)
        print(summary_df.to_string(index=False))
        print(f"\n📊 Total categories processed: {len(all_results)}")
    else:
        print("\n⚠️ No hotspots extracted for any category")


if __name__ == "__main__":
    main()
