/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.ILabeledAttributeArrayInstance;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledAttributeArrayDataset;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.AttributeDiscretizationPolicy;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.ListProcessor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.collections4.keyvalue.MultiKey;
import org.apache.commons.collections4.map.MultiKeyMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AttributeBasedStratiAmountSelectorAndAssigner<I extends ILabeledAttributeArrayInstance<?>, D extends IOrderedLabeledAttributeArrayDataset<I, ?>>
implements IStratiAmountSelector<D>,
IStratiAssigner<I, D> {
    private static final Logger LOG = LoggerFactory.getLogger(AttributeBasedStratiAmountSelectorAndAssigner.class);
    private static final DiscretizationHelper.DiscretizationStrategy DEFAULT_DISCRETIZATION_STRATEGY = DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE;
    private static final int DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT = 5;
    private List<Integer> attributeIndices;
    private MultiKeyMap<Object, Integer> stratumAssignments;
    private int numCPUs = 1;
    private D dataset;
    private Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies;
    private Map<Integer, Set<Object>> attributeValues;
    private DiscretizationHelper.DiscretizationStrategy discretizationStrategy;
    private int numberOfCategories;

    public AttributeBasedStratiAmountSelectorAndAssigner() {
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = 5;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices) {
        this(attributeIndices, null);
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = 5;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices, DiscretizationHelper.DiscretizationStrategy discretizationStrategy, int numberOfCategories) {
        this(attributeIndices, null);
        this.discretizationStrategy = discretizationStrategy;
        this.numberOfCategories = numberOfCategories;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices, Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies) {
        if (attributeIndices == null || attributeIndices.isEmpty()) {
            throw new IllegalArgumentException("No attribute indices are provided!");
        }
        this.attributeIndices = attributeIndices;
        this.discretizationPolicies = discretizationPolicies;
    }

    @Override
    public int selectStratiAmount(D dataset) {
        this.dataset = dataset;
        this.computeAttributeValues();
        int noStrati = 1;
        for (Set<Object> values : this.attributeValues.values()) {
            noStrati *= values.size();
        }
        if (LOG.isInfoEnabled()) {
            LOG.info(String.format("%d strati are needed", noStrati));
        }
        return noStrati;
    }

    private void computeAttributeValues() {
        LOG.info("computeAttributeValues(): enter");
        if (this.attributeIndices == null || this.attributeIndices.isEmpty()) {
            int targetIndex = this.dataset.getNumberOfAttributes();
            if (LOG.isInfoEnabled()) {
                LOG.info(String.format("No attribute indices provided. Working with target attribute only (index: %d", targetIndex));
            }
            this.attributeIndices = Collections.singletonList(targetIndex);
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Computing attribute values for attribute indices {}", this.attributeIndices);
        }
        for (int attributeIndex : this.attributeIndices) {
            if (attributeIndex <= this.dataset.getNumberOfAttributes()) continue;
            throw new IndexOutOfBoundsException(String.format("Attribute index %d is out of bounds for the delivered data set!", attributeIndex));
        }
        this.attributeValues = new HashMap<Integer, Set<Object>>();
        for (int attributeIndex : this.attributeIndices) {
            this.attributeValues.put(attributeIndex, new HashSet());
        }
        ExecutorService threadPool = Executors.newFixedThreadPool(this.numCPUs);
        ArrayList<Future<Map<Integer, Set<Object>>>> futures = new ArrayList<Future<Map<Integer, Set<Object>>>>();
        if (LOG.isInfoEnabled()) {
            LOG.info(String.format("Starting %d threads for computation..", this.numCPUs));
        }
        int listSize = this.dataset.size() / this.numCPUs;
        for (List list : Lists.partition(this.dataset, (int)listSize)) {
            futures.add(threadPool.submit(new ListProcessor(list, new HashSet<Integer>(this.attributeIndices), this.dataset)));
        }
        for (Future future : futures) {
            try {
                Map localAttributeValues = (Map)future.get();
                for (Map.Entry<Integer, Set<Object>> entry : this.attributeValues.entrySet()) {
                    this.attributeValues.get(entry.getKey()).addAll((Collection)localAttributeValues.get(entry.getKey()));
                }
            }
            catch (ExecutionException e) {
                LOG.error("Exception while waiting for future to complete..", (Throwable)e);
            }
            catch (InterruptedException e) {
                LOG.error("Thread has been interrupted");
                Thread.currentThread().interrupt();
            }
        }
        threadPool.shutdown();
        DiscretizationHelper<D> discretizationHelper = new DiscretizationHelper<D>();
        if (this.discretizationPolicies == null) {
            LOG.info("No discretization policies provided. Computing defaults..");
            this.discretizationPolicies = discretizationHelper.createDefaultDiscretizationPolicies(this.dataset, this.attributeIndices, this.attributeValues, this.discretizationStrategy, this.numberOfCategories);
        }
        if (!this.discretizationPolicies.isEmpty()) {
            if (LOG.isInfoEnabled()) {
                LOG.info("Discretizing numeric attributes using policies: {}", this.discretizationPolicies);
            }
            discretizationHelper.discretizeAttributeValues(this.discretizationPolicies, this.attributeValues);
        }
        LOG.info("computeAttributeValues(): leave");
    }

    @Override
    public void setNumCPUs(int numberOfCPUs) {
        if (numberOfCPUs < 1) {
            throw new IllegalArgumentException("Number of CPU cores must be nonnegative");
        }
        this.numCPUs = numberOfCPUs;
    }

    @Override
    public int getNumCPUs() {
        return this.numCPUs;
    }

    @Override
    public void init(D dataset, int stratiAmount) {
        this.init(dataset);
    }

    public void init(D dataset) {
        LOG.debug("init(): enter");
        if (this.dataset != null && this.dataset.equals(dataset) && this.attributeValues != null) {
            LOG.info("No recomputation of the attribute values needed");
        } else {
            this.dataset = dataset;
            this.computeAttributeValues();
        }
        ArrayList<Set<Object>> sets = new ArrayList<Set<Object>>(this.attributeValues.values());
        Set cartesianProducts = Sets.cartesianProduct(sets);
        this.stratumAssignments = new MultiKeyMap();
        LOG.info("There are {} elements in the cartesian product of the attribute values", (Object)cartesianProducts.size());
        LOG.info("Assigning stratum numbers to elements in the cartesian product..");
        int stratumCounter = 0;
        for (List cartesianProduct : cartesianProducts) {
            Object[] arr = new Object[cartesianProduct.size()];
            cartesianProduct.toArray(arr);
            MultiKey multiKey = new MultiKey(arr);
            if (this.stratumAssignments.containsKey((Object)multiKey)) {
                throw new IllegalStateException(String.format("Mulitkey %s occured twice!", multiKey.toString()));
            }
            this.stratumAssignments.put(new MultiKey(arr), (Object)stratumCounter++);
        }
        LOG.debug("init(): leave");
    }

    @Override
    public int assignToStrati(I datapoint) {
        MultiKey multiKey;
        if (this.stratumAssignments == null || this.stratumAssignments.isEmpty()) {
            throw new IllegalStateException("StratiAssigner has not been initialized!");
        }
        Object[] instanceAttributeValues = new Object[this.attributeIndices.size()];
        DiscretizationHelper discretizationHelper = new DiscretizationHelper();
        for (int i = 0; i < this.attributeIndices.size(); ++i) {
            Object value;
            int attributeIndex = this.attributeIndices.get(i);
            if (this.toBeDiscretized(attributeIndex)) {
                Object raw = attributeIndex == this.dataset.getNumberOfAttributes() ? datapoint.getTargetValue() : datapoint.getAttributeValueAtPosition(attributeIndex, Object.class).getValue();
                value = discretizationHelper.discretize((Double)raw, this.discretizationPolicies.get(attributeIndex));
            } else {
                value = attributeIndex == this.dataset.getNumberOfAttributes() ? datapoint.getTargetValue() : datapoint.getAttributeValueAtPosition(attributeIndex, Object.class).getValue();
            }
            instanceAttributeValues[i] = value;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("Attribute values are: %s", Arrays.toString(instanceAttributeValues)));
        }
        if (!this.stratumAssignments.containsKey((Object)(multiKey = new MultiKey(instanceAttributeValues)))) {
            throw new IllegalStateException(String.format("No assignment available for attribute combination %s", Arrays.toString(instanceAttributeValues)));
        }
        int stratum = (Integer)this.stratumAssignments.get((Object)multiKey);
        LOG.debug("Assigned stratum {}", (Object)stratum);
        return stratum;
    }

    private boolean toBeDiscretized(int index) {
        return this.discretizationPolicies.containsKey(index);
    }
}

