/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.dataset;

import ai.libs.jaicore.basic.reconstruction.ReconstructionInstruction;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.core.dataset.AInstance;
import ai.libs.jaicore.ml.core.dataset.Dataset;
import ai.libs.jaicore.ml.core.dataset.DenseInstance;
import ai.libs.jaicore.ml.core.dataset.SparseInstance;
import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.IntBasedCategoricalAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.NumericAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ICategoricalAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.INumericAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.common.reconstruction.IReconstructible;
import org.api4.java.common.reconstruction.IReconstructionInstruction;

public class DatasetUtil {
    public static final int EXPANSION_SQUARES = 1;
    public static final int EXPANSION_LOGARITHM = 2;
    public static final int EXPANSION_PRODUCTS = 3;

    private DatasetUtil() {
    }

    public static Map<Object, Integer> getLabelCounts(ILabeledDataset<?> ds) {
        HashMap<Object, Integer> labelCounter = new HashMap<Object, Integer>();
        ds.forEach(li -> {
            Object label = li.getLabel();
            labelCounter.put(label, labelCounter.computeIfAbsent(label, l -> 0) + 1);
        });
        return labelCounter;
    }

    public static int getLabelCountDifference(ILabeledDataset<?> d1, ILabeledDataset<?> d2) {
        Map<Object, Integer> c1 = DatasetUtil.getLabelCounts(d1);
        Map<Object, Integer> c2 = DatasetUtil.getLabelCounts(d2);
        Collection labels = SetUtil.union((Collection[])new Collection[]{c1.keySet(), c2.keySet()});
        int diff = 0;
        for (Object label : labels) {
            diff += Math.abs(c1.get(label) - c2.get(label));
        }
        return diff;
    }

    private static ILabeledDataset<?> convertTargetOfDataset(ILabeledDataset<?> dataset, IAttribute attr, Map<Object, ? extends Object> conversionMap, String utilMethodName) {
        ArrayList<IAttribute> attList = new ArrayList<IAttribute>(dataset.getInstanceSchema().getAttributeList());
        LabeledInstanceSchema scheme = new LabeledInstanceSchema(dataset.getRelationName(), attList, attr);
        Dataset datasetModified = new Dataset(scheme);
        int numAttributes = dataset.getNumAttributes();
        for (ILabeledInstance i : dataset) {
            AInstance ci;
            if (i instanceof DenseInstance) {
                ci = new DenseInstance(i.getAttributes(), conversionMap.get(i.getLabel()));
            } else if (i instanceof SparseInstance) {
                ci = new SparseInstance(numAttributes, ((SparseInstance)i).getAttributeMap(), conversionMap.get(i.getLabel()));
            } else {
                throw new UnsupportedOperationException();
            }
            if (!datasetModified.getLabelAttribute().isValidValue(ci.getLabel())) {
                throw new IllegalStateException("Value " + ci.getLabel() + " is not a valid label value for label attribute " + datasetModified.getLabelAttribute());
            }
            datasetModified.add(ci);
        }
        if (dataset instanceof IReconstructible) {
            ((IReconstructible)dataset).getConstructionPlan().getInstructions().forEach(datasetModified::addInstruction);
            try {
                datasetModified.addInstruction((IReconstructionInstruction)new ReconstructionInstruction(DatasetUtil.class.getMethod(utilMethodName, ILabeledDataset.class), new Object[]{"this"}));
            }
            catch (NoSuchMethodException | SecurityException e) {
                throw new UnsupportedOperationException(e);
            }
        }
        return datasetModified;
    }

    public static ILabeledDataset<?> convertToClassificationDataset(ILabeledDataset<?> dataset) {
        IAttribute currentLabelAttribute = dataset.getLabelAttribute();
        if (currentLabelAttribute instanceof ICategoricalAttribute) {
            return dataset;
        }
        Set values = dataset.stream().map(x -> x.getLabel().toString()).collect(Collectors.toSet());
        IntBasedCategoricalAttribute attr = new IntBasedCategoricalAttribute(currentLabelAttribute.getName(), new ArrayList<String>(values));
        HashMap<Object, Integer> conversionMap = new HashMap<Object, Integer>();
        for (ILabeledInstance i : dataset) {
            if (conversionMap.containsKey(i.getLabel())) continue;
            conversionMap.put(i.getLabel(), attr.deserializeAttributeValue(i.getLabel().toString()));
        }
        return DatasetUtil.convertTargetOfDataset(dataset, (IAttribute)attr, conversionMap, "convertToClassificationDataset");
    }

    public static ILabeledDataset<?> convertToRegressionDataset(ILabeledDataset<?> dataset) {
        IAttribute currentLabelAttribute = dataset.getLabelAttribute();
        if (currentLabelAttribute instanceof INumericAttribute) {
            return dataset;
        }
        NumericAttribute attr = new NumericAttribute(currentLabelAttribute.getName());
        HashMap<Object, Double> labelMap = new HashMap<Object, Double>();
        ICategoricalAttribute catLabel = (ICategoricalAttribute)dataset.getLabelAttribute();
        try {
            for (String label : catLabel.getLabels()) {
                try {
                    labelMap.put(catLabel.deserializeAttributeValue(label), Integer.valueOf(label).doubleValue());
                }
                catch (NumberFormatException e) {
                    labelMap.put(catLabel.deserializeAttributeValue(label), Double.parseDouble(label));
                }
            }
        }
        catch (NumberFormatException e) {
            labelMap.clear();
            List labels = catLabel.getLabels();
            for (int i = 0; i < labels.size(); ++i) {
                labelMap.put(catLabel.deserializeAttributeValue((String)labels.get(i)), Double.valueOf(labels.indexOf(labels.get(i))));
            }
        }
        return DatasetUtil.convertTargetOfDataset(dataset, (IAttribute)attr, labelMap, "convertToRegressionDataset");
    }

    public static ILabeledDataset<?> getDatasetFromMapCollection(Collection<Map<String, Object>> datasetAsListOfMaps, String nameOfLabelAttribute) {
        List<String> keyOrder = datasetAsListOfMaps.iterator().next().keySet().stream().sorted().collect(Collectors.toList());
        return DatasetUtil.getDatasetFromMapCollection(datasetAsListOfMaps, nameOfLabelAttribute, keyOrder);
    }

    public static ILabeledDataset<?> getDatasetFromMapCollection(Collection<Map<String, Object>> datasetAsListOfMaps, String nameOfLabelAttribute, List<String> orderOfAttributes) {
        HashSet<String> keys = new HashSet<String>(orderOfAttributes);
        for (Map<String, Object> dataPoint : datasetAsListOfMaps) {
            if (keys.equals(dataPoint.keySet())) continue;
            throw new IllegalStateException();
        }
        ArrayList<IAttribute> attributeList = new ArrayList<IAttribute>();
        for (String key : orderOfAttributes) {
            if (key.equals(nameOfLabelAttribute)) continue;
            Object val = datasetAsListOfMaps.iterator().next().get(key);
            if (val instanceof Number) {
                attributeList.add((IAttribute)new NumericAttribute(key));
                continue;
            }
            if (val instanceof Boolean) {
                attributeList.add((IAttribute)new IntBasedCategoricalAttribute(key, Arrays.asList("false", "true")));
                continue;
            }
            throw new UnsupportedOperationException();
        }
        LabeledInstanceSchema schema = new LabeledInstanceSchema("rel", attributeList, (IAttribute)new NumericAttribute(nameOfLabelAttribute));
        Dataset metaDataset = new Dataset(schema);
        for (Map<String, Object> row : datasetAsListOfMaps) {
            ILabeledInstance inst = DatasetUtil.getInstanceFromMap(schema, row, nameOfLabelAttribute);
            metaDataset.add(inst);
        }
        return metaDataset;
    }

    public static ILabeledInstance getInstanceFromMap(ILabeledInstanceSchema schema, Map<String, Object> row, String nameOfLabelAttribute) {
        return DatasetUtil.getInstanceFromMap(schema, row, nameOfLabelAttribute, new HashMap<IAttribute, Function<ILabeledInstance, Double>>());
    }

    public static ILabeledInstance getInstanceFromMap(ILabeledInstanceSchema schema, Map<String, Object> row, String nameOfLabelAttribute, Map<IAttribute, Function<ILabeledInstance, Double>> attributeValueComputer) {
        ArrayList<Object> attributes = new ArrayList<Object>(schema.getNumAttributes());
        ArrayList<Integer> attributeToRecover = new ArrayList<Integer>();
        int i = 0;
        for (IAttribute att : schema.getAttributeList()) {
            if (row.containsKey(att.getName())) {
                attributes.add(row.get(att.getName()));
            } else {
                attributeToRecover.add(i);
                attributes.add(null);
            }
            ++i;
        }
        DenseInstance inst = new DenseInstance(attributes, row.get(nameOfLabelAttribute));
        if (inst.getNumAttributes() != schema.getNumAttributes()) {
            throw new IllegalStateException("Created dense instance with " + inst.getNumAttributes() + " attributes where the scheme requires " + schema.getNumAttributes());
        }
        Iterator iterator = attributeToRecover.iterator();
        while (iterator.hasNext()) {
            int attIndex = (Integer)iterator.next();
            inst.setAttributeValue(attIndex, attributeValueComputer.get(schema.getAttribute(attIndex)).apply(inst));
        }
        return inst;
    }

    public static Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> getPairOfNewAttributesAndExpansionMap(ILabeledDataset<?> dataset, int ... expansions) throws InterruptedException {
        List attributeList = dataset.getInstanceSchema().getAttributeList();
        ArrayList<NumericAttribute> newAttributes = new ArrayList<NumericAttribute>();
        boolean computeSquares = false;
        boolean computeProducts = false;
        boolean computeLogs = false;
        block5: for (int expansion : expansions) {
            switch (expansion) {
                case 2: {
                    computeLogs = true;
                    continue block5;
                }
                case 1: {
                    computeSquares = true;
                    continue block5;
                }
                case 3: {
                    computeProducts = true;
                    continue block5;
                }
                default: {
                    throw new UnsupportedOperationException("Unknown expansion " + expansion);
                }
            }
        }
        HashMap<NumericAttribute, Function<ILabeledInstance, Double>> transformations = new HashMap<NumericAttribute, Function<ILabeledInstance, Double>>();
        for (int attId = 0; attId < dataset.getNumAttributes(); ++attId) {
            NumericAttribute dAtt;
            int attIdFinal = attId;
            IAttribute att = dataset.getAttribute(attId);
            if (computeSquares && att instanceof INumericAttribute) {
                dAtt = new NumericAttribute(att.getName() + "_2");
                newAttributes.add(dAtt);
                transformations.put(dAtt, i -> Math.pow(Double.parseDouble(i.getAttributeValue(attIdFinal).toString()), 2.0));
                continue;
            }
            if (!computeLogs || !(att instanceof INumericAttribute)) continue;
            dAtt = new NumericAttribute(att.getName() + "_log");
            newAttributes.add(dAtt);
            transformations.put(dAtt, i -> Math.log((Double)i.getAttributeValue(attIdFinal)));
        }
        if (computeProducts) {
            Collection featureSubSets = SetUtil.powerset((Collection)attributeList);
            for (Collection subset : featureSubSets) {
                if (subset.size() > 3 || subset.size() < 2) continue;
                StringBuilder featureName = new StringBuilder("x");
                ArrayList<Integer> indices = new ArrayList<Integer>();
                for (IAttribute feature : subset.stream().sorted((a1, a2) -> a1.getName().compareTo(a2.getName())).collect(Collectors.toList())) {
                    featureName.append("_" + feature.getName());
                    indices.add(attributeList.indexOf(feature));
                }
                NumericAttribute dAtt = new NumericAttribute(featureName.toString());
                if (attributeList.contains(dAtt)) {
                    throw new IllegalStateException("Dataset already has attribute " + dAtt.getName());
                }
                if (newAttributes.contains(dAtt)) {
                    throw new IllegalStateException("Already added attribute " + dAtt.getName());
                }
                newAttributes.add(dAtt);
                transformations.put(dAtt, i -> {
                    double val = 1.0;
                    Iterator iterator = indices.iterator();
                    while (iterator.hasNext()) {
                        int index = (Integer)iterator.next();
                        val *= Double.parseDouble(i.getAttributeValue(index).toString());
                    }
                    return val;
                });
            }
        }
        return new Pair(newAttributes, transformations);
    }

    public static ILabeledDataset<?> getExpansionOfDataset(ILabeledDataset<?> dataset, int ... expansions) throws InterruptedException {
        return DatasetUtil.getExpansionOfDataset(dataset, DatasetUtil.getPairOfNewAttributesAndExpansionMap(dataset, expansions));
    }

    public static ILabeledDataset<?> getExpansionOfDataset(ILabeledDataset<?> dataset, Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> expansionDescription) {
        ArrayList<IAttribute> newAttributeList = new ArrayList<IAttribute>(dataset.getInstanceSchema().getAttributeList());
        newAttributeList.addAll((Collection)expansionDescription.getX());
        LabeledInstanceSchema schema = new LabeledInstanceSchema(dataset.getRelationName() + "_expansion", newAttributeList, dataset.getLabelAttribute());
        Dataset ds = new Dataset(schema);
        for (ILabeledInstance i : dataset) {
            ds.add(DatasetUtil.getExpansionOfInstance(i, expansionDescription));
        }
        return ds;
    }

    public static ILabeledInstance getExpansionOfInstance(ILabeledInstance i, Pair<List<IAttribute>, Map<IAttribute, Function<ILabeledInstance, Double>>> expansionDescription) {
        ArrayList<Object> attributes = new ArrayList<Object>(Arrays.asList(i.getAttributes()));
        for (IAttribute newAtt : (List)expansionDescription.getX()) {
            attributes.add(((Function)((Map)expansionDescription.getY()).get(newAtt)).apply(i));
        }
        return new DenseInstance(attributes, i.getLabel());
    }
}

