#!/usr/bin/env python3
"""
Convert datapure polygon annotations to PASCAL VOC format.
Processes all batches and creates semantic and instance segmentation masks.

This script normalizes class IDs across all batches to a unified mapping,
assigns proper z-order for correct layering, and uses distinctive colors.
"""

import json
import os
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw
import shutil
from collections import defaultdict

# Unified class mapping: class_name -> (new_id, z_order, rgb_color)
# Consistent across all batches, with proper z-order and distinctive colors
UNIFIED_CLASS_MAPPING = {
    '_background_': (0, -1, (0, 0, 0)),
    # Background/Environment (z=0-9)
    'Sky': (1, 0, (135, 206, 235)),
    'Vegetation': (2, 1, (34, 139, 34)),
    'Ground': (3, 2, (139, 90, 43)),
    # Large structures (z=10-19)
    'Building': (4, 10, (160, 82, 45)),
    'Facade': (5, 11, (205, 133, 63)),
    'Roof': (6, 12, (178, 34, 34)),
    # Large features (z=20-29)
    'Balcony': (7, 20, (70, 130, 180)),
    'Vehicle': (8, 21, (128, 128, 128)),
    'Shop/billboards': (9, 22, (255, 140, 0)),
    # Architectural details (z=30-39)
    'Column': (10, 30, (147, 112, 219)),
    'Cornice': (11, 31, (219, 112, 147)),
    'Molding': (12, 32, (255, 182, 193)),
    'Sill': (13, 33, (216, 191, 216)),
    'Chimney': (14, 34, (139, 0, 0)),
    # Openings - on top of building/facade (z=40-49)
    'Windows': (15, 40, (0, 0, 255)),
    'Door': (16, 41, (0, 255, 255)),
    'Bay Window': (17, 42, (65, 105, 225)),
    'Blind': (18, 43, (176, 196, 222)),
    'Blind/Window': (19, 43, (173, 216, 230)),
    # Decorative details (z=50-59)
    'Deco': (20, 50, (255, 215, 0)),
    # Occlusion markers (z=60-69)
    'Occluded/W': (21, 60, (200, 200, 255)),
    'Occluded/V': (22, 61, (200, 255, 200)),
    'Occluded/B': (23, 62, (255, 220, 200)),
    'Occluded/D': (24, 63, (200, 255, 255)),
    'Occluded/C': (25, 64, (255, 200, 255)),
    # Special classes (z=70+)
    'Person': (26, 70, (255, 0, 0)),
    'ROI': (27, 71, (255, 255, 0)),
    '_ignore_': (28, 72, (50, 50, 50)),
}

# Build palette array from unified mapping (sorted by new_id)
PALETTE = [None] * (max(v[0] for v in UNIFIED_CLASS_MAPPING.values()) + 1)
for class_name, (new_id, z_order, rgb) in UNIFIED_CLASS_MAPPING.items():
    PALETTE[new_id] = rgb

def load_classes(batch_dir):
    """Load class definitions from classes.json"""
    classes_file = batch_dir / 'classes.json'
    with open(classes_file, 'r') as f:
        classes = json.load(f)

    # Create mapping from class ID to name
    class_map = {0: '_background_'}
    for cls in classes:
        class_map[cls['id']] = cls['name']

    return class_map

def get_class_info(class_name):
    """
    Get unified class info (new_id, z_order, rgb_color) for a class name.
    Returns None if class not found in mapping.
    """
    return UNIFIED_CLASS_MAPPING.get(class_name, None)

def get_z_order(class_name):
    """
    Get z-order for a class name. Lower numbers = background (drawn first),
    higher numbers = foreground (drawn last).
    """
    info = get_class_info(class_name)
    if info:
        return info[1]  # z_order is second element
    return 999  # Unknown classes go to the end

def polygon_to_mask(polygon_points, img_size):
    """Convert polygon points to binary mask"""
    mask = Image.new('L', img_size, 0)
    draw = ImageDraw.Draw(mask)

    # Convert flat list to list of tuples
    points = [(polygon_points[i], polygon_points[i+1])
              for i in range(0, len(polygon_points), 2)]

    if len(points) >= 3:
        draw.polygon(points, fill=255)

    return np.array(mask)

def process_batch(batch_dir, output_base):
    """Process a single batch of annotations"""
    print(f"Processing {batch_dir.name}...")

    # Load annotations
    annotations_file = batch_dir / 'annotations.json'
    if not annotations_file.exists():
        print(f"  No annotations.json found, skipping")
        return 0

    with open(annotations_file, 'r') as f:
        annotations = json.load(f)

    # Load classes
    class_map = load_classes(batch_dir)

    # Find images directory
    images_dir = batch_dir / 'images'
    if not images_dir.exists():
        print(f"  No images directory found, skipping")
        return 0

    processed_count = 0

    # Process each image
    for img_filename, img_annotations in annotations.items():
        # Skip metadata keys
        if img_filename.startswith('___'):
            continue

        # Find the actual image file
        img_path = images_dir / img_filename
        if not img_path.exists():
            print(f"  Warning: {img_filename} not found")
            continue

        # Load image to get dimensions
        try:
            img = Image.open(img_path)
            img_size = img.size

            # Save image to JPEGImages (using PIL to avoid permission issues)
            output_img_path = output_base / 'JPEGImages' / img_filename
            img.save(output_img_path)

            # Create masks
            class_mask = np.zeros((img_size[1], img_size[0]), dtype=np.uint8)
            instance_mask = np.zeros((img_size[1], img_size[0]), dtype=np.uint16)

            # Handle different annotation formats
            # Some batches have {instances: [...]} format, others have direct list
            anns_list = img_annotations.get('instances', img_annotations) if isinstance(img_annotations, dict) else img_annotations

            # Sort annotations by z-order to ensure proper layering
            # Lower z-order values are drawn first (background), higher values last (foreground)
            def get_ann_z_order(ann):
                if not isinstance(ann, dict):
                    return 999  # Put invalid annotations at the end
                class_id = ann.get('classId', -1)
                class_name = class_map.get(class_id, '')
                return get_z_order(class_name)

            anns_list_sorted = sorted(anns_list, key=get_ann_z_order)

            # Process each annotation in z-order
            instance_id = 1
            for ann in anns_list_sorted:
                if not isinstance(ann, dict) or ann.get('type') != 'polygon':
                    continue

                old_class_id = ann.get('classId', -1)

                # Map old class ID to class name, then to new unified ID
                class_name = class_map.get(old_class_id, None)
                if not class_name:
                    continue  # Skip unknown classes

                # Get unified class info
                class_info = get_class_info(class_name)
                if not class_info:
                    # Class not in unified mapping, skip
                    continue

                new_class_id = class_info[0]  # new_id is first element

                points = ann.get('points', [])
                if not points:
                    continue

                # Create polygon mask
                poly_mask = polygon_to_mask(points, img_size)

                # Update class mask with NEW unified class ID (semantic segmentation)
                class_mask[poly_mask > 0] = new_class_id

                # Update instance mask (instance segmentation)
                instance_mask[poly_mask > 0] = instance_id
                instance_id += 1

            # Save class segmentation
            base_name = img_filename.rsplit('.', 1)[0]
            class_img = Image.fromarray(class_mask, mode='P')
            class_img.putpalette([c for rgb in PALETTE for c in rgb])
            class_img.save(output_base / 'SegmentationClass' / f'{base_name}.png')

            # Save class PNG
            class_img.save(output_base / 'SegmentationClassPNG' / f'{base_name}.png')

            # Save class visualization
            class_vis = class_img.convert('RGB')
            class_vis.save(output_base / 'SegmentationClassVisualization' / f'{base_name}.png')

            # Save instance segmentation
            instance_img = Image.fromarray(instance_mask, mode='I;16')
            instance_img.save(output_base / 'SegmentationObject' / f'{base_name}.png')

            # Save instance PNG (convert to palette for visualization)
            instance_vis = np.zeros((img_size[1], img_size[0], 3), dtype=np.uint8)
            unique_ids = np.unique(instance_mask)
            for idx, inst_id in enumerate(unique_ids):
                if inst_id == 0:
                    continue
                color_idx = (idx % (len(PALETTE) - 1)) + 1
                mask = instance_mask == inst_id
                instance_vis[mask] = PALETTE[color_idx]

            instance_vis_img = Image.fromarray(instance_vis)
            instance_vis_img.save(output_base / 'SegmentationObjectPNG' / f'{base_name}.png')
            instance_vis_img.save(output_base / 'SegmentationObjectVisualization' / f'{base_name}.png')

            processed_count += 1

        except Exception as e:
            print(f"  Error processing {img_filename}: {e}")
            continue

    return processed_count

def main():
    # Setup paths
    base_dir = Path('/mnt/vision/data/kaust/annotations/crop/datapure')
    data_dir = base_dir / 'raw' / 'Data'
    output_dir = base_dir / 'psv-datapure-voc'

    # Create output directory structure
    output_dirs = [
        'JPEGImages',
        'SegmentationClass',
        'SegmentationClassPNG',
        'SegmentationClassVisualization',
        'SegmentationObject',
        'SegmentationObjectPNG',
        'SegmentationObjectVisualization'
    ]

    print("Creating output directory structure...")
    for dir_name in output_dirs:
        (output_dir / dir_name).mkdir(parents=True, exist_ok=True)

    # Get all batch directories
    batch_dirs = sorted([d for d in data_dir.iterdir() if d.is_dir() and d.name.startswith('Batch')])
    print(f"Found {len(batch_dirs)} batches")

    # Process each batch
    total_processed = 0
    for batch_dir in batch_dirs:
        count = process_batch(batch_dir, output_dir)
        total_processed += count

    # Create class_names.txt using unified mapping
    print("\nCreating class_names.txt...")

    # Sort classes by their new unified ID
    sorted_classes = sorted(UNIFIED_CLASS_MAPPING.items(), key=lambda x: x[1][0])

    with open(output_dir / 'class_names.txt', 'w') as f:
        for class_name, (new_id, z_order, rgb) in sorted_classes:
            f.write(f"{class_name}\n")

    print(f"\nConversion complete!")
    print(f"Total images processed: {total_processed}")
    print(f"Output directory: {output_dir}")

if __name__ == '__main__':
    main()
