/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.schema.DatasetPropertyComputer;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.AttributeDiscretizationPolicy;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAmountSelector;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAssigner;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AttributeBasedStratiAmountSelectorAndAssigner
implements IStratiAmountSelector,
IStratiAssigner,
ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(AttributeBasedStratiAmountSelectorAndAssigner.class);
    private static final DiscretizationHelper.DiscretizationStrategy DEFAULT_DISCRETIZATION_STRATEGY = DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE;
    private final DiscretizationHelper discretizationHelper = new DiscretizationHelper();
    private static final int DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT = 5;
    private List<Integer> attributeIndices;
    private Map<List<Object>, Integer> stratumIDs;
    private int numCPUs = 1;
    private IDataset<?> dataset;
    private int numAttributes;
    private Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies;
    private DiscretizationHelper.DiscretizationStrategy discretizationStrategy;
    private int numberOfCategories;
    private boolean initialized;

    public AttributeBasedStratiAmountSelectorAndAssigner() {
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = 5;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices) {
        this(attributeIndices, null);
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = 5;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices, DiscretizationHelper.DiscretizationStrategy discretizationStrategy, int numberOfCategories) {
        this(attributeIndices, null);
        this.discretizationStrategy = discretizationStrategy;
        this.numberOfCategories = numberOfCategories;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> attributeIndices, Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies) {
        if (attributeIndices == null || attributeIndices.isEmpty()) {
            throw new IllegalArgumentException("No attribute indices are provided!");
        }
        this.attributeIndices = attributeIndices;
        this.discretizationPolicies = discretizationPolicies;
        this.logger.info("Created assigner. Attributes to be discretized: {}", discretizationPolicies == null ? "none" : discretizationPolicies.keySet());
    }

    @Override
    public int selectStratiAmount(IDataset<?> dataset) {
        this.logger.debug("Selecting number of strati for dataset with {} items.", (Object)dataset.size());
        if (this.dataset == null) {
            this.init(dataset, -1);
        } else if (!this.dataset.equals(dataset)) {
            throw new IllegalArgumentException("Can only select strati amount for a dataset provided before.");
        }
        return this.stratumIDs.size();
    }

    private void discretizeAttributeValues(Map<Integer, Set<Object>> attributeValues) {
        if (this.discretizationPolicies == null) {
            this.logger.info("No discretization policies provided. Computing defaults.");
            this.discretizationPolicies = this.discretizationHelper.createDefaultDiscretizationPolicies(this.dataset, this.attributeIndices, attributeValues, this.discretizationStrategy, this.numberOfCategories);
        }
        if (!this.discretizationPolicies.isEmpty()) {
            if (this.logger.isInfoEnabled()) {
                this.logger.info("Discretizing numeric attributes using policies: {}", this.discretizationPolicies);
            }
            this.discretizationHelper.discretizeAttributeValues(this.discretizationPolicies, attributeValues);
        }
        this.logger.info("computeAttributeValues(): leave");
    }

    public void setNumCPUs(int numberOfCPUs) {
        if (numberOfCPUs < 1) {
            throw new IllegalArgumentException("Number of CPU cores must be nonnegative");
        }
        this.numCPUs = numberOfCPUs;
    }

    public int getNumCPUs() {
        return this.numCPUs;
    }

    public void init(IDataset<?> dataset) {
        this.init(dataset, -1);
    }

    @Override
    public void init(IDataset<?> dataset, int stratiAmount) {
        this.logger.debug("init(): enter");
        if (this.initialized) {
            this.logger.warn("Ignoring further initialization.");
            return;
        }
        if (dataset == null) {
            throw new IllegalArgumentException("Cannot set dataset to NULL");
        }
        this.dataset = dataset;
        this.numAttributes = dataset.getNumAttributes();
        int n = dataset.getNumAttributes();
        for (int i : this.attributeIndices) {
            if (i < 0) {
                throw new IllegalArgumentException("Attribute index for stratified splits must not be negative!");
            }
            if (i > n) {
                throw new IllegalArgumentException("Attribute index for stratified splits must not exceed number of attributes!");
            }
            if (i != n || dataset instanceof ILabeledDataset) continue;
            throw new IllegalArgumentException("Attribute index for stratified splits must only equal the number of attributes if the dataset is labeled, because then the label column id is the number of attributes!");
        }
        Map<Integer, Set<Object>> attributeValues = DatasetPropertyComputer.computeAttributeValues(dataset, this.attributeIndices, this.numCPUs);
        this.discretizeAttributeValues(attributeValues);
        ArrayList<Set<Object>> sets = new ArrayList<Set<Object>>(attributeValues.values());
        Set cartesianProduct = Sets.cartesianProduct(sets);
        this.logger.info("There are {} elements in the cartesian product of the attribute values", (Object)cartesianProduct.size());
        this.logger.info("Assigning stratum numbers to elements in the cartesian product..");
        this.stratumIDs = new HashMap<List<Object>, Integer>();
        int stratumCounter = 0;
        for (List tuple : cartesianProduct) {
            this.stratumIDs.put(tuple, stratumCounter++);
        }
        this.logger.info("Initialized strati assigner with {} strati.", (Object)this.stratumIDs.size());
        this.initialized = true;
    }

    @Override
    public int assignToStrati(IInstance datapoint) {
        if (!this.initialized) {
            throw new IllegalStateException("Assigner has not been initialized yet.");
        }
        ArrayList<Object> instanceAttributeValues = new ArrayList<Object>(this.attributeIndices.size());
        for (int i = 0; i < this.attributeIndices.size(); ++i) {
            Object value;
            int attributeIndex = this.attributeIndices.get(i);
            if (this.toBeDiscretized(attributeIndex)) {
                Object raw = attributeIndex == this.dataset.getNumAttributes() ? ((ILabeledInstance)datapoint).getLabel() : datapoint.getAttributeValue(attributeIndex);
                value = this.discretizationHelper.discretize((Double)raw, this.discretizationPolicies.get(attributeIndex));
                Objects.requireNonNull(value);
            } else if (attributeIndex == this.numAttributes) {
                value = ((ILabeledInstance)datapoint).getLabel();
                if (value == null) {
                    throw new IllegalArgumentException("Cannot assign data point " + datapoint + " to any stratum, because it has no label.");
                }
            } else {
                value = datapoint.getAttributeValue(attributeIndex);
                Objects.requireNonNull(value);
            }
            instanceAttributeValues.add(value);
        }
        int stratum = this.stratumIDs.get(instanceAttributeValues);
        this.logger.debug("Attribute values are: {}. Corresponding stratum is: {}", instanceAttributeValues, (Object)stratum);
        return stratum;
    }

    private boolean toBeDiscretized(int index) {
        return this.discretizationPolicies.containsKey(index);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        this.discretizationHelper.setLoggerName(name + ".discretizer");
    }
}

