package org.deeplearning4j.rl4j.policy;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;

/* loaded from: input_file:org/deeplearning4j/rl4j/policy/Policy.class */
public abstract class Policy<O, A> implements IPolicy<O, A> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/rl4j/policy/Policy$RefacStepCountable.class */
    public class RefacStepCountable implements StepCountable {
        private int stepCounter;

        private RefacStepCountable() {
            this.stepCounter = 0;
        }

        public void increment() {
            this.stepCounter++;
        }

        @Override // org.deeplearning4j.rl4j.learning.StepCountable
        public int getStepCounter() {
            return 0;
        }

        public void setStepCounter(int i) {
            this.stepCounter = i;
        }
    }

    public abstract NeuralNet getNeuralNet();

    @Override // org.deeplearning4j.rl4j.policy.IPolicy
    public abstract A nextAction(Observation observation);

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
        return play(mdp, (IHistoryProcessor) null);
    }

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor.Configuration configuration) {
        return play(mdp, new HistoryProcessor(configuration));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.rl4j.policy.IPolicy
    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor iHistoryProcessor) {
        RefacStepCountable refacStepCountable = new RefacStepCountable();
        LegacyMDPWrapper<O, A, AS> legacyMDPWrapper = new LegacyMDPWrapper<>(mdp, iHistoryProcessor, refacStepCountable);
        int skipFrame = iHistoryProcessor != null ? iHistoryProcessor.getConf().getSkipFrame() : 1;
        Learning.InitMdp<Observation> refacInitMdp = refacInitMdp(legacyMDPWrapper, iHistoryProcessor);
        Observation lastObs = refacInitMdp.getLastObs();
        double reward = refacInitMdp.getReward();
        A noOp = legacyMDPWrapper.getActionSpace().noOp();
        refacStepCountable.setStepCounter(refacInitMdp.getSteps());
        while (!legacyMDPWrapper.isDone()) {
            A nextAction = refacStepCountable.getStepCounter() % skipFrame != 0 ? noOp : nextAction(lastObs);
            noOp = nextAction;
            StepReply<Observation> step = legacyMDPWrapper.step(nextAction);
            reward += step.getReward();
            lastObs = (Observation) step.getObservation();
            refacStepCountable.increment();
        }
        return reward;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> legacyMDPWrapper, IHistoryProcessor iHistoryProcessor) {
        getNeuralNet().reset();
        Observation m27reset = legacyMDPWrapper.m27reset();
        int i = 0;
        double d = 0.0d;
        boolean z = iHistoryProcessor != null;
        int skipFrame = z ? (z ? iHistoryProcessor.getConf().getSkipFrame() : 1) * (iHistoryProcessor.getConf().getHistoryLength() - 1) : 0;
        while (i < skipFrame && !legacyMDPWrapper.isDone()) {
            StepReply<Observation> step = legacyMDPWrapper.step(legacyMDPWrapper.getActionSpace().noOp());
            d += step.getReward();
            m27reset = (Observation) step.getObservation();
            i++;
        }
        return new Learning.InitMdp<>(i, m27reset, d);
    }
}
