import React, { useMemo, useState, useRef, useLayoutEffect, useCallback } from 'react'
import { Box, SxProps, Theme } from '@mui/material'
import { LinePath, AreaClosed } from '@visx/shape'
import { scaleLinear, scalePoint, scaleOrdinal } from '@visx/scale'
import { AxisBottom, AxisLeft, AxisScale, TickRendererProps } from '@visx/axis'
import { LegendOrdinal } from '@visx/legend'
import { Tooltip as VisxTooltip, defaultStyles } from '@visx/tooltip'
import { localPoint } from '@visx/event'
import { GridRows, GridColumns } from '@visx/grid'
import { curveLinear } from '@visx/curve'
import { useSprings, animated, to } from '@react-spring/web'
import { Colors, Fonts } from '../../../Utils/theme'

const DEFAULT_Y_MAX = 20000;

// Define the Point type
type Point = {
  x: number
  y: number
  id: string
  color?: string
}

// PointData represents the raw data point without id and color
type PointData = {
  x: number
  y: number
}

// LineData represents each series in the chart
type LineData = {
  id: string
  color?: string
  data: PointData[]
  key?: string
}

interface CostLineChartProps {
  data: LineData[]
  colors: string[]
  costTypes: string[]
  year?: number
  isYearly?: boolean
}

interface LineChartTooltipProps {
  tooltipData: Point
  tooltipLeft: number
  tooltipTop: number
  year: number
  isYearly: boolean
  getTotalMonthlyCost: (x: number) => number
}

// Styles for MUI specific components
const muiStyles: { [key: string]: SxProps<Theme> } = {
  container: {
    width: '100%',
    height: 'calc(100% - 2.5rem)',
    position: 'relative',
    // border: '1px solid rgba(0, 0, 0, 0.1)',
    padding: '1rem',
    display: 'flex',
    flexDirection: 'column',
    alignItems: 'center',
  },
  euroSign: {
    position: 'absolute',
    top: '0.875rem',
    left: '2.625rem',
    fontSize: '1.25rem',
    fontWeight: 'bold',
  },
  yearLabel: {
    position: 'absolute',
    top: '1.25rem',
    left: '50%',
    transform: 'translateX(-50%)',
    fontSize: '1.125rem',
    fontWeight: 'bold',
  },
  legend: {
    display: 'flex',
    fontSize: '1.125rem',
    fontFamily: Fonts.body,
    opacity: 1,
    visibility: 'visible',
  },
}

// Styles for Tooltip etc. components
const cssStyles = {
  tooltip: {
    ...defaultStyles,
    position: 'fixed',
    transform: 'translate(-50%, calc(-100% + 2rem))',
    background: Colors.primaryDarker,
    color: Colors.white,
    padding: '0.75rem 1rem',
    borderRadius: '0.5rem',
    fontSize: '0.875rem',
    fontFamily: Fonts.body,
    maxWidth: '13rem',
    textAlign: 'center',
    zIndex: 1000,
    pointerEvents: 'none',
  } as React.CSSProperties,
  tooltipTitle: {
    fontSize: '1.125rem',
    fontFamily: Fonts.body,
  } as React.CSSProperties,
  tooltipContent: {
    marginTop: '0.5rem',
    fontSize: '1rem',
  } as React.CSSProperties,
  axisTick: {
    fill: Colors.text,
    fontSize: '1rem',
    fontFamily: Fonts.body,
    textAnchor: 'end' as const,
    dx: '-0.5rem',
    dy: '0.32em',
  },
  bottomTickLabel: {
    fill: Colors.text,
    fontSize: '1rem',
    fontFamily: Fonts.body,
    textAnchor: 'middle' as const,
    dy: '0.5rem',
    dx: '0'
  },
}

// Utility function for zero-padded months
const getZeroPaddedMonth = (monthNumber: number): string => {
  return monthNumber.toString().padStart(2, '0')
}

/*
// Format month names based on index
const getFormattedMonth = (index: number) => {
  const month = DateTime.fromObject({ month: index }).setLocale('fi').toFormat('LLLL')
  return upperFirst(month)
}
*/

// Helper function to convert rem to pixels dynamically
const remToPixels = (rem: number): number => {
  const rootFontSize = parseFloat(getComputedStyle(document.documentElement).fontSize)
  return rem * rootFontSize
}

const formatXAxisValue = (value: number, isYearly: boolean, currentYear: number) => {
  if (isYearly) {
    return (currentYear - 10 + (value - 1)).toString()
  }
  return getZeroPaddedMonth(value)
}

interface AnimatedTickProps {
  x: number
  y: number
  formattedValue?: string | number
  opacity: number
  innerHeight: number
}

const AnimatedTick: React.FC<AnimatedTickProps> = ({
  x,
  y,
  formattedValue,
  innerHeight,
}) => {
  const springProps = useSprings(1, [{
    from: { y: y + innerHeight / 2, opacity: 0 },
    to: { y, opacity: 1 },
    config: { tension: 120, friction: 14 },
  }])

  return (
    <animated.g
      transform={to([springProps[0].y], (yVal) => `translate(${x}, ${yVal})`)}
    >
      <animated.text 
        style={{
          ...cssStyles.axisTick,
          opacity: springProps[0].opacity,
        }}
      >
        {(formattedValue ?? '').toString()}
      </animated.text>
    </animated.g>
  )
}

interface AnimatedAxisLeftProps {
  scale: AxisScale<number>
  left?: number
  top?: number
}

const AnimatedAxisLeft: React.FC<AnimatedAxisLeftProps> = ({
  scale,
  left = 0,
  top = 0,
}) => {
  const formatTickValue = (value: number | { valueOf(): number }) =>
    new Intl.NumberFormat('fi-FI').format(value.valueOf())

  return (
    <AxisLeft
      top={top}
      left={left}
      scale={scale}
      tickFormat={formatTickValue}
      tickComponent={(tickProps: TickRendererProps) => {
        const { x = 0, y = 0, formattedValue } = tickProps
        return (
          <AnimatedTick
            x={x}
            y={y}
            formattedValue={formattedValue}
            opacity={1}
            // This gives us the height of the scale
            innerHeight={scale.range()[0]}
          />
        )
      }}
      hideTicks
      hideAxisLine
      tickLabelProps={() => cssStyles.axisTick}
    />
  )
}

const LineChartTooltip: React.FC<LineChartTooltipProps> = ({
  tooltipData,
  tooltipLeft,
  tooltipTop,
  year,
  getTotalMonthlyCost,
  isYearly,
}) => {
  const formatTooltipDate = () => {
    if (isYearly) {
      return `${year - 10 + (tooltipData.x - 1)}`
    }
    return `${getZeroPaddedMonth(tooltipData.x)}/${year}`
  }

  return (
    <VisxTooltip
      top={tooltipTop}
      left={tooltipLeft}
      style={cssStyles.tooltip}
    >
      <div>
        <strong style={cssStyles.tooltipTitle}>{tooltipData.id}</strong>
        <div style={cssStyles.tooltipContent}>
          {formatTooltipDate()} {tooltipData.y.toLocaleString('fi-FI', { maximumFractionDigits: 2 })} €
        </div>
        <div style={cssStyles.tooltipContent}>
          Kulut yht. {getTotalMonthlyCost(tooltipData.x).toLocaleString('fi-FI', { maximumFractionDigits: 2 })} €
        </div>
      </div>
    </VisxTooltip>
  )
}

// Define animated components
const AnimatedLinePath = animated(LinePath)
const AnimatedAreaClosed = animated(AreaClosed)

const CostLineChart: React.FC<CostLineChartProps> = ({
  data,
  colors,
  costTypes,
  year = 2024,
  isYearly = false,
}) => {
  const [tooltipData, setTooltipData] = useState<Point | null>(null)
  const [tooltipLeft, setTooltipLeft] = useState<number>(0)
  const [tooltipTop, setTooltipTop] = useState<number>(0)
  const svgRef = useRef<SVGSVGElement | null>(null)
  const containerRef = useRef<HTMLDivElement | null>(null)
  const [width, setWidth] = useState<number>(0)
  const [height, setHeight] = useState<number>(0)
  const [dimensionsReady, setDimensionsReady] = useState<boolean>(false)

  // Crosshair state
  const [hoveredX, setHoveredX] = useState<number | null>(null)
  const [hoveredY, setHoveredY] = useState<number | null>(null)

  const isInitialMount = useRef(true);
  const prevDataRef = useRef<string>('');

  // Memoize the stringified data to prevent unnecessary updates
  const currentDataString = useMemo(() => JSON.stringify(data), [data]);

  useLayoutEffect(() => {
    const updateDimensions = () => {
      if (containerRef.current) {
        const boundingRect = containerRef.current.getBoundingClientRect()
        setWidth(boundingRect.width)
        setHeight(boundingRect.height)
        setDimensionsReady(true)
      }
    }
  
    updateDimensions()
    window.addEventListener('resize', updateDimensions)
    return () => window.removeEventListener('resize', updateDimensions)
  }, [])

  const margin = {
    top: remToPixels(3.25),
    right: remToPixels(3),
    bottom: remToPixels(3),
    left: remToPixels(5),
  }

  const innerWidth = width - margin.left - margin.right
  const innerHeight = height - margin.top - margin.bottom

  // Modify the year label to show range for yearly view
  const yearLabel = useMemo(() => {
    if (isYearly) {
      return `${year - 10} - ${year + 1}`  // Shows e.g., "2014 - 2025" for year 2024
    }
    return year.toString()
  }, [year, isYearly])

  // Check if all values are zero (for hiding lines but not the chart)
  const allValuesZero = useMemo(() => {
    return !data || 
      data.length === 0 || 
      data.every(series => 
        !series.data || 
        series.data.length === 0 ||
        series.data.every(point => point.y === 0)
      );
  }, [data]);

  // Separate check for completely empty data (no data structure at all)
  const isDataEmpty = useMemo(() => {
    return !data || data.length === 0 || data.every(series => !series.data || series.data.length === 0);
  }, [data]);

  // Prepare data with explicit cost types - only used when we have non-zero values
  const mappedData = useMemo<LineData[]>(() => {
    if (allValuesZero) return [];
    
    return data
      .map((series, index) => ({
        ...series,
        color: colors[index] || series.color,
        id: costTypes[index],
        key: `${costTypes[index]}-${JSON.stringify(series.data)}`
      }))
      .filter(series => series.data.some(point => point.y !== 0));
  }, [data, colors, costTypes, allValuesZero]);

  // Flatten all points for easier tooltip handling
  const allPoints: Point[] = useMemo<Point[]>(() => {
    if (isDataEmpty) return [];
    return mappedData.flatMap(series =>
      series.data.map(d => ({
        x: d.x,
        y: d.y,
        id: series.id,
        color: series.color,
      }))
    );
  }, [mappedData, isDataEmpty]);

  // Combine all x values and sort them (x is a number)
  const xValues = useMemo<number[]>(() => {
    const points = Array.from({ length: 12 }, (_, i) => i + 1)
    if (isDataEmpty || allValuesZero) return points
    
    const allX = allPoints.map(d => d.x)
    const uniqueX = Array.from(new Set(allX))
    return uniqueX.sort((a, b) => a - b)
  }, [allPoints, isDataEmpty, allValuesZero])

  const xScale = useMemo(
    () =>
      scalePoint<number>({
        domain: xValues.length > 0 ? xValues : [0],
        range: [0, innerWidth],
        padding: 0,
        align: 0.5,
      }),
    [xValues, innerWidth]
  )

  const yMax = useMemo<number>(() => {
    if (allValuesZero || isDataEmpty) return DEFAULT_Y_MAX;
    const maxY = Math.max(...allPoints.map(d => d.y), 0);
    return maxY > 0 ? maxY * 1.1 : DEFAULT_Y_MAX;
  }, [allPoints, allValuesZero, isDataEmpty]);

  const yScale = useMemo(
    () =>
      scaleLinear<number>({
        domain: [0, yMax],
        range: [innerHeight, 0],
        nice: true,
      }),
    [yMax, innerHeight]
  )

  const legendScale = useMemo(
    () =>
      scaleOrdinal<string, string>({
        domain: costTypes,
        range: colors,
      }),
    [costTypes, colors]
  )

  // Calculate the total path length for proper animation
  const getPathLength = useCallback((seriesData: PointData[]): number => {
    if (!dimensionsReady) return 0;
    let length = 0;
    for (let i = 1; i < seriesData.length; i++) {
      const x1 = xScale(seriesData[i - 1].x)!;
      const y1 = yScale(seriesData[i - 1].y);
      const x2 = xScale(seriesData[i].x)!;
      const y2 = yScale(seriesData[i].y);
      const dx = x2 - x1;
      const dy = y2 - y1;
      length += Math.sqrt(dx * dx + dy * dy);
    }
    return length;
  }, [dimensionsReady, xScale, yScale]);

  // Memoize path lengths
  const pathLengths = useMemo(() => 
    mappedData.map(series => getPathLength(series.data)),
    [mappedData, getPathLength]
  );

  // Create and control springs
  const springs = useSprings(
    mappedData.length,
    mappedData.map((_, i) => ({
      from: { pathLength: pathLengths[i] },
      to: { pathLength: 0 },
      reset: isInitialMount.current || currentDataString !== prevDataRef.current,
      config: { tension: 80, friction: 20 },
    }))
  );

  useLayoutEffect(() => {
    if (!dimensionsReady) return;

    if (isInitialMount.current || currentDataString !== prevDataRef.current) {
      prevDataRef.current = currentDataString;
      isInitialMount.current = false;
    }
  }, [dimensionsReady, currentDataString]);

  const getScaledX = (d: any) => xScale(d.x)!;
  const getScaledY = (d: any) => yScale(d.y);

  const getTotalMonthlyCost = (x: number): number => {
    // Calculate total sum for the month in question
    const total = allPoints
      .filter(point => point.x === x)
      .reduce((acc, point) => acc + point.y, 0)
    return total
  }

  // Event handlers for tooltip and crosshair
  const handleTooltip = (event: React.MouseEvent<SVGRectElement, MouseEvent>) => {
    const point = localPoint(event)
    if (!point || !svgRef.current) {
      setTooltipData(null)
      setHoveredX(null)
      setHoveredY(null)
      return
    }

    // Calculate mouse position relative to the SVG's inner chart area
    const svgRect = svgRef.current.getBoundingClientRect()
    const xPos = point.x - margin.left
    const yPos = point.y - margin.top

    // Store the hovered cursor position
    setHoveredX(xPos)
    setHoveredY(yPos)

    // Find the closest point to the cursor
    const closest = allPoints.reduce<{ point: Point | null; distance: number }>(
      (prev, current) => {
        const cx = xScale(current.x)!
        const cy = yScale(current.y)
        const dx = cx - xPos
        const dy = cy - yPos
        const distance = Math.sqrt(dx * dx + dy * dy)
        if (distance < prev.distance) {
          return { point: current, distance }
        }
        return prev
      },
      { point: null, distance: Infinity }
    )

    const distanceThreshold = 10

    if (closest.point && closest.distance <= distanceThreshold) {
      setTooltipData(closest.point)

      // Calculate absolute positions for the tooltip
      const absoluteLeft = svgRect.left + margin.left + xScale(closest.point.x)!
      const absoluteTop = svgRect.top + yScale(closest.point.y)

      setTooltipLeft(absoluteLeft)
      setTooltipTop(absoluteTop)
    } else {
      setTooltipData(null)
    }
  }

  return (
    <Box ref={containerRef} sx={muiStyles.container}>
      <Box sx={muiStyles.euroSign}>€</Box>
      <Box sx={muiStyles.yearLabel}>{yearLabel}</Box>
      {dimensionsReady ? (
        <svg
          ref={svgRef}
          width={width}
          height={height}
          role='img'
          aria-label='Cost Line Chart'
          style={{ overflow: 'visible' }}
        >
          <g transform={`translate(${margin.left},${margin.top})`}>
            <GridRows
              scale={yScale}
              width={innerWidth}
              height={innerHeight}
              stroke='rgba(0,0,0,0.1)'
              strokeDasharray='2,2'
            />
            <GridColumns
              scale={xScale}
              width={innerWidth}
              height={innerHeight}
              stroke='rgba(0,0,0,0.1)'
              strokeDasharray='2,2'
            />
            {!allValuesZero && !isDataEmpty && (
              <>
                <defs>
                  {mappedData.map(series => (
                    <linearGradient
                      key={`gradient-${series.id}`}
                      id={`gradient-${series.id}`}
                      gradientUnits='userSpaceOnUse'
                      x1={0}
                      y1={yScale(0)}
                      x2={0}
                      y2={yScale(yMax)}
                    >
                      <stop offset='0%' stopColor={series.color} stopOpacity={0.9} />
                      <stop offset='100%' stopColor={series.color} stopOpacity={0.1} />
                    </linearGradient>
                  ))}
                </defs>

                {springs.map((props, index) => {
                  const series = mappedData[index];
                  const pathLength = pathLengths[index];
                  
                  return (
                    <animated.g key={series.key}>
                      <AnimatedAreaClosed
                        data={series.data}
                        x={getScaledX}
                        y={getScaledY}
                        yScale={yScale}
                        fill={`url(#gradient-${series.id})`}
                        stroke="none"
                        curve={curveLinear}
                        style={{
                          opacity: props.pathLength.to(pl => 1 - pl / pathLength)
                        }}
                      />
                      <AnimatedLinePath
                        data={series.data}
                        x={getScaledX}
                        y={getScaledY}
                        stroke={series.color}
                        strokeWidth={2}
                        curve={curveLinear}
                        style={{
                          strokeDasharray: pathLength,
                          strokeDashoffset: props.pathLength
                        }}
                      />
                    </animated.g>
                  );
                })}

                <g>
                  <rect
                    width={innerWidth}
                    height={innerHeight}
                    fill='transparent'
                    onMouseMove={handleTooltip}
                    onMouseLeave={() => {
                      setTooltipData(null);
                      setHoveredX(null);
                      setHoveredY(null);
                    }}
                  />
                  {hoveredX !== null && hoveredY !== null && (
                    <>
                      <line
                        x1={hoveredX}
                        x2={hoveredX}
                        y1={0}
                        y2={innerHeight}
                        stroke='rgba(0,0,0,0.3)'
                        strokeDasharray='2,2'
                        pointerEvents='none'
                      />
                      <line
                        x1={0}
                        x2={innerWidth}
                        y1={hoveredY}
                        y2={hoveredY}
                        stroke='rgba(0,0,0,0.3)'
                        strokeDasharray='2,2'
                        pointerEvents='none'
                      />
                    </>
                  )}
                </g>
              </>
            )}

            {/* Axes */}
            <AxisBottom
              top={innerHeight}
              scale={xScale}
              tickFormat={(value: any) => formatXAxisValue(value, isYearly, year)}
              hideAxisLine
              hideTicks
              tickLabelProps={cssStyles.bottomTickLabel}
            />
            <AnimatedAxisLeft scale={yScale} />
          </g>
        </svg>
      ) : null}
      {dimensionsReady && (
        <LegendOrdinal
          scale={legendScale}
          direction='row'
          labelMargin='0 30px 0 0'
          shape='circle'
          shapeMargin='0 5px 0 0'
          labelFormat={(label: string) => label}
          style={muiStyles.legend as React.CSSProperties}
        />
      )}
      {tooltipData && !allValuesZero && !isDataEmpty && (
        <LineChartTooltip
          tooltipData={tooltipData}
          tooltipLeft={tooltipLeft}
          tooltipTop={tooltipTop}
          year={year}
          isYearly={isYearly}
          getTotalMonthlyCost={getTotalMonthlyCost}
        />
      )}
    </Box>
  )
}

export default CostLineChart
