/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.hasco.knowledgebase;

import ai.libs.hasco.core.Util;
import ai.libs.hasco.knowledgebase.ExtractionOfImportantParametersFailedException;
import ai.libs.hasco.knowledgebase.IParameterImportanceEstimator;
import ai.libs.hasco.knowledgebase.PerformanceKnowledgeBase;
import ai.libs.hasco.model.Component;
import ai.libs.hasco.model.ComponentInstance;
import ai.libs.jaicore.ml.core.FeatureDomain;
import ai.libs.jaicore.ml.core.FeatureSpace;
import ai.libs.jaicore.ml.intervaltree.ExtendedRandomForest;
import com.google.common.collect.Sets;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instances;

public class FANOVAParameterImportanceEstimator
implements IParameterImportanceEstimator {
    private static final Logger LOGGER = LoggerFactory.getLogger(FANOVAParameterImportanceEstimator.class);
    private PerformanceKnowledgeBase performanceKnowledgeBase;
    private String benchmarkName;
    private Map<String, HashMap<Set<Integer>, Double>> importanceDictionary;
    private Map<String, Set<String>> importantParameterMap;
    private int minNumSamples;
    private double importanceThreshold;
    private int sizeOfLargestSubsetToConsider;
    private Set<String> prunedParameters;

    public FANOVAParameterImportanceEstimator(PerformanceKnowledgeBase performanceKnowledgeBase, String benchmarkName, int minNumSamples, double importanceThreshold) {
        this.performanceKnowledgeBase = performanceKnowledgeBase;
        this.benchmarkName = benchmarkName;
        this.importanceDictionary = new HashMap<String, HashMap<Set<Integer>, Double>>();
        this.importantParameterMap = new HashMap<String, Set<String>>();
        this.minNumSamples = minNumSamples;
        this.importanceThreshold = importanceThreshold;
        this.sizeOfLargestSubsetToConsider = 2;
        this.prunedParameters = new HashSet<String>();
    }

    public FANOVAParameterImportanceEstimator(String benchmarkName, int minNumSamples, double importanceThreshold) {
        this(null, benchmarkName, minNumSamples, importanceThreshold);
    }

    @Override
    public Set<String> extractImportantParameters(ComponentInstance composition, boolean recompute) throws ExtractionOfImportantParametersFailedException {
        String pipelineIdentifier = Util.getComponentNamesOfComposition(composition);
        if (this.importantParameterMap.containsKey(pipelineIdentifier)) {
            return this.importantParameterMap.get(pipelineIdentifier);
        }
        Instances data = this.performanceKnowledgeBase.getPerformanceSamples(this.benchmarkName, composition);
        FeatureSpace space = new FeatureSpace(data);
        HashSet<String> importantParameters = new HashSet<String>();
        if (space.getDimensionality() < 2) {
            for (FeatureDomain domain : space.getFeatureDomains()) {
                importantParameters.add(domain.getName());
            }
            return importantParameters;
        }
        for (FeatureDomain domain : space.getFeatureDomains()) {
            this.prunedParameters.add(domain.getName());
        }
        ExtendedRandomForest forest = new ExtendedRandomForest();
        try {
            forest.buildClassifier(data);
            forest.prepareForest(data);
        }
        catch (Exception e) {
            throw new ExtractionOfImportantParametersFailedException("Could not build model", e);
        }
        if (!this.importanceDictionary.containsKey(pipelineIdentifier)) {
            this.importanceDictionary.put(pipelineIdentifier, new HashMap());
        }
        HashSet<Integer> parameterIndices = new HashSet<Integer>();
        for (int i = 0; i < data.numAttributes() - 1; ++i) {
            parameterIndices.add(i);
        }
        for (int k = 1; k <= this.sizeOfLargestSubsetToConsider; ++k) {
            Set currentSubsets = Sets.combinations(parameterIndices, (int)k);
            for (Set subset : currentSubsets) {
                double currentImportance;
                if (recompute) {
                    currentImportance = forest.computeMarginalVarianceContributionForFeatureSubset(subset);
                    this.importanceDictionary.get(pipelineIdentifier).put(subset, currentImportance);
                } else if (this.importanceDictionary.get(pipelineIdentifier).containsKey(subset)) {
                    LOGGER.debug("Taking value from dictionary");
                    currentImportance = this.importanceDictionary.get(pipelineIdentifier).get(subset);
                } else {
                    currentImportance = forest.computeMarginalVarianceContributionForFeatureSubset(subset);
                    this.importanceDictionary.get(pipelineIdentifier).put(subset, currentImportance);
                    if (Double.isNaN(currentImportance)) {
                        currentImportance = 1.0;
                        LOGGER.debug("importance value is NaN, so it will be set to 1");
                    }
                }
                LOGGER.debug("Importance value for parameter subset {}: {}", (Object)subset, (Object)currentImportance);
                LOGGER.debug("Importance value {} >= {}: ", new Object[]{currentImportance, this.importanceThreshold, currentImportance >= this.importanceThreshold});
                if (!(currentImportance >= this.importanceThreshold)) continue;
                Iterator iterator = subset.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    importantParameters.add(forest.getFeatureSpace().getFeatureDomain(i).getName());
                }
            }
        }
        this.importantParameterMap.put(pipelineIdentifier, importantParameters);
        this.prunedParameters.removeAll(importantParameters);
        return importantParameters;
    }

    @Override
    public Map<String, Double> computeImportanceForSingleComponent(Component component) {
        Instances data = this.performanceKnowledgeBase.getPerformanceSamplesForIndividualComponent(this.benchmarkName, component);
        if (data == null) {
            return null;
        }
        ExtendedRandomForest forest = new ExtendedRandomForest();
        HashMap<String, Double> result = new HashMap<String, Double>();
        try {
            forest.buildClassifier(data);
            for (int i = 0; i < data.numAttributes() - 1; ++i) {
                HashSet<Integer> set = new HashSet<Integer>();
                set.add(i);
                double importance = forest.computeMarginalVarianceContributionForFeatureSubset(set);
                result.put(data.attribute(i).name(), importance);
            }
        }
        catch (Exception e) {
            LOGGER.error("Could not build model and compute marginal variance contribution.", (Throwable)e);
        }
        return result;
    }

    @Override
    public boolean readyToEstimateImportance(ComponentInstance composition) {
        return this.performanceKnowledgeBase.kDistinctAttributeValuesAvailable(this.benchmarkName, composition, this.minNumSamples);
    }

    @Override
    public PerformanceKnowledgeBase getPerformanceKnowledgeBase() {
        return this.performanceKnowledgeBase;
    }

    @Override
    public void setPerformanceKnowledgeBase(PerformanceKnowledgeBase performanceKnowledgeBase) {
        this.performanceKnowledgeBase = performanceKnowledgeBase;
    }

    @Override
    public int getNumberPrunedParameters() {
        return this.prunedParameters.size();
    }

    @Override
    public Set<String> getPrunedParameters() {
        return this.prunedParameters;
    }
}

