/*
 * Decompiled with CFR 0.152.
 */
package edu.umd.hooka.alignment.hmm;

import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.Array2D;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.PartialCountContainer;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.alignment.ZeroProbabilityException;
import edu.umd.hooka.alignment.hmm.ATable;
import edu.umd.hooka.alignment.hmm.IntArray2D;
import edu.umd.hooka.alignment.model1.Model1;
import edu.umd.hooka.ttables.TTable;
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;

public class HMM
extends Model1 {
    public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999);
    static final int MAX_LENGTH = 500;
    static final float THRESH = 0.5f;
    Array2D emission = new Array2D(250000);
    IntArray2D e_coords = new IntArray2D(250000);
    IntArray2D e_words = new IntArray2D(250000);
    Array2D transition = new Array2D(250000);
    IntArray2D transition_coords = new IntArray2D(250000);
    Array2D alphas = new Array2D(250000);
    Array2D betas = new Array2D(250000);
    Array2D viterbi = new Array2D(250000);
    IntArray2D backtrace = new IntArray2D(250000);
    ATable amodel;
    ATable acounts;
    int l = -1;
    int m = -1;
    AlignmentPosteriorGrid m1_post = null;

    public void setModel1Posteriors(AlignmentPosteriorGrid m1pg) {
        this.m1_post = m1pg;
    }

    protected HMM(TTable ttable, ATable atable, boolean useNull) {
        super(ttable, useNull);
        this.amodel = atable;
        this.acounts = (ATable)this.amodel.clone();
        this.acounts.clear();
    }

    public HMM(TTable ttable, ATable atable) {
        super(ttable, false);
        this.amodel = atable;
        this.acounts = (ATable)this.amodel.clone();
        this.acounts.clear();
    }

    @Override
    public void writePartialCounts(OutputCollector<IntWritable, PartialCountContainer> output) throws IOException {
        super.writePartialCounts(output);
        PartialCountContainer pcc = new PartialCountContainer();
        pcc.setContent(this.acounts);
        output.collect((Object)ACOUNT_VOC_ID, (Object)pcc);
        this.acounts.clear();
    }

    public void buildHMMTables(PhrasePair pp) {
        int[] es = pp.getE().getWords();
        int[] fs = pp.getF().getWords();
        this.l = es.length;
        this.m = fs.length;
        this.emission.resize(this.m + 1, this.l + 1);
        this.e_coords.resize(this.m + 1, this.l + 1);
        this.e_words.resize(this.m + 1, this.l + 1);
        this.e_words.fill(-1);
        this.e_coords.fill(-1);
        for (int i = 1; i <= this.l; ++i) {
            int ei = es[i - 1];
            for (int j = 1; j <= this.m; ++j) {
                int fj = fs[j - 1];
                this.e_coords.set(j, i, i);
                this.emission.set(j, i, this.tmodel.get(ei, fj));
                this.e_words.set(j, i, i - 1);
            }
        }
        this.transition.resize(this.l + 1, this.l + 1);
        this.transition_coords.resize(this.l + 1, this.l + 1);
        this.transition_coords.fill(-1);
        for (int i_prev = 0; i_prev <= this.l; ++i_prev) {
            for (int i = 1; i <= this.l; ++i) {
                this.transition_coords.set(i_prev, i, this.amodel.getCoord(i - i_prev, (char)this.l));
                this.transition.set(i_prev, i, this.amodel.get(i - i_prev, (char)this.l));
            }
        }
    }

    public final int getNumStates() {
        return this.transition.getSize2();
    }

    public final float getTransitionProb(int s_prev, int s) {
        return this.transition.get(s_prev, s);
    }

    public final float getEmissionProb(int j, int s) {
        return this.emission.get(j, s);
    }

    public final void addPartialJumpCountsToATable(ATable ac) {
        ac.plusEquals(this.acounts);
    }

    @Override
    public void processTrainingInstance(PhrasePair pp, Reporter r) {
        if (pp.getE().size() >= this.amodel.getMaxDist() - 1) {
            return;
        }
        if (pp.getF().size() >= this.amodel.getMaxDist() - 1) {
            return;
        }
        if (pp.getE().size() == 0) {
            return;
        }
        if (pp.getF().size() == 0) {
            return;
        }
        this.buildHMMTables(pp);
        float totalLogProb = this.baumWelch(pp, null);
        if (r != null) {
            r.incrCounter((Enum)CrossEntropyCounters.LOGPROB, (long)(-totalLogProb));
            r.incrCounter((Enum)CrossEntropyCounters.WORDCOUNT, (long)pp.getF().size());
        }
    }

    public final float baumWelch(PhrasePair pp, AlignmentPosteriorGrid pg) {
        int j;
        boolean use_m1;
        float m1penalty;
        float m1boost;
        int s;
        int j2;
        this.initializeCountTableForSentencePair(pp);
        int[] obs = pp.getF().getWords();
        int J = obs.length + 1;
        int numStates = this.getNumStates();
        int l = pp.getE().getWords().length;
        float[] anorms = new float[J];
        this.alphas.resize(J + 1, this.getNumStates());
        this.betas.resize(J + 1, this.getNumStates());
        this.alphas.set(0, 0, 1.0f);
        anorms[0] = 1.0f;
        Alignment m1a = null;
        if (this.m1_post != null) {
            m1a = this.m1_post.alignPosteriorThreshold(0.5f);
        }
        for (j2 = 1; j2 < J; ++j2) {
            for (s = 0; s < numStates; ++s) {
                float alpha = 0.0f;
                m1boost = 1.0f;
                m1penalty = 0.0f;
                use_m1 = false;
                if (m1a != null && m1a.isFAligned(j2 - 1)) {
                    float m1post = 0.0f;
                    use_m1 = true;
                    for (int i = 0; i < l; ++i) {
                        if (!m1a.aligned(j2 - 1, i)) continue;
                        m1post = this.m1_post.getAlignmentPointPosterior(j2 - 1, i + 1);
                    }
                    m1boost = (float)Math.sqrt(m1post);
                    m1penalty = 1.0f - m1boost;
                }
                for (int s_prev = 0; s_prev < numStates; ++s_prev) {
                    float trans = this.getTransitionProb(s_prev, s);
                    if (use_m1) {
                        trans = s <= l && s > 0 && m1a.aligned(j2 - 1, s - 1) ? m1boost : (trans *= m1penalty);
                    }
                    alpha += this.alphas.get(j2 - 1, s_prev) * trans;
                }
                this.alphas.set(j2, s, alpha *= this.getEmissionProb(j2, s));
            }
            try {
                anorms[j2] = this.alphas.normalizeColumn(j2);
                continue;
            }
            catch (ZeroProbabilityException ex) {
                this.notifyUnalignablePair(pp, ex.getMessage());
                return 0.0f;
            }
        }
        for (int s2 = 1; s2 < numStates; ++s2) {
            this.betas.set(J - 1, s2, 1.0f);
        }
        for (j2 = J - 2; j2 >= 1; --j2) {
            for (s = 0; s < numStates; ++s) {
                float beta = 0.0f;
                m1boost = 1.0f;
                m1penalty = 0.0f;
                use_m1 = false;
                if (m1a != null && m1a.isFAligned(j2 - 1)) {
                    float m1post = 0.0f;
                    use_m1 = true;
                    for (int i = 0; i < l; ++i) {
                        if (!m1a.aligned(j2 - 1, i)) continue;
                        m1post = this.m1_post.getAlignmentPointPosterior(j2 - 1, i + 1);
                    }
                    m1boost = (float)Math.sqrt(m1post);
                    m1penalty = 1.0f - m1boost;
                }
                for (int s_next = 0; s_next < numStates; ++s_next) {
                    float trans = this.getTransitionProb(s, s_next);
                    if (use_m1) {
                        trans = s <= l && s > 0 && m1a.aligned(j2 - 1, s - 1) ? m1boost : (trans *= m1penalty);
                    }
                    beta += this.betas.get(j2 + 1, s_next) * trans * this.getEmissionProb(j2 + 1, s_next);
                }
                this.betas.set(j2, s, beta /= anorms[j2]);
            }
        }
        float[] totalProb = new float[J];
        for (j = 1; j < J; ++j) {
            int s3;
            float tp = 0.0f;
            for (s3 = 0; s3 < numStates; ++s3) {
                tp += this.betas.get(j, s3) * this.alphas.get(j, s3);
            }
            totalProb[j] = tp;
            for (s3 = 0; s3 < numStates; ++s3) {
                int iplus1 = this.e_coords.get(j, s3);
                if (iplus1 == -1) continue;
                float pc = this.betas.get(j, s3) * this.alphas.get(j, s3) / tp;
                if (pg != null) {
                    int e = 0;
                    if (s3 <= l) {
                        e = s3;
                    }
                    if (s3 == 0) continue;
                    float p = pg.getAlignmentPointPosterior(j - 1, e) + pc;
                    pg.setAlignmentPointPosterior(j - 1, e, p);
                    continue;
                }
                try {
                    this.addTranslationCount(iplus1, j - 1, pc);
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException("J=" + J + ", numStates=" + numStates + ": Failed to add (" + iplus1 + "," + (j - 1) + ") += " + pc + " s=" + s3 + " pp=" + pp + "\n E:\n" + this.e_coords);
                }
            }
        }
        if (pg == null) {
            for (j = 1; j < J - 1; ++j) {
                for (int s_prev = 0; s_prev < numStates; ++s_prev) {
                    for (int s4 = 0; s4 < numStates; ++s4) {
                        int tc = this.transition_coords.get(s_prev, s4);
                        if (tc == -1) continue;
                        float m1boost2 = 1.0f;
                        float m1penalty2 = 0.0f;
                        boolean use_m12 = false;
                        if (m1a != null && m1a.isFAligned(j - 1)) {
                            float m1post = 0.0f;
                            use_m12 = true;
                            for (int i = 0; i < l; ++i) {
                                if (!m1a.aligned(j - 1, i)) continue;
                                m1post = this.m1_post.getAlignmentPointPosterior(j - 1, i + 1);
                            }
                            m1boost2 = (float)Math.sqrt(m1post);
                            m1penalty2 = 1.0f - m1boost2;
                        }
                        float trans = this.getTransitionProb(s_prev, s4);
                        if (use_m12) {
                            trans = s4 <= l && s4 > 0 && m1a.aligned(j - 1, s4 - 1) ? m1boost2 : (trans *= m1penalty2);
                        }
                        if (use_m12) continue;
                        float pc = this.alphas.get(j, s_prev) * trans * this.emission.get(j + 1, s4) / anorms[j + 1] * this.betas.get(j + 1, s4) / totalProb[j + 1];
                        this.acounts.add(tc, (char)l, pc);
                    }
                }
            }
        }
        float tlp = 0.0f;
        for (float n : anorms) {
            tlp = (float)((double)tlp + Math.log(n));
        }
        return tlp;
    }

    @Override
    public AlignmentPosteriorGrid computeAlignmentPosteriors(PhrasePair pp) {
        AlignmentPosteriorGrid res = new AlignmentPosteriorGrid(pp);
        this.buildHMMTables(pp);
        this.baumWelch(pp, res);
        return res;
    }

    @Override
    public Alignment viterbiAlign(PhrasePair sentence, PerplexityReporter reporter) {
        int s;
        this.buildHMMTables(sentence);
        Alignment res = new Alignment(sentence.getF().size(), sentence.getE().size());
        int J = sentence.getF().size() + 1;
        int numStates = this.getNumStates();
        this.viterbi.resize(J, this.getNumStates());
        this.backtrace.resize(J, this.getNumStates());
        this.viterbi.fill(Float.NEGATIVE_INFINITY);
        this.viterbi.set(0, 0, 0.0f);
        int lene = sentence.getE().getWords().length;
        Alignment m1a = null;
        if (this.m1_post != null) {
            m1a = this.m1_post.alignPosteriorThreshold(0.5f);
        }
        for (int j = 1; j < J; ++j) {
            int s2;
            boolean valid = false;
            for (s = 1; s < numStates; ++s) {
                float best = Float.NEGATIVE_INFINITY;
                int best_s = -1;
                double emitLogProb = Math.log(this.emission.get(j, s));
                if (emitLogProb == Double.NEGATIVE_INFINITY) continue;
                for (int s_prev = 0; s_prev < numStates; ++s_prev) {
                    float cur;
                    float m1boost = 1.0f;
                    float m1penalty = 0.0f;
                    boolean use_m1 = false;
                    if (m1a != null && m1a.isFAligned(j - 1)) {
                        float m1post = 0.0f;
                        use_m1 = true;
                        for (int i = 0; i < lene; ++i) {
                            if (!m1a.aligned(j - 1, i)) continue;
                            m1post = this.m1_post.getAlignmentPointPosterior(j - 1, i + 1);
                        }
                        m1boost = (float)Math.sqrt(m1post);
                        m1penalty = 1.0f - m1boost;
                    }
                    float trans = this.getTransitionProb(s_prev, s);
                    if (use_m1) {
                        trans = s <= this.l && s > 0 && m1a.aligned(j - 1, s - 1) ? m1boost : (trans *= m1penalty);
                    }
                    if (!((cur = (float)((double)this.viterbi.get(j - 1, s_prev) + Math.log(trans) + emitLogProb)) > best)) continue;
                    best = cur;
                    best_s = s_prev;
                }
                this.viterbi.set(j, s, best);
                if (best != Float.NEGATIVE_INFINITY) {
                    valid = true;
                }
                this.backtrace.set(j, s, best_s);
            }
            if (valid) continue;
            float best = Float.NEGATIVE_INFINITY;
            int bests = -1;
            for (s2 = 1; s2 < numStates; ++s2) {
                if (!(this.viterbi.get(j - 1, s2) > best)) continue;
                best = this.viterbi.get(j - 1, s2);
                bests = s2;
            }
            for (s2 = 1; s2 < numStates; ++s2) {
                this.viterbi.set(j, s2, 0.0f);
                this.backtrace.set(j, s2, bests);
            }
        }
        float best = Float.NEGATIVE_INFINITY;
        int best_s = -1;
        for (s = 1; s < numStates; ++s) {
            if (!(this.viterbi.get(J - 1, s) > best)) continue;
            best = this.viterbi.get(J - 1, s);
            best_s = s;
        }
        reporter.addFactor(best, J - 1);
        int e = best_s;
        for (int f = J - 1; f > 0; --f) {
            if (e <= 0) {
                throw new ZeroProbabilityException("  Error f=" + f + " e=" + e + "  sentence + \n" + this.viterbi + "\n" + this.emission + "\n" + this.transition + "\n" + this.backtrace);
            }
            if ((double)this.viterbi.get(f, e) < 0.0) {
                try {
                    int af = f - 1;
                    int ae = this.e_words.get(f, e);
                    if (ae >= 0) {
                        res.align(af, ae);
                    }
                }
                catch (RuntimeException ex) {
                    throw new RuntimeException("Caught " + ex + "\nvit(f,e)=" + this.viterbi.get(f, e) + "  size(f,e)=" + sentence.getF().size() + "," + sentence.getE().size() + " Error f=" + f + " e=" + e + "  sentence + \n" + this.viterbi + "\n" + this.emission + "\n" + this.transition + "\n" + this.backtrace + "\n" + this.e_words);
                }
            }
            e = this.backtrace.get(f, e);
        }
        return res;
    }
}

