/*
 * Decompiled with CFR 0.152.
 */
package net.sf.tweety.machinelearning.svm;

import libsvm.svm;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import net.sf.tweety.machinelearning.DefaultObservation;
import net.sf.tweety.machinelearning.DoubleCategory;
import net.sf.tweety.machinelearning.ParameterSet;
import net.sf.tweety.machinelearning.Trainer;
import net.sf.tweety.machinelearning.TrainingParameter;
import net.sf.tweety.machinelearning.TrainingSet;
import net.sf.tweety.machinelearning.svm.SupportVectorMachine;

public class MultiClassRbfTrainer
implements Trainer<DefaultObservation, DoubleCategory> {
    public static final TrainingParameter C_PARAMETER = new TrainingParameter("C", 2000.0, 2000.0, 2.0E-5, 2.0E15);
    public static final TrainingParameter GAMMA_PARAMETER = new TrainingParameter("gamma", 2.0E-8, 2.0E-8, 2.0E-15, 2000.0);
    private TrainingParameter c;
    private TrainingParameter gamma;

    public MultiClassRbfTrainer() {
        this.c = C_PARAMETER.instantiateWithDefaultValue();
        this.gamma = GAMMA_PARAMETER.instantiateWithDefaultValue();
    }

    public MultiClassRbfTrainer(double c, double gamma) {
        this.c = C_PARAMETER.instantiate(c);
        this.gamma = GAMMA_PARAMETER.instantiate(gamma);
    }

    public SupportVectorMachine train(TrainingSet<DefaultObservation, DoubleCategory> trainingSet) {
        ParameterSet set = new ParameterSet();
        set.add(this.c);
        set.add(this.gamma);
        return this.train((TrainingSet)trainingSet, set);
    }

    public SupportVectorMachine train(TrainingSet<DefaultObservation, DoubleCategory> trainingSet, ParameterSet params) {
        if (!params.containsParameter(C_PARAMETER) || !params.containsParameter(GAMMA_PARAMETER)) {
            throw new IllegalArgumentException("Parameters missing.");
        }
        svm_parameter param = new svm_parameter();
        param.svm_type = 0;
        param.kernel_type = 2;
        param.eps = 0.001;
        param.cache_size = 256.0;
        param.nr_weight = 0;
        param.gamma = params.getParameter(GAMMA_PARAMETER).getValue();
        param.C = params.getParameter(C_PARAMETER).getValue();
        return new SupportVectorMachine(svm.svm_train((svm_problem)trainingSet.toLibsvmProblem(), (svm_parameter)param));
    }

    @Override
    public ParameterSet getParameterSet() {
        ParameterSet set = new ParameterSet();
        set.add(this.c);
        set.add(this.gamma);
        return set;
    }

    @Override
    public boolean setParameterSet(ParameterSet params) {
        if (!params.containsParameter(C_PARAMETER) || !params.containsParameter(GAMMA_PARAMETER)) {
            throw new IllegalArgumentException("Parameters missing.");
        }
        this.c = params.getParameter(C_PARAMETER);
        this.gamma = params.getParameter(GAMMA_PARAMETER);
        return true;
    }

    public String toString() {
        return "RBF<" + this.c.getValue() + "," + this.gamma.getValue() + ">";
    }
}

