/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.attribute.IAttributeType;
import ai.libs.jaicore.ml.core.dataset.attribute.primitive.NumericAttributeType;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.AttributeDiscretizationPolicy;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.Interval;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DiscretizationHelper<I extends IInstance> {
    private static final Logger LOG = LoggerFactory.getLogger(DiscretizationHelper.class);

    public Map<Integer, AttributeDiscretizationPolicy> createDefaultDiscretizationPolicies(IDataset<I> dataset, List<Integer> indices, Map<Integer, Set<Object>> attributeValues, DiscretizationStrategy discretizationStrategy, int numberOfCategories) {
        HashMap<Integer, AttributeDiscretizationPolicy> discretizationPolicies = new HashMap<Integer, AttributeDiscretizationPolicy>();
        Set<Integer> indicesToConsider = this.getNumericIndicesFromDataset(dataset);
        indicesToConsider.retainAll(indices);
        block4: for (int index : indicesToConsider) {
            List<Double> numericValues = this.getSortedNumericValues(attributeValues, index);
            if (numericValues.size() <= numberOfCategories) {
                LOG.info("No discretization policy for attribute {} needed", (Object)index);
                continue;
            }
            switch (discretizationStrategy) {
                case EQUAL_SIZE: {
                    discretizationPolicies.put(index, this.equalSizePolicy(numericValues, numberOfCategories));
                    continue block4;
                }
                case EQUAL_LENGTH: {
                    discretizationPolicies.put(index, this.equalLengthPolicy(numericValues, numberOfCategories));
                    continue block4;
                }
            }
            throw new IllegalArgumentException(String.format("Invalid strategy: %s", new Object[]{discretizationStrategy}));
        }
        return discretizationPolicies;
    }

    public AttributeDiscretizationPolicy equalSizePolicy(List<Double> numericValues, int numberOfCategories) {
        if (numericValues.isEmpty()) {
            throw new IllegalArgumentException("No values provided");
        }
        ArrayList<Interval> intervals = new ArrayList<Interval>();
        int stepwidth = numericValues.size() / numberOfCategories;
        int limit = Math.min(numberOfCategories, numericValues.size());
        for (int i = 0; i < limit; ++i) {
            int lower = i * stepwidth;
            int upper = i == limit - 1 ? numericValues.size() - 1 : (i + 1) * stepwidth - 1;
            intervals.add(new Interval(numericValues.get(lower), numericValues.get(upper)));
        }
        return new AttributeDiscretizationPolicy(intervals);
    }

    public AttributeDiscretizationPolicy equalLengthPolicy(List<Double> numericValues, int numberOfCategories) {
        ArrayList<Interval> intervals = new ArrayList<Interval>();
        double max = Collections.max(numericValues);
        double min = Collections.min(numericValues);
        double stepwidth = Math.abs(max - min) / (double)numberOfCategories;
        for (int i = 0; i < numberOfCategories; ++i) {
            double lower = min + (double)i * stepwidth;
            double upper = min + (double)(i + 1) * stepwidth;
            intervals.add(new Interval(lower, upper));
        }
        return new AttributeDiscretizationPolicy(intervals);
    }

    private List<Double> getSortedNumericValues(Map<Integer, Set<Object>> attributeValues, int attributeIndex) {
        Set<Object> values = attributeValues.get(attributeIndex);
        ArrayList<Double> toReturn = new ArrayList<Double>();
        values.forEach(v -> toReturn.add((Double)v));
        Collections.sort(toReturn);
        return toReturn;
    }

    private Set<Integer> getNumericIndicesFromDataset(IDataset<I> dataset) {
        HashSet<Integer> numericAttributes = new HashSet<Integer>();
        ArrayList attributeTypes = new ArrayList(dataset.getAttributeTypes());
        attributeTypes.add(dataset.getTargetType());
        for (int i = 0; i < attributeTypes.size(); ++i) {
            IAttributeType attributeType = (IAttributeType)attributeTypes.get(i);
            if (!(attributeType instanceof NumericAttributeType)) continue;
            numericAttributes.add(i);
        }
        return numericAttributes;
    }

    protected void discretizeAttributeValues(Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies, Map<Integer, Set<Object>> attributeValues) {
        Set<Integer> numericIndices = discretizationPolicies.keySet();
        for (int index : numericIndices) {
            Set<Object> originalValues = attributeValues.get(index);
            HashSet<Integer> discretizedValues = new HashSet<Integer>();
            for (Object value : originalValues) {
                double d = (Double)value;
                discretizedValues.add(this.discretize(d, discretizationPolicies.get(index)));
            }
            LOG.info("Attribute index {}: Reduced values from {} to {}", new Object[]{index, originalValues.size(), discretizedValues.size()});
            attributeValues.put(index, discretizedValues);
        }
    }

    protected int discretize(double value, AttributeDiscretizationPolicy policy) {
        List<Interval> intervals = policy.getIntervals();
        for (Interval i : intervals) {
            if (!i.contains(value)) continue;
            return intervals.indexOf(i);
        }
        throw new IllegalStateException(String.format("Policy does not cover value %f", value));
    }

    public static enum DiscretizationStrategy {
        EQUAL_LENGTH,
        EQUAL_SIZE;

    }
}

