import { ReactNode, useEffect, useState } from 'react';
import { Merge, UnionToIntersection } from 'system';
import { SchemaOf, object } from 'yup';

export type WizardAction<T> = {
  label?: string;
  onClick?: ({
    values,
    handleNext,
    handleBack,
    setStepIndex,
    last,
  }: {
    values: Merge<UnionToIntersection<T>>;
    handleNext: () => void;
    handleBack: () => void;
    last: boolean;
    setStepIndex: (index: number) => void;
  }) => Record<string, unknown> | Promise<Record<string, unknown>> | void | Promise<void>;
  disabled?: boolean;
  skipValidation?: boolean;
  align?: 'left' | 'right';
  variant?: 'text' | 'outlined' | 'contained';
  fullWidth?: boolean;
  loading?: boolean;
  splitOptions?: {
    label?: string;
    onClick?: ({
      values,
      handleNext,
      handleBack,
      setStepIndex,
      last,
    }: {
      values: Merge<UnionToIntersection<T>>;
      handleNext: () => void;
      handleBack: () => void;
      setStepIndex: (index: number) => void;
      last: boolean;
    }) => Record<string, unknown> | Promise<Record<string, unknown>> | void | Promise<void>;
  }[];
};

export type WizardStep<T> = {
  header?: {
    icon?: string | ReactNode;
    label?: string;
  };

  schema?: SchemaOf<T>;
  FieldsComponent: (props?: Record<string, unknown>) => JSX.Element;
  props?: Record<string, unknown>;

  actions: WizardAction<T>[];
};

export function useWizard<T extends Record<string, unknown>>({
  steps,
  goToStep,
}: {
  steps: WizardStep<T>[];
  goToStep?: number;
}) {
  const [stepIndex, setStepIndex] = useState(0);

  const activeStep = steps[stepIndex];
  const first = stepIndex === 0;
  const last = stepIndex === steps.length - 1;

  useEffect(() => {
    if ((goToStep ?? 0) - 1 >= 0) setStepIndex((goToStep ?? 1) - 1);
  }, [goToStep]);

  const handleNext = () => {
    return new Promise((resolve) => {
      setTimeout(() => {
        if (!last) {
          setStepIndex(Math.min(stepIndex + 1, steps.length - 1));
        }
        resolve('');
      }, 1);
    });
  };

  const handleBack = () => {
    if (!first) {
      setStepIndex(Math.max(stepIndex - 1, 0));
    }
  };

  const schema = steps
    .slice(0, stepIndex + 1)
    .filter((step) => step.schema)
    .reduce(
      (currentSchema, { schema }) => (schema ? currentSchema.concat(schema) : currentSchema),
      object()
    );

  return {
    handleNext,
    handleBack,
    schema,
    setStepIndex,
    stepIndex,
    activeStep,
    first,
    last,
  };
}
