/*
 * Decompiled with CFR 0.152.
 */
package net.sf.tweety.machinelearning;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collection;
import java.util.HashSet;
import java.util.StringTokenizer;
import libsvm.svm_node;
import libsvm.svm_problem;
import net.sf.tweety.commons.util.Pair;
import net.sf.tweety.machinelearning.Category;
import net.sf.tweety.machinelearning.DefaultObservation;
import net.sf.tweety.machinelearning.DoubleCategory;
import net.sf.tweety.machinelearning.Observation;

public class TrainingSet<S extends Observation, T extends Category>
extends HashSet<Pair<S, T>> {
    private static final long serialVersionUID = 6814079760992723045L;

    public boolean add(S obs, T cat) {
        return this.add(new Pair(obs, cat));
    }

    public Collection<T> getCategories() {
        HashSet<Category> cats = new HashSet<Category>();
        for (Pair entry : this) {
            cats.add((Category)entry.getSecond());
        }
        return cats;
    }

    public TrainingSet<S, T> getObservations(T cat) {
        TrainingSet<S, T> result = new TrainingSet<S, T>();
        for (Pair entry : this) {
            if (!((Category)entry.getSecond()).equals(cat)) continue;
            result.add(entry);
        }
        return result;
    }

    public svm_problem toLibsvmProblem() {
        svm_problem problem = new svm_problem();
        problem.l = this.size();
        problem.y = new double[problem.l];
        problem.x = new svm_node[problem.l][];
        int idx = 0;
        for (Pair entry : this) {
            problem.y[idx] = ((Category)entry.getSecond()).asDouble();
            problem.x[idx] = ((Observation)entry.getFirst()).toSvmNode();
            ++idx;
        }
        return problem;
    }

    public static TrainingSet<DefaultObservation, DoubleCategory> loadLibsvmTrainingFile(File file) throws NumberFormatException, IOException {
        String line;
        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        TrainingSet<DefaultObservation, DoubleCategory> set = new TrainingSet<DefaultObservation, DoubleCategory>();
        while ((line = br.readLine()) != null) {
            StringTokenizer tokens = new StringTokenizer(line, " ");
            DoubleCategory cat = new DoubleCategory(Double.parseDouble(tokens.nextToken()));
            DefaultObservation obs = new DefaultObservation();
            while (tokens.hasMoreElements()) {
                StringTokenizer tokens2 = new StringTokenizer(tokens.nextToken(), ":");
                tokens2.nextToken();
                obs.add(Double.parseDouble(tokens2.nextToken()));
            }
            set.add(obs, cat);
        }
        br.close();
        return set;
    }
}

