import type { Model, Prediction } from "replicate";
import type { APIPrediction } from "./types";
import { API_BASE_URL } from "./constants";

export class PredictionError extends Error {
  model: string;
  prompt: string;

  constructor(message: string, model: string, prompt: string) {
    super(message);
    this.model = model;
    this.prompt = prompt;
  }
}

export class RateLimitError extends Error {
  constructor() {
    super("Rate limit exceeded");
  }
}

export async function fetchPrediction({
  id,
  signal,
}: {
  id?: string;
  signal: AbortSignal;
}): Promise<APIPrediction> {
  if (!id) {
    throw new Error("No prediction uuid provided");
  }

  const res = await fetch(`${API_BASE_URL}/api/poll?id=${id}`, {
    signal,
  });

  if (res.status === 429) {
    throw new RateLimitError();
  }

  if (!res.ok) {
    throw new PredictionError("Failed to fetch prediction", "", "");
  }

  return await res.json();
}

const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));

export async function createPrediction({
  model,
  prompt,
  name,
  defaultExample,
}: {
  model: Model;
  prompt: string;
  name: string;
  defaultExample: Prediction;
}): Promise<APIPrediction> {
  const isCreatingDefaultExample =
    // biome-ignore lint/suspicious/noExplicitAny: Safe to ignore because we expect the default example to be an object with a string key and any value.
    prompt === (defaultExample.input as Record<string, any>)[name];

  if (isCreatingDefaultExample) {
    // When re-creating the default example, let's avoid
    // a round trip to the server and return the default example directly.
    const randomSleep = Math.floor(Math.random() * 2000) + 1000;
    await sleep(randomSleep);
    return Promise.resolve(defaultExample as APIPrediction);
  }

  const res = await fetch(`${API_BASE_URL}/api/prediction`, {
    method: "POST",
    headers: {
      "Content-Type": "application/json",
    },
    body: JSON.stringify({
      model: `${model.owner}/${model.name}`,
      version: model.latest_version?.id,
      input: {
        [name]: prompt,
      },
      defaultExample,
    }),
  });

  if (res.status === 429) {
    throw new RateLimitError();
  }

  if (!res.ok) {
    throw new PredictionError(
      "Failed to create prediction",
      `${model.owner}/${model.name}`,
      prompt,
    );
  }

  return await res.json();
}
