#!/usr/bin/env python3
"""
Extracts only road features (highways, streets, etc.) from an
OpenStreetMap .pbf file downloaded from Geofabrik.

✅ Features:
 - Automatic folder creation pe    except MemoryError as e:
        print(f"❌ Memory error during extraction: {e}")
        print("🔄 Attempting regional processing...")
        roads = extract_roads_by_regions(filename, location_dir, location)
        if roads is None:
            print("💡 Regional processing failed. Try processing a smaller country/region")
            sys.exit(1)ion
 - Saves standardized GeoPackage to: output/{LOCATION}/{location}_roads.gpkg
"""

import argparse
import sys
from pathlib import Path
import requests
from tqdm import tqdm
from pyrosm import OSM
import geopandas as gpd
import pandas as pd
import os

# Add parent directory to path to allow imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from scripts.geofabrik_registry import get_geofabrik_url

# ------------------------------
# MEMORY-EFFICIENT EXTRACTION
# ------------------------------
def extract_roads_by_regions(filename, location_dir, location):
    """Extract roads by processing geographic regions separately"""
    import gc
    from shapely.geometry import box
    
    print("🏗️ Processing large dataset by geographic regions...")
    osm = OSM(str(filename))
    
    # Get bounds
    boundaries_gdf = osm.get_boundaries()
    bounds = boundaries_gdf.total_bounds  # Returns (min_lon, min_lat, max_lon, max_lat)
    print(f"📍 Dataset bounds: {bounds}")
    
    # Split into regions (rough grid)
    min_lon, min_lat, max_lon, max_lat = bounds
    
    # Create 6x6 grid (36 regions) for better memory management
    lon_step = (max_lon - min_lon) / 8
    lat_step = (max_lat - min_lat) / 8
    
    all_roads = []
    
    for i in range(8):
        for j in range(8):
            try:
                # Define region bounds
                region_min_lon = min_lon + i * lon_step
                region_max_lon = min_lon + (i + 1) * lon_step
                region_min_lat = min_lat + j * lat_step
                region_max_lat = min_lat + (j + 1) * lat_step
                
                region_bbox = (region_min_lon, region_min_lat, region_max_lon, region_max_lat)
                
                print(f"🔍 Processing region {i+1}-{j+1}/16: {region_bbox}")
                
                # Create OSM object with bounding box for this region (must be list)
                region_osm = OSM(str(filename), bounding_box=list(region_bbox))
                
                # Extract roads for this region
                region_roads = region_osm.get_data_by_custom_criteria(
                    custom_filter={"highway": True},  # All highway types
                    filter_type="keep",
                    keep_nodes=False,
                    keep_relations=False,
                    extra_attributes=[
                        "highway", "name", "lanes", "width", "surface", 
                        "material", "bridge", "tunnel", "layer"
                    ]
                )
                
                if region_roads is not None and not region_roads.empty:
                    print(f"  ✅ Found {len(region_roads):,} roads in region {i+1}-{j+1}")
                    all_roads.append(region_roads)
                else:
                    print(f"  ⚪ No roads in region {i+1}-{j+1}")
                
                # Force garbage collection after each region
                gc.collect()
                
            except Exception as e:
                print(f"  ❌ Error in region {i+1}-{j+1}: {e}")
                continue
    
    if all_roads:
        print("� Combining all regions...")
        combined_roads = gpd.GeoDataFrame(pd.concat(all_roads, ignore_index=True))
        print(f"✅ Total extracted: {len(combined_roads):,} roads from all regions")
        return combined_roads
    
    return None

# ------------------------------
# MAIN LOGIC
# ------------------------------
def extract_roads(location: str):
    location = location.lower()
    base_dir = Path.cwd()
    location_dir = base_dir / "output" / location
    location_dir.mkdir(parents=True, exist_ok=True)

    filename = location_dir / f"{location}.osm.pbf"
    output_path = location_dir / "data.gpkg"
    
    # Get URL from registry (supports US states, countries, etc.)
    try:
        url = get_geofabrik_url(location)
        print(f"🔗 Geofabrik URL: {url}")
    except ValueError as e:
        print(f"❌ {e}")
        sys.exit(1)

    # ------------------------------
    # STEP 1: DOWNLOAD FILE
    # ------------------------------
    try:
        if not filename.exists():
            print(f"⬇️ Downloading OSM data for {location.capitalize()}...\n")
            with requests.get(url, stream=True, timeout=120) as r:
                r.raise_for_status()
                total_size = int(r.headers.get("content-length", 0))
                block_size = 8192

                with open(filename, "wb") as f, tqdm(
                    total=total_size,
                    unit="B",
                    unit_scale=True,
                    unit_divisor=1024,
                    desc=f"{location}.osm.pbf"
                ) as bar:
                    for chunk in r.iter_content(chunk_size=block_size):
                        f.write(chunk)
                        bar.update(len(chunk))

            if filename.stat().st_size < 1024 * 100:
                raise ValueError("Downloaded file seems too small — possible download error.")
            print(f"\n✅ Download complete: {filename}")
        else:
            print(f"✅ Found existing file: {filename}")
    except Exception as e:
        print(f"❌ Download failed: {e}")
        sys.exit(1)

    # ------------------------------
    # STEP 2: EXTRACT ROADS (Memory-efficient)
    # ------------------------------
    print("\n🚧 Extracting roads...")
    roads = None
    
    # Check file size to decide on extraction strategy
    file_size_mb = filename.stat().st_size / (1024 ** 2)
    print(f"📊 PBF file size: {file_size_mb:.1f} MB")
    
    # Use chunked extraction for files > 500 MB to avoid OOM
    use_chunked = file_size_mb > 500
    
    if use_chunked:
        print(f"💡 Large file detected ({file_size_mb:.0f} MB) - using chunked extraction with 6x6 grid")
        try:
            roads = extract_roads_by_regions(filename, location_dir, location)
            if roads is None or roads.empty:
                raise ValueError("Chunked extraction returned no data")
        except Exception as e:
            print(f"❌ Chunked extraction failed: {e}")
            sys.exit(1)
    else:
        try:
            osm = OSM(str(filename))
            print("🔍 Reading OSM metadata...")
            
            # Memory-efficient extraction with ALL attributes
            roads = osm.get_data_by_custom_criteria(
                custom_filter={"highway": True},
                filter_type="keep",
                keep_nodes=False,
                keep_relations=False,
                extra_attributes=[
                    "highway", "name", "lanes", "width", "surface", 
                    "material", "bridge", "tunnel", "layer"
                ]
            )

            if roads is None or roads.empty:
                raise ValueError("No road data extracted")

            print(f"✅ Extracted {len(roads):,} roads from {location.capitalize()} data.")
            print(f"📊 Memory usage: {roads.memory_usage(deep=True).sum() / 1024**2:.1f} MB")

        except MemoryError as e:
            print(f"❌ Memory error: {e}")
            print("🔄 Attempting chunked extraction...")
            try:
                roads = extract_roads_by_regions(filename, location_dir, location)
                if roads is None or roads.empty:
                    raise ValueError("Chunked extraction returned no data")
            except Exception as chunk_e:
                print(f"❌ Chunked extraction also failed: {chunk_e}")
                sys.exit(1)
        except Exception as e:
            print(f"❌ Road extraction failed: {e}")
            sys.exit(1)

    # ------------------------------
    # STEP 3: REPROJECT
    # ------------------------------
    try:
        roads = roads.to_crs(epsg=4326)
        print("🌐 Reprojected to WGS84 (EPSG:4326).")
    except Exception as e:
        print(f"⚠️ Reprojection failed: {e}")

    # ------------------------------
    # STEP 4: SAVE CLEANED DATA
    # ------------------------------
    try:
        output_path = location_dir / f"{location}_roads.gpkg"
        roads.to_file(output_path, layer="roads", driver="GPKG")
        print(f"💾 Saved extracted roads to: {output_path}")
    except Exception as e:
        print(f"❌ Saving failed: {e}")
        sys.exit(1)

    # ------------------------------
    # STEP 5: CLEAN UP (OPTIONAL)
    # ------------------------------
    DELETE_DOWNLOADED_PBF = False  # 👈 toggle this if you want to keep raw files

    try:
        if DELETE_DOWNLOADED_PBF and filename.exists():
            filename.unlink()  # delete only the PBF that was used
            print(f"🧹 Deleted processed PBF file: {filename.name}")
        else:
            print(f"📦 Kept PBF files (DELETE_DOWNLOADED_PBF={DELETE_DOWNLOADED_PBF})")
    except Exception as e:
        print(f"⚠️ Could not delete {filename}: {e}")

    # ------------------------------
    # STEP 6: SUMMARY
    # ------------------------------
    print("\n✅ Process complete!")
    print(f"📁 Folder: {location_dir}")
    print(f"📥 Input: {filename.name}")
    print(f"💾 Output: {output_path.name}")
    print(f"🛣️ Roads extracted: {len(roads):,}")
    print("-" * 60)
    print(roads.head())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Download and extract OSM road data for a location."
    )
    parser.add_argument("location", help="Location name (e.g., alabama, egypt, thailand).")
    args = parser.parse_args()

    extract_roads(args.location)
