/*
 * Decompiled with CFR 0.152.
 */
package edu.umd.cloud9.util.cfd;

import edu.umd.cloud9.io.pair.PairOfInts;
import edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution;
import edu.umd.cloud9.util.fd.Int2IntFrequencyDistributionFastutil;
import edu.umd.cloud9.util.fd.Int2LongFrequencyDistributionFastutil;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;

public class Int2IntConditionalFrequencyDistributionFastutil
implements Int2IntConditionalFrequencyDistribution {
    private final Int2ObjectMap<Int2IntFrequencyDistributionFastutil> distributions = new Int2ObjectOpenHashMap();
    private final Int2LongFrequencyDistributionFastutil marginals = new Int2LongFrequencyDistributionFastutil();
    private long sumOfAllFrequencies = 0L;

    @Override
    public void set(int k, int cond, int v) {
        if (!this.distributions.containsKey(cond)) {
            Int2IntFrequencyDistributionFastutil fd = new Int2IntFrequencyDistributionFastutil();
            fd.set(k, v);
            this.distributions.put(cond, (Object)fd);
            this.marginals.increment(k, v);
            this.sumOfAllFrequencies += (long)v;
        } else {
            Int2IntFrequencyDistributionFastutil fd = (Int2IntFrequencyDistributionFastutil)this.distributions.get(cond);
            int rv = fd.get(k);
            fd.set(k, v);
            this.distributions.put(cond, (Object)fd);
            this.marginals.increment(k, -rv + v);
            this.sumOfAllFrequencies = this.sumOfAllFrequencies - (long)rv + (long)v;
        }
    }

    @Override
    public void increment(int k, int cond) {
        this.increment(k, cond, 1);
    }

    @Override
    public void increment(int k, int cond, int v) {
        int cur = this.get(k, cond);
        if (cur == 0) {
            this.set(k, cond, v);
        } else {
            this.set(k, cond, cur + v);
        }
    }

    @Override
    public int get(int k, int cond) {
        if (!this.distributions.containsKey(cond)) {
            return 0;
        }
        return ((Int2IntFrequencyDistributionFastutil)this.distributions.get(cond)).get(k);
    }

    @Override
    public long getMarginalCount(int k) {
        return this.marginals.get(k);
    }

    @Override
    public Int2IntFrequencyDistributionFastutil getConditionalDistribution(int cond) {
        if (this.distributions.containsKey(cond)) {
            return (Int2IntFrequencyDistributionFastutil)this.distributions.get(cond);
        }
        return new Int2IntFrequencyDistributionFastutil();
    }

    @Override
    public long getSumOfAllCounts() {
        return this.sumOfAllFrequencies;
    }

    @Override
    public void check() {
        Int2IntFrequencyDistributionFastutil m = new Int2IntFrequencyDistributionFastutil();
        long totalSum = 0L;
        for (Int2IntFrequencyDistributionFastutil fd : this.distributions.values()) {
            long conditionalSum = 0L;
            for (PairOfInts pair : fd) {
                conditionalSum += (long)pair.getRightElement();
                m.increment(pair.getLeftElement(), pair.getRightElement());
            }
            if (conditionalSum != fd.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            totalSum += fd.getSumOfCounts();
        }
        if (totalSum != this.getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + totalSum + ", Expected " + this.getSumOfAllCounts());
        }
        for (PairOfInts e : m) {
            if ((long)e.getRightElement() == this.marginals.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
        for (PairOfInts e : m) {
            if (e.getRightElement() == m.get(e.getLeftElement())) continue;
            throw new RuntimeException("Internal Error!");
        }
    }
}

