import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { AiModel, Models, ModelDatasets, AiModelDataset, AiModelType } from "generated/graphql";
import { RootState } from "./store";

export interface ModelState {
  initialized: boolean;
  models: AiModel[];
  modelDatasets: { [id: string]: ModelDatasetState };
  cursor?: string | null;
  loaded: boolean;
  limit: number;
  modelLabels: { [id: string]: string };
  currentModel: string;
}

export interface ModelDatasetState {
  datasets: AiModelDataset[];
  cursor?: string | null;
  loaded: boolean;
}

export function appendResult<T extends { id: string }>(responses: T[], newResponses: T[]): T[] {
  const existingIds = new Set(responses.map((r) => r.id));
  const filtered = newResponses.filter((r) => !existingIds.has(r.id));
  return [...responses, ...filtered];
}

const defaultModels = {
  [AiModelType.Gpt4]: "GPT-4",
  [AiModelType.Gpt4Turbo]: "GPT-4 Turbo",
  [AiModelType.Chatgpt_3_5Turbo]: "GPT-3.5 Turbo",
  [AiModelType.Chatgpt_3_5Turbo_16k]: "GPT-3.5 Turbo (16K)",
};

export function getModelMaxToken(model?: AiModelType) {
  switch (model) {
    case AiModelType.Gpt4:
      return 8191;
    case AiModelType.Gpt4_32k:
      return 32767;
    case AiModelType.Gpt4Turbo:
      return 128000;
    case AiModelType.Chatgpt_3_5Turbo:
      return 4096;
    case AiModelType.Chatgpt_3_5Turbo_16k:
      return 16384;
    default:
      return 4096;
  }
}

export const CURRENT_AI_MODEL_LOCAL_STORAGE_KEY = "gists-default-ai-model-key";

const initialState: ModelState = {
  initialized: false,
  models: [],
  modelDatasets: {},
  cursor: null,
  loaded: false,
  limit: 24,
  modelLabels: defaultModels,
  currentModel: localStorage.getItem(CURRENT_AI_MODEL_LOCAL_STORAGE_KEY) ?? AiModelType.Chatgpt_3_5Turbo,
};

export const modelsSlice = createSlice({
  name: "models",
  initialState,
  reducers: {
    addModels: (state, action: PayloadAction<Models | null | undefined>) => {
      state.initialized = true;
      const newCursor = action.payload?.cursor;
      const newToOld = [...(action.payload?.models ?? [])];
      if (newCursor && newCursor !== state.cursor && newToOld.length > 0) {
        // we loaded a new page. let's push the result
        state.models = appendResult(state.models, newToOld);
        // now it won't load again
        state.cursor = newCursor;
      }

      if (newToOld.length < state.limit) {
        state.loaded = true;
      }

      action.payload?.models?.forEach((m) => {
        if (m.isReady) {
          state.modelLabels[m.id] = m.name;
        }
      });

      return state;
    },
    switchModel: (state, action: PayloadAction<string>) => {
      // ignore until all models are loaded
      if (!state.initialized) {
        return;
      }

      if (!state.modelLabels[action.payload]) {
        state.currentModel = AiModelType.Chatgpt_3_5Turbo;
      } else {
        state.currentModel = action.payload;
      }

      if (state.currentModel && !state.modelLabels[state.currentModel]) {
        state.modelLabels[state.currentModel] = "Loading custom model...";
      }

      localStorage.setItem(CURRENT_AI_MODEL_LOCAL_STORAGE_KEY, state.currentModel);
    },
    // add a single model, it doesn't change the cursor
    createModel: (state, action: PayloadAction<AiModel>) => {
      if (action.payload) {
        state.models = [action.payload, ...state.models];
      }
      if (action.payload?.id && action.payload?.name && action.payload?.isReady) {
        state.modelLabels[action.payload.id] = action.payload.name;
      }
      return state;
    },
    updateModel: (state, action: PayloadAction<AiModel>) => {
      const newResponses = state.models.map((m) => (m.id === action.payload.id ? action.payload : m));
      state.models = newResponses;
      if (action.payload?.id && action.payload?.name && action.payload?.isReady) {
        state.modelLabels[action.payload.id] = action.payload.name;
      }
      return state;
    },
    removeModels: (state, action: PayloadAction<string[]>) => {
      const toBeDeleted = new Set<string>(action.payload ?? []);
      state.models = state.models.filter((p) => !toBeDeleted.has(p.id));
      action.payload.forEach((id) => {
        delete state.modelLabels[id];
      });
      return state;
    },
    clearModels: (state) => {
      state.initialized = false;
      state.models = [];
      state.modelDatasets = {};
      state.loaded = false;
      state.cursor = null;
      state.modelLabels = defaultModels;
      return state;
    },
    addModelDatasets: (state, action: PayloadAction<ModelDatasets>) => {
      state.initialized = true;
      const newCursor = action.payload?.cursor;
      const newToOld = [...(action.payload?.datasets ?? [])];

      const modelID = action?.payload?.modelID;
      if (!modelID) {
        return;
      }

      if (newCursor && newCursor !== state.modelDatasets[modelID]?.cursor && newToOld.length > 0) {
        if (state.modelDatasets[action.payload.modelID]) {
          // we loaded a new page. let's push the result
          state.modelDatasets[modelID].datasets = appendResult(state.modelDatasets[modelID].datasets, newToOld);
        } else {
          state.modelDatasets[modelID].datasets = newToOld;
        }

        // now it won't load again
        state.modelDatasets[action.payload.modelID].cursor = newCursor;
      }

      if (newToOld.length < state.limit) {
        state.modelDatasets[modelID].loaded = true;
      }

      return state;
    },
  },
});

export const { addModels, createModel, updateModel, removeModels, switchModel, addModelDatasets, clearModels } =
  modelsSlice.actions;

export const selectModels = (state: RootState) => state.models;

export default modelsSlice.reducer;
