import { Box, useTheme } from '@mui/material'
import {
  DataGridPremium,
  GridValidRowModel,
  useGridApiRef,
} from '@mui/x-data-grid-premium'
import { Entity } from '@synop-react/types'
import { useEffect, useRef } from 'react'

import { HeaderCell } from './Table'
import { NoData } from '..'
import { NoDataProps } from '../NoData/NoData'
import { TableProps } from './types'

const ROWS_PER_PAGE_OPTIONS = [10, 25, 50, 100]

export function BaseTable<
  R extends GridValidRowModel,
  T extends Entity<R> = Entity<R>
>({
  columns: columnSpec,
  getRowId = (row: R) => (row as unknown as T).id,
  initialSortColumn,
  initialSortOrder = 'asc',
  noRowsMessage,
  onRowClick = () => null,
  tableData,
  initialState,
  ...dataGridProps
}: TableProps<R>) {
  const theme = useTheme()
  const isTreeStructure = !!dataGridProps.getTreeDataPath

  // Disable grouping when treeData is enabled
  const columns = columnSpec.map((column) => ({
    ...column,
    groupable: !isTreeStructure,
    renderHeader:
      column.renderHeader ??
      ((params) => (
        <HeaderCell title={params.colDef.headerName} tooltip={column.tooltip} />
      )),
  }))

  // The table data is stored in a ref so that the exact same reference is passed to the
  // `DataGridPremium` component on every render. This is necessary to support the customized
  // state management implemented in the `useTableStateManager` hook. The `tableData` prop
  // can continue to be updated without causing the `DataGridPremium`'s props to change.
  const initialData = useRef(tableData)
  const apiRef = useTableStateManager(tableData, isTreeStructure)

  return (
    <Box sx={{ flexGrow: 1 }}>
      <DataGridPremium
        {...dataGridProps}
        apiRef={apiRef}
        autoHeight
        columns={columns}
        disableAggregation
        disableRowGrouping
        getRowId={getRowId}
        initialState={{
          ...initialState,
          sorting: initialSortColumn && {
            sortModel: [
              { field: initialSortColumn as string, sort: initialSortOrder },
            ],
          },
        }}
        onRowClick={onRowClick}
        pageSizeOptions={ROWS_PER_PAGE_OPTIONS}
        pagination
        rows={initialData.current}
        slotProps={{
          noRowsOverlay: {
            message: noRowsMessage,
            sx: { textTransform: 'unset' },
          } as NoDataProps,
        }}
        slots={{ noRowsOverlay: NoData }}
        sx={{
          border: 'none',
          '& .MuiDataGrid-columnHeaderTitle': {
            fontWeight: 600,
          },
          '& .MuiDataGrid-row:hover': {
            backgroundColor: theme.palette.primary[50],
          },
        }}
        treeData={isTreeStructure}
      />
    </Box>
  )
}

/**
 * Implements some customized state management logic for the `BaseTable` component. This
 * hook is used to update the table's data when the `tableData` prop changes in a way which
 * preserves the expanded state of any rows.
 */
function useTableStateManager(
  rows: GridValidRowModel[],
  isTreeStructure: boolean
) {
  const apiRef = useGridApiRef()

  // When the data changes, call the `setRows` method to update the table. Before doing so,
  // check for any rows which are currently expanded and save their IDs. After the table is
  // updated, expand the rows which were previously expanded.
  useEffect(() => {
    // Don't bother with the lookup if the table is not tree-structured
    const expandedIds = []
    if (isTreeStructure) {
      for (const [id, row] of Object.entries(apiRef.current.state.rows.tree)) {
        if (row.type === 'group' && row.childrenExpanded) {
          expandedIds.push(id)
        }
      }
    }

    // Update the row data
    apiRef.current.setRows(rows)

    // Expand the rows which were previously expanded. This is a somewhat inefficient
    // implementation, as the `setRowChildrenExpansion` method calls the apiRef's `setState` and
    // `forceUpdate` methods which are relatively expensive. However, `setRowChildrenExpansion`
    // also implements some relevant logic for determining which rows should be visible
    // (specifically the `visibleRowsLookup` calculation), so it's convenient to just use it.
    for (const id of expandedIds) {
      try {
        apiRef.current.setRowChildrenExpansion(id, true)
      } catch (e) {
        // If the row is no longer in the table, just ignore it
      }
    }
  }, [apiRef, rows, isTreeStructure])

  return apiRef
}
