import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { RootState } from "./store";
import { extractMustacheVariables } from "utils";
import {
  Benchmark,
  Benchmarks,
  CriteriaType,
  Evaluator,
  OutputType,
  RunTemplateCaseResult,
  RunTemplateResult,
  Template,
  TemplateInputField,
  TemplateLogs,
  TemplateVariant,
} from "generated/graphql";

const DefaultEvalutorTypes: { [key: string]: boolean } = {
  [CriteriaType.Equivalence]: true,
  [CriteriaType.Safety]: true,
  [CriteriaType.Fairness]: false,
  [CriteriaType.Privacy]: false,
  [CriteriaType.Equal]: false,
  // [CriteriaType.Truthfulness]: false,
  // [CriteriaType.Completeness]: false,
};

export const DefaultEvalutors = Object.keys(DefaultEvalutorTypes).map((type) => {
  return { type, production: false, active: Boolean(DefaultEvalutorTypes[type]) } as Evaluator;
});

export function appendLogs(
  responses: RunTemplateCaseResult[],
  newResponses: RunTemplateCaseResult[]
): RunTemplateCaseResult[] {
  const set = new Set(responses.map((r) => r.id));
  const filtered = newResponses.filter((r) => !set.has(r.id));
  return [...responses, ...filtered];
}

export function appendBenchmarks(responses: Benchmark[], newResponses: Benchmark[]): Benchmark[] {
  const set = new Set(responses.map((r) => r.id));
  const filtered = newResponses.filter((r) => !set.has(r.id));
  return [...responses, ...filtered];
}

export interface EvalState {
  needsRefresh: boolean;
  template?: Template;
  usedVariables: string[];
  unusedVariables: string[];
  values: { [key: string]: string };
  fields: { [key: string]: TemplateInputField };
  variant?: TemplateVariant;
  evaluating: boolean;
  runResults: RunTemplateResult[];
  evaluators: Evaluator[];
  percentages: { [key: string]: number };
  percentagesValid: boolean;
  percentagesChanged: boolean;

  logs: RunTemplateCaseResult[];
  logsCursor?: string;
  logsLoaded: boolean;
  logsLimit: number;

  benchmarks: Benchmark[];
  benchmarksCursor?: string;
  benchmarksLoaded: boolean;
  benchmarksLimit: number;
}

const initialState: EvalState = {
  needsRefresh: true,
  usedVariables: [],
  unusedVariables: [],
  values: {},
  fields: {},
  evaluating: false,
  runResults: [],
  evaluators: DefaultEvalutors,
  percentages: {},
  percentagesValid: true,
  percentagesChanged: false,

  logs: [],
  logsLoaded: false,
  logsLimit: 20,

  benchmarks: [],
  benchmarksLoaded: false,
  benchmarksLimit: 10,
};

function processTemplate(template: Template) {
  const messages = (template?.variants ?? []).flatMap((v) => [
    v.fields?.gpt?.systemMessage ?? "",
    ...(v.fields?.gpt?.messages?.map((m) => m.text ?? "") ?? []),
  ]);
  const variables = extractMustacheVariables(messages.join(" "));

  const used = new Set<string>(variables);
  const unusedVariables = (
    template?.fields?.inputs?.filter((input) => input.name && !used.has(input.name ?? "")) ?? []
  ).map((input) => input.name ?? "");

  const fields = (template?.fields?.inputs ?? []).reduce((acc, input) => {
    acc[input.name] = input;
    return acc;
  }, {} as { [key: string]: TemplateInputField });

  return {
    usedVariables: variables,
    unusedVariables,
    fields,
  };
}

function processPercentages(
  oldPercentages: { [key: string]: number },
  variants: TemplateVariant[]
): {
  percentagesValid: boolean;
  percentages: { [key: string]: number };
} {
  // remove all keys that are not in the variants
  const keys = new Set(variants.map((v) => v.id));
  const percentages = Object.keys(oldPercentages)
    .filter((key) => keys.has(key))
    .reduce((acc, key) => {
      acc[key] = oldPercentages[key];
      return acc;
    }, {} as { [key: string]: number });

  const values = Object.values(percentages);

  if (variants.length === 0) {
    return {
      percentagesValid: true,
      percentages: {},
    };
  }

  if (variants.length === 1) {
    return {
      percentagesValid: true,
      percentages: { [variants[0].id]: 100 },
    };
  }

  const total = values.reduce((acc, value) => acc + value, 0);
  if (total !== 100) {
    return {
      percentagesValid: false,
      percentages: oldPercentages,
    };
  }

  return {
    percentagesValid: true,
    percentages: oldPercentages,
  };
}

function enumKeys<O extends object, K extends keyof O = keyof O>(obj: O): K[] {
  return Object.keys(obj).filter((k) => Number.isNaN(+k)) as K[];
}

// compute it once
const ValidEvaluators = {} as { [key: string]: boolean };
for (const val of enumKeys(CriteriaType)) {
  if (typeof CriteriaType[val] === "string") {
    ValidEvaluators[CriteriaType[val]] = true;
  }
}

function sanitizeEvaluators(evals: Evaluator[]): Evaluator[] {
  // keep nondefault evaluators that are valid
  const nondefault = [] as Evaluator[];
  const templateEvaluators = evals.reduce((acc, evaluator) => {
    if (ValidEvaluators[evaluator.type]) {
      if (typeof DefaultEvalutorTypes[evaluator.type] !== "undefined") {
        acc[evaluator.type] = evaluator;
      } else {
        nondefault.push(evaluator);
      }
    }
    return acc;
  }, {} as { [key: string]: Evaluator });

  // ensure all default evaluators are present
  const newEvaluators = DefaultEvalutors.map((evaluator) => {
    return { ...evaluator, ...(templateEvaluators[evaluator.type] ?? {}) };
  });

  return [...newEvaluators, ...nondefault];
}

export const evalSlice = createSlice({
  name: "eval",
  initialState,
  reducers: {
    refreshTemplate: (state) => {
      return { ...state, needsRefresh: true };
    },
    templateChanged: (state, action: PayloadAction<Template>) => {
      const percentages = (action.payload?.fields?.percentages ?? []).reduce((acc, p) => {
        acc[p.variantID] = p.percentage;
        return acc;
      }, {} as { [key: string]: number });

      return {
        ...state,
        ...processTemplate(action.payload),
        template: action.payload,
        values: {},
        evaluating: false,
        runResults: [],
        evaluators: sanitizeEvaluators(action.payload.fields.evaluators ?? DefaultEvalutors),
        ...processPercentages(percentages, action.payload.variants ?? []),
        needsRefresh: false,
        percentagesChanged: false,

        logs: [],
        logsCursor: undefined,
        logsLoaded: false,

        benchmarks: [],
        benchmarksCursor: undefined,
        benchmarksLoaded: false,
      };
    },
    variantChanged: (state, action: PayloadAction<TemplateVariant>) => {
      if (!state.template) {
        return state;
      }

      let exist = false;
      const newVariants = (state.template.variants ?? []).map((v) => {
        if (v.id === action.payload.id) {
          exist = true;
          return action.payload;
        }
        return v;
      });

      if (!exist) {
        newVariants.push(action.payload);
      }

      const newTemplate = { ...state.template, variants: newVariants };
      return {
        ...state,
        variant: action.payload,
        template: newTemplate,
        ...processTemplate(newTemplate),

        logs: [],
        logsCursor: undefined,
        logsLoaded: false,

        benchmarks: [],
        benchmarksCursor: undefined,
        benchmarksLoaded: false,
      };
    },
    valueChanged: (state, action: PayloadAction<{ name: string; value: string }>) => {
      state.values[action.payload.name] = action.payload.value;
      return state;
    },
    evaluatorsChanged: (state, action: PayloadAction<Evaluator[]>) => {
      state.evaluators = sanitizeEvaluators(action.payload);
      return state;
    },
    inputFieldChanged: (state, action: PayloadAction<TemplateInputField>) => {
      state.fields[action.payload.name] = action.payload;
      return state;
    },
    outputTypeChanged: (state, action: PayloadAction<OutputType>) => {
      if (state.template?.fields?.output) {
        state.template.fields.output.type = action.payload;
      }
      return state;
    },
    outputSchemaChanged: (state, action: PayloadAction<string>) => {
      if (state.template?.fields?.output) {
        try {
          JSON.parse(action.payload);
          state.template.fields.output.schema = action.payload;
        } catch (e) {
          // do nothing
        }
      }
      return state;
    },
    setEvaluating: (state, action: PayloadAction<boolean>) => {
      state.evaluating = action.payload;
      return state;
    },
    setTemplateResult: (state, action: PayloadAction<RunTemplateResult>) => {
      state.evaluating = false;
      state.runResults = [action.payload, ...state.runResults];
      return state;
    },
    variantPercentageChanged: (state, action: PayloadAction<{ id: string; traffic: number }>) => {
      const { percentagesValid, percentages } = processPercentages(
        { ...state.percentages, [action.payload.id]: action.payload.traffic },
        state.template?.variants ?? []
      );
      state.percentagesValid = percentagesValid;
      state.percentages = percentages;
      state.percentagesChanged = true;
      return state;
    },
    addLogs: (state, action: PayloadAction<TemplateLogs | null | undefined>) => {
      const newCursor = action.payload?.cursor;
      const newToOld = [...(action.payload?.logs ?? [])];
      if (newCursor && newCursor !== state.logsCursor && newToOld.length > 0) {
        // we loaded a new page. let's push the result
        state.logs = appendLogs(state.logs, newToOld);
        // now it won't load again
        state.logsCursor = newCursor;
      }

      if (newToOld.length < state.logsLimit) {
        state.logsLoaded = true;
      }

      return state;
    },
    addBenchmarks: (state, action: PayloadAction<Benchmarks | null | undefined>) => {
      const newCursor = action.payload?.cursor;
      const newToOld = [...(action.payload?.benchmarks ?? [])];
      if (newCursor && newCursor !== state.benchmarksCursor && newToOld.length > 0) {
        // we loaded a new page. let's push the result
        state.benchmarks = appendBenchmarks(state.benchmarks, newToOld);
        // now it won't load again
        state.benchmarksCursor = newCursor;
      }

      if (newToOld.length < state.benchmarksLimit) {
        state.benchmarksLoaded = true;
      }

      return state;
    },
  },
});

export const {
  refreshTemplate,
  templateChanged,
  variantChanged,
  valueChanged,
  inputFieldChanged,
  outputTypeChanged,
  outputSchemaChanged,
  setEvaluating,
  setTemplateResult,
  evaluatorsChanged,
  variantPercentageChanged,
  addLogs,
  addBenchmarks,
} = evalSlice.actions;

export const selectEval = (state: RootState) => state.eval;

export default evalSlice.reducer;
