import * as z from "zod";
import { convertBase64ImageToBlob } from "~/domain/network";
import type { ImageSegmentationInferenceResults } from "./types";

export const imageSegmentationsSchema = z.preprocess(
  (arg, ctx) => {
    if (typeof arg !== "string") {
      return arg;
    }

    let parsedArg;
    try {
      parsedArg = JSON.parse(arg);
    } catch {
      ctx.addIssue({
        code: z.ZodIssueCode.custom,
        message: "Unable to parse raw data",
        fatal: true,
      });

      return;
    }

    return parsedArg;
  },
  z
    .object({
      meta: z.object({
        image: z.object({
          width: z.number(),
          height: z.number(),
        }),
      }),
      output: z.array(
        z.object({
          label: z.string().nullable(),
          score: z.number().nullable(),
          mask: z.string(),
        }),
      ),
    })
    .transform(
      async (
        { meta, output },
        ctx,
      ): Promise<ImageSegmentationInferenceResults> => {
        const {
          image: { width: imageWidth, height: imageHeight },
        } = meta;

        const maskImages = await Promise.all(
          output.map(({ mask }) => convertBase64ImageToBlob(mask)),
        );

        if (!maskImages.every((mask) => mask != null)) {
          ctx.addIssue({
            code: z.ZodIssueCode.custom,
            message: "Not all masks were valid",
            fatal: true,
          });

          return z.NEVER;
        }

        const colorCache = new Map<
          string | null,
          [r: number, g: number, b: number]
        >();

        return {
          imageWidth,
          imageHeight,
          segmentations: await Promise.all(
            output.map(async ({ label, score }, index) => {
              let color = colorCache.get(label);
              if (color === undefined) {
                color = await generateSegmentationColor(label);
                colorCache.set(label, color);
              }

              return {
                label,
                score,
                color,
                mask: maskImages[index],
              };
            }),
          ),
        };
      },
    ),
);

/**
 * Generates a deterministic RGB color for a segmentation label.
 */
async function generateSegmentationColor(
  label: string | null,
): Promise<[r: number, g: number, b: number]> {
  // The Web Crypto API's `subtle` object is typically a danger zone but the
  // `digest()` method is alright in this situation as it's not being used for
  // anything cryptographic, just to deterministically hash a label from which
  // bytes will be pulled for the RGB pixel. The reason for hashing is there
  // needs to be a large variety in the generated colors so users can easily
  // distinguish the different masks but the colors should be consistent between
  // results.
  const buff = await window.crypto.subtle.digest(
    // Not cryptographically secure but doesn't need to be
    "SHA-1",
    // String has to be encoded to UTF-8
    new TextEncoder().encode(label ?? ""),
  );

  // SHA-1 returns a 160 bit hash so the view will have exactly 20 elements
  const view = new Uint8Array(buff);

  // Deterministically grab 3 bytes from the start, middle, and end to use as
  // the R, G, and B channels, respectively.
  return [view[0], view[10], view[19]];
}
