/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.multilabel.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.NoopNormalizer;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.multilabel.sgd.MultiLabelObjective;

public final class Hinge
implements MultiLabelObjective {
    @Config(description="The classification margin.")
    private double margin = 1.0;

    public Hinge(double margin) {
        this.margin = margin;
    }

    public Hinge() {
        this(1.0);
    }

    public Pair<Double, SGDVector> lossAndGradient(SGDVector truth, SGDVector prediction) {
        DenseVector labels = truth instanceof SparseVector ? ((SparseVector)truth).densify() : (DenseVector)truth;
        DenseVector densePred = prediction instanceof SparseVector ? ((SparseVector)prediction).densify() : (DenseVector)prediction;
        double loss = 0.0;
        for (int i = 0; i < labels.size(); ++i) {
            double pred;
            double lbl = labels.get(i) == 0.0 ? -1.0 : 1.0;
            double score = lbl * (pred = densePred.get(i));
            if (score < this.margin) {
                densePred.set(i, lbl);
            } else {
                densePred.set(i, 0.0);
            }
            loss += Math.max(0.0, this.margin - score);
        }
        return new Pair((Object)loss, (Object)densePred);
    }

    @Override
    public VectorNormalizer getNormalizer() {
        return new NoopNormalizer();
    }

    @Override
    public boolean isProbabilistic() {
        return false;
    }

    @Override
    public double threshold() {
        return 0.0;
    }

    public String toString() {
        return "MultiLabelHinge(margin=" + this.margin + ")";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "MultiLabelObjective");
    }
}

