"""
Tile 2048x2048 images into 512x512 crops and export road graphs as GeoJSON.

For each region_X:
  - Splits region_X_sat.png and region_X_gt.png into a 4x4 grid of 512x512 crops
  - Clips the road graph (region_X_refine_gt_graph.p) to each tile
  - Writes GeoJSON LineStrings in tile-local pixel coordinates

Output structure mirrors input splits, e.g.:
  <output_dir>/train/1/region_0_tile_r0_c0_sat.png
  <output_dir>/train/1/region_0_tile_r0_c0_gt.png
  <output_dir>/train/1/region_0_tile_r0_c0_roads.geojson
"""

import argparse
import json
import multiprocessing as mp
import pickle
import sys
from pathlib import Path

import numpy as np
from PIL import Image
from tqdm import tqdm

TILE_SIZE = 512
FULL_SIZE = 2048
GRID = FULL_SIZE // TILE_SIZE  # 4


def zhang_suen_thinning(binary_mask: np.ndarray) -> np.ndarray:
    """Return a 1-pixel-wide skeleton for a binary mask using Zhang-Suen thinning."""
    img = (binary_mask > 0).astype(np.uint8)
    changed = True

    while changed:
        changed = False
        for step in (0, 1):
            pad = np.pad(img, 1, mode="constant")
            p1 = pad[1:-1, 1:-1]
            p2 = pad[:-2, 1:-1]
            p3 = pad[:-2, 2:]
            p4 = pad[1:-1, 2:]
            p5 = pad[2:, 2:]
            p6 = pad[2:, 1:-1]
            p7 = pad[2:, :-2]
            p8 = pad[1:-1, :-2]
            p9 = pad[:-2, :-2]

            neighbors = [p2, p3, p4, p5, p6, p7, p8, p9, p2]
            b = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9
            a = np.zeros_like(p1)
            for i in range(8):
                a += ((neighbors[i] == 0) & (neighbors[i + 1] == 1)).astype(np.uint8)

            common = (p1 == 1) & (b >= 2) & (b <= 6) & (a == 1)
            if step == 0:
                m1 = (p2 * p4 * p6) == 0
                m2 = (p4 * p6 * p8) == 0
            else:
                m1 = (p2 * p4 * p8) == 0
                m2 = (p2 * p6 * p8) == 0

            remove = common & m1 & m2
            if np.any(remove):
                img[remove] = 0
                changed = True

    return img


def skeletonize_mask_image(mask_img: Image.Image) -> Image.Image:
    """Binarize and skeletonize a mask tile, returning an 8-bit single-channel image."""
    arr = np.array(mask_img)
    if arr.ndim == 3:
        arr = arr[..., 0]
    skel = zhang_suen_thinning((arr > 0).astype(np.uint8))
    return Image.fromarray((skel * 255).astype(np.uint8), mode="L")


def clip_edge_to_tile(p1, p2, x0, y0, x1, y1):
    """
    Clip line segment p1->p2 to the axis-aligned box [x0,x1) x [y0,y1).
    Returns clipped (p1', p2') or None if fully outside.
    Uses Cohen-Sutherland algorithm.
    """
    LEFT, RIGHT, BOTTOM, TOP = 1, 2, 4, 8

    def code(x, y):
        c = 0
        if x < x0:
            c |= LEFT
        elif x > x1:
            c |= RIGHT
        if y < y0:
            c |= BOTTOM
        elif y > y1:
            c |= TOP
        return c

    ax, ay = p1
    bx, by = p2
    ca, cb = code(ax, ay), code(bx, by)

    while True:
        if not (ca | cb):  # both inside
            return (ax, ay), (bx, by)
        if ca & cb:  # both outside same region
            return None
        # pick outside point
        c = ca if ca else cb
        if c & LEFT:
            t = (x0 - ax) / (bx - ax) if bx != ax else 0
            x, y = x0, ay + t * (by - ay)
        elif c & RIGHT:
            t = (x1 - ax) / (bx - ax) if bx != ax else 0
            x, y = x1, ay + t * (by - ay)
        elif c & BOTTOM:
            t = (y0 - ay) / (by - ay) if by != ay else 0
            x, y = ax + t * (bx - ax), y0
        else:  # TOP
            t = (y1 - ay) / (by - ay) if by != ay else 0
            x, y = ax + t * (bx - ax), y1
        if c == ca:
            ax, ay = x, y
            ca = code(ax, ay)
        else:
            bx, by = x, y
            cb = code(bx, by)


def graph_to_geojson_tile(graph, row, col):
    """
    Build a GeoJSON FeatureCollection of road edges clipped to tile (row, col).
    Coordinates are in tile-local pixels: (0,0) = top-left of tile.
    """
    x0 = col * TILE_SIZE
    y0 = row * TILE_SIZE
    x1 = x0 + TILE_SIZE
    y1 = y0 + TILE_SIZE

    features = []
    seen = set()

    for node, neighbors in graph.items():
        for nb in neighbors:
            edge_key = (min(node, nb), max(node, nb))
            if edge_key in seen:
                continue
            seen.add(edge_key)

            clipped = clip_edge_to_tile(node, nb, x0, y0, x1, y1)
            if clipped is None:
                continue
            pa, pb = clipped
            # Convert to tile-local coords; GeoJSON uses [x, y]
            local_pa = [pa[0] - x0, pa[1] - y0]
            local_pb = [pb[0] - x0, pb[1] - y0]
            features.append({
                "type": "Feature",
                "geometry": {
                    "type": "LineString",
                    "coordinates": [local_pa, local_pb]
                },
                "properties": {}
            })

    return {"type": "FeatureCollection", "features": features}


def process_region(graph_path: Path, sat_path: Path, gt_path: Path, out_dir: Path, skeletonize_gt: bool):
    prefix = sat_path.stem.replace("_sat", "")  # e.g. region_0

    # Load graph
    with open(graph_path, "rb") as f:
        graph = pickle.load(f)

    # Load images
    sat = Image.open(sat_path)
    gt = Image.open(gt_path)

    for row in range(GRID):
        for col in range(GRID):
            x0 = col * TILE_SIZE
            y0 = row * TILE_SIZE
            box = (x0, y0, x0 + TILE_SIZE, y0 + TILE_SIZE)
            tag = f"{prefix}_tile_r{row}_c{col}"

            # Crop images
            sat.crop(box).save(out_dir / f"{tag}_sat.png")
            gt_tile = gt.crop(box)
            if skeletonize_gt:
                gt_tile = skeletonize_mask_image(gt_tile)
            gt_tile.save(out_dir / f"{tag}_gt.png")

            # Build and save GeoJSON
            geojson = graph_to_geojson_tile(graph, row, col)
            with open(out_dir / f"{tag}_roads.geojson", "w") as f:
                json.dump(geojson, f)


def process_split(split_root: Path, out_root: Path, workers: int = 8, skeletonize_gt: bool = True):
    jobs = []
    for city_dir in sorted(split_root.iterdir()):
        if not city_dir.is_dir():
            continue
        out_city = out_root / city_dir.name
        out_city.mkdir(parents=True, exist_ok=True)

        sat_files = sorted(city_dir.glob("*_sat.png"))
        if not sat_files:
            continue

        for sat_path in sat_files:
            prefix = sat_path.stem.replace("_sat", "")
            gt_path = city_dir / f"{prefix}_gt.png"
            graph_path = city_dir / f"{prefix}_refine_gt_graph.p"
            if gt_path.exists() and graph_path.exists():
                jobs.append((graph_path, sat_path, gt_path, out_city, skeletonize_gt))

    desc = str(split_root.name)
    with mp.Pool(workers) as pool:
        list(tqdm(pool.imap_unordered(_process_region_star, jobs), total=len(jobs), desc=desc))


def _process_region_star(args):
    try:
        process_region(*args)
    except Exception as e:
        print(f"  [error] {args[1]}: {e}", file=sys.stderr)


def main():
    parser = argparse.ArgumentParser(description="Tile Global-Scale dataset into 512x512 crops with GeoJSON road graphs.")
    parser.add_argument("dataset_root", type=Path, help="Path to dataset root (contains train/, val/, etc.)")
    parser.add_argument("output_root", type=Path, help="Path to write tiled outputs")
    parser.add_argument("--splits", nargs="+", default=["train", "val", "in-domain-test", "out_of_domain"],
                        help="Which splits to process (default: all)")
    parser.add_argument("--workers", type=int, default=8, help="Number of parallel worker processes (default: 8)")
    parser.add_argument("--no-skeletonize-gt", action="store_true",
                        help="Disable skeletonization and keep raw GT crops")
    args = parser.parse_args()

    skeletonize_gt = not args.no_skeletonize_gt
    for split in args.splits:
        split_path = args.dataset_root / split
        if not split_path.exists():
            print(f"[skip] {split_path} not found", file=sys.stderr)
            continue
        out_path = args.output_root / split
        print(f"Processing {split} -> {out_path}")
        process_split(split_path, out_path, workers=args.workers, skeletonize_gt=skeletonize_gt)

    print("Done.")


if __name__ == "__main__":
    main()
