/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.optimizer;

import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.psf.optimizer.OptMMUpdateFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MomentumUpdateFunc
extends OptMMUpdateFunc {
    private static final Log LOG = LogFactory.getLog(MomentumUpdateFunc.class);

    public MomentumUpdateFunc() {
    }

    public MomentumUpdateFunc(int matId, int factor, double momentum, double lr) {
        this(matId, factor, momentum, lr, 0.0, 1);
    }

    public MomentumUpdateFunc(int matId, int offset, double momentum, double lr, double regParam) {
        this(matId, offset, momentum, lr, regParam, 1);
    }

    public MomentumUpdateFunc(int matId, int offset, double momentum, double lr, double regParam, int batchSize) {
        super(matId, new int[]{offset}, new double[]{momentum, lr, regParam, batchSize});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(RowBasedPartition partition, int factor, double[] scalars) {
        double momentum = scalars[0];
        double lr = scalars[1];
        double regParam = scalars[2];
        double batchSize = scalars[3];
        for (int f = 0; f < factor; ++f) {
            ServerRow gradientServerRow = partition.getRow(f + 2 * factor);
            try {
                gradientServerRow.startWrite();
                Vector weight = ServerRowUtils.getVector((ServerRow)partition.getRow(f));
                Vector velocity = ServerRowUtils.getVector((ServerRow)partition.getRow(f + factor));
                Vector gradient = ServerRowUtils.getVector((ServerRow)gradientServerRow);
                if (batchSize > 1.0) {
                    gradient.idiv(batchSize);
                }
                if (regParam != 0.0) {
                    gradient.iaxpy(weight, regParam);
                }
                velocity.imul(momentum).iadd(gradient);
                weight.isub(velocity.mul(lr));
                gradient.clear();
                continue;
            }
            finally {
                gradientServerRow.endWrite();
            }
        }
    }
}

