import { identity, mapValues, size } from "lodash";

import { dfs, keyOf } from "./graph";
import {
  ConnectedNode,
  DirectedGraph,
  Direction,
  Node,
  Reducer,
  Reducers,
  isNode,
} from "./types";

/** Represents a mapping of aggregate keys to aggregated values */
type Aggregates<A extends object> = { [K in keyof A]: A[K] };

/** Represents a node together with its aggregated sub-graph */
export type AggregatedNode<
  G extends object,
  K extends keyof G,
  A extends object
> = ConnectedNode<G, K> & { aggregates: Aggregates<A> };

/** Represents a graph whose parent nodes are aggregated */
export type AggregatedGraph<G extends object, A extends object> = {
  nodes: AggregatedNode<G, keyof G, A>[];
};

/** Aggregates the parent nodes of a directed graph
 *
 * Aggregates both children and parents of the selected nodes.
 */
export const aggregate = <G extends object, A extends object>(
  graph: DirectedGraph<G>,
  reducers: Reducers<G, A>
): AggregatedGraph<G, A> => {
  const dfsValues = (direction: Direction) =>
    dfs(graph, direction, {
      init: (node) =>
        mapValues(reducers, (r) => {
          const memo = r.initialize();
          r.reduce(memo, r.toValue(node));
          return memo;
        }),
      combine: (left, right) => {
        for (const [agg, reducer] of Object.entries(reducers)) {
          const aggKey = agg as keyof A;
          const memo = left[aggKey];
          const redux = reducer as Reducer<G, any, any, A[keyof A]>;
          for (const value of redux.span(right[aggKey])) {
            redux.reduce(memo, value);
          }
        }
        return left;
      },
    });
  const nodes: AggregatedNode<G, keyof G, A>[] = [];
  const childDfs = dfsValues("children");
  const parentDfs = dfsValues("parents");
  for (const parent of graph.nodes) {
    const childValues = childDfs[keyOf(parent)];
    const parentValues = parentDfs[keyOf(parent)];
    const aggregates = {} as Aggregates<A>;
    for (const key of Object.keys(childValues)) {
      const aggKey = key as keyof A;
      const redux = reducers[aggKey];
      const memo = redux.initialize();
      redux.reduce(memo, redux.toValue(parent));
      for (const value of redux.span(childValues[aggKey])) {
        redux.reduce(memo, value);
      }
      for (const value of redux.span(parentValues[aggKey])) {
        redux.reduce(memo, value);
      }
      aggregates[aggKey] = redux.finalize(memo);
    }
    nodes.push({ ...parent, aggregates });
  }
  return { nodes };
};

/** Aggregates all nodes of a given type, reachable from the
 *  parent node, into an array
 */
export const array = <G extends object, K extends keyof G>(
  tpe: K
): Reducer<G, string[], string | undefined, string[]> => ({
  finalize: identity,
  initialize: () => [],
  reduce: (memo, value) => {
    if (value) memo.push(value);
  },
  span: identity,
  toValue: (node) => (isNode(tpe)(node) ? node.key : undefined),
});

/** Aggregates all node ids of a given type, reachable from the
 *  parent node, into a set
 */
export const distinct = <G extends object, K extends keyof G, T>(
  tpe: K,
  extract: (node: Node<G, K>) => T | undefined,
  key: (item: T) => string
): Reducer<G, Map<string, T>, T | undefined, T[]> => ({
  finalize: (values) => [...values.values()],
  initialize: () => new Map(),
  reduce: (memo, value) => {
    if (value) memo.set(key(value), value);
  },
  span: (memo) => memo.values(),
  toValue: (node) => (isNode(tpe)(node) ? extract(node) : undefined),
});

/** Counts the result of an aggregation */
export const count = <G extends object>(
  inner: Reducer<G, any, any, Iterable<any>>
): Reducer<G, any, any, number> => ({
  ...inner,
  finalize: (value) => size(value),
});

/** Takes the first element of an aggregation */
export const first = <G extends object, T>(
  inner: Reducer<G, any, any, Iterable<T>>
): Reducer<G, any, any, T> => ({ ...inner, finalize: (value) => value[0] });

/** Takes the max of values for all node ids of a given type */
export const max = <G extends object, K extends keyof G>(
  type: K,
  toNumber: (value: G[K]) => number
): Reducer<G, { value: number }, number, number> => ({
  finalize: (memo) => memo.value,
  initialize: () => ({ value: NaN }),
  reduce: (memo, value) => {
    // Like Math.max, but returns finite values if one value is NaN
    memo.value = isNaN(memo.value)
      ? value
      : isNaN(value)
      ? memo.value
      : Math.max(memo.value, value);
  },
  span: function* (memo) {
    yield memo.value;
  },
  toValue: (node) => (isNode(type)(node) ? toNumber(node.data) : NaN),
});

/** Sums values for all node ids of a given type */
export const sum = <G extends object, K extends keyof G>(
  tpe: K,
  toNumber: (value: G[K]) => number
): Reducer<G, { value: number }, number, number> => ({
  finalize: (memo) => memo.value,
  initialize: () => ({ value: 0 }),
  reduce: (memo, value) => (memo.value += value),
  span: (memo) => [memo.value],
  toValue: (node) => (isNode(tpe)(node) ? toNumber(node.data) : 0),
});

type PaintedValue<V> = {
  color: string | undefined;
  isPaintNode: boolean;
  value: V;
};
type PaintedMemo<M> = {
  newColor: string | undefined;
  painted: Record<string, M>;
  unpainted: M;
};

/** Groups by labels
 *
 * Grouping is by traversal through node labels. Labeling nodes are
 * aggregated.
 */
export const paint = <G extends object, M, V, R>(
  colorOf: (node: Node<G, keyof G>) => string | false | null | undefined,
  inner: Reducer<G, M, V, R>
): Reducer<G, PaintedMemo<M>, PaintedValue<V>, Record<string, R>> => ({
  finalize: (memo) => mapValues(memo.painted, inner.finalize),
  initialize: () => ({
    newColor: undefined,
    painted: {},
    unpainted: inner.initialize(),
  }),
  reduce: (memo, { color, isPaintNode, value }) => {
    // Three possibilities:
    // - We are lower on the graph than the painting node; all `color` values will
    //   be undefined.
    // - We are on the painting node; set the color of all child nodes to the
    //   paint value.
    // - We are higher on the graph than the painting node; only keep already
    //   colored nodes.

    memo.newColor ??= isPaintNode ? color : undefined;

    const paint = color ?? memo.newColor;

    if (paint) {
      memo.painted[paint] ??= inner.initialize();
      inner.reduce(memo.painted[paint], value);
    } else {
      inner.reduce(memo.unpainted, value);
    }
  },
  span: function* (memo) {
    for (const [color, innerMemo] of Object.entries(memo.painted)) {
      for (const value of inner.span(innerMemo)) {
        yield { color, isPaintNode: false, value };
      }
    }
    for (const value of inner.span(memo.unpainted)) {
      yield { color: undefined, isPaintNode: false, value };
    }
  },
  toValue: (node) => {
    const color = colorOf(node);
    const value = inner.toValue(node);
    return { color: color ? color : undefined, isPaintNode: !!color, value };
  },
});
