/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.optimizer;

import com.tencent.angel.ml.math2.ufuncs.OptFuncs;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.psf.optimizer.OptMMUpdateFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class AdamUpdateFunc
extends OptMMUpdateFunc {
    private static final Log LOG = LogFactory.getLog(AdamUpdateFunc.class);

    public AdamUpdateFunc() {
    }

    public AdamUpdateFunc(int matId, int factor, double gamma, double epsilon, double beta, double lr, double regParam, int iteration) {
        this(matId, factor, gamma, epsilon, beta, lr, regParam, iteration, 1);
    }

    public AdamUpdateFunc(int matId, int factor, double gamma, double epsilon, double beta, double lr, double regParam, int iteration, int batchSize) {
        super(matId, new int[]{factor}, new double[]{gamma, epsilon, beta, lr, regParam, iteration, batchSize});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(RowBasedPartition partition, int factor, double[] scalars) {
        double gamma = scalars[0];
        double epsilon = scalars[1];
        double beta = scalars[2];
        double lr = scalars[3];
        double regParam = scalars[4];
        double epoch = scalars[5];
        double batchSize = scalars[6];
        if (epoch == 0.0) {
            epoch = 1.0;
        }
        double powBeta = Math.pow(beta, epoch);
        double powGamma = Math.pow(gamma, epoch);
        for (int f = 0; f < factor; ++f) {
            ServerRow gradientServerRow = partition.getRow(f + 3 * factor);
            try {
                gradientServerRow.startWrite();
                Vector weight = ServerRowUtils.getVector((ServerRow)partition.getRow(f));
                Vector velocity = ServerRowUtils.getVector((ServerRow)partition.getRow(f + factor));
                Vector square = ServerRowUtils.getVector((ServerRow)partition.getRow(f + 2 * factor));
                Vector gradient = ServerRowUtils.getVector((ServerRow)gradientServerRow);
                if (batchSize > 1.0) {
                    gradient.idiv(batchSize);
                }
                if (regParam != 0.0) {
                    gradient.iaxpy(weight, regParam);
                }
                OptFuncs.iexpsmoothing((Vector)velocity, (Vector)gradient, (double)beta);
                OptFuncs.iexpsmoothing2((Vector)square, (Vector)gradient, (double)gamma);
                Vector delta = OptFuncs.adamdelta((Vector)velocity, (Vector)square, (double)powBeta, (double)powGamma);
                weight.iaxpy(delta, -lr);
                gradient.clear();
                continue;
            }
            finally {
                gradientServerRow.endWrite();
            }
        }
    }
}

