/*
 * Decompiled with CFR 0.152.
 */
package net.sf.tweety.machinelearning;

import net.sf.tweety.machinelearning.Category;
import net.sf.tweety.machinelearning.ClassificationTester;
import net.sf.tweety.machinelearning.Observation;
import net.sf.tweety.machinelearning.ParameterSet;
import net.sf.tweety.machinelearning.ParameterTrainer;
import net.sf.tweety.machinelearning.Trainer;
import net.sf.tweety.machinelearning.TrainingParameter;
import net.sf.tweety.machinelearning.TrainingSet;

public class GridSearchParameterLearner<S extends Observation, T extends Category>
extends ParameterTrainer<S, T> {
    private int depth;
    private int partitions;
    private ClassificationTester<S, T> tester;

    public GridSearchParameterLearner(Trainer<S, T> trainer, ClassificationTester<S, T> tester, int depth, int partitions) {
        super(trainer);
        this.tester = tester;
        this.depth = depth;
        this.partitions = partitions;
    }

    @Override
    public ParameterSet learnParameters(TrainingSet<S, T> trainingSet) {
        Trainer trainer = this.getTrainer();
        ParameterSet set = trainer.getParameterSet();
        int[] indices = new int[set.size()];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = 0;
        }
        double[] lowerBounds = new double[set.size()];
        double[] upperBounds = new double[set.size()];
        int idx = 0;
        for (TrainingParameter param : set) {
            lowerBounds[idx] = param.getLowerBound();
            upperBounds[idx++] = param.getUpperBound();
        }
        double maxPerformance = 0.0;
        int[] bestIdxs = new int[set.size()];
        for (int i = 0; i < this.depth; ++i) {
            do {
                trainer.setParameterSet(this.adjustParameterSet(set, indices, lowerBounds, upperBounds));
                double performance = this.tester.test(trainer, trainingSet);
                if (!(performance > maxPerformance)) continue;
                maxPerformance = performance;
                System.arraycopy(indices, 0, bestIdxs, 0, set.size());
            } while (!this.increment(indices, this.partitions));
            if (i + 1 == this.depth) continue;
            for (int j = 0; j < set.size(); ++j) {
                lowerBounds[j] = lowerBounds[j] + (upperBounds[j] - lowerBounds[j]) / (double)this.partitions * (double)Math.max(indices[j] - 1, 0);
                upperBounds[j] = lowerBounds[j] + (upperBounds[j] - lowerBounds[j]) / (double)this.partitions * (double)Math.min(indices[j] + 1, this.partitions);
                indices[j] = 0;
            }
            maxPerformance = 0.0;
        }
        return this.adjustParameterSet(set, bestIdxs, lowerBounds, upperBounds);
    }

    private ParameterSet adjustParameterSet(ParameterSet set, int[] indices, double[] lowerBounds, double[] upperBounds) {
        ParameterSet newParams = new ParameterSet();
        int idx = 0;
        for (TrainingParameter param : set) {
            newParams.add(param.instantiate(lowerBounds[idx] + (upperBounds[idx] - lowerBounds[idx]) / (double)this.partitions * (double)indices[idx]));
            ++idx;
        }
        return newParams;
    }

    private boolean increment(int[] indices, int maxIdx) {
        boolean carry = false;
        for (int i = 0; i < indices.length; ++i) {
            if (indices[i] < maxIdx) {
                int n = i;
                indices[n] = indices[n] + 1;
                return false;
            }
            indices[i] = 0;
            carry = true;
        }
        return carry;
    }
}

