#!/usr/bin/env python3
"""
Memory-efficient road extraction that processes by highway type to reduce memory usage.
"""

import argparse
import sys
import gc
from pathlib import Path
import requests
from tqdm import tqdm
from pyrosm import OSM
import geopandas as gpd
import pandas as pd

from scripts.geofabrik_registry import get_geofabrik_url

def extract_roads_by_type(location: str):
    """Extract roads by processing different highway types separately"""
    location = location.lower()
    base_dir = Path.cwd()
    location_dir = base_dir / "output" / location
    location_dir.mkdir(parents=True, exist_ok=True)

    # Get download URL
    url = get_geofabrik_url(location)
    print(f"🔗 Geofabrik URL: {url}")

    filename = location_dir / f"{location}.osm.pbf"

    # Download if needed
    try:
        if not filename.exists():
            print(f"⬇️ Downloading OSM data for {location.capitalize()}...\n")
            with requests.get(url, stream=True, timeout=120) as r:
                r.raise_for_status()
                total_size = int(r.headers.get("content-length", 0))
                block_size = 8192

                with open(filename, "wb") as f, tqdm(
                    total=total_size,
                    unit="B",
                    unit_scale=True,
                    unit_divisor=1024,
                    desc=f"{location}.osm.pbf"
                ) as bar:
                    for chunk in r.iter_content(chunk_size=block_size):
                        f.write(chunk)
                        bar.update(len(chunk))

            if filename.stat().st_size < 1024 * 100:
                raise ValueError("Downloaded file seems too small — possible download error.")
            print(f"\n✅ Download complete: {filename}")
        else:
            print(f"✅ Found existing file: {filename}")
    except Exception as e:
        print(f"❌ Download failed: {e}")
        sys.exit(1)

    # Process by highway type groups to manage memory
    print(f"\n🚧 Extracting roads by highway type groups...")
    
    # Group highway types by importance to process separately
    highway_groups = {
        "major": ["motorway", "trunk", "primary", "secondary", "tertiary",
                 "motorway_link", "trunk_link", "primary_link", "secondary_link"],
        "local": ["residential", "unclassified", "service", "living_street"],
        "paths": ["footway", "cycleway", "path", "pedestrian", "track", "bridleway"],
        "other": True  # All remaining highway types
    }
    
    all_roads = []
    
    for group_name, highway_types in highway_groups.items():
        try:
            print(f"\n🛣️ Processing {group_name} roads...")
            
            # Create fresh OSM object for each group
            osm = OSM(str(filename))
            
            # Extract this group of roads
            if highway_types == True:
                # For "other", get all highways not already processed
                processed_types = []
                for prev_group, prev_types in list(highway_groups.items())[:-1]:
                    if prev_types != True:
                        processed_types.extend(prev_types)
                
                # This is complex with pyrosm, so skip "other" for now
                print("  ⚪ Skipping 'other' group to conserve memory")
                continue
            else:
                roads = osm.get_data_by_custom_criteria(
                    custom_filter={"highway": highway_types},
                    filter_type="keep",
                    keep_nodes=False,
                    keep_relations=False,
                    extra_attributes=[
                        "highway", "name", "lanes", "width", "surface", 
                        "material", "bridge", "tunnel", "layer"
                    ]
                )
            
            if roads is not None and not roads.empty:
                print(f"  ✅ Extracted {len(roads):,} {group_name} roads")
                all_roads.append(roads)
            else:
                print(f"  ⚪ No {group_name} roads found")
            
            # Clear memory
            del osm
            del roads
            gc.collect()
            
        except Exception as e:
            print(f"  ❌ Error processing {group_name} roads: {e}")
            continue
    
    if not all_roads:
        print("❌ No roads extracted from any groups")
        sys.exit(1)
    
    # Combine results
    print("\n🔗 Combining all road types...")
    try:
        combined_roads = pd.concat(all_roads, ignore_index=True)
        combined_roads = gpd.GeoDataFrame(combined_roads)
        print(f"✅ Total extracted: {len(combined_roads):,} roads")
        
        # Remove duplicates if any
        original_count = len(combined_roads)
        combined_roads = combined_roads.drop_duplicates()
        if len(combined_roads) < original_count:
            print(f"🧹 Removed {original_count - len(combined_roads):,} duplicates")
            
    except Exception as e:
        print(f"❌ Error combining road types: {e}")
        sys.exit(1)
    
    # Clean up memory
    del all_roads
    gc.collect()
    
    # Reproject
    try:
        combined_roads = combined_roads.to_crs(epsg=4326)
        print("🌐 Reprojected to WGS84 (EPSG:4326).")
    except Exception as e:
        print(f"⚠️ Reprojection failed: {e}")

    # Save
    try:
        output_path = location_dir / f"{location}_roads.gpkg"
        combined_roads.to_file(output_path, layer="roads", driver="GPKG")
        print(f"💾 Saved extracted roads to: {output_path}")
        
        # Show summary
        print(f"\n📊 Summary:")
        print(f"   Total roads: {len(combined_roads):,}")
        print(f"   File size: {output_path.stat().st_size / 1024**2:.1f} MB")
        print(f"   Highway types found:")
        highway_counts = combined_roads['highway'].value_counts().head(10)
        for highway_type, count in highway_counts.items():
            print(f"     {highway_type}: {count:,}")
        
    except Exception as e:
        print(f"❌ Failed to save file: {e}")
        sys.exit(1)

def main():
    parser = argparse.ArgumentParser(description="Extract roads from OSM data by type")
    parser.add_argument("location", help="Location name (e.g., france, spain)")
    
    args = parser.parse_args()
    extract_roads_by_type(args.location)

if __name__ == "__main__":
    main()