# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy>=2.3.0,<3",
#     "pillow>=11.2.1,<12",
#     "matplotlib>=3.10.3,<4",
#     "scipy>=1.15.3,<2",
# ]
# ///

"""
Shuffled Bayer-Order Online k-means Color Reduction
Pekka Väänänen, 2026

License: CC0 https://creativecommons.org/public-domain/cc0/
"""

from PIL import Image
import argparse
import numpy as np
import os
import random

def dither_matrix_to_offsets(M):
    assert M.shape[0] == M.shape[1]
    S = M.shape[0]
    order = np.argsort(M.reshape(-1))
    return np.array(list(zip(order // S, order % S)))


def gen_matrix_pattern(M, data_shape):
    S = M.shape[0]
    M_yx = dither_matrix_to_offsets(M)
    H = data_shape[0] // S
    W = data_shape[1] // S
    for mi in range(S*S):
        for yt in range(H):
            for xt in range(W):
                y = S*yt + int(M_yx[mi][0])
                x = S*xt + int(M_yx[mi][1])
                yield y,x

def gen_matrix_pattern_shuffled(M, data_shape):
    S = M.shape[0]
    M_yx = dither_matrix_to_offsets(M)
    H = data_shape[0] // S
    W = data_shape[1] // S
    for mi in range(S*S):
        coords = []
        for yt in range(H):
            for xt in range(W):
                y = S*yt + int(M_yx[mi][0])
                x = S*xt + int(M_yx[mi][1])
                coords.append((y,x))

        random.shuffle(coords)
        for y, x in coords:
            yield y, x


def maximin_init_vectorized(X, K, distance=None):
    centers = []

    centers.append(np.mean(X, axis=0))
    dist = np.zeros((X.shape[0],K))

    for iter in range(K-1):
        if distance is None:
            dist[:, iter] = np.sum((X - centers[iter])**2, axis=1)
        else:
            dist[:, iter] = distance(X, centers[iter])

        
        # 'dist' has now distance to every center, for every point.
        # We want to first take the minimum over columns,
        # and then the index of the point with the largest value aka
        # longest distance to any center.

        k = len(centers)
        closest = dist[:, :k].min(axis=1)
        max_dist_i = np.argmax(closest)
        
        new_center = X[max_dist_i]
        centers.append(new_center)
    
    return np.array(centers)



def quantize_okm_bayer(array: np.ndarray, num_colors: int, extra_rounds=0, order='bayer4x4'):
    """
    Online k-means color quantization.
    See Amber Abernathy's 2022 thesis "The Incremental Online k-means Clustering Algorithm"
    and "Fast color quantization using MacQueen's k-means algorithm" (2019) by Thompson, Celebi, and Buck.

    This variant processes pixels in a Bayer Matrix order.
    """
    assert array.ndim == 3
    assert array.shape[2] == 3
    H, W = array.shape[:2]

    array = array.astype(np.float32)
    X = array.reshape(-1, 3).astype(np.float32)
    clusters = maximin_init_vectorized(X, num_colors)

    # a [K x 3] array
    clusters = np.array(clusters)
    K = clusters.shape[0]
    cluster_sizes = np.zeros(K, dtype=np.int32)
    indices = np.zeros((H, W), dtype=np.int32)

    if order == 'bayer2x2':
        M = np.array([
            [1, 3],
            [4, 2]])
        S = M.shape[0]
    elif order == 'bayer4x4':
        M = np.array([
            [ 1,  9,  3, 11],
            [13,  5, 15,  7],
            [ 4, 12,  2, 10],
            [16,  8, 14,  6]])
        S = M.shape[0]
    elif order == 'sobol':
        from scipy.stats.qmc import Sobol
        sampler = Sobol(2)
    elif order == 'random':
        rng = np.random.default_rng(seed=123)
    elif order == 'raster':
        pass
    else:
        raise RuntimeError("Unsupported matrix type")
    
    # refine
    for pass_idx in range(1+extra_rounds):
        if order in ['bayer2x2', 'bayer4x4']:
            assert (H % S == 0) and (W % S == 0), "input must be a multiple of dither matrix size"
            coords = gen_matrix_pattern_shuffled(M, (H, W))
            # coords = gen_matrix_pattern(M, (H, W))
        elif order == 'sobol':
            # NOTE: a single pass isn't guaranteed to cover the whole image
            num_samples = H*W
            coords = sampler.integers((0, 0), u_bounds=(H, W), n=num_samples)
        elif order == 'random':
            coords_linear = rng.permutation(H*W)
            coords = zip(coords_linear // W, coords_linear % W)
        elif order == 'raster':
            coords_linear = np.array(range(H*W))
            coords = zip(coords_linear // W, coords_linear % W)
        else:
            raise RuntimeError("Unsupported matrix type")

        # for each point xi
        for pi, pj in coords:
            x = array[pi, pj]
            # update point xi, this is the shuffled order
            diff = clusters - x[None, :]
            diff = np.sum(diff ** 2, axis=1)
            # find the index with minimum
            c_idx = np.argmin(diff)
            indices[pi, pj] = c_idx

            # update cluster
            t = cluster_sizes[c_idx] + 1
            cluster_sizes[c_idx] = t

            rt = t**(-0.5)
            ci = clusters[c_idx]

            clusters[c_idx] = (1 - rt) * ci + rt * x


    clusters = np.array(clusters)
    clusters_8bit = np.clip(np.round(clusters), 0, 255).astype(np.uint8)
    return indices, clusters_8bit


def map_pixels_to_palette_per_row(img, palette):
    """
    Find the index of a 'palette' color closest to each 'img' pixel color.
    Uses squared Euclidean distance and brute force.
    """
    
    H, W, _ = img.shape

    inds = np.zeros((H, W), dtype=np.int32)

    for i in range(img.shape[0]):
        diff = img[i, ...].reshape(-1, 1, 3) - palette.reshape(1, -1, 3).astype(float)
        dist = np.sum(diff**2, axis=-1)
        inds[i, ...] = np.argmin(dist, axis=-1)

    return inds


def run_okm(img, num_colors:int, kmeans=0, pixel_mapping='bayer', order='bayer2x2'):
    pixels = np.array(img)
    # Crop the input to a multiple of four.
    S = 4
    pixels = pixels[:(pixels.shape[0]//S)*S, :(pixels.shape[1]//S)*S]

    indices, palette = quantize_okm_bayer(pixels, num_colors, extra_rounds=kmeans, order=order)
    if pixel_mapping == 'bayer':
        pass
    elif pixel_mapping == 'euclidean':
        indices = map_pixels_to_palette_per_row(pixels, palette)
    else:
        raise RuntimeError(f"Unsupported pixel mapping: {pixel_mapping}")
    result = np.take(palette, indices, axis=0)
    return Image.fromarray(result)



parser = argparse.ArgumentParser(description='OKM image quantization')
parser.add_argument('input', help='Input image path')
parser.add_argument('num_colors', type=int, help='Number of colors in the palette (K)')
parser.add_argument('--kmeans', type=int, default=0, help='Number of extra kmeans rounds (default: 0)')
parser.add_argument('--pixel_mapping', choices=['bayer', 'euclidean'], default='bayer',
                    help='Pixel mapping strategy (default: bayer)')
parser.add_argument('--order', choices=['raster', 'random', 'bayer2x2', 'bayer4x4', 'sobol', 'raster'], default='bayer2x2',
                    help='Reduction order (default: bayer2x2)')
parser.add_argument('--output', help='Output image path (default: input_{order}_k{num_colors}.png)')

args = parser.parse_args()

img = Image.open(args.input).convert('RGB')
if args.output is None:
    base, ext = os.path.splitext(os.path.basename(args.input))
    extra = ""
    if args.pixel_mapping != 'bayer':
        extra = extra + "_euclidean"
    args.output = f"{base}_{args.order}_k{args.num_colors}{extra}{ext}"

result = run_okm(img, args.num_colors, kmeans=args.kmeans, 
                 pixel_mapping=args.pixel_mapping, order=args.order)
result.save(args.output)