package org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGenerator.class */
public class TimeSeriesGenerator {
    private static final int DEFAULT_SAMPLING_RATE = 1;
    private static final int DEFAULT_STRIDE = 1;
    private static final Integer DEFAULT_START_INDEX = 0;
    private static final Integer DEFAULT_END_INDEX = null;
    private static final boolean DEFAULT_SHUFFLE = false;
    private static final boolean DEFAULT_REVERSE = false;
    private static final int DEFAULT_BATCH_SIZE = 128;
    private INDArray data;
    private INDArray targets;
    private int length;
    private int samplingRate;
    private int stride;
    private int startIndex;
    private int endIndex;
    private boolean shuffle;
    private boolean reverse;
    private int batchSize;

    /* JADX WARN: Type inference failed for: r2v5, types: [org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence.TimeSeriesGenerator$1] */
    /* JADX WARN: Type inference failed for: r2v8, types: [org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence.TimeSeriesGenerator$2] */
    public static TimeSeriesGenerator fromJson(String str) throws IOException, InvalidKerasConfigurationException {
        Map<String, Object> parseJsonString = KerasModelUtils.parseJsonString(new String(Files.readAllBytes(Paths.get(str, new String[0]))));
        if (!parseJsonString.containsKey("config")) {
            throw new InvalidKerasConfigurationException("No configuration found for Keras tokenizer");
        }
        Map map = (Map) parseJsonString.get("config");
        int intValue = ((Integer) map.get("length")).intValue();
        int intValue2 = ((Integer) map.get("sampling_rate")).intValue();
        int intValue3 = ((Integer) map.get("stride")).intValue();
        int intValue4 = ((Integer) map.get("start_index")).intValue();
        int intValue5 = ((Integer) map.get("end_index")).intValue();
        int intValue6 = ((Integer) map.get("batch_size")).intValue();
        boolean booleanValue = ((Boolean) map.get("shuffle")).booleanValue();
        boolean booleanValue2 = ((Boolean) map.get("reverse")).booleanValue();
        Gson gson = new Gson();
        List list = (List) gson.fromJson((String) map.get("data"), new TypeToken<List<List<Double>>>() { // from class: org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence.TimeSeriesGenerator.1
        }.getType());
        List list2 = (List) gson.fromJson((String) map.get("targets"), new TypeToken<List<List<Double>>>() { // from class: org.deeplearning4j.nn.modelimport.keras.preprocessing.sequence.TimeSeriesGenerator.2
        }.getType());
        int size = list.size();
        int size2 = ((List) list.get(0)).size();
        INDArray create = Nd4j.create(size, size2);
        INDArray create2 = Nd4j.create(size, size2);
        for (int i = 0; i < size; i++) {
            create.put(i, Nd4j.create((List) list.get(i)));
            create2.put(i, Nd4j.create((List) list2.get(i)));
        }
        return new TimeSeriesGenerator(create, create2, intValue, intValue2, intValue3, Integer.valueOf(intValue4), Integer.valueOf(intValue5), booleanValue, booleanValue2, intValue6);
    }

    public TimeSeriesGenerator(INDArray iNDArray, INDArray iNDArray2, int i, int i2, int i3, Integer num, Integer num2, boolean z, boolean z2, int i4) throws InvalidKerasConfigurationException {
        this.data = iNDArray;
        this.targets = iNDArray2;
        this.length = i;
        this.samplingRate = i2;
        if (i3 != 1) {
            throw new InvalidKerasConfigurationException("currently no strides > 1 supported, got: " + i3);
        }
        this.stride = i3;
        this.startIndex = num.intValue() + i;
        this.endIndex = (num2 == null ? Integer.valueOf(iNDArray.rows() - 1) : num2).intValue();
        this.shuffle = z;
        this.reverse = z2;
        this.batchSize = i4;
        if (this.startIndex > this.endIndex) {
            throw new IllegalArgumentException("Start index of sequence has to be smaller then end index, got startIndex : " + this.startIndex + " and endIndex: " + this.endIndex);
        }
    }

    public TimeSeriesGenerator(INDArray iNDArray, INDArray iNDArray2, int i) throws InvalidKerasConfigurationException {
        this(iNDArray, iNDArray2, i, 1, 1, DEFAULT_START_INDEX, DEFAULT_END_INDEX, false, false, DEFAULT_BATCH_SIZE);
    }

    public int length() {
        return ((this.endIndex - this.startIndex) + (this.batchSize * this.stride)) / (this.batchSize * this.stride);
    }

    public Pair<INDArray, INDArray> next(int i) {
        INDArray arange;
        if (this.shuffle) {
            arange = Nd4j.getRandom().nextInt(this.endIndex, new int[]{this.batchSize});
            arange.addi(Integer.valueOf(this.startIndex));
        } else {
            arange = Nd4j.arange(this.startIndex + this.batchSize + (this.stride * i), Math.min(r0 + (this.batchSize * this.stride), this.endIndex + 1));
        }
        INDArray create = Nd4j.create(new long[]{arange.length(), this.length / this.samplingRate, this.data.columns()});
        INDArray create2 = Nd4j.create(new long[]{arange.length(), this.targets.columns()});
        for (int i2 = 0; i2 < arange.rows(); i2++) {
            long j = (long) arange.getDouble(i2);
            create.putSlice(i2, this.data.get(new INDArrayIndex[]{NDArrayIndex.interval(j - this.length, this.samplingRate, j)}));
            create2.putSlice(i2, this.targets.get(new INDArrayIndex[]{NDArrayIndex.point((long) arange.getDouble(i2))}));
        }
        if (this.reverse) {
            create = Nd4j.reverse(create);
        }
        return new Pair<>(create, create2);
    }

    public INDArray getData() {
        return this.data;
    }

    public INDArray getTargets() {
        return this.targets;
    }

    public int getLength() {
        return this.length;
    }

    public int getSamplingRate() {
        return this.samplingRate;
    }

    public int getStride() {
        return this.stride;
    }

    public int getStartIndex() {
        return this.startIndex;
    }

    public int getEndIndex() {
        return this.endIndex;
    }

    public boolean isShuffle() {
        return this.shuffle;
    }

    public boolean isReverse() {
        return this.reverse;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setData(INDArray iNDArray) {
        this.data = iNDArray;
    }

    public void setTargets(INDArray iNDArray) {
        this.targets = iNDArray;
    }

    public void setLength(int i) {
        this.length = i;
    }

    public void setSamplingRate(int i) {
        this.samplingRate = i;
    }

    public void setStride(int i) {
        this.stride = i;
    }

    public void setStartIndex(int i) {
        this.startIndex = i;
    }

    public void setEndIndex(int i) {
        this.endIndex = i;
    }

    public void setShuffle(boolean z) {
        this.shuffle = z;
    }

    public void setReverse(boolean z) {
        this.reverse = z;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TimeSeriesGenerator)) {
            return false;
        }
        TimeSeriesGenerator timeSeriesGenerator = (TimeSeriesGenerator) obj;
        if (!timeSeriesGenerator.canEqual(this)) {
            return false;
        }
        INDArray data = getData();
        INDArray data2 = timeSeriesGenerator.getData();
        if (data == null) {
            if (data2 != null) {
                return false;
            }
        } else if (!data.equals(data2)) {
            return false;
        }
        INDArray targets = getTargets();
        INDArray targets2 = timeSeriesGenerator.getTargets();
        if (targets == null) {
            if (targets2 != null) {
                return false;
            }
        } else if (!targets.equals(targets2)) {
            return false;
        }
        return getLength() == timeSeriesGenerator.getLength() && getSamplingRate() == timeSeriesGenerator.getSamplingRate() && getStride() == timeSeriesGenerator.getStride() && getStartIndex() == timeSeriesGenerator.getStartIndex() && getEndIndex() == timeSeriesGenerator.getEndIndex() && isShuffle() == timeSeriesGenerator.isShuffle() && isReverse() == timeSeriesGenerator.isReverse() && getBatchSize() == timeSeriesGenerator.getBatchSize();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof TimeSeriesGenerator;
    }

    public int hashCode() {
        INDArray data = getData();
        int hashCode = (1 * 59) + (data == null ? 43 : data.hashCode());
        INDArray targets = getTargets();
        return (((((((((((((((((hashCode * 59) + (targets == null ? 43 : targets.hashCode())) * 59) + getLength()) * 59) + getSamplingRate()) * 59) + getStride()) * 59) + getStartIndex()) * 59) + getEndIndex()) * 59) + (isShuffle() ? 79 : 97)) * 59) + (isReverse() ? 79 : 97)) * 59) + getBatchSize();
    }

    public String toString() {
        return "TimeSeriesGenerator(data=" + getData() + ", targets=" + getTargets() + ", length=" + getLength() + ", samplingRate=" + getSamplingRate() + ", stride=" + getStride() + ", startIndex=" + getStartIndex() + ", endIndex=" + getEndIndex() + ", shuffle=" + isShuffle() + ", reverse=" + isReverse() + ", batchSize=" + getBatchSize() + ")";
    }
}
