/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.dataset;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledDataset;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DyadRankingDataset
extends ArrayList<IDyadRankingInstance>
implements IOrderedLabeledDataset<IDyadRankingInstance, IDyadRankingInstance> {
    private transient Logger logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    private static final long serialVersionUID = -1102494546233523992L;

    public DyadRankingDataset() {
    }

    public DyadRankingDataset(Collection<IDyadRankingInstance> c) {
        super(c);
    }

    public DyadRankingDataset(int initialCapacity) {
        super(initialCapacity);
    }

    public DyadRankingDataset(List<IDyadRankingInstance> dyadRankingInstances) {
        super(dyadRankingInstances);
    }

    public void serialize(OutputStream out) {
        try {
            for (IDyadRankingInstance instance : this) {
                for (Dyad dyad : instance) {
                    out.write(dyad.getInstance().toString().getBytes());
                    out.write(";".getBytes());
                    out.write(dyad.getAlternative().toString().getBytes());
                    out.write("|".getBytes());
                }
                out.write("\n".getBytes());
            }
        }
        catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    public void deserialize(InputStream in) {
        this.clear();
        try {
            String row;
            LineIterator input = IOUtils.lineIterator((InputStream)in, (Charset)StandardCharsets.UTF_8);
            while (input.hasNext() && !(row = input.next()).isEmpty()) {
                String[] dyadTokens;
                LinkedList<Dyad> dyads = new LinkedList<Dyad>();
                for (String dyadString : dyadTokens = row.split("\\|")) {
                    String[] values = dyadString.split(";");
                    if (values[0].length() <= 1 || values[1].length() <= 1) continue;
                    String[] instanceValues = values[0].substring(1, values[0].length() - 1).split(",");
                    String[] alternativeValues = values[1].substring(1, values[1].length() - 1).split(",");
                    DenseDoubleVector instance = new DenseDoubleVector(instanceValues.length);
                    for (int i = 0; i < instanceValues.length; ++i) {
                        instance.setValue(i, Double.parseDouble(instanceValues[i]));
                    }
                    DenseDoubleVector alternative = new DenseDoubleVector(alternativeValues.length);
                    for (int i = 0; i < alternativeValues.length; ++i) {
                        alternative.setValue(i, Double.parseDouble(alternativeValues[i]));
                    }
                    Dyad dyad = new Dyad((Vector)instance, (Vector)alternative);
                    dyads.add(dyad);
                }
                this.add(new DyadRankingInstance(dyads));
            }
        }
        catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof DyadRankingDataset)) {
            return false;
        }
        DyadRankingDataset dataset = (DyadRankingDataset)o;
        if (dataset.size() != this.size()) {
            return false;
        }
        for (int i = 0; i < dataset.size(); ++i) {
            IDyadRankingInstance i2;
            IDyadRankingInstance i1 = (IDyadRankingInstance)this.get(i);
            if (i1.equals(i2 = (IDyadRankingInstance)dataset.get(i))) continue;
            return false;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int result = 17;
        for (IDyadRankingInstance instance : this) {
            result = result * 31 + instance.hashCode();
        }
        return result;
    }

    public List<INDArray> toND4j() {
        ArrayList<INDArray> ndList = new ArrayList<INDArray>();
        Iterator iterator = this.iterator();
        while (iterator.hasNext()) {
            IDyadRankingInstance instance;
            IDyadRankingInstance drInstance = instance = (IDyadRankingInstance)iterator.next();
            ndList.add(this.dyadRankingToMatrix(drInstance));
        }
        return ndList;
    }

    private INDArray dyadToVector(Dyad dyad) {
        INDArray instanceOfDyad = Nd4j.create((double[])dyad.getInstance().asArray());
        INDArray alternativeOfDyad = Nd4j.create((double[])dyad.getAlternative().asArray());
        return Nd4j.hstack((INDArray[])new INDArray[]{instanceOfDyad, alternativeOfDyad});
    }

    public static DyadRankingDataset fromOrderedDyadList(List<Dyad> orderedDyad) {
        List<IDyadRankingInstance> dyadRankingInstance = Arrays.asList(new DyadRankingInstance(orderedDyad));
        return new DyadRankingDataset(dyadRankingInstance);
    }

    private INDArray dyadRankingToMatrix(IDyadRankingInstance drInstance) {
        ArrayList<INDArray> dyadList = new ArrayList<INDArray>(drInstance.length());
        for (Dyad dyad : drInstance) {
            INDArray dyadVector = this.dyadToVector(dyad);
            dyadList.add(dyadVector);
        }
        INDArray dyadMatrix = Nd4j.vstack(dyadList);
        return dyadMatrix;
    }

    public DyadRankingDataset createEmpty() {
        return new DyadRankingDataset();
    }

    @Override
    public int getFrequency(IDyadRankingInstance instance) {
        return (int)this.stream().filter(instance::equals).count();
    }
}

