import {PyodideClient} from 'components/pyodide/PyodideProvider';
import {
  GradingScripts,
  type AdvancedConsoleUnitTest,
} from '../../models/unitTests/advancedConsole';
import {ProjectFilesCode, ProjectFileStructure} from 'projects/types';
import {getAllFileNames} from 'ide/utils/general';
import Swal from 'sweetalert2';
import {geometricMLE} from './distributions/geometric';
import {binomialMLE} from './distributions/binomial';

/**
 * Constants indicating for what reason the advanced autograder
 * returned a failure.
 *
 * PythonError - Pyodide failed to interpret the student code and raised an exception
 * AdvancedAutograderError - Anything raised by the Python grading script through `autograde_error(message)`
 * RandomVarError - Any issue with the distribution of a random variable checked by the JS `verifyRandomVariables`
 */
export enum AdvancedAutograderResultTypes {
  PythonError = 'PythonError',
  AdvancedAutograderError = 'AdvancedAutograderError',
  RandomVarError = 'RandomVarError',
  Success = 'Success',
}

/**
 * Data loaded from the server for running an advanced console unit test.
 */
export interface LoadedAdvancedConsoleUnitTestData {
  /** The unit tests to run. */
  unitTests: AdvancedConsoleUnitTest[];
}

export type RandomVars = Map<string, any>;
export type ErrorState = Map<string, any>;

export interface AdvancedAutograderState {
  registeredVars: RandomVars;
  errorState: ErrorState;
  stopRequested: boolean;
}

export const DEFAULT_ADVANCED_AUTOGRADER_STATE = {
  registeredVars: new Map(),
  errorState: new Map([['has_error', false]]),
  stopRequested: false,
};

// Toast displayed to the user while the advanced console unit tests are running.
export const ADVANCED_CONSOLE_UNIT_TEST_RUNNING_TOAST = Swal.mixin({
  title: 'Running unit tests...',
  toast: true,
  position: 'top-end',
  showConfirmButton: false,
  icon: 'info',
});

const getErrorMessage = (randomVar: Map<string, any>): string => {
  if (randomVar.get('error_str')) return randomVar.get('error_str');

  // Default error strings for different random variables
  switch (randomVar.get('dist')) {
    case 'uniform':
      const min = randomVar.get('params').get('min');
      const max = randomVar.get('params').get('max');
      const observedMin = Math.min(...randomVar.get('outputs'));
      const observedMax = Math.max(...randomVar.get('outputs'));

      return `The bounds of one of your random function calls are incorrect. We expected (${min}, ${max}) but found (${observedMin}, ${observedMax}).`;
    default:
      return 'One or more of your random function calls are incorrect. Double check how you create and use your random variables!';
  }
};

/**
 * Compares the distribution of each variable observed over `maxTrials` to
 * the distribution expected for the variable
 * @param randomVars - expected distribution info and individual observations
 * for each RV
 * @returns `allCorrect` - True if everything is correct. `errorMessage` - error message if the test failed, null otherwise
 */
export const verifyRandomVariables = (
  randomVars: RandomVars,
): {
  allCorrect: boolean;
  errorMessage: string | null;
} => {
  // Check ranges for variables
  let allCorrect = true;
  let errorMessage = null;

  randomVars.forEach((rv, var_name) => {
    if (rv.get('dist') === 'uniform') {
      // Method: Determine whether we see the correct min and max
      const observedMin = Math.min(...rv.get('outputs'));
      const observedMax = Math.max(...rv.get('outputs'));

      if (
        observedMin !== rv.get('params').get('min') ||
        observedMax !== rv.get('params').get('max')
      ) {
        allCorrect = false;
        errorMessage = getErrorMessage(rv);
      }
    } else if (rv.get('dist') === 'geometric') {
      // Method: MLE -- Estimate p for the achieved distribution, see if its close enough to the expected
      const mle = geometricMLE(rv.get('outputs'));
      const p = rv.get('params').get('p');

      if (Math.abs(mle - p) > rv.get('params').get('epsilon')) {
        allCorrect = false;
        errorMessage = getErrorMessage(rv);
      }
    } else if (rv.get('dist') === 'binomial') {
      // Method: MLE -- Estimate p for achieved distribution, see if its close enough to the expeceted
      const mle = binomialMLE(rv.get('outputs'));
      const p = rv.get('params').get('p');

      if (Math.abs(mle - p) > rv.get('params').get('epsilon'))
        allCorrect = false;
      errorMessage = getErrorMessage(rv);
    } else {
      console.error(
        `Unrecognized distribution ${rv.get(
          'dist',
        )} for ${var_name}! Aborting.`,
      );
      allCorrect = false;
      return;
    }
  });

  return {allCorrect, errorMessage};
};

/**
 * Helper function to calculate # trials necessary to be falsely wrong with
 * < `tolerance` probability.
 * @param vars - distribution information of the random variables
 * @returns - number of trials necessary
 */
export const calculateMaxTrials = (vars: RandomVars) => {
  // If there are no random variables, only need 1 trial
  let maxTrials = 1;

  // Iterate over every variable and use info to calculate trials
  vars.forEach(v => {
    const dist = v.get('dist');

    if (dist == 'uniform') {
      const numValues =
        v.get('params').get('max') - v.get('params').get('min') + 1;
      const p = 1 / numValues;
      const q = 1 - p;

      const numObservations = Math.ceil(
        Math.log(v.get('params').get('tolerance')) / Math.log(q),
      );

      // Divide observations by # times observed per trial
      const numTrials = numObservations / v.get('outputs').length;

      maxTrials = Math.max(maxTrials, numTrials);
    } else if (dist == 'geometric') {
      // Calculate based on confidence intervals for p_mle to be within epsilon of p
      const Z_SCORE_95 = 1.96;
      const EPSILON = v.get('params').get('epsilon');

      // Approximate the variance of p_mle, using the Delta Method
      const p = v.get('params').get('p');
      const pVariance = p * (1 - p);

      // Calculate n for 95% confidence that p_mle within EPSILON of p
      const sampleSize =
        (Z_SCORE_95 * Z_SCORE_95 * pVariance) / (EPSILON * EPSILON);

      maxTrials = Math.max(maxTrials, Math.ceil(sampleSize));
    } else if (dist == 'binomial') {
      const Z_SCORE_95 = 1.96;
      const EPSILON = v.get('params').get('epsilon');

      // Calculate the variance of the distribution of p_mle
      const p = v.get('params').get('p');
      const pVariance = p * (1 - p);

      const sampleSize =
        (Z_SCORE_95 * Z_SCORE_95 * pVariance) / (EPSILON * EPSILON);

      maxTrials = Math.max(maxTrials, Math.ceil(sampleSize));
    } else {
      console.error(
        'ERROR: Non uniform distribution encountered in variable map',
      );
    }
  });

  return maxTrials;
};

/**
 * Run a single trial of the students code with the given autograding scripts.
 * NOTE: Updates `randomVars` object by reference with the observed random
 * variables.
 *
 * @param pyodideClient - Pyodide runner
 * @param code - Student's code
 * @param gradingScripts - TA supplied `grade_input` and `grade_output`
 * functions
 * @param randomVars - JS object aggregating observed random variables across
 * all trials
 * @returns Object of `advancedAutograderState` - random variables and error
 * state from autograder, and `error` - errors from the actual Python
 * interpreter itself
 */
export const runTrial = async (
  pyodideClient: PyodideClient,
  filesCode: ProjectFilesCode,
  fileStructure: ProjectFileStructure,
  gradingScripts: GradingScripts,
  randomVars: RandomVars,
): Promise<{
  advancedAutograderState: AdvancedAutograderState;
  error: string[];
}> => {
  const allFileNames = getAllFileNames(fileStructure);
  const mainFile = allFileNames.find(file => file.name === 'main.py');
  if (!mainFile) {
    return {
      advancedAutograderState: DEFAULT_ADVANCED_AUTOGRADER_STATE,
      error: ['No main.py file found'],
    };
  }
  const {advancedAutograderState, error} = await pyodideClient.testCode(
    filesCode[mainFile.id].content,
    mainFile,
    [],
    gradingScripts,
  );

  // Merge the observed vars from this trial with `randomVars` (persists over all trials)
  // If key is not in there, transfer all data + outputs, otherwise append observation to outputs
  advancedAutograderState.registeredVars.forEach((v, k) => {
    if (!randomVars.has(k)) {
      randomVars.set(k, v);
    } else {
      randomVars
        .get(k)
        .get('outputs')
        .push(...v.get('outputs'));
    }
  });

  return {advancedAutograderState, error};
};
