/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.model.crf.crfpp;

import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.crf.crfpp.CRFEncoderThread;
import com.hankcs.hanlp.model.crf.crfpp.EncoderFeatureIndex;
import com.hankcs.hanlp.model.crf.crfpp.LbfgsOptimizer;
import com.hankcs.hanlp.model.crf.crfpp.TaggerImpl;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class Encoder {
    public static int MODEL_VERSION = 100;

    public boolean learn(String templFile, String trainFile, String modelFile, boolean textModelFile, int maxitr, int freq, double eta, double C, int threadNum, int shrinkingSize, Algorithm algorithm) {
        if (eta <= 0.0) {
            System.err.println("eta must be > 0.0");
            return false;
        }
        if (C < 0.0) {
            System.err.println("C must be >= 0.0");
            return false;
        }
        if (shrinkingSize < 1) {
            System.err.println("shrinkingSize must be >= 1");
            return false;
        }
        if (threadNum <= 0) {
            System.err.println("thread must be  > 0");
            return false;
        }
        EncoderFeatureIndex featureIndex = new EncoderFeatureIndex(threadNum);
        ArrayList<TaggerImpl> x = new ArrayList<TaggerImpl>();
        if (!featureIndex.open(templFile, trainFile)) {
            System.err.println("Fail to open " + templFile + " " + trainFile);
        }
        BufferedReader br = null;
        try {
            InputStreamReader isr = new InputStreamReader(IOUtil.newInputStream(trainFile), "UTF-8");
            br = new BufferedReader(isr);
            int lineNo = 0;
            while (true) {
                TaggerImpl tagger = new TaggerImpl(TaggerImpl.Mode.LEARN);
                tagger.open(featureIndex);
                TaggerImpl.ReadStatus status = tagger.read(br);
                if (status == TaggerImpl.ReadStatus.ERROR) {
                    System.err.println("error when reading " + trainFile);
                    return false;
                }
                if (!tagger.empty()) {
                    if (!tagger.shrink()) {
                        System.err.println("fail to build feature index ");
                        return false;
                    }
                } else {
                    if (status != TaggerImpl.ReadStatus.EOF) continue;
                    break;
                }
                tagger.setThread_id_(lineNo % threadNum);
                x.add(tagger);
                if (++lineNo % 100 != 0) continue;
                System.out.println(lineNo + ".. ");
            }
            br.close();
        }
        catch (IOException e) {
            System.err.println("train file " + trainFile + " does not exist.");
            return false;
        }
        featureIndex.shrink(freq, x);
        double[] alpha = new double[featureIndex.size()];
        Arrays.fill(alpha, 0.0);
        featureIndex.setAlpha_(alpha);
        System.out.println("Number of sentences: " + x.size());
        System.out.println("Number of features:  " + featureIndex.size());
        System.out.println("Number of thread(s): " + threadNum);
        System.out.println("Freq:                " + freq);
        System.out.println("eta:                 " + eta);
        System.out.println("C:                   " + C);
        System.out.println("shrinking size:      " + shrinkingSize);
        switch (algorithm) {
            case CRF_L1: {
                if (this.runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, true)) break;
                System.err.println("CRF_L1 execute error");
                return false;
            }
            case CRF_L2: {
                if (this.runCRF(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum, false)) break;
                System.err.println("CRF_L2 execute error");
                return false;
            }
            case MIRA: {
                if (this.runMIRA(x, featureIndex, alpha, maxitr, C, eta, shrinkingSize, threadNum)) break;
                System.err.println("MIRA execute error");
                return false;
            }
        }
        if (!featureIndex.save(modelFile, textModelFile)) {
            System.err.println("Failed to save model");
        }
        System.out.println("Done!");
        return true;
    }

    private boolean runCRF(List<TaggerImpl> x, EncoderFeatureIndex featureIndex, double[] alpha, int maxItr, double C, double eta, int shrinkingSize, int threadNum, boolean orthant) {
        double oldObj = 1.0E37;
        int converge = 0;
        LbfgsOptimizer lbfgs = new LbfgsOptimizer();
        ArrayList<CRFEncoderThread> threads = new ArrayList<CRFEncoderThread>();
        int i = 0;
        while (i < threadNum) {
            CRFEncoderThread thread = new CRFEncoderThread(alpha.length);
            thread.start_i = i++;
            thread.size = x.size();
            thread.threadNum = threadNum;
            thread.x = x;
            threads.add(thread);
        }
        int all = 0;
        for (int i2 = 0; i2 < x.size(); ++i2) {
            all += x.get(i2).size();
        }
        ExecutorService executor = Executors.newFixedThreadPool(threadNum);
        for (int itr = 0; itr < maxItr; ++itr) {
            int k;
            int i3;
            featureIndex.clear();
            try {
                executor.invokeAll(threads);
            }
            catch (Exception e) {
                e.printStackTrace();
                return false;
            }
            for (i3 = 1; i3 < threadNum; ++i3) {
                ((CRFEncoderThread)threads.get((int)0)).obj += ((CRFEncoderThread)threads.get((int)i3)).obj;
                ((CRFEncoderThread)threads.get((int)0)).err += ((CRFEncoderThread)threads.get((int)i3)).err;
                ((CRFEncoderThread)threads.get((int)0)).zeroone += ((CRFEncoderThread)threads.get((int)i3)).zeroone;
            }
            for (i3 = 1; i3 < threadNum; ++i3) {
                for (k = 0; k < featureIndex.size(); ++k) {
                    int n = k;
                    ((CRFEncoderThread)threads.get((int)0)).expected[n] = ((CRFEncoderThread)threads.get((int)0)).expected[n] + ((CRFEncoderThread)threads.get((int)i3)).expected[k];
                }
            }
            int numNonZero = 0;
            if (orthant) {
                for (k = 0; k < featureIndex.size(); ++k) {
                    ((CRFEncoderThread)threads.get((int)0)).obj += Math.abs(alpha[k] / C);
                    if (alpha[k] == 0.0) continue;
                    ++numNonZero;
                }
            } else {
                numNonZero = featureIndex.size();
                for (k = 0; k < featureIndex.size(); ++k) {
                    ((CRFEncoderThread)threads.get((int)0)).obj += alpha[k] * alpha[k] / (2.0 * C);
                    int n = k;
                    ((CRFEncoderThread)threads.get((int)0)).expected[n] = ((CRFEncoderThread)threads.get((int)0)).expected[n] + alpha[k] / C;
                }
            }
            for (int i4 = 1; i4 < threadNum; ++i4) {
                ((CRFEncoderThread)threads.get((int)i4)).expected = null;
            }
            double diff = itr == 0 ? 1.0 : Math.abs(oldObj - ((CRFEncoderThread)threads.get((int)0)).obj) / oldObj;
            StringBuilder b = new StringBuilder();
            b.append("iter=").append(itr);
            b.append(" terr=").append(1.0 * (double)((CRFEncoderThread)threads.get((int)0)).err / (double)all);
            b.append(" serr=").append(1.0 * (double)((CRFEncoderThread)threads.get((int)0)).zeroone / (double)x.size());
            b.append(" act=").append(numNonZero);
            b.append(" obj=").append(((CRFEncoderThread)threads.get((int)0)).obj);
            b.append(" diff=").append(diff);
            System.out.println(b.toString());
            oldObj = ((CRFEncoderThread)threads.get((int)0)).obj;
            converge = diff < eta ? ++converge : 0;
            if (itr > maxItr || converge == 3) break;
            int ret = lbfgs.optimize(featureIndex.size(), alpha, ((CRFEncoderThread)threads.get((int)0)).obj, ((CRFEncoderThread)threads.get((int)0)).expected, orthant, C);
            if (ret > 0) continue;
            return false;
        }
        executor.shutdown();
        try {
            executor.awaitTermination(-1L, TimeUnit.SECONDS);
        }
        catch (Exception e) {
            e.printStackTrace();
            System.err.println("fail waiting executor to shutdown");
        }
        return true;
    }

    public boolean runMIRA(List<TaggerImpl> x, EncoderFeatureIndex featureIndex, double[] alpha, int maxItr, double C, double eta, int shrinkingSize, int threadNum) {
        Object[] shrinkArr = new Integer[x.size()];
        Arrays.fill(shrinkArr, (Object)0);
        List<Object> shrink = Arrays.asList(shrinkArr);
        Object[] upperArr = new Double[x.size()];
        Arrays.fill(upperArr, (Object)0.0);
        List<Object> upperBound = Arrays.asList(upperArr);
        Double[] expectArr = new Double[featureIndex.size()];
        List<Double> expected = Arrays.asList(expectArr);
        if (threadNum > 1) {
            System.err.println("WARN: MIRA does not support multi-threading");
        }
        int converge = 0;
        int all = 0;
        for (int i = 0; i < x.size(); ++i) {
            all += x.get(i).size();
        }
        for (int itr = 0; itr < maxItr; ++itr) {
            int zeroone = 0;
            int err = 0;
            int activeSet = 0;
            int upperActiveSet = 0;
            double maxKktViolation = 0.0;
            for (int i = 0; i < x.size(); ++i) {
                if ((Integer)shrink.get(i) >= shrinkingSize) continue;
                ++activeSet;
                for (int t = 0; t < expected.size(); ++t) {
                    expected.set(t, 0.0);
                }
                double costDiff = x.get(i).collins(expected);
                int errorNum = x.get(i).eval();
                err += errorNum;
                if (errorNum != 0) {
                    ++zeroone;
                }
                if (errorNum == 0) {
                    shrink.set(i, (Integer)shrink.get(i) + 1);
                    continue;
                }
                shrink.set(i, 0);
                double s = 0.0;
                for (int k = 0; k < expected.size(); ++k) {
                    s += expected.get(k) * expected.get(k);
                }
                double mu = Math.max(0.0, ((double)errorNum - costDiff) / s);
                if ((Double)upperBound.get(i) + mu > C) {
                    mu = C - (Double)upperBound.get(i);
                    ++upperActiveSet;
                } else {
                    maxKktViolation = Math.max((double)errorNum - costDiff, maxKktViolation);
                }
                if (!(mu > 1.0E-10)) continue;
                upperBound.set(i, (Double)upperBound.get(i) + mu);
                upperBound.set(i, Math.min(C, (Double)upperBound.get(i)));
                for (int k = 0; k < expected.size(); ++k) {
                    int n = k;
                    alpha[n] = alpha[n] + mu * expected.get(k);
                }
            }
            double obj = 0.0;
            for (int i = 0; i < featureIndex.size(); ++i) {
                obj += alpha[i] * alpha[i];
            }
            StringBuilder b = new StringBuilder();
            b.append("iter=").append(itr);
            b.append(" terr=").append(1.0 * (double)err / (double)all);
            b.append(" serr=").append(1.0 * (double)zeroone / (double)x.size());
            b.append(" act=").append(activeSet);
            b.append(" uact=").append(upperActiveSet);
            b.append(" obj=").append(obj);
            b.append(" kkt=").append(maxKktViolation);
            System.out.println(b.toString());
            if (maxKktViolation <= 0.0) {
                for (int i = 0; i < shrink.size(); ++i) {
                    shrink.set(i, 0);
                }
                ++converge;
            } else {
                converge = 0;
            }
            if (itr > maxItr || converge == 2) break;
        }
        return true;
    }

    public static void main(String[] args) {
        if (args.length < 3) {
            System.err.println("incorrect No. of args");
            return;
        }
        String templFile = args[0];
        String trainFile = args[1];
        String modelFile = args[2];
        Encoder enc = new Encoder();
        long time1 = new Date().getTime();
        if (!enc.learn(templFile, trainFile, modelFile, false, 100000, 1, 1.0E-4, 1.0, 1, 20, Algorithm.CRF_L2)) {
            System.err.println("error training model");
            return;
        }
        System.out.println(new Date().getTime() - time1);
    }

    public static enum Algorithm {
        CRF_L2,
        CRF_L1,
        MIRA;


        public static Algorithm fromString(String algorithm) {
            if ((algorithm = algorithm.toLowerCase()).equals("crf") || algorithm.equals("crf-l2")) {
                return CRF_L2;
            }
            if (algorithm.equals("crf-l1")) {
                return CRF_L1;
            }
            if (algorithm.equals("mira")) {
                return MIRA;
            }
            throw new IllegalArgumentException("invalid algorithm: " + algorithm);
        }
    }
}

