#!/usr/bin/env python3
"""squircle.py — converte PNG raw em squircle 512×512 com bg color correto.

Pipeline:
1. Resize+pad pra 512×512 (preservando aspect ratio, transparent pad)
2. Detecta bg color: override hex > edge-ring k-means > auto-neutral por luminance
3. Composita logo sobre bg sólido (elimina cantos transparentes)
4. Aplica squircle mask (radius ~22%)
5. Modula saturação *= 0.5
6. Salva 512×512 PNG

Usage as module:
    from squircle import make_squircle
    make_squircle("raw.png", "out.png", override_hex="#fafafa", desat=True)

Usage as CLI:
    python squircle.py <input> <output> [--bg #RRGGBB] [--no-desat]
"""
from __future__ import annotations
import argparse
import io
import subprocess
import sys
from pathlib import Path
from PIL import Image, ImageColor


def load_image(path: str | Path) -> Image.Image:
    """Open PNG, SVG, or other formats. SVGs rasterized via rsvg-convert (with cairosvg fallback)."""
    path = Path(path)
    if path.suffix.lower() == ".svg":
        # First: rsvg-convert (preferred, fast)
        result = subprocess.run(
            ["rsvg-convert", "-w", "1024", "-h", "1024", str(path)],
            capture_output=True,
        )
        if result.returncode == 0 and len(result.stdout) > 100:
            return Image.open(io.BytesIO(result.stdout)).convert("RGBA")
        # Fallback: cairosvg (more permissive parser, accepts broken xmlns)
        try:
            import cairosvg
            png_bytes = cairosvg.svg2png(
                url=str(path), output_width=1024, output_height=1024,
                unsafe=True,
            )
            return Image.open(io.BytesIO(png_bytes)).convert("RGBA")
        except Exception as e:
            raise RuntimeError(
                f"Both rsvg-convert and cairosvg failed for {path.name}: {e}"
            )
    return Image.open(path).convert("RGBA")

SCRIPT_DIR = Path(__file__).parent
# Default: circle mask 85% diameter (15% transparent ring outer).
# Reduz tanto a moldura branca quanto o conteúdo proporcionalmente.
MASK_PATH = SCRIPT_DIR / "circle-mask-85pct-512.png"
MASK_FULL_PATH = SCRIPT_DIR / "circle-mask-512.png"
SQUIRCLE_MASK_PATH = SCRIPT_DIR / "squircle-mask-512.png"
SIZE = 512
EDGE_MARGIN_PCT = 0.08  # outer 8% ring
CONTENT_PCT = 0.85  # conteúdo ocupa 85% do canvas (=100% da moldura visível 85%)


def trim_image(img: Image.Image) -> Image.Image:
    """Auto-crop borda do logo: remove pixels transparentes/quase-brancos das margens.
    Garante que o CONTEÚDO real (não o frame do raw) seja a unidade de medida."""
    img = img.convert("RGBA")
    # Strategy 1: bbox dos pixels com alpha > 0 (logos com fundo transparente)
    bbox = img.getbbox()
    if bbox:
        cropped = img.crop(bbox)
        # Se o crop reduziu pelo menos 5%, usar — senão raw já estava trimmed
        if cropped.size[0] < img.size[0] * 0.95 or cropped.size[1] < img.size[1] * 0.95:
            return cropped
    # Strategy 2: pra logos com bg branco/cor sólida (sem alpha real),
    # detecta cor de canto e trima onde difere — mais arriscado, skip por enquanto
    return img


def resize_pad(img: Image.Image, size: int = SIZE, content_pct: float = CONTENT_PCT) -> Image.Image:
    """Fit img content into (size*content_pct)×(size*content_pct) area, centered on size canvas.
    Trim transparent borders ANTES do resize pra garantir proporção uniforme."""
    img = img.convert("RGBA")
    img = trim_image(img)  # remove margens transparentes do raw
    w, h = img.size
    target = max(1, int(size * content_pct))
    scale = target / max(w, h)
    new_size = (max(1, int(w * scale)), max(1, int(h * scale)))
    img = img.resize(new_size, Image.LANCZOS)
    canvas = Image.new("RGBA", (size, size), (0, 0, 0, 0))
    canvas.paste(img, ((size - new_size[0]) // 2, (size - new_size[1]) // 2), img)
    return canvas


def edge_pixels(img: Image.Image, margin_pct: float = EDGE_MARGIN_PCT) -> list[tuple[int, int, int]]:
    """Sample opaque pixels from outer ring. Returns list of (R,G,B) tuples."""
    w, h = img.size
    margin = max(int(margin_pct * min(w, h)), 8)
    px = img.load()
    pixels = []
    # Top + bottom strips
    for y in list(range(margin)) + list(range(h - margin, h)):
        for x in range(0, w, 2):  # subsample by 2
            r, g, b, a = px[x, y]
            if a > 200:
                pixels.append((r, g, b))
    # Left + right strips (avoid double-counting corners)
    for x in list(range(margin)) + list(range(w - margin, w)):
        for y in range(margin, h - margin, 2):
            r, g, b, a = px[x, y]
            if a > 200:
                pixels.append((r, g, b))
    return pixels


def median_cut_2(pixels: list[tuple[int, int, int]]) -> tuple[int, int, int] | None:
    """Simple 2-cluster median-cut. Returns dominant cluster centroid as (R,G,B)."""
    if not pixels:
        return None
    # Find axis with greatest range
    rs = [p[0] for p in pixels]
    gs = [p[1] for p in pixels]
    bs = [p[2] for p in pixels]
    ranges = [max(rs) - min(rs), max(gs) - min(gs), max(bs) - min(bs)]
    axis = ranges.index(max(ranges))
    pivot = sorted(p[axis] for p in pixels)[len(pixels) // 2]
    cluster_a = [p for p in pixels if p[axis] <= pivot]
    cluster_b = [p for p in pixels if p[axis] > pivot]
    larger = cluster_a if len(cluster_a) >= len(cluster_b) else cluster_b
    if not larger:
        larger = pixels
    n = len(larger)
    avg = (
        sum(p[0] for p in larger) // n,
        sum(p[1] for p in larger) // n,
        sum(p[2] for p in larger) // n,
    )
    return avg


def detect_bg(img: Image.Image) -> tuple[int, int, int] | None:
    """Detect probable bg color via edge ring sampling + 2-cluster median cut.
    Returns RGB tuple or None if logo is mostly transparent at edges."""
    pixels = edge_pixels(img)
    if len(pixels) < 100:
        return None
    return median_cut_2(pixels)


def luminance(rgb: tuple[int, int, int]) -> float:
    """Perceptual luminance 0..1 (Rec.709 weights)."""
    r, g, b = rgb
    return (0.2126 * r + 0.7152 * g + 0.0722 * b) / 255.0


def auto_neutral(img: Image.Image) -> tuple[int, int, int]:
    """Pick contrasting neutral bg based on dominant luminance of opaque logo pixels."""
    px = img.load()
    w, h = img.size
    sum_l = 0.0
    count = 0
    for y in range(0, h, 4):
        for x in range(0, w, 4):
            r, g, b, a = px[x, y]
            if a > 200:
                sum_l += luminance((r, g, b))
                count += 1
    if count == 0:
        return (250, 249, 245)  # default light cream
    avg_l = sum_l / count
    # Logo claro → bg escuro charcoal; logo escuro → bg claro cream
    return (10, 10, 10) if avg_l > 0.5 else (250, 249, 245)


def parse_hex(hex_str: str) -> tuple[int, int, int]:
    """#RRGGBB → (R,G,B)."""
    return ImageColor.getrgb(hex_str)[:3]


def desaturate(img: Image.Image, factor: float = 0.5) -> Image.Image:
    """Reduce saturation. factor=1.0 = unchanged, 0.0 = grayscale."""
    img = img.convert("RGBA")
    r, g, b, a = img.split()
    rgb = Image.merge("RGB", (r, g, b))
    gray = rgb.convert("L").convert("RGB")
    blended = Image.blend(gray, rgb, factor)
    rb, gb, bb = blended.split()
    return Image.merge("RGBA", (rb, gb, bb, a))


_mask_cache: dict[str, Image.Image] = {}


def squircle_mask(shape: str = "circle") -> Image.Image:
    """Load (and cache) alpha mask 512×512 (L mode). shape=circle (default) | squircle."""
    if shape not in _mask_cache:
        path = SQUIRCLE_MASK_PATH if shape == "squircle" else MASK_PATH
        _mask_cache[shape] = Image.open(path).convert("L")
    return _mask_cache[shape].copy()


def make_squircle(
    raw_path: str | Path,
    output_path: str | Path,
    override_hex: str | None = None,
    desat: bool = False,
    shape: str = "circle",
) -> dict:
    """Apply full pipeline. shape=circle (default) | squircle. desat=False default (preserve cores)."""
    img = load_image(raw_path)
    img = resize_pad(img, SIZE)

    # Determine bg color
    if override_hex:
        bg = parse_hex(override_hex)
        bg_source = f"override:{override_hex}"
    else:
        detected = detect_bg(img)
        if detected:
            bg = detected
            bg_source = f"detected:#{bg[0]:02x}{bg[1]:02x}{bg[2]:02x}"
        else:
            bg = auto_neutral(img)
            bg_source = f"neutral:#{bg[0]:02x}{bg[1]:02x}{bg[2]:02x}"

    # Composite logo over solid bg
    canvas = Image.new("RGBA", (SIZE, SIZE), bg + (255,))
    canvas.alpha_composite(img)

    # Apply alpha mask (circle ou squircle)
    mask = squircle_mask(shape)
    canvas.putalpha(mask)

    # Optional desaturation
    if desat:
        canvas = desaturate(canvas, 0.5)
        canvas.putalpha(mask)

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    canvas.save(output_path, "PNG", optimize=True)

    return {
        "input": str(raw_path),
        "output": str(output_path),
        "bg": f"#{bg[0]:02x}{bg[1]:02x}{bg[2]:02x}",
        "bg_source": bg_source,
        "desat": desat,
    }


def main():
    ap = argparse.ArgumentParser(description="Convert raw PNG/SVG to circle (or squircle) 512×512.")
    ap.add_argument("input", help="raw PNG/SVG path")
    ap.add_argument("output", help="output PNG path")
    ap.add_argument("--bg", help="override bg hex (#RRGGBB)")
    ap.add_argument("--shape", choices=["circle", "squircle"], default="circle")
    ap.add_argument("--desat", action="store_true", help="apply 50%% desaturation")
    args = ap.parse_args()

    info = make_squircle(args.input, args.output, args.bg, desat=args.desat, shape=args.shape)
    print(f"OK: bg={info['bg']} ({info['bg_source']}) → {info['output']}")


if __name__ == "__main__":
    main()
