/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class TimeSeriesGenerator {
    private static final int DEFAULT_SAMPLING_RATE = 1;
    private static final int DEFAULT_STRIDE = 1;
    private static final Integer DEFAULT_START_INDEX = 0;
    private static final Integer DEFAULT_END_INDEX = null;
    private static final boolean DEFAULT_SHUFFLE = false;
    private static final boolean DEFAULT_REVERSE = false;
    private static final int DEFAULT_BATCH_SIZE = 128;
    private INDArray data;
    private INDArray targets;
    private int length;
    private int samplingRate;
    private int stride;
    private int startIndex;
    private int endIndex;
    private boolean shuffle;
    private boolean reverse;
    private int batchSize;

    public static TimeSeriesGenerator fromJson(String jsonFileName) throws IOException, InvalidKerasConfigurationException {
        String json = new String(Files.readAllBytes(Paths.get(jsonFileName, new String[0])));
        Map<String, Object> timeSeriesBaseConfig = KerasModelUtils.parseJsonString(json);
        if (!timeSeriesBaseConfig.containsKey("config")) {
            throw new InvalidKerasConfigurationException("No configuration found for Keras tokenizer");
        }
        Map timeSeriesConfig = (Map)timeSeriesBaseConfig.get("config");
        int length = (Integer)timeSeriesConfig.get("length");
        int samplingRate = (Integer)timeSeriesConfig.get("sampling_rate");
        int stride = (Integer)timeSeriesConfig.get("stride");
        int startIndex = (Integer)timeSeriesConfig.get("start_index");
        int endIndex = (Integer)timeSeriesConfig.get("end_index");
        int batchSize = (Integer)timeSeriesConfig.get("batch_size");
        boolean shuffle = (Boolean)timeSeriesConfig.get("shuffle");
        boolean reverse = (Boolean)timeSeriesConfig.get("reverse");
        Gson gson = new Gson();
        List dataList = (List)gson.fromJson((String)timeSeriesConfig.get("data"), new TypeToken<List<List<Double>>>(){}.getType());
        List targetsList = (List)gson.fromJson((String)timeSeriesConfig.get("targets"), new TypeToken<List<List<Double>>>(){}.getType());
        int dataPoints = dataList.size();
        int dataPointsPerRow = ((List)dataList.get(0)).size();
        INDArray data = Nd4j.create((int[])new int[]{dataPoints, dataPointsPerRow});
        INDArray targets = Nd4j.create((int[])new int[]{dataPoints, dataPointsPerRow});
        for (int i = 0; i < dataPoints; ++i) {
            data.put(i, Nd4j.create((List)((List)dataList.get(i))));
            targets.put(i, Nd4j.create((List)((List)targetsList.get(i))));
        }
        TimeSeriesGenerator gen = new TimeSeriesGenerator(data, targets, length, samplingRate, stride, startIndex, endIndex, shuffle, reverse, batchSize);
        return gen;
    }

    public TimeSeriesGenerator(INDArray data, INDArray targets, int length, int samplingRate, int stride, Integer startIndex, Integer endIndex, boolean shuffle, boolean reverse, int batchSize) throws InvalidKerasConfigurationException {
        this.data = data;
        this.targets = targets;
        this.length = length;
        this.samplingRate = samplingRate;
        if (stride != 1) {
            throw new InvalidKerasConfigurationException("currently no strides > 1 supported, got: " + stride);
        }
        this.stride = stride;
        this.startIndex = startIndex + length;
        if (endIndex == null) {
            endIndex = data.rows() - 1;
        }
        this.endIndex = endIndex;
        this.shuffle = shuffle;
        this.reverse = reverse;
        this.batchSize = batchSize;
        if (this.startIndex > this.endIndex) {
            throw new IllegalArgumentException("Start index of sequence has to be smaller then end index, got startIndex : " + this.startIndex + " and endIndex: " + this.endIndex);
        }
    }

    public TimeSeriesGenerator(INDArray data, INDArray targets, int length) throws InvalidKerasConfigurationException {
        this(data, targets, length, 1, 1, DEFAULT_START_INDEX, DEFAULT_END_INDEX, false, false, 128);
    }

    public int length() {
        return (this.endIndex - this.startIndex + this.batchSize * this.stride) / (this.batchSize * this.stride);
    }

    public Pair<INDArray, INDArray> next(int index) {
        INDArray rows;
        if (this.shuffle) {
            rows = Nd4j.getRandom().nextInt(this.endIndex, new int[]{this.batchSize});
            rows.addi((Number)this.startIndex);
        } else {
            int i = this.startIndex + this.batchSize + this.stride * index;
            rows = Nd4j.arange((double)i, (double)Math.min(i + this.batchSize * this.stride, this.endIndex + 1));
        }
        INDArray samples = Nd4j.create((long[])new long[]{rows.length(), this.length / this.samplingRate, this.data.columns()});
        INDArray targets = Nd4j.create((long[])new long[]{rows.length(), this.targets.columns()});
        for (int j = 0; j < rows.rows(); ++j) {
            long idx = (long)rows.getDouble((long)j);
            INDArrayIndex indices = NDArrayIndex.interval((long)(idx - (long)this.length), (long)this.samplingRate, (long)idx);
            INDArray slice = this.data.get(new INDArrayIndex[]{indices});
            samples.putSlice(j, slice);
            INDArrayIndex point = NDArrayIndex.point((long)((long)rows.getDouble((long)j)));
            targets.putSlice(j, this.targets.get(new INDArrayIndex[]{point}));
        }
        if (this.reverse) {
            samples = Nd4j.reverse((INDArray)samples);
        }
        return new Pair((Object)samples, (Object)targets);
    }

    public INDArray getData() {
        return this.data;
    }

    public INDArray getTargets() {
        return this.targets;
    }

    public int getLength() {
        return this.length;
    }

    public int getSamplingRate() {
        return this.samplingRate;
    }

    public int getStride() {
        return this.stride;
    }

    public int getStartIndex() {
        return this.startIndex;
    }

    public int getEndIndex() {
        return this.endIndex;
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public boolean isReverse() {
        return this.reverse;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setData(INDArray data) {
        this.data = data;
    }

    public void setTargets(INDArray targets) {
        this.targets = targets;
    }

    public void setLength(int length) {
        this.length = length;
    }

    public void setSamplingRate(int samplingRate) {
        this.samplingRate = samplingRate;
    }

    public void setStride(int stride) {
        this.stride = stride;
    }

    public void setStartIndex(int startIndex) {
        this.startIndex = startIndex;
    }

    public void setEndIndex(int endIndex) {
        this.endIndex = endIndex;
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    public void setReverse(boolean reverse) {
        this.reverse = reverse;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TimeSeriesGenerator)) {
            return false;
        }
        TimeSeriesGenerator other = (TimeSeriesGenerator)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getLength() != other.getLength()) {
            return false;
        }
        if (this.getSamplingRate() != other.getSamplingRate()) {
            return false;
        }
        if (this.getStride() != other.getStride()) {
            return false;
        }
        if (this.getStartIndex() != other.getStartIndex()) {
            return false;
        }
        if (this.getEndIndex() != other.getEndIndex()) {
            return false;
        }
        if (this.isShuffle() != other.isShuffle()) {
            return false;
        }
        if (this.isReverse() != other.isReverse()) {
            return false;
        }
        if (this.getBatchSize() != other.getBatchSize()) {
            return false;
        }
        INDArray this$data = this.getData();
        INDArray other$data = other.getData();
        if (this$data == null ? other$data != null : !this$data.equals(other$data)) {
            return false;
        }
        INDArray this$targets = this.getTargets();
        INDArray other$targets = other.getTargets();
        return !(this$targets == null ? other$targets != null : !this$targets.equals(other$targets));
    }

    protected boolean canEqual(Object other) {
        return other instanceof TimeSeriesGenerator;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getLength();
        result = result * 59 + this.getSamplingRate();
        result = result * 59 + this.getStride();
        result = result * 59 + this.getStartIndex();
        result = result * 59 + this.getEndIndex();
        result = result * 59 + (this.isShuffle() ? 79 : 97);
        result = result * 59 + (this.isReverse() ? 79 : 97);
        result = result * 59 + this.getBatchSize();
        INDArray $data = this.getData();
        result = result * 59 + ($data == null ? 43 : $data.hashCode());
        INDArray $targets = this.getTargets();
        result = result * 59 + ($targets == null ? 43 : $targets.hashCode());
        return result;
    }

    public String toString() {
        return "TimeSeriesGenerator(data=" + this.getData() + ", targets=" + this.getTargets() + ", length=" + this.getLength() + ", samplingRate=" + this.getSamplingRate() + ", stride=" + this.getStride() + ", startIndex=" + this.getStartIndex() + ", endIndex=" + this.getEndIndex() + ", shuffle=" + this.isShuffle() + ", reverse=" + this.isReverse() + ", batchSize=" + this.getBatchSize() + ")";
    }
}

