import { MetricName } from "../../graphql";
import { MetricConfig } from "./MetricConfig";
import { DataPoint } from "./types";

/**
 * Calculates weighted average, works with or without segments.
 */
const calculateNonCountAverage = (dataPoints: DataPoint[]): number => {
  const definedValues = dataPoints.filter(
    (d) => d.value !== undefined && d.value !== null
  );

  const totalInterviews = definedValues.reduce(
    (sum, d) => sum + (d.countDataPoints ?? 0),
    0
  );
  const totalWeightedMetric = definedValues.reduce(
    (sum, d) => sum + (d.value ?? 0) * (d.countDataPoints ?? 0),
    0
  );
  return totalWeightedMetric / Math.max(totalInterviews, 1);
};

/**
 * Calculates unsegmented average.
 */
const calculateCountAverage = (dataPoints: DataPoint[]): number => {
  const definedValues = dataPoints.filter(
    (d) => d.value !== undefined && d.value !== null
  );
  const numXValues = new Set(definedValues.map((v) => v.dataId)).size;
  const metricTotal = definedValues.reduce((sum, d) => sum + (d.value ?? 0), 0);
  return metricTotal / numXValues;
};

export const calculateAverage = (
  dataPoints: DataPoint[],
  metric: MetricName
): number => {
  return MetricConfig[metric].countMetric
    ? calculateCountAverage(dataPoints)
    : calculateNonCountAverage(dataPoints);
};
