package shz.model;

import shz.ToList;
import shz.ToMap;
import shz.Validator;

import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

/**
 * 树结构节点
 */
@SuppressWarnings("unchecked")
public class TreeNode<T extends TreeNode<T>> {
    protected String id;
    protected List<T> childes;

    public final int leafCount() {
        if (Validator.isEmpty(childes)) return 1;
        return childes.stream().map(TreeNode::leafCount).reduce(Integer::sum).orElse(0);
    }

    public final String getId() {
        return id;
    }

    public final void setId(String id) {
        this.id = id;
    }

    public final List<T> getChildes() {
        return childes;
    }

    public final void setChildes(List<T> childes) {
        this.childes = childes;
    }

    @Override
    public String toString() {
        return "TreeNode{" +
                "id='" + id + '\'' +
                ", childes=" + childes +
                '}';
    }

    /**
     * 获取分组树
     *
     * @param list          需要分组的数据
     * @param mapper        数据映射器
     * @param lastMapper    最后分组数据映射器
     * @param lastCollector 最后分组数据收集器
     * @param classifiers   分组函数集
     */
    public static <E, K, T extends TreeNode<T>, R> List<T> group(List<E> list, BiFunction<K, List<T>, T> mapper,
                                                                 BiFunction<K, R, T> lastMapper,
                                                                 Collector<E, ?, R> lastCollector,
                                                                 Function<E, K>... classifiers) {
        if (Validator.isEmpty(list)) return Collections.emptyList();
        Collector<E, ?, ? extends Map<K, ?>> collector = Collectors.groupingBy(classifiers[classifiers.length - 1], lastCollector);
        for (int i = classifiers.length - 2; i >= 0; --i) collector = Collectors.groupingBy(classifiers[i], collector);
        return mergeGroup(list.stream().collect(collector), mapper, lastMapper);
    }

    private static <K, T extends TreeNode<T>, R> List<T> mergeGroup(Map<K, ?> group, BiFunction<K, List<T>, T> mapper,
                                                                    BiFunction<K, R, T> lastMapper) {
        return ToList.explicitCollect(group.keySet().stream().map(k -> {
            Object nextGroup = group.get(k);
            return nextGroup instanceof Map
                    ? mapper.apply(k, mergeGroup((Map<K, ?>) nextGroup, mapper, lastMapper))
                    : lastMapper.apply(k, (R) nextGroup);
        }), group.size());
    }

    /**
     * 排序分组树
     */
    public static <T extends TreeNode<T>> void sort(List<T> nodes, Comparator<T> comparator) {
        if (Validator.isEmpty(nodes)) return;
        nodes.sort(comparator);
        nodes.forEach(node -> sort(node.getChildes(), comparator));
    }

    /**
     * 合并分组树
     */
    public static <T extends TreeNode<T>> List<T> merge(List<T> nodes, List<T> otherNodes, BiFunction<T, T, T> merger) {
        if (Validator.isEmpty(nodes)) return Validator.isEmpty(otherNodes) ? Collections.emptyList() : otherNodes;
        if (Validator.isEmpty(otherNodes)) return nodes;
        Map<String, T> map = ToMap.collect(nodes.stream(), TreeNode::getId, Function.identity());
        Map<String, T> otherMap = ToMap.collect(otherNodes.stream(), TreeNode::getId, Function.identity());
        Set<String> ids = new HashSet<>(map.keySet());
        Set<String> otherIds = new HashSet<>(otherMap.keySet());
        if (!otherIds.removeAll(ids) || !otherIds.isEmpty()) ids.addAll(otherIds);
        return ToList.explicitCollect(ids.stream().map(k -> {
            T t = map.get(k);
            T other = otherMap.get(k);
            if (t == null) return other;
            else if (other == null) return t;
            T result = merger.apply(t, other);
            result.setChildes(merge(t.getChildes(), other.getChildes(), merger));
            return result;
        }), ids.size());
    }
}
