/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.psf.optimizer;

import com.tencent.angel.ml.math2.ufuncs.Ufuncs;
import com.tencent.angel.ml.math2.vector.Vector;
import com.tencent.angel.ml.psf.optimizer.OptMMUpdateFunc;
import com.tencent.angel.ps.storage.partition.RowBasedPartition;
import com.tencent.angel.ps.storage.vector.ServerRow;
import com.tencent.angel.ps.storage.vector.ServerRowUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class PGDUpdateFunc
extends OptMMUpdateFunc {
    private static final Log LOG = LogFactory.getLog(PGDUpdateFunc.class);

    public PGDUpdateFunc() {
    }

    public PGDUpdateFunc(int matId, int factor, double lr, double l1RegParam, double l2RegParam) {
        this(matId, factor, lr, l1RegParam, l2RegParam, 1);
    }

    public PGDUpdateFunc(int matId, int factor, double lr, double l1RegParam, double l2RegParam, int batchSize) {
        super(matId, new int[]{factor}, new double[]{lr, l1RegParam, l2RegParam, batchSize});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void update(RowBasedPartition partition, int factor, double[] scalars) {
        double lr = scalars[0];
        double l1RegParam = scalars[1];
        double l2RegParam = scalars[2];
        double batchSize = (int)scalars[3];
        for (int f = 0; f < factor; ++f) {
            ServerRow gradientServerRow = partition.getRow(f + factor);
            try {
                gradientServerRow.startWrite();
                Vector weight = ServerRowUtils.getVector((ServerRow)partition.getRow(f));
                Vector gradient = ServerRowUtils.getVector((ServerRow)gradientServerRow);
                if (batchSize > 1.0) {
                    gradient.idiv(batchSize);
                }
                double lrTemp = lr / (1.0 + l2RegParam * lr);
                if (l2RegParam != 0.0) {
                    weight.imul(1.0 - lrTemp * l2RegParam).iaxpy(gradient, -lrTemp);
                } else {
                    weight.iaxpy(gradient, -lrTemp);
                }
                if (l1RegParam != 0.0) {
                    Ufuncs.isoftthreshold((Vector)weight, (double)(lrTemp * l1RegParam));
                }
                gradient.clear();
                continue;
            }
            finally {
                gradientServerRow.endWrite();
            }
        }
    }
}

