import {
  CalculationSliceType,
  Change,
  EdgeTableMeasure,
  MeasureElement,
  NodeTableMeasure,
} from "../CalculationGraph/calculations";
import {
  maskToUnit,
  matchSign,
  normaliseByUnit,
  normaliseToMasks,
  roundIfMask,
} from "../utils/unitNormalisation";
import { shortUUID } from "../utils/uuid";

type TraversalEntry = {
  tableId: string;
  relationId: string | null;
};

type ChangeDict = { [changeId: string]: Change };

function getTraversableTables(
  tableOrder: CalculationSliceType["tableOrder"],
  relations: CalculationSliceType["relations"]
): TraversalEntry[] {
  // Find all tableIds that do not have a relation
  // pointing to the first row of the table.
  const tableIds = tableOrder.filter(
    (
      tableId // For each table, check all relations
    ) =>
      Object.values(relations).findIndex(
        (
          relation // If relation.top contains tableId, return true
        ) => relation.top.findIndex((entry) => entry.tableId === tableId) !== -1
      ) === -1
  );

  return tableIds.map((tableId) => ({ tableId, relationId: null }));
}

function findRelationChanges(
  state: CalculationSliceType,
  relationId: string | null,
  size: string
): ChangeDict {
  if (relationId === null) return {};

  const relation = state.relations[relationId];

  // First, sum up all values in relation.bottom (leading into the relation) for this size
  const bottomStitchesSum = relation.bottom.reduce(
    (subtotal, tableReference) => {
      const relationTable = state.tables[tableReference.tableId];

      // Get the last measure of the table
      const relationMeasure = relationTable.measures[
        relationTable.order[relationTable.order.length - 1]
      ] as NodeTableMeasure;

      // Get nodeElement from measure for current size
      const nodeId = relationMeasure.nodes[size];
      const element = state.graph.elements[state.graph.nodes[nodeId].element];

      // Add value of current node to total for this measure, multiplied by scalar.
      return (
        subtotal +
        normaliseToMasks(
          element,
          element.direction === "Horizontal"
            ? state.gauge.horizontal
            : state.gauge.vertical
        ) *
          tableReference.scalar
      );
    },
    0
  );

  // Summarise all scalars in relation.top so we can properly scale values when evaluating changes
  const topScalarSum = relation.top
    .map((entry) => entry.scalar)
    .reduce((a, b) => a + b, 0);

  // Create any necessary changes for each of the tables in relation.top
  // (leading out of the relation) for this size
  const changeEntries = relation.top
    .map((relationEntry) => {
      // Get the element through relation -> table -> measure -> node
      const table = state.tables[relationEntry.tableId];
      const measure = table.measures[table.order[0]] as NodeTableMeasure;
      const node = state.graph.nodes[measure.nodes[size]];
      const element = state.graph.elements[node.element];

      // Get the current table's share of the mask total, normalized to the element kind
      // of the first measure.
      const valueFromRelation =
        (maskToUnit(
          bottomStitchesSum,
          element.direction === "Horizontal"
            ? state.gauge.horizontal
            : state.gauge.vertical,
          element.unit
        ) *
          relationEntry.scalar) /
        topScalarSum;

      // If the calculated value is not equal to the current value, register a change
      if (valueFromRelation !== element.value) {
        const id = shortUUID();
        return [
          id,
          {
            id,
            elementId: element.id,
            newValue: valueFromRelation,
            source: {
              kind: "ForwardTableRelationChange",
              relationId: relation.id,
              bottomSum: bottomStitchesSum,
              scalar: relationEntry.scalar,
            },
          },
        ];
      } else {
        return null;
      }
    })
    .filter((change) => change !== null);

  // TS isn't smart enough to understand that the filter above removes all null values,
  // so we cast it here
  return Object.fromEntries(changeEntries as [string, Change][]);
}

function findRoundChanges(
  elements: {
    node: MeasureElement;
    nextNode: MeasureElement;
    horizontal: MeasureElement;
    frequency: MeasureElement;
  },
  gauge: CalculationSliceType["gauge"]
): ChangeDict {
  // change = total distance / distance between up/down
  // This always ends up being a mask value in practice because the units cancel out
  // (technically the value is unitless). So we convert it to CmHorizontal if that's
  // the horizontalEdgeElement.type

  // Horizontal value
  const nodeValueWithHorizontalUnit = normaliseByUnit(
    elements.node,
    gauge.horizontal,
    elements.horizontal.unit
  );
  const frequencyValueWithHorizontalUnit = normaliseByUnit(
    elements.frequency,
    gauge.horizontal,
    elements.horizontal.unit
  );
  const nextNodeValueWithHorizontalUnit = normaliseByUnit(
    elements.nextNode,
    gauge.horizontal,
    elements.horizontal.unit
  );

  const nodeDifferenceWithHorizontalUnit =
    nextNodeValueWithHorizontalUnit - nodeValueWithHorizontalUnit;

  let newHorizontalValue: number;
  if (nodeDifferenceWithHorizontalUnit > 0) {
    // Increase
    newHorizontalValue = Math.abs(
      roundIfMask(
        maskToUnit(
          nodeValueWithHorizontalUnit / frequencyValueWithHorizontalUnit,
          gauge.horizontal,
          elements.horizontal.unit
        ),
        elements.horizontal.unit
      )
    );
  } else if (nodeDifferenceWithHorizontalUnit < 0) {
    // Decrease
    // Here we need a special case because there are two masks per decrease
    // which means node / freq isn't correct anymore
    newHorizontalValue = matchSign(
      roundIfMask(
        maskToUnit(
          (nodeValueWithHorizontalUnit + 2 * nodeDifferenceWithHorizontalUnit) /
            frequencyValueWithHorizontalUnit,
          gauge.horizontal,
          elements.horizontal.unit
        ),
        elements.horizontal.unit
      ),
      nodeDifferenceWithHorizontalUnit
    );
  } else {
    // No change
    newHorizontalValue = 0;
  }

  // Frequency value
  const nodeValueWithFrequencyUnit = normaliseByUnit(
    elements.node,
    gauge.horizontal,
    elements.frequency.unit
  );
  const nextNodeValueWithFrequencyUnit = normaliseByUnit(
    elements.nextNode,
    gauge.horizontal,
    elements.frequency.unit
  );
  const horizontalValueWithFrequencyUnit = normaliseByUnit(
    elements.horizontal,
    gauge.horizontal,
    elements.frequency.unit
  );

  const nodeDifferenceWithFrequencyUnit =
    nextNodeValueWithFrequencyUnit - nodeValueWithFrequencyUnit;

  let newFrequencyValue: number;
  if (nodeDifferenceWithFrequencyUnit > 0) {
    // Increase
    newFrequencyValue = Math.abs(
      roundIfMask(
        maskToUnit(
          nodeValueWithFrequencyUnit / horizontalValueWithFrequencyUnit,
          gauge.horizontal,
          elements.frequency.unit
        ),
        elements.frequency.unit
      )
    );
  } else if (nodeDifferenceWithFrequencyUnit < 0) {
    // Decrease
    // In this case we have to remove the difference twice to get the distance
    // between each change because each decrease uses two masks and ends up with one
    newFrequencyValue = Math.abs(
      roundIfMask(
        maskToUnit(
          (nodeValueWithFrequencyUnit + 2 * nodeDifferenceWithFrequencyUnit) /
            horizontalValueWithFrequencyUnit,
          gauge.horizontal,
          elements.frequency.unit
        ),
        elements.frequency.unit
      )
    );
  } else {
    // No difference
    newFrequencyValue = 0;
  }

  const changes: ChangeDict = {};

  if (newHorizontalValue !== elements.horizontal.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.horizontal.id,
      newValue: newHorizontalValue,
      source: {
        kind: "RoundTotalChange",
        perChangeDistance: elements.frequency.id,
        totalDistance: elements.node.id,
      },
    };
  }

  if (newFrequencyValue !== elements.frequency.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.frequency.id,
      newValue: newFrequencyValue,
      source: {
        kind: "RoundDistanceChange",
        totalDistance: elements.node.id,
        horizontalChange: elements.horizontal.id,
      },
    };
  }

  return changes;
}

function findVerticalChanges(
  elements: {
    node: MeasureElement;
    nextNode: MeasureElement;
    horizontal: MeasureElement;
    frequency: MeasureElement;
    vertical: MeasureElement;
  },
  gauge: CalculationSliceType["gauge"],
  tableMarks: number
): ChangeDict {
  // Else, if the edge change is vertical, check edge internal consistency with marks

  //* The variable "frequency" is really distance between each instance of change. Sorry.
  //* I call it "distance" in the following equalities.
  // TODO: Change the name of the variable

  // horizontal = (vertical / distance) * 2 * tablemarks
  // distance   = (vertical / horizontal) * 2 * tablemarks !! Horizontal must be normalised to masks and therefore so must vertical
  // vertical   = distance * horizontal / (2 * tablemarks) !! Horizontal must be normalised to masks

  // Why normalise vertical in distance and not distance in vertical? Because of unit cancellation.
  // If the vertical is in cm, we end up with a number with unit cm/m, which is off from the correct answer
  // By a ratio of the gauge if we try to convert it back to cm or m. m/m is unitless, but since the unit is 1m,
  // that works fine. Hence we can convert vertical also and always end up with masks.
  // Whereas cm * m => cm without any further normalisation because m is the unit,
  // and we don't need to convert.

  // Data integrity check: match sign with node difference (dist and vert are always positive,
  // since they are real distances)

  // Horizontal value
  const verticalValueWithHorizontalUnit = normaliseByUnit(
    elements.vertical,
    gauge.vertical,
    elements.horizontal.unit
  );
  const frequencyValueWithHorizontalUnit = normaliseByUnit(
    elements.frequency,
    gauge.horizontal,
    elements.horizontal.unit
  );
  const nodeDifference =
    normaliseToMasks(elements.nextNode, gauge.horizontal) -
    normaliseToMasks(elements.node, gauge.horizontal);

  const newHorizontalValue = matchSign(
    roundIfMask(
      maskToUnit(
        (verticalValueWithHorizontalUnit / frequencyValueWithHorizontalUnit) *
          tableMarks *
          2,
        gauge.horizontal,
        elements.horizontal.unit
      ),
      elements.horizontal.unit
    ),
    nodeDifference
  );

  // Frequency value
  const verticalValueAsMasks = normaliseToMasks(
    elements.vertical,
    gauge.vertical
  );

  const newFrequencyValue = Math.abs(
    roundIfMask(
      maskToUnit(
        (verticalValueAsMasks /
          normaliseToMasks(elements.horizontal, gauge.horizontal)) *
          tableMarks *
          2,
        gauge.vertical,
        elements.frequency.unit
      ),
      elements.frequency.unit
    )
  );

  // Vertical value
  const frequencyValueWithVerticalUnit = normaliseByUnit(
    elements.frequency,
    gauge.vertical,
    elements.vertical.unit
  );

  const newVerticalValue = Math.abs(
    roundIfMask(
      (normaliseToMasks(elements.horizontal, gauge.horizontal) *
        frequencyValueWithVerticalUnit) /
        (tableMarks || 1) / // This or-statement treats vertical changes with 0 marks as 1 mark; e.g change happens in the start and end of a row (non-cylindrical)
        2,
      elements.vertical.unit
    )
  );

  const changes: ChangeDict = {};

  if (newHorizontalValue !== elements.horizontal.value) {
    const changeId = shortUUID();
    changes[changeId] = {
      id: changeId,
      elementId: elements.horizontal.id,
      newValue: newHorizontalValue,
      source: {
        kind: "VerticalTotalChange",
        perChangeMasks: tableMarks * 2,
        totalDistance: elements.vertical.id,
        perChangeDistance: elements.frequency.id,
      },
    };
  }

  if (newFrequencyValue !== elements.frequency.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.frequency.id,
      newValue: newFrequencyValue,
      source: {
        kind: "VerticalDistanceChange",
        perChangeMasks: tableMarks * 2,
        totalDistance: elements.vertical.id,
        horizontalChange: elements.horizontal.id,
      },
    };
  }

  if (newVerticalValue !== elements.vertical.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.vertical.id,
      newValue: newVerticalValue,
      source: {
        kind: "VerticalTotalDistanceChange",
        perChangeMasks: tableMarks * 2,
        horizontalChange: elements.horizontal.id,
        perChangeDistance: elements.frequency.id,
      },
    };
  }

  return changes;
}

function findSleeveBindOffChanges(
  elements: {
    node: MeasureElement;
    nextNode: MeasureElement;
    horizontal: MeasureElement;
    frequency: MeasureElement;
    vertical: MeasureElement;
  },
  gauge: CalculationSliceType["gauge"],
  tableMarks: number
): ChangeDict {
  const newHorizontalValue =
    normaliseByUnit(
      elements.nextNode,
      gauge.horizontal,
      elements.horizontal.unit
    ) -
    normaliseByUnit(elements.node, gauge.horizontal, elements.horizontal.unit);

  const newFrequencyValue =
    normaliseByUnit(
      elements.nextNode,
      gauge.horizontal,
      elements.frequency.unit
    ) / (tableMarks || 1); // Avoid dividing by zero

  const changes: ChangeDict = {};

  if (newHorizontalValue !== elements.horizontal.value) {
    const changeId = shortUUID();
    changes[changeId] = {
      id: changeId,
      elementId: elements.horizontal.id,
      newValue: newHorizontalValue,
      source: {
        kind: "SleeveBindOffTotalChange",
        parentNode: elements.node.id,
        childNode: elements.nextNode.id,
      },
    };
  }

  if (newFrequencyValue !== elements.frequency.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.frequency.id,
      newValue: newFrequencyValue,
      source: {
        kind: "SleeveBindOffDistanceBetweenMarksChange",
        marks: tableMarks || 1,
        totalDistance: elements.nextNode.id,
      },
    };
  }

  return changes;
}

function findEdgeChanges(
  elements: {
    node: MeasureElement;
    nextNode: MeasureElement;
    horizontal: MeasureElement;
    frequency: MeasureElement;
    vertical: MeasureElement;
  },
  gauge: CalculationSliceType["gauge"],
  tableMarks: number,
  changeKind: EdgeTableMeasure["changeKind"]
): ChangeDict {
  switch (changeKind) {
    case "HorizontalFirst":
    case "HorizontalLast":
      return findRoundChanges(elements, gauge);
    case "Vertical":
      return findVerticalChanges(elements, gauge, tableMarks);
    case "SleeveBindOff":
      return findSleeveBindOffChanges(elements, gauge, tableMarks);
  }
}

function findNodeChanges(
  elements: {
    node: MeasureElement;
    nextNode: MeasureElement;
    horizontal: MeasureElement;
  },
  gauge: CalculationSliceType["gauge"]
): ChangeDict {
  // node = nextNode - edge
  const newNodeValue =
    normaliseByUnit(elements.nextNode, gauge.horizontal, elements.node.unit) -
    normaliseByUnit(elements.horizontal, gauge.horizontal, elements.node.unit);

  // edge = nextNode - node
  const newEdgeValue =
    normaliseByUnit(
      elements.nextNode,
      gauge.horizontal,
      elements.horizontal.unit
    ) -
    normaliseByUnit(elements.node, gauge.horizontal, elements.horizontal.unit);

  // nextNode = node + edge
  const newNextNodeValue =
    normaliseByUnit(elements.node, gauge.horizontal, elements.nextNode.unit) +
    normaliseByUnit(
      elements.horizontal,
      gauge.horizontal,
      elements.nextNode.unit
    );

  // Create and return changes
  const changes: ChangeDict = {};

  if (newNodeValue !== elements.node.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.node.id,
      newValue: newNodeValue,
      source: {
        kind: "BackwardEdgeChange",
        edgeChange: elements.horizontal.id,
        childNode: elements.nextNode.id,
      },
    };
  }

  if (newEdgeValue !== elements.horizontal.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.horizontal.id,
      newValue: newEdgeValue,
      source: {
        kind: "NodeDifference",
        parentNode: elements.node.id,
        childNode: elements.nextNode.id,
      },
    };
  }

  if (newNextNodeValue !== elements.nextNode.value) {
    const changeId = shortUUID();

    changes[changeId] = {
      id: changeId,
      elementId: elements.nextNode.id,
      newValue: newNextNodeValue,
      source: {
        kind: "ForwardEdgeChange",
        edgeChange: elements.horizontal.id,
        parentNode: elements.node.id,
      },
    };
  }

  return changes;
}

export {
  getTraversableTables,
  findRelationChanges,
  findEdgeChanges,
  findNodeChanges,
};
export type { TraversalEntry };
