package org.deeplearning4j.rl4j.learning.sync;

import java.util.List;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/Transition.class */
public final class Transition<A> {
    private final Observation observation;
    private final A action;
    private final double reward;
    private final boolean isTerminal;
    private final INDArray nextObservation;

    public Transition(Observation observation, A a, double d, boolean z, Observation observation2) {
        this.observation = observation;
        this.action = a;
        this.reward = d;
        this.isTerminal = z;
        long[] jArr = (long[]) observation2.getData().shape().clone();
        jArr[0] = 1;
        this.nextObservation = observation2.getData().get(new INDArrayIndex[]{NDArrayIndex.point(0L)}).reshape(jArr);
    }

    private Transition(Observation observation, A a, double d, boolean z, INDArray iNDArray) {
        this.observation = observation;
        this.action = a;
        this.reward = d;
        this.isTerminal = z;
        this.nextObservation = iNDArray;
    }

    public static INDArray concat(INDArray[] iNDArrayArr) {
        return Nd4j.concat(0, iNDArrayArr);
    }

    public Transition<A> dup() {
        return new Transition<>(this.observation.dup(), this.action, this.reward, this.isTerminal, this.nextObservation.dup());
    }

    public static <A> INDArray buildStackedObservations(List<Transition<A>> list) {
        int size = list.size();
        long[] shape = getShape(list);
        INDArray[] iNDArrayArr = new INDArray[size];
        for (int i = 0; i < size; i++) {
            iNDArrayArr[i] = list.get(i).getObservation().getData();
        }
        return Nd4j.concat(0, iNDArrayArr).reshape(shape);
    }

    public static <A> INDArray buildStackedNextObservations(List<Transition<A>> list) {
        int size = list.size();
        long[] shape = getShape(list);
        INDArray[] iNDArrayArr = new INDArray[size];
        for (int i = 0; i < size; i++) {
            Transition<A> transition = list.get(i);
            INDArray data = transition.getObservation().getData();
            long j = data.shape()[0];
            if (j != 1) {
                iNDArrayArr[i] = Nd4j.concat(0, new INDArray[]{transition.getNextObservation(), data.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, j - 1)})});
            } else {
                iNDArrayArr[i] = transition.getNextObservation();
            }
        }
        return Nd4j.concat(0, iNDArrayArr).reshape(shape);
    }

    private static <A> long[] getShape(List<Transition<A>> list) {
        long[] jArr;
        long[] shape = list.get(0).getObservation().getData().shape();
        if (shape[0] == 1) {
            jArr = new long[shape.length];
            System.arraycopy(shape, 0, jArr, 0, shape.length);
        } else {
            jArr = new long[shape.length + 1];
            System.arraycopy(shape, 1, jArr, 2, shape.length - 1);
            jArr[1] = shape[1];
        }
        jArr[0] = list.size();
        return jArr;
    }

    public Observation getObservation() {
        return this.observation;
    }

    public A getAction() {
        return this.action;
    }

    public double getReward() {
        return this.reward;
    }

    public boolean isTerminal() {
        return this.isTerminal;
    }

    public INDArray getNextObservation() {
        return this.nextObservation;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Transition)) {
            return false;
        }
        Transition transition = (Transition) obj;
        Observation observation = getObservation();
        Observation observation2 = transition.getObservation();
        if (observation == null) {
            if (observation2 != null) {
                return false;
            }
        } else if (!observation.equals(observation2)) {
            return false;
        }
        A action = getAction();
        Object action2 = transition.getAction();
        if (action == null) {
            if (action2 != null) {
                return false;
            }
        } else if (!action.equals(action2)) {
            return false;
        }
        if (Double.compare(getReward(), transition.getReward()) != 0 || isTerminal() != transition.isTerminal()) {
            return false;
        }
        INDArray nextObservation = getNextObservation();
        INDArray nextObservation2 = transition.getNextObservation();
        return nextObservation == null ? nextObservation2 == null : nextObservation.equals(nextObservation2);
    }

    public int hashCode() {
        Observation observation = getObservation();
        int hashCode = (1 * 59) + (observation == null ? 43 : observation.hashCode());
        A action = getAction();
        int hashCode2 = (hashCode * 59) + (action == null ? 43 : action.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getReward());
        int i = (((hashCode2 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + (isTerminal() ? 79 : 97);
        INDArray nextObservation = getNextObservation();
        return (i * 59) + (nextObservation == null ? 43 : nextObservation.hashCode());
    }

    public String toString() {
        return "Transition(observation=" + getObservation() + ", action=" + getAction() + ", reward=" + getReward() + ", isTerminal=" + isTerminal() + ", nextObservation=" + getNextObservation() + ")";
    }
}
