package org.deeplearning4j.rl4j.util;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/util/LegacyMDPWrapper.class */
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
    private final MDP<O, A, AS> wrappedMDP;
    private final WrapperObservationSpace observationSpace;
    private final ILearning learning;
    private IHistoryProcessor historyProcessor;
    private final StepCountable stepCountable;
    private int skipFrame;
    private int step;

    /* loaded from: input_file:org/deeplearning4j/rl4j/util/LegacyMDPWrapper$WrapperObservationSpace.class */
    public static class WrapperObservationSpace implements ObservationSpace<Observation> {
        private final int[] shape;

        public WrapperObservationSpace(int[] iArr) {
            this.shape = iArr;
        }

        public String getName() {
            return null;
        }

        public INDArray getLow() {
            return null;
        }

        public INDArray getHigh() {
            return null;
        }

        public int[] getShape() {
            return this.shape;
        }
    }

    public LegacyMDPWrapper(MDP<O, A, AS> mdp, ILearning iLearning) {
        this(mdp, iLearning, null, null);
    }

    public LegacyMDPWrapper(MDP<O, A, AS> mdp, IHistoryProcessor iHistoryProcessor, StepCountable stepCountable) {
        this(mdp, null, iHistoryProcessor, stepCountable);
    }

    private LegacyMDPWrapper(MDP<O, A, AS> mdp, ILearning iLearning, IHistoryProcessor iHistoryProcessor, StepCountable stepCountable) {
        this.step = 0;
        this.wrappedMDP = mdp;
        this.observationSpace = new WrapperObservationSpace(mdp.getObservationSpace().getShape());
        this.learning = iLearning;
        this.historyProcessor = iHistoryProcessor;
        this.stepCountable = stepCountable;
    }

    private IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor != null ? this.historyProcessor : this.learning.getHistoryProcessor();
    }

    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        this.historyProcessor = iHistoryProcessor;
    }

    private int getStep() {
        return this.stepCountable != null ? this.stepCountable.getStepCounter() : this.learning.getStepCounter();
    }

    public AS getActionSpace() {
        return (AS) this.wrappedMDP.getActionSpace();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: reset, reason: merged with bridge method [inline-methods] */
    public Observation m27reset() {
        INDArray input = getInput(this.wrappedMDP.reset());
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        if (historyProcessor != null) {
            historyProcessor.record(input);
        }
        Observation observation = new Observation(new INDArray[]{input}, false);
        if (historyProcessor != null) {
            this.skipFrame = historyProcessor.getConf().getSkipFrame();
            historyProcessor.add(input);
        }
        this.step = 0;
        return observation;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StepReply<Observation> step(A a) {
        Observation observation;
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        StepReply step = this.wrappedMDP.step(a);
        INDArray input = getInput(step.getObservation());
        this.step++;
        int i = 0;
        if (historyProcessor != null) {
            historyProcessor.record(input);
            i = this.skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
            if ((getStep() % this.skipFrame == 0 && this.step >= i) || (this.step % this.skipFrame == 0 && this.step < i)) {
                historyProcessor.add(input);
            }
        }
        if (historyProcessor == null || this.step < i) {
            observation = new Observation(new INDArray[]{input}, false);
        } else {
            observation = new Observation(historyProcessor.getHistory(), true);
            observation.getData().muli(Double.valueOf(1.0d / historyProcessor.getScale()));
        }
        return new StepReply<>(observation, step.getReward(), step.isDone(), step.getInfo());
    }

    public void close() {
        this.wrappedMDP.close();
    }

    public boolean isDone() {
        return this.wrappedMDP.isDone();
    }

    public MDP<Observation, A, AS> newInstance() {
        return new LegacyMDPWrapper(this.wrappedMDP.newInstance(), this.learning);
    }

    private INDArray getInput(O o) {
        INDArray create = Nd4j.create(((Encodable) o).toArray());
        int[] shape = this.observationSpace.getShape();
        return shape.length == 1 ? create.reshape(new long[]{1, create.length()}) : create.reshape(shape);
    }

    public MDP<O, A, AS> getWrappedMDP() {
        return this.wrappedMDP;
    }

    /* renamed from: getObservationSpace, reason: merged with bridge method [inline-methods] */
    public WrapperObservationSpace m28getObservationSpace() {
        return this.observationSpace;
    }
}
