/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.multilabel.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.SigmoidNormalizer;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.sgd.MultiLabelObjective;

public final class BinaryCrossEntropy
implements MultiLabelObjective {
    private static final VectorNormalizer normalizer = new SigmoidNormalizer();

    public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
        DenseVector labels = truth instanceof SparseVector ? ((SparseVector)truth).densify() : (DenseVector)truth;
        DenseVector densePred = prediction instanceof SparseVector ? ((SparseVector)prediction).densify() : (DenseVector)prediction;
        double loss = 0.0;
        for (int i = 0; i < prediction.size(); ++i) {
            double label = labels.get(i);
            double pred = densePred.get(i);
            double yhat = SigmoidNormalizer.sigmoid((double)pred);
            loss += Math.max(pred, 0.0) - pred * label + Math.log1p(Math.exp(-Math.abs(pred)));
            densePred.set(i, -(yhat - label));
        }
        return new Pair((Object)loss, (Object)densePred);
    }

    @Override
    public VectorNormalizer getNormalizer() {
        return normalizer;
    }

    @Override
    public boolean isProbabilistic() {
        return true;
    }

    @Override
    public double threshold() {
        return 0.5;
    }

    public String toString() {
        return "BinaryCrossEntropy";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "MultiLabelObjective");
    }
}

