/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.dataset;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.core.dataset.Dataset;
import ai.libs.jaicore.ml.core.dataset.schema.LabeledInstanceSchema;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.DyadRankingAttribute;
import ai.libs.jaicore.ml.core.dataset.schema.attribute.SetOfObjectsAttribute;
import ai.libs.jaicore.ml.ranking.dyad.dataset.AGeneralDatasetBackedDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.Dyad;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DyadRankingDataset
extends AGeneralDatasetBackedDataset<IDyadRankingInstance>
implements IDyadRankingDataset {
    private static final String MSG_REMOVAL_FORBIDDEN = "Cannot remove a column for dyad DyadRankingDataset.";
    private Logger logger = LoggerFactory.getLogger(DyadRankingDataset.class);
    private LabeledInstanceSchema labeledInstanceSchema;

    public DyadRankingDataset() {
        this("");
    }

    public DyadRankingDataset(String relationName) {
        this.createInstanceSchema(relationName);
        this.setInternalDataset(new Dataset(this.labeledInstanceSchema));
    }

    public DyadRankingDataset(LabeledInstanceSchema labeledInstanceSchema) {
        this.labeledInstanceSchema = labeledInstanceSchema.getCopy();
        this.setInternalDataset(new Dataset(this.labeledInstanceSchema));
    }

    public DyadRankingDataset(String relationName, Collection<IDyadRankingInstance> c) {
        this(relationName);
        this.addAll(c);
    }

    public DyadRankingDataset(Collection<IDyadRankingInstance> c) {
        this("", c);
    }

    private void createInstanceSchema(String relationName) {
        SetOfObjectsAttribute<IDyad> dyadSetAttribute = new SetOfObjectsAttribute<IDyad>("dyads", IDyad.class);
        DyadRankingAttribute dyadRankingAttribute = new DyadRankingAttribute("ranking");
        this.labeledInstanceSchema = new LabeledInstanceSchema(relationName, Arrays.asList(dyadSetAttribute), (IAttribute)dyadRankingAttribute);
    }

    public void serialize(OutputStream out) {
        try {
            for (IDyadRankingInstance instance : this) {
                for (IDyad dyad : instance) {
                    out.write(dyad.getContext().toString().getBytes());
                    out.write(";".getBytes());
                    out.write(dyad.getAlternative().toString().getBytes());
                    out.write("|".getBytes());
                }
                out.write("\n".getBytes());
            }
        }
        catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    public void deserialize(InputStream in) {
        this.clear();
        try {
            String row;
            LineIterator input = IOUtils.lineIterator((InputStream)in, (Charset)StandardCharsets.UTF_8);
            while (input.hasNext() && !(row = input.next()).isEmpty()) {
                String[] dyadTokens;
                LinkedList<IDyad> dyads = new LinkedList<IDyad>();
                for (String dyadString : dyadTokens = row.split("\\|")) {
                    String[] values = dyadString.split(";");
                    if (values[0].length() <= 1 || values[1].length() <= 1) continue;
                    String[] instanceValues = values[0].substring(1, values[0].length() - 1).split(",");
                    String[] alternativeValues = values[1].substring(1, values[1].length() - 1).split(",");
                    DenseDoubleVector instance = new DenseDoubleVector(instanceValues.length);
                    for (int i = 0; i < instanceValues.length; ++i) {
                        instance.setValue(i, Double.parseDouble(instanceValues[i]));
                    }
                    DenseDoubleVector alternative = new DenseDoubleVector(alternativeValues.length);
                    for (int i = 0; i < alternativeValues.length; ++i) {
                        alternative.setValue(i, Double.parseDouble(alternativeValues[i]));
                    }
                    Dyad dyad = new Dyad((IVector)instance, (IVector)alternative);
                    dyads.add(dyad);
                }
                this.add(new DenseDyadRankingInstance(dyads));
            }
        }
        catch (IOException e) {
            this.logger.warn(e.getMessage());
        }
    }

    @Override
    public boolean equals(Object o) {
        if (!(o instanceof DyadRankingDataset)) {
            return false;
        }
        DyadRankingDataset dataset = (DyadRankingDataset)o;
        if (dataset.size() != this.size()) {
            return false;
        }
        for (int i = 0; i < dataset.size(); ++i) {
            IDyadRankingInstance i2;
            IDyadRankingInstance i1 = (IDyadRankingInstance)this.get(i);
            if (i1.equals(i2 = (IDyadRankingInstance)dataset.get(i))) continue;
            return false;
        }
        return true;
    }

    @Override
    public int hashCode() {
        int result = 17;
        for (IDyadRankingInstance instance : this) {
            result = result * 31 + instance.hashCode();
        }
        return result;
    }

    public List<INDArray> toND4j() {
        ArrayList<INDArray> ndList = new ArrayList<INDArray>();
        Iterator iterator = this.iterator();
        while (iterator.hasNext()) {
            IDyadRankingInstance instance;
            IDyadRankingInstance drInstance = instance = (IDyadRankingInstance)iterator.next();
            ndList.add(this.dyadRankingToMatrix(drInstance));
        }
        return ndList;
    }

    private INDArray dyadToVector(IDyad dyad) {
        INDArray instanceOfDyad = Nd4j.create((double[])dyad.getContext().asArray());
        INDArray alternativeOfDyad = Nd4j.create((double[])dyad.getAlternative().asArray());
        return Nd4j.hstack((INDArray[])new INDArray[]{instanceOfDyad, alternativeOfDyad});
    }

    private INDArray dyadRankingToMatrix(IDyadRankingInstance drInstance) {
        ArrayList<INDArray> dyadList = new ArrayList<INDArray>(drInstance.getNumberOfRankedElements());
        for (IDyad dyad : drInstance) {
            INDArray dyadVector = this.dyadToVector(dyad);
            dyadList.add(dyadVector);
        }
        INDArray dyadMatrix = Nd4j.vstack(dyadList);
        return dyadMatrix;
    }

    public ILabeledInstanceSchema getInstanceSchema() {
        return this.labeledInstanceSchema;
    }

    public Object[] getLabelVector() {
        return this.getInternalDataset().getLabelVector();
    }

    public DyadRankingDataset createEmptyCopy() {
        return new DyadRankingDataset(this.labeledInstanceSchema);
    }

    public Object[][] getFeatureMatrix() {
        return this.getInternalDataset().getFeatureMatrix();
    }

    public void removeColumn(int columnPos) {
        throw new UnsupportedOperationException(MSG_REMOVAL_FORBIDDEN);
    }

    public void removeColumn(String columnName) {
        throw new UnsupportedOperationException(MSG_REMOVAL_FORBIDDEN);
    }

    public void removeColumn(IAttribute attribute) {
        throw new UnsupportedOperationException(MSG_REMOVAL_FORBIDDEN);
    }

    public IDataset<IDyadRankingInstance> createCopy() throws DatasetCreationException, InterruptedException {
        DyadRankingDataset copy = this.createEmptyCopy();
        for (IDyadRankingInstance i : this) {
            copy.add(i);
        }
        return copy;
    }
}

