import { CircularProgress, MenuItem, Select } from "@mui/material";
import { endOfMonth, format, startOfMonth } from "date-fns";
import { useEffect, useMemo, useState } from "react";
import { useTeamTokenUsage } from "../../../../../../api/queryHooks";
import { formatDateInputValue } from "../../../../../../utils/datetime";
import { TokenUsageGraph } from "./TokenGraphUtils/TokenUsageGraph";
import { ChartEndpointEntry } from "./TokenGraphUtils/types";
import {
  tokenListToCharEndpointData,
  usersUsageToAggregateCharEndpointData,
} from "./TokenGraphUtils/utils";

export const TeamTokenUsageComponent = () => {
  const [selectedUser, setSelectedUser] = useState<string>("");
  const [dateRange, setDateRange] = useState<[Date, Date]>([
    startOfMonth(new Date()),
    new Date(),
  ]);
  const { isLoading, data } = useTeamTokenUsage(dateRange[0], dateRange[1]);

  const [dateRangeMonth, setDateRangeMonth] = useState<[Date, Date]>([
    startOfMonth(new Date()),
    new Date(),
  ]);
  const { isLoading: isLoadingMonth, data: dataMonth } = useTeamTokenUsage(
    dateRangeMonth[0],
    dateRangeMonth[1],
  );
  const [totalTokenMetrics, setTotalTokenMetrics] =
    useState<ChartEndpointEntry>({} as ChartEndpointEntry);
  const [month, setMonth] = useState<string>(format(new Date(), "yyyy-MM"));
  const thisMonth = format(new Date(), "yyyy-MM");

  const handleMonthChange = (event: React.ChangeEvent<HTMLInputElement>) => {
    const [year, month] = event.target.value.split("-").map(Number);
    const extractedMonth = new Date(year, month - 1, 1); // Months are 0-indexed in JavaScript Date
    setMonth(`${year}-${month.toString().padStart(2, "0")}`);

    setDateRangeMonth([
      startOfMonth(extractedMonth),
      endOfMonth(extractedMonth),
    ]);
  };

  const users = useMemo(() => {
    if (!data) {
      return [];
    }

    return data.users_usage.map((user_usage) => ({
      username: user_usage.username,
      email: user_usage.email,
    }));
  }, [data]);

  const perUserChartData: { [key: string]: ChartEndpointEntry[] } =
    useMemo(() => {
      if (!data) {
        return {};
      }

      return data.users_usage.reduce((acc, user_usage) => {
        return {
          ...acc,
          [user_usage.username]: tokenListToCharEndpointData(
            user_usage.token_usage.token_list,
          ),
        };
      }, {});
    }, [data]);

  const perUserChartMonthData: { [key: string]: ChartEndpointEntry[] } =
    useMemo(() => {
      if (!dataMonth) {
        return {};
      }

      return dataMonth.users_usage.reduce((acc, user_usage) => {
        return {
          ...acc,
          [user_usage.username]: tokenListToCharEndpointData(
            user_usage.token_usage.token_list,
          ),
        };
      }, {});
    }, [dataMonth]);

  const perUserTokenUsageData: {
    [key: string]: { total_cost: number; total_token_cost: number };
  } = useMemo(() => {
    if (!data) {
      return {};
    }

    return data.users_usage.reduce((acc, user_usage) => {
      return {
        ...acc,
        [user_usage.username]: {
          total_cost: user_usage.token_usage.token_list,
          total_token_cost: user_usage.token_usage.token_list,
        },
      };
    }, {});
  }, [data]);

  const teamAggregateChartData: ChartEndpointEntry[] = useMemo(() => {
    if (!data) {
      return [];
    }

    const TOTAL_TOKENS = 75000000; // TODO: this is hard coded for the moment
    let ret: ChartEndpointEntry[] = usersUsageToAggregateCharEndpointData(
      perUserChartMonthData,
    );
    if (ret.length === 0) {
      setTotalTokenMetrics({} as ChartEndpointEntry);
      return [];
    }
    const totalTokensUsed = ret.reduce(
      (acc, entry) => ({
        totalTokensUsed: acc.totalTokensUsed + entry.totalTokens,
      }),
      { totalTokensUsed: 0 },
    );
    const totalTokenMetrics: ChartEndpointEntry = {
      endpoint: "Total Tokens",
      promptTokens: 0,
      completionTokens: 0,
      totalTokens: TOTAL_TOKENS,
      tokensLeft: Math.max(0, TOTAL_TOKENS - totalTokensUsed.totalTokensUsed),
      tokensUsed: totalTokensUsed.totalTokensUsed,
      models: {},
    };
    setTotalTokenMetrics(totalTokenMetrics);
    ret.unshift(totalTokenMetrics);

    return ret;
  }, [dataMonth]);

  useEffect(() => {
    if (isLoading) {
      return;
    }

    if (!data || selectedUser) {
      return;
    }

    setSelectedUser(data.users_usage[0].username);
  });

  return (
    <div className="p-8 bg-white rounded-xl shadow-lg border-2 border-gray-200 transition-colors duration-300">
      <h2 className="text-3xl font-bold text-gray-800 mb-6 pb-2 border-b-2 border-gray-200">
        Team Token Usage
      </h2>

      <div className="flex flex-col md:flex-row justify-between items-center gap-6 mb-8">
        <div className="w-full md:w-1/3">
          <label className="block text-gray-700 font-medium mb-2">Month:</label>
          <input
            className="w-full px-4 py-2 border-2 border-gray-300 rounded-lg focus:ring-2 focus:ring-primary focus:border-primary transition duration-200 ease-in-out"
            type="month"
            value={month}
            onChange={(event) => {
              handleMonthChange(event);
            }}
          />
        </div>
      </div>

      {totalTokenMetrics && totalTokenMetrics.endpoint && (
        <div className="bg-gray-50 rounded-lg p-6 border-2 border-gray-200 mb-8">
          <div className="flex flex-col items-center">
            {thisMonth === month && (
              <div className="text-2xl font-bold text-primary mb-2">
                {(
                  (totalTokenMetrics.tokensLeft /
                    totalTokenMetrics.totalTokens) *
                  100
                ).toFixed(2)}
                % Tokens Remaining
              </div>
            )}
            <div className="text-lg text-gray-700">
              {totalTokenMetrics.tokensUsed.toLocaleString()} /{" "}
              {totalTokenMetrics.totalTokens.toLocaleString()} Tokens Used
            </div>
          </div>
        </div>
      )}

      {isLoadingMonth ? (
        <div className="flex justify-center items-center py-20">
          <CircularProgress className="text-primary" />
        </div>
      ) : (
        <div className="bg-gray-50 rounded-lg p-6 border-2 border-gray-200 mb-8">
          <TokenUsageGraph data={teamAggregateChartData} />
        </div>
      )}

      <h3 className="text-2xl font-bold text-gray-800 mb-6 pb-2 border-b-2 border-gray-200">
        Individual User Usage
      </h3>

      <div className="flex flex-col md:flex-row justify-between items-center gap-6 mb-8">
        <div className="w-full md:w-1/3">
          <label className="block text-gray-700 font-medium mb-2">
            Start Date:
          </label>
          <input
            type="date"
            disabled={isLoading}
            className="w-full px-4 py-2 border-2 border-gray-300 rounded-lg focus:ring-2 focus:ring-primary focus:border-primary transition duration-200 ease-in-out"
            value={formatDateInputValue(dateRange[0])}
            onChange={(e) =>
              setDateRange(([_, oldEnd]) => [new Date(e.target.value), oldEnd])
            }
            max={formatDateInputValue(dateRange[1])}
          />
        </div>
        <div className="w-full md:w-1/3">
          <label className="block text-gray-700 font-medium mb-2">
            End Date:
          </label>
          <input
            type="date"
            disabled={isLoading}
            className="w-full px-4 py-2 border-2 border-gray-300 rounded-lg focus:ring-2 focus:ring-primary focus:border-primary transition duration-200 ease-in-out"
            value={formatDateInputValue(dateRange[1])}
            onChange={(e) =>
              setDateRange(([oldStart, _]) => [
                oldStart,
                new Date(e.target.value),
              ])
            }
            min={formatDateInputValue(dateRange[0])}
            max={formatDateInputValue(new Date())}
          />
        </div>
        <div className="w-full md:w-1/3">
          <label className="block text-gray-700 font-medium mb-2">
            Select User
          </label>
          <Select
            disabled={isLoading}
            className="w-full border-2 border-gray-300 rounded-lg focus:ring-2 focus:ring-primary focus:border-primary transition duration-200 ease-in-out"
            value={selectedUser}
            onChange={(e) => setSelectedUser(e.target.value as string)}
          >
            {users.map((user_info) => (
              <MenuItem key={user_info.username} value={user_info.username}>
                {user_info.email}
              </MenuItem>
            ))}
          </Select>
        </div>
      </div>

      {isLoading ? (
        <div className="flex justify-center items-center py-20">
          <CircularProgress className="text-primary" />
        </div>
      ) : (
        selectedUser && (
          <div className="bg-gray-50 rounded-lg p-6 border-2 border-gray-200">
            <TokenUsageGraph data={perUserChartData[selectedUser]} />
          </div>
        )
      )}
    </div>
  );
};
