/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.ITimeSeriesInstance;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.TimeSeriesInstance;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.attribute.NDArrayTimeseriesAttribute;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.model.INDArrayTimeseries;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.model.NDArrayTimeseries;
import ai.libs.jaicore.ml.core.dataset.ADataset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ITimeseriesAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class TimeSeriesDataset
extends ADataset<ITimeSeriesInstance>
implements ILabeledDataset<ITimeSeriesInstance> {
    private static final long serialVersionUID = -6819487387561457394L;
    private List<INDArray> valueMatrices;
    private List<INDArray> timestampMatrices;
    private transient List<Object> targets;

    public TimeSeriesDataset(ILabeledInstanceSchema schema, List<INDArray> valueMatrices, List<INDArray> timestampMatrices, List<Object> targets) {
        this(schema);
        for (IAttribute att : schema.getAttributeList()) {
            if (att instanceof ITimeseriesAttribute) continue;
            throw new IllegalArgumentException("The schema contains attributes which are not timeseries");
        }
        Set valueInstances = valueMatrices.stream().map(x -> x.shape()[0]).collect(Collectors.toSet());
        if (valueInstances.size() > 1) {
            throw new IllegalArgumentException("The value matrices vary in length i.e. they have different number of instances");
        }
        Set timestampInstances = timestampMatrices.stream().map(x -> x.shape()[0]).collect(Collectors.toSet());
        if (timestampInstances.size() > 1) {
            throw new IllegalArgumentException("The timestamp matrices vary in length i.e. they have different number of instances");
        }
        valueInstances.addAll(timestampInstances);
        if (valueInstances.size() > 1) {
            throw new IllegalArgumentException("There are different number of instances for values and timestamps");
        }
        this.valueMatrices = valueMatrices;
        this.timestampMatrices = timestampMatrices;
        this.targets = targets;
    }

    public TimeSeriesDataset(ILabeledInstanceSchema schema) {
        super(schema);
    }

    public void add(String attributeName, INDArray valueMatrix, INDArray timestampMatrix) {
        this.valueMatrices.add(valueMatrix);
        this.timestampMatrices.add(timestampMatrix);
        this.addAttribute(attributeName, valueMatrix);
    }

    public void removeColumn(int index) {
        this.valueMatrices.remove(index);
        this.timestampMatrices.remove(index);
        this.getInstanceSchema().removeAttribute(index);
    }

    public void replace(int index, INDArray valueMatrix, INDArray timestampMatrix) {
        this.valueMatrices.set(index, valueMatrix);
        if (timestampMatrix != null && this.timestampMatrices != null && this.timestampMatrices.size() > index) {
            this.timestampMatrices.set(index, timestampMatrix);
        }
        NDArrayTimeseriesAttribute type = this.createAttribute("ts" + index, valueMatrix);
        this.getInstanceSchema().removeAttribute(index);
        this.getInstanceSchema().addAttribute(index, (IAttribute)type);
    }

    public Object getTargets() {
        return this.targets;
    }

    public INDArray getTargetsAsINDArray() {
        if (this.targets.get(0) instanceof Number) {
            return Nd4j.create((double[])this.targets.stream().mapToDouble(x -> (Double)x).toArray());
        }
        return null;
    }

    public int getNumberOfVariables() {
        return this.valueMatrices.size();
    }

    public long getNumberOfInstances() {
        return this.valueMatrices.get(0).shape()[0];
    }

    public INDArray getValues(int index) {
        return this.valueMatrices.get(index);
    }

    public INDArray getTimestamps(int index) {
        return this.timestampMatrices.get(index);
    }

    public INDArray getValuesOrNull(int index) {
        return this.valueMatrices.size() > index ? this.valueMatrices.get(index) : null;
    }

    public INDArray getTimestampsOrNull(int index) {
        return this.timestampMatrices != null && this.timestampMatrices.size() > index ? this.timestampMatrices.get(index) : null;
    }

    @Override
    public boolean isEmpty() {
        return this.valueMatrices.isEmpty();
    }

    public boolean isUnivariate() {
        return this.valueMatrices.size() == 1;
    }

    public boolean isMultivariate() {
        return this.valueMatrices.size() > 1;
    }

    private NDArrayTimeseriesAttribute createAttribute(String name, INDArray valueMatrix) {
        int length = (int)valueMatrix.shape()[1];
        return new NDArrayTimeseriesAttribute(name, length);
    }

    private void addAttribute(String name, INDArray valueMatrix) {
        NDArrayTimeseriesAttribute type = this.createAttribute(name, valueMatrix);
        this.getInstanceSchema().addAttribute((IAttribute)type);
        this.valueMatrices.add(valueMatrix);
    }

    @Override
    public TimeSeriesInstance get(int index) {
        ArrayList<INDArrayTimeseries> attributeValues = new ArrayList<INDArrayTimeseries>();
        for (int i = 0; i < this.valueMatrices.size(); ++i) {
            attributeValues.add(new NDArrayTimeseries(this.valueMatrices.get(i).getRow((long)index)));
        }
        Object target = this.targets.get(index);
        return new TimeSeriesInstance(attributeValues, target);
    }

    @Override
    public Iterator<ITimeSeriesInstance> iterator() {
        return new TimeSeriesDatasetIterator();
    }

    public TimeSeriesDataset createEmptyCopy() throws DatasetCreationException, InterruptedException {
        return new TimeSeriesDataset(this.getInstanceSchema());
    }

    @Override
    public Object[][] getFeatureMatrix() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Object[] getLabelVector() {
        return this.targets.toArray();
    }

    public TimeSeriesDataset createCopy() throws DatasetCreationException, InterruptedException {
        TimeSeriesDataset copy = this.createEmptyCopy();
        for (ITimeSeriesInstance i : this) {
            copy.add(i);
        }
        return copy;
    }

    @Override
    public int hashCode() {
        int prime = 31;
        int result = super.hashCode();
        result = 31 * result + (this.targets == null ? 0 : this.targets.hashCode());
        result = 31 * result + (this.timestampMatrices == null ? 0 : this.timestampMatrices.hashCode());
        result = 31 * result + (this.valueMatrices == null ? 0 : this.valueMatrices.hashCode());
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj)) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        TimeSeriesDataset other = (TimeSeriesDataset)obj;
        if (this.targets == null ? other.targets != null : !this.targets.equals(other.targets)) {
            return false;
        }
        if (this.timestampMatrices == null ? other.timestampMatrices != null : !this.timestampMatrices.equals(other.timestampMatrices)) {
            return false;
        }
        return !(this.valueMatrices == null ? other.valueMatrices != null : !this.valueMatrices.equals(other.valueMatrices));
    }

    class TimeSeriesDatasetIterator
    implements Iterator<ITimeSeriesInstance> {
        private int current = 0;

        TimeSeriesDatasetIterator() {
        }

        @Override
        public boolean hasNext() {
            return TimeSeriesDataset.this.getNumberOfInstances() > (long)this.current;
        }

        @Override
        public ITimeSeriesInstance next() {
            if (!this.hasNext()) {
                throw new NoSuchElementException();
            }
            return TimeSeriesDataset.this.get(this.current++);
        }
    }
}

