package org.deeplearning4j.rl4j.policy;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/Policy.class */
public abstract class Policy<O extends Encodable, A> {
    public abstract NeuralNet getNeuralNet();

    public abstract A nextAction(INDArray iNDArray);

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
        return play(mdp, (IHistoryProcessor) null);
    }

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor.Configuration configuration) {
        return play(mdp, new HistoryProcessor(configuration));
    }

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor iHistoryProcessor) {
        Object nextAction;
        getNeuralNet().reset();
        Learning.InitMdp initMdp = Learning.initMdp(mdp, iHistoryProcessor);
        Encodable encodable = (Encodable) initMdp.getLastObs();
        double reward = initMdp.getReward();
        Object noOp = mdp.getActionSpace().noOp();
        int steps = initMdp.getSteps();
        INDArray[] iNDArrayArr = null;
        while (!mdp.isDone()) {
            INDArray input = Learning.getInput(mdp, encodable);
            boolean z = iHistoryProcessor != null;
            if (z) {
                iHistoryProcessor.record(input);
            }
            if (steps % (z ? iHistoryProcessor.getConf().getSkipFrame() : 1) != 0) {
                nextAction = noOp;
            } else {
                if (iNDArrayArr == null) {
                    if (z) {
                        iHistoryProcessor.add(input);
                        iNDArrayArr = iHistoryProcessor.getHistory();
                    } else {
                        iNDArrayArr = new INDArray[]{input};
                    }
                }
                INDArray concat = Transition.concat(iNDArrayArr);
                if (z) {
                    concat.muli(Double.valueOf(1.0d / iHistoryProcessor.getScale()));
                }
                if (getNeuralNet().isRecurrent()) {
                    concat = concat.reshape(Learning.makeShape(1, concat.shape(), 1));
                } else if (concat.shape().length > 2) {
                    concat = concat.reshape(Learning.makeShape(1, concat.shape()));
                }
                nextAction = nextAction(concat);
            }
            noOp = nextAction;
            StepReply step = mdp.step(nextAction);
            reward += step.getReward();
            if (z) {
                iHistoryProcessor.add(Learning.getInput(mdp, (Encodable) step.getObservation()));
            }
            iNDArrayArr = z ? iHistoryProcessor.getHistory() : new INDArray[]{Learning.getInput(mdp, (Encodable) step.getObservation())};
            steps++;
        }
        return reward;
    }
}
