
import os
import argparse
import xmltodict
import json
from dataclasses import dataclass
import tqdm.auto as tq

p = argparse.ArgumentParser()
p.add_argument('--annotations', '-a', help="CVAT XML file")
p.add_argument('--imagedir', '-d', help="Images folder", default='images')

args = p.parse_args()

with open(args.annotations, 'rb') as f:
    a = xmltodict.parse(f, force_list=['polygon', 'box', 'image'])
a = a['annotations']

images = a['image']
meta = a['meta']

@dataclass
class Label:
    name: str
    color: str
    attributes: dict 

labels = {info['name']: Label(**info) for info in meta['task']['labels']['label']}


for image in tq.tqdm(images, 'Extracting labels from images'):

    id_ = int(image['@id'])
    name = image['@name']
    width = int(image['@width'])
    height = int(image['@height'])

    shapes = []
    
    if 'box' in image:
        for box in image['box']:
            label = box['@label']
            occluded = int(box['@occluded'])
            source = box['@source']
            xl = float(box['@xtl'])
            yt = float(box['@ytl'])
            xr = float(box['@xbr'])
            yb = float(box['@ybr'])
            z_order = int(box['@z_order'])

            shape = dict(
                label=label.lower(),
                points=[[xl, yt], [xr, yb]],
                group_id=None,
                shape_type='rectangle',
                flags=dict(
                    source=source,
                    occluded=occluded,
                    z_order=z_order
                )
            )
            shapes.append(shape)
    else:
        print(f"WARNING: {image} -- No ROI box")
    
    if 'polygon' in image:
        for poly in image['polygon']:
            label = poly['@label']
            occluded = int(poly['@occluded'])
            source = poly['@source']
            z_order = int(poly['@z_order'])
            points = poly['@points']
            points = [[float(c) for c in point.split(',')] for point in points.split(';')]

            shape = dict(
                label=label.lower(),
                points=points,
                group_id=None,
                shape_type='polygon',
                flags=dict(
                    source=source,
                    occluded=occluded,
                    z_order=z_order
                )
            )
            shapes.append(shape)
    else:
        print(f"ERROR: {image} -- No polygons in image")
                
    image_path = os.path.join(args.imagedir, name)
    #encoded = base64.b64encode(open(image_path, 'rb').read())

    labelme = dict(
        version='4.0.0',
       # imageData=encoded,
        imageData=None,
        imagePath=name,
        imageHeight=height,
        imageWidth=width,
        flags=dict(),
        shapes=shapes
    )

    json_path = os.path.splitext(image_path)[0] + '.json'
    with open(json_path, 'w') as f:
        # breakpoint()
        json.dump(labelme, f, indent=True)
        tq.tqdm.write(json_path)

    with open(os.path.splitext(image_path)[0] + '.labels', 'w') as f:
        f.writelines([f'{label}\n' for label in labels])


