package org.deeplearning4j.rl4j.learning;

import java.util.Random;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/Learning.class */
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
    private static final Logger log = LoggerFactory.getLogger(Learning.class);
    private final Random random;
    private int stepCounter = 0;
    private int epochCounter = 0;
    private IHistoryProcessor historyProcessor = null;

    /* loaded from: input_file:org/deeplearning4j/rl4j/learning/Learning$InitMdp.class */
    public static final class InitMdp<O> {
        private final int steps;
        private final O lastObs;
        private final double reward;

        public int getSteps() {
            return this.steps;
        }

        public O getLastObs() {
            return this.lastObs;
        }

        public double getReward() {
            return this.reward;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof InitMdp)) {
                return false;
            }
            InitMdp initMdp = (InitMdp) obj;
            if (getSteps() != initMdp.getSteps()) {
                return false;
            }
            O lastObs = getLastObs();
            Object lastObs2 = initMdp.getLastObs();
            if (lastObs == null) {
                if (lastObs2 != null) {
                    return false;
                }
            } else if (!lastObs.equals(lastObs2)) {
                return false;
            }
            return Double.compare(getReward(), initMdp.getReward()) == 0;
        }

        public int hashCode() {
            int steps = (1 * 59) + getSteps();
            O lastObs = getLastObs();
            int hashCode = (steps * 59) + (lastObs == null ? 43 : lastObs.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getReward());
            return (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        }

        public String toString() {
            return "Learning.InitMdp(steps=" + getSteps() + ", lastObs=" + getLastObs() + ", reward=" + getReward() + ")";
        }

        public InitMdp(int i, O o, double d) {
            this.steps = i;
            this.lastObs = o;
            this.reward = d;
        }
    }

    public Learning(ILearning.LConfiguration lConfiguration) {
        this.random = new Random(lConfiguration.getSeed());
    }

    public static Integer getMaxAction(INDArray iNDArray) {
        return Integer.valueOf(Nd4j.argMax(iNDArray, new int[]{Integer.MAX_VALUE}).getInt(new int[]{0}));
    }

    public static <O extends Encodable, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O o) {
        INDArray create = Nd4j.create(o.toArray());
        int[] shape = mdp.getObservationSpace().getShape();
        return shape.length == 1 ? create.reshape(new long[]{1, create.length()}) : create.reshape(shape);
    }

    public static <O extends Encodable, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp, IHistoryProcessor iHistoryProcessor) {
        Encodable encodable = (Encodable) mdp.reset();
        Encodable encodable2 = encodable;
        int i = 0;
        double d = 0.0d;
        boolean z = iHistoryProcessor != null;
        int skipFrame = z ? iHistoryProcessor.getConf().getSkipFrame() : 1;
        int historyLength = z ? skipFrame * (iHistoryProcessor.getConf().getHistoryLength() - 1) : 0;
        while (i < historyLength) {
            INDArray input = getInput(mdp, encodable);
            if (z) {
                iHistoryProcessor.record(input);
            }
            Object noOp = mdp.getActionSpace().noOp();
            if (i % skipFrame == 0 && z) {
                iHistoryProcessor.add(input);
            }
            StepReply step = mdp.step(noOp);
            d += step.getReward();
            encodable2 = (Encodable) step.getObservation();
            i++;
        }
        return new InitMdp<>(i, encodable2, d);
    }

    public static int[] makeShape(int i, int[] iArr) {
        int[] iArr2 = new int[iArr.length + 1];
        iArr2[0] = i;
        System.arraycopy(iArr, 0, iArr2, 1, iArr.length);
        return iArr2;
    }

    public static int[] makeShape(int i, int[] iArr, int i2) {
        int[] iArr2 = {i, 1};
        for (int i3 : iArr) {
            iArr2[1] = iArr2[1] * i3;
        }
        iArr2[2] = i2;
        return iArr2;
    }

    @Override // org.deeplearning4j.rl4j.learning.NeuralNetFetchable
    public abstract NN getNeuralNet();

    public int incrementStep() {
        int i = this.stepCounter;
        this.stepCounter = i + 1;
        return i;
    }

    public int incrementEpoch() {
        int i = this.epochCounter;
        this.epochCounter = i + 1;
        return i;
    }

    public void setHistoryProcessor(IHistoryProcessor.Configuration configuration) {
        this.historyProcessor = new HistoryProcessor(configuration);
    }

    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        this.historyProcessor = iHistoryProcessor;
    }

    public INDArray getInput(O o) {
        return getInput(getMdp(), o);
    }

    public InitMdp<O> initMdp() {
        getNeuralNet().reset();
        return initMdp(getMdp(), getHistoryProcessor());
    }

    public Random getRandom() {
        return this.random;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning, org.deeplearning4j.rl4j.learning.StepCountable
    public int getStepCounter() {
        return this.stepCounter;
    }

    public void setStepCounter(int i) {
        this.stepCounter = i;
    }

    public int getEpochCounter() {
        return this.epochCounter;
    }

    public void setEpochCounter(int i) {
        this.epochCounter = i;
    }

    public IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor;
    }
}
