package com.xzchaoo.commons.basic.heap;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * 数据结构-堆, 当节点的值变化时, 支持更新操作.
 * <p>created at 2020-08-11
 *
 * @author xiangfeng.xzc
 * @since 1.0.3
 */
public class Heap<N> {
    private final List<N>         nodes = new ArrayList<>();
    private final NodeFunction<N> nf;

    public Heap(NodeFunction<N> nf) {
        this.nf = Objects.requireNonNull(nf);
    }

    public void init() {
        for (int i = size() / 2; i >= 0; i--) {
            siftDown(nodes.get(i));
        }
    }

    public void initPush(N n) {
        nf.setIndex(n, nodes.size());
        nodes.add(n);
    }

    public N peek() {
        return nodes.isEmpty() ? null : nodes.get(0);
    }

    public N pop() {
        int size = nodes.size();
        if (size == 0) {
            throw new IllegalStateException("heap is empty");
        }
        N n = nodes.get(0);
        if (size == 1) {
            nodes.clear();
        } else {
            N last = nodes.remove(size - 1);
            nodes.set(0, last);
            nf.setIndex(last, 0);
            siftDown(last);
        }
        return n;
    }

    /**
     * 用户需要自己保证不重复
     *
     * @param n node
     */
    public void push(N n) {
        nf.setIndex(n, nodes.size());
        nodes.add(n);
        siftUp(n);
    }

    public boolean isEmpty() {
        return nodes.isEmpty();
    }

    public int size() {
        return nodes.size();
    }

    public void update(N n) {
        int oldIndex = nf.getIndex(n);
        // 尝试向上, 如果失败就向下调整
        siftUp(n);
        if (oldIndex == nf.getIndex(n)) {
            siftDown(n);
        }
    }

    private void siftDown(N node) {
        int index = nf.getIndex(node);
        int ns = nodes.size();
        int downIndex = (index << 1) + 1;
        while (downIndex < ns) {
            N downNode = nodes.get(downIndex);
            int downIndex2 = downIndex + 1;
            if (downIndex2 < ns) {
                N downNode2 = nodes.get(downIndex2);
                if (nf.compare(downNode2, downNode) < 0) {
                    downIndex = downIndex2;
                    downNode = downNode2;
                }
            }
            if (nf.compare(downNode, node) < 0) {
                nodes.set(index, downNode);
                nf.setIndex(downNode, index);
                index = downIndex;
                downIndex = (downIndex << 1) + 1;
            } else {
                break;
            }
        }
        nodes.set(index, node);
        nf.setIndex(node, index);
    }

    private void siftUp(N node) {
        int index = nf.getIndex(node);
        while (index > 0) {
            int upIndex = (index - 1) >> 1;
            N upNode = nodes.get(upIndex);
            if (nf.compare(node, upNode) < 0) {
                nodes.set(index, upNode);
                nf.setIndex(upNode, index);
                index = upIndex;
            } else {
                break;
            }
        }
        nodes.set(index, node);
        nf.setIndex(node, index);
    }
}
