/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.optimizer;

import com.tencent.angel.ml.math2.ufuncs.OptFuncs;
import com.tencent.angel.ml.math2.ufuncs.Ufuncs;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.psf.optimizer.OptMMUpdateFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class FTRLUpdateFunc
extends OptMMUpdateFunc {
    private static final Log LOG = LogFactory.getLog(FTRLUpdateFunc.class);

    public FTRLUpdateFunc() {
    }

    public FTRLUpdateFunc(int matId, int factor, double alpha, double beta, double lambda1, double lambda2, int epoch) {
        this(matId, factor, alpha, beta, lambda1, lambda2, epoch, 1);
    }

    public FTRLUpdateFunc(int matId, int factor, double alpha, double beta, double lambda1, double lambda2, int epoch, int batchSize) {
        super(matId, new int[]{factor}, new double[]{alpha, beta, lambda1, lambda2, epoch, batchSize});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(RowBasedPartition partition, int factor, double[] scalars) {
        double alpha = scalars[0];
        double beta = scalars[1];
        double lambda1 = scalars[2];
        double lambda2 = scalars[3];
        int epoch = (int)scalars[4];
        int batchSize = (int)scalars[5];
        for (int f = 0; f < factor; ++f) {
            ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
            try {
                gradientServerRow.startWrite();
                Vector weight = ServerRowUtils.getVector((ServerRow)partition.getRow(f));
                Vector zModel = ServerRowUtils.getVector((ServerRow)partition.getRow(f + factor));
                Vector nModel = ServerRowUtils.getVector((ServerRow)partition.getRow(f + 2 * factor));
                Vector gradient = ServerRowUtils.getVector((ServerRow)gradientServerRow);
                if (batchSize > 1) {
                    gradient.idiv((double)batchSize);
                }
                Vector delta = OptFuncs.ftrldelta((Vector)nModel, (Vector)gradient, (double)alpha);
                Ufuncs.iaxpy2((Vector)nModel, (Vector)gradient, (double)1.0);
                zModel.iadd(gradient.sub(delta.mul(weight)));
                Vector newWeight = Ufuncs.ftrlthreshold((Vector)zModel, (Vector)nModel, (double)alpha, (double)beta, (double)lambda1, (double)lambda2);
                weight.setStorage(newWeight.getStorage());
                gradient.clear();
                continue;
            }
            finally {
                gradientServerRow.endWrite();
            }
        }
    }
}

