/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.optimizer;

import com.tencent.angel.ml.math2.ufuncs.OptFuncs;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.psf.optimizer.OptMMUpdateFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class AdaDeltaUpdateFunc
extends OptMMUpdateFunc {
    private static final Log LOG = LogFactory.getLog(AdaDeltaUpdateFunc.class);

    public AdaDeltaUpdateFunc() {
    }

    public AdaDeltaUpdateFunc(int matId, int factor, double epsilon, double alpha, double beta, double lr, double regL1Param, double regL2Param, int epoch) {
        this(matId, factor, epsilon, alpha, beta, lr, regL1Param, regL2Param, epoch, 1);
    }

    public AdaDeltaUpdateFunc(int matId, int factor, double epsilon, double alpha, double beta, double lr, double regL1Param, double regL2Param, int epoch, int batchSize) {
        super(matId, new int[]{factor}, new double[]{epsilon, alpha, beta, lr, regL1Param, regL2Param, epoch, batchSize});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(RowBasedPartition partition, int factor, double[] scalars) {
        double epsilon = scalars[0];
        double alpha = scalars[1];
        double beta = scalars[2];
        double lr = scalars[3];
        double l1RegParam = scalars[4];
        double l2RegParam = scalars[5];
        double epoch = (int)scalars[6];
        double batchSize = (int)scalars[7];
        for (int f = 0; f < factor; ++f) {
            ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
            try {
                gradientServerRow.startWrite();
                Vector weight = ServerRowUtils.getVector((ServerRow)partition.getRow(f));
                Vector square1 = ServerRowUtils.getVector((ServerRow)partition.getRow(f + factor));
                Vector square2 = ServerRowUtils.getVector((ServerRow)partition.getRow(f + 2 * factor));
                Vector gradient = ServerRowUtils.getVector((ServerRow)gradientServerRow);
                if (batchSize > 1.0) {
                    gradient.idiv(batchSize);
                }
                OptFuncs.iexpsmoothing2((Vector)square1, (Vector)gradient, (double)alpha);
                Vector hessian = OptFuncs.adadeltahessian((Vector)square1, (Vector)square2);
                if (l2RegParam != 0.0) {
                    gradient.iaxpy(weight, l2RegParam);
                }
                OptFuncs.iadadeltadelta((Vector)gradient, (Vector)hessian, (double)l2RegParam);
                weight.isub(gradient);
                OptFuncs.iexpsmoothing2((Vector)square2, (Vector)gradient, (double)beta);
                if (l1RegParam != 0.0) {
                    OptFuncs.iadadeltathredshold((Vector)weight, (Vector)hessian, (double)l1RegParam, (double)l2RegParam);
                }
                gradient.clear();
                continue;
            }
            finally {
                gradientServerRow.endWrite();
            }
        }
    }
}

