import {
    Table as TableInstance,
    TableOptions,
    RowData,
    useReactTable,
    flexRender,
    getCoreRowModel,
    getSortedRowModel,
    TableState,
    Column,
    SortDirection,
    SortingState,
    functionalUpdate,
    RowSelectionState,
    VisibilityState,
    getPaginationRowModel,
    PaginationState,
    Row,
    ExpandedState,
    Cell,
} from '@tanstack/react-table';
import { CSSProperties, ForwardedRef, Fragment, HTMLProps, ReactNode, useEffect } from 'react';
import { renderDefaultSorting } from './helpers';
import cn from 'classnames';

type SortingProps<TableData> = Pick<
    TableOptions<TableData>,
    'manualSorting' | 'enableSorting' | 'enableSortingRemoval' | 'sortDescFirst'
> &
    Partial<Pick<TableState, 'sorting'>> & {
        onSortingChange?: (newSorting: SortingState) => void;
        renderSortingControls?: (column: Column<TableData>) => ReactNode;
        sortDirections?: [SortDirection] | [SortDirection, SortDirection];
    };

type RowSelectionProps<TableData> = Pick<TableOptions<TableData>, 'enableRowSelection' | 'enableMultiRowSelection'> &
    Partial<Pick<TableState, 'rowSelection'>> & {
        onRowSelectionChange?: (rowSelection: RowSelectionState) => void;
    };

type ColumnVisibilityProps = Partial<Pick<TableState, 'columnVisibility'>> & {
    onColumnVisibilityChange?: (columnVisibility: VisibilityState) => void;
};

type ColumnPinningProps<TableData> = Partial<Pick<TableState, 'columnPinning'>> &
    Pick<TableOptions<TableData>, 'enablePinning' | 'enableColumnPinning'>;

type PaginationProps<TableData> = Partial<Pick<TableState, 'pagination'>> &
    Pick<TableOptions<TableData>, 'manualPagination' | 'pageCount' | 'autoResetPageIndex'> & {
        onPaginationChange?: (state: PaginationState) => void;
    };

type SpanIdGetter<TableData> = (row: Row<TableData>) => string;
type RowSpanHelper<TableData> = { [columnId: string]: SpanIdGetter<TableData> };

type OptionsProps<TableData> = Partial<Pick<TableOptions<TableData>, 'getRowId'>> & {
    onRowClick?: (row: Row<TableData>) => void;
    rowSpanHelper?: RowSpanHelper<TableData>;
};

type ExpandedProps<TableData> = Partial<
    Pick<TableOptions<TableData>, 'enableExpanding' | 'manualExpanding'> & {
        renderSubComponent: (row: Row<TableData>) => ReactNode;
        expandedState: ExpandedState;
        onExpandedChange: (expandedState: ExpandedState) => void;
    }
>;

type StylesProps<TableData> = {
    getRowClassName?: (row: Row<TableData>) => string;
};

type RenderProps<TableData> = Partial<{
    renderCell: (params: RenderCellParams<TableData>) => ReactNode;
    renderBody: (params: RenderBodyProps<TableData>) => ReactNode;
}>;

export type TableProps<TableData> = {
    instance?: ForwardedRef<TableInstance<TableData>>;
    shouldDisplayFooter?: boolean;
    noDataLabel?: ReactNode | undefined;
} & Omit<HTMLProps<HTMLTableElement>, 'data'> &
    Pick<TableOptions<TableData>, 'data' | 'columns' | 'meta'> &
    SortingProps<TableData> &
    RowSelectionProps<TableData> &
    ColumnVisibilityProps &
    ColumnPinningProps<TableData> &
    PaginationProps<TableData> &
    OptionsProps<TableData> &
    ExpandedProps<TableData> &
    StylesProps<TableData> &
    RenderProps<TableData>;

type Handlers<TableData> = Pick<
    TableProps<TableData>,
    'onRowSelectionChange' | 'onSortingChange' | 'onColumnVisibilityChange' | 'onPaginationChange'
>;

const Table = <TableData extends RowData>({
    instance,
    data,
    columns,
    sorting,
    onSortingChange,
    manualSorting,
    enableSorting,
    enableSortingRemoval,
    sortDescFirst,
    renderSortingControls = renderDefaultSorting,
    sortDirections = ['asc', 'desc'],
    enableRowSelection,
    enableMultiRowSelection,
    rowSelection,
    onRowSelectionChange,
    columnVisibility,
    onColumnVisibilityChange,
    columnPinning = { left: [], right: [] },
    enablePinning,
    enableColumnPinning,
    manualPagination,
    pageCount,
    autoResetPageIndex,
    onPaginationChange,
    pagination = { pageIndex: 0, pageSize: data.length },
    getRowId,
    shouldDisplayFooter,
    noDataLabel,
    rowSpanHelper,
    enableExpanding,
    onExpandedChange,
    manualExpanding,
    expandedState,
    renderSubComponent,
    onRowClick,
    getRowClassName,
    meta,
    renderCell = renderDefaultCell,
    renderBody = renderDefaultBody,
    ...props
}: TableProps<TableData>) => {
    const handlers = Object.fromEntries(
        Object.entries({
            onSortingChange: onSortingChange
                ? (updaterOrValue) => {
                      if (!sorting) {
                          return;
                      }

                      const newSorting = functionalUpdate(updaterOrValue, sorting);
                      onSortingChange?.(newSorting);
                  }
                : undefined,
            onRowSelectionChange: onRowSelectionChange
                ? (updaterOrValue) => {
                      if (!rowSelection) {
                          return;
                      }

                      const newRowSelection = functionalUpdate(updaterOrValue, rowSelection);
                      onRowSelectionChange?.(newRowSelection);
                  }
                : undefined,
            onColumnVisibilityChange: onColumnVisibilityChange
                ? (updaterOrValue) => {
                      if (!columnVisibility) {
                          return;
                      }

                      const newColumnVisibility = functionalUpdate(updaterOrValue, columnVisibility);
                      onColumnVisibilityChange?.(newColumnVisibility);
                  }
                : undefined,
            onPaginationChange: onPaginationChange
                ? (updaterOrValue) => {
                      if (!pagination) {
                          return;
                      }

                      const newPagination = functionalUpdate(updaterOrValue, pagination);
                      onPaginationChange?.(newPagination);
                  }
                : undefined,
            onExpandedChange: onExpandedChange
                ? (updaterOrValue: ExpandedState | ((oldState: ExpandedState) => ExpandedState)) => {
                      if (!expandedState) {
                          return;
                      }

                      const newExpandedState = functionalUpdate(updaterOrValue, expandedState);
                      onExpandedChange?.(newExpandedState);
                  }
                : undefined,
        } as Handlers<TableData>).filter(([_, value]) => value !== undefined)
    );

    const table = useReactTable({
        getCoreRowModel: getCoreRowModel<TableData>(),
        getSortedRowModel: getSortedRowModel<TableData>(),
        getPaginationRowModel: getPaginationRowModel<TableData>(),
        getRowId,
        data,
        columns: columns.map((c) => ({ ...c, enableSorting: Boolean(c.enableSorting) })),
        state: Object.fromEntries(
            Object.entries({
                sorting,
                rowSelection,
                columnVisibility,
                columnPinning,
                pagination,
                expanded: expandedState,
            } as TableState).filter(([_, value]) => value !== undefined)
        ),
        enableSorting: Boolean(enableSorting),
        manualSorting,
        enableSortingRemoval,
        sortDescFirst,
        enableRowSelection: Boolean(enableRowSelection),
        enableMultiRowSelection: Boolean(enableMultiRowSelection),
        enablePinning: Boolean(enablePinning),
        enableColumnPinning: Boolean(enableColumnPinning),
        manualPagination: Boolean(manualPagination),
        autoResetPageIndex: Boolean(autoResetPageIndex),
        pageCount: pageCount,
        enableExpanding: Boolean(enableExpanding),
        manualExpanding: Boolean(manualExpanding),
        meta,
        ...handlers,
    });

    useEffect(() => {
        if (!instance || !table) {
            return;
        }

        if (typeof instance === 'function') {
            instance(table);
            return;
        }

        instance.current = table;
    }, [instance, table]);

    const getPinnedStyles = (column: Column<TableData, unknown>) => {
        if (!column.getIsPinned()) {
            return;
        }

        const group = column.getIsPinned();
        const index = column.getPinnedIndex();
        const groupColumns = group === 'left' ? table.getLeftLeafColumns() : table.getRightLeafColumns();
        const distance = groupColumns.reduce((distance, column) => {
            const isBefore = group === 'left' ? column.getPinnedIndex() < index : column.getPinnedIndex() > index;
            if (isBefore && column.getIsVisible()) {
                return distance + column.getSize();
            }

            return distance;
        }, 0);

        const styles: CSSProperties = {
            position: 'sticky',
            [group as string]: `${distance}px`,
            width: column.getSize(),
            minWidth: column.getSize(),
        };

        return styles;
    };

    const getPinnedEdgeClassName = (column: Column<TableData, unknown>) => {
        if (!column.getIsPinned()) {
            return;
        }

        const group = column.getIsPinned();
        const index = column.getPinnedIndex();
        const groupColumns = group === 'left' ? table.getLeftLeafColumns() : table.getRightLeafColumns();
        const indexes = groupColumns.map((c) => c.getPinnedIndex());
        const isEdge = group === 'left' ? Math.max(...indexes) === index : Math.min(...indexes) === index;
        if (!isEdge) {
            return;
        }

        return group === 'left' ? 'table-left-pinned-edge' : 'table-right-pinned-edge';
    };

    const spannedRowsData = rowSpanHelper
        ? Object.entries(rowSpanHelper).reduce((acc, [columnId, getSpanId]) => {
              const columnSpanData = table.getRowModel().rows.reduce((acc, row) => {
                  const spanId = getSpanId(row);
                  if (!spanId) {
                      return acc;
                  }

                  if (!acc[spanId]) {
                      acc[spanId] = [];
                  }

                  acc[spanId].push(row.index);
                  return acc;
              }, {} as { [spanId: string]: number[] });

              acc[columnId] = columnSpanData;
              return acc;
          }, {} as { [columnId: string]: { [spanId: string]: number[] } })
        : {};

    return (
        <table {...props}>
            <thead>
                {table.getHeaderGroups().map((headerGroup) => (
                    <tr key={headerGroup.id}>
                        {headerGroup.headers.map((header) => {
                            const toggle = () => {
                                if (!header.column.getCanSort()) {
                                    return;
                                }

                                const nextDirection = header.column.getNextSortingOrder();
                                if (!nextDirection || sortDirections?.includes(nextDirection)) {
                                    header.column.toggleSorting();
                                    return;
                                }

                                // in case we would like to sort by desc or asc only
                                const current = header.column.getIsSorted();
                                const next: SortDirection | false = !current ? sortDirections[0] : false;
                                if (!next && enableSortingRemoval) {
                                    header.column.clearSorting();
                                    return;
                                }

                                header.column.toggleSorting(next === 'desc');
                            };

                            return (
                                <th
                                    key={header.id}
                                    onClick={toggle}
                                    style={getPinnedStyles(header.column)}
                                    colSpan={header.colSpan}
                                    className={cn(
                                        getPinnedEdgeClassName(header.column),
                                        header.column.columnDef.meta?.className,
                                        {
                                            'table-header-sortable': header.column.getCanSort(),
                                        }
                                    )}
                                >
                                    {header.isPlaceholder
                                        ? null
                                        : flexRender(header.column.columnDef.header, header.getContext())}
                                    {header.column.getCanSort() && renderSortingControls(header.column)}
                                </th>
                            );
                        })}
                    </tr>
                ))}
            </thead>

            {renderBody({
                table,
                getPinnedEdgeClassName,
                getPinnedStyles,
                renderCell,
                spannedRowsData,
                getRowClassName,
                noDataLabel,
                onRowClick,
                renderSubComponent,
                rowSpanHelper,
            })}

            {shouldDisplayFooter && (
                <tfoot>
                    {table.getFooterGroups().map((footerGroup) => (
                        <tr key={footerGroup.id}>
                            {footerGroup.headers.map((header) => (
                                <td
                                    key={header.id}
                                    style={getPinnedStyles(header.column)}
                                    className={getPinnedEdgeClassName(header.column)}
                                >
                                    {header.isPlaceholder
                                        ? null
                                        : flexRender(header.column.columnDef.footer, header.getContext())}
                                </td>
                            ))}
                        </tr>
                    ))}
                </tfoot>
            )}
        </table>
    );
};

type RenderCellParams<TableData> = {
    cell: Cell<TableData, unknown>;
    getPinnedStyles: (column: Column<TableData, unknown>) => CSSProperties | undefined;
    getPinnedEdgeClassName: (column: Column<TableData, unknown>) => string | undefined;
    rowSpanHelper?: RowSpanHelper<TableData>;
    spannedRowsData: { [columnId: string]: { [spanId: string]: number[] } };
};

function renderDefaultCell<TableData>(props: RenderCellParams<TableData>) {
    const spanId = props.rowSpanHelper?.[props.cell.column.id]?.(props.cell.row);
    const rowSpanRange = spanId ? props.spannedRowsData[props.cell.column.id]?.[spanId] : undefined;

    const isFirstRowInSpan = rowSpanRange?.[0] === props.cell.row.index;
    const rowSpan = rowSpanRange && rowSpanRange.length > 1 ? rowSpanRange?.length : undefined;

    return (
        <td
            key={props.cell.id}
            style={props.getPinnedStyles(props.cell.column)}
            className={cn(props.getPinnedEdgeClassName(props.cell.column), props.cell.column.columnDef.meta?.className)}
            rowSpan={rowSpanRange && !isFirstRowInSpan ? 0 : rowSpan}
        >
            {flexRender(props.cell.column.columnDef.cell, props.cell.getContext())}
        </td>
    );
}

type RenderBodyProps<TableData> = {
    table: TableInstance<TableData>;
    onRowClick?: (row: Row<TableData>) => void;
    getRowClassName?: (row: Row<TableData>) => string;
    renderCell: (props: RenderCellParams<TableData>) => ReactNode;
} & Pick<TableProps<TableData>, 'renderSubComponent' | 'noDataLabel'> &
    Omit<RenderCellParams<TableData>, 'cell'>;

function renderDefaultBody<TableData>({
    getPinnedEdgeClassName,
    getPinnedStyles,
    renderCell,
    spannedRowsData,
    table,
    getRowClassName,
    onRowClick,
    renderSubComponent,
    rowSpanHelper,
    noDataLabel,
}: RenderBodyProps<TableData>) {
    return (
        <tbody>
            {!table.getRowModel().rows.length && noDataLabel && (
                <tr>
                    <td colSpan={table.getLeafHeaders().length}>{noDataLabel}</td>
                </tr>
            )}

            {table.getRowModel().rows.map((row) => (
                <Fragment key={row.id}>
                    <tr onClick={() => onRowClick?.(row)} className={getRowClassName?.(row)}>
                        {row.getVisibleCells().map((cell) =>
                            renderCell?.({
                                cell,
                                getPinnedEdgeClassName,
                                getPinnedStyles,
                                spannedRowsData,
                                rowSpanHelper,
                            })
                        )}
                    </tr>
                    {row.getIsExpanded() && (
                        <tr>
                            <td colSpan={row.getVisibleCells().length}>{renderSubComponent?.(row)}</td>
                        </tr>
                    )}
                </Fragment>
            ))}
        </tbody>
    );
}

export default Table;
