package org.deeplearning4j.rl4j.learning.async;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.class */
public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN extends NeuralNet> extends AsyncThread<OBSERVATION, Integer, DiscreteSpace, NN> {
    private NN current;
    private UpdateAlgorithm<NN> updateAlgorithm;
    private ExperienceHandler experienceHandler;

    public AsyncThreadDiscrete(IAsyncGlobal<NN> iAsyncGlobal, MDP<OBSERVATION, Integer, DiscreteSpace> mdp, TrainingListenerList trainingListenerList, int i, int i2) {
        super(mdp, trainingListenerList, i, i2);
        this.experienceHandler = new StateActionExperienceHandler();
        synchronized (iAsyncGlobal) {
            this.current = (NN) iAsyncGlobal.getTarget().m26clone();
        }
    }

    protected abstract UpdateAlgorithm<NN> buildUpdateAlgorithm();

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        super.setHistoryProcessor(iHistoryProcessor);
        this.updateAlgorithm = buildUpdateAlgorithm();
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    protected void preEpisode() {
        this.experienceHandler.reset();
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncThread.SubEpochReturn trainSubEpoch(Observation observation, int i) {
        this.current.copy(getAsyncGlobal().getTarget());
        Observation observation2 = observation;
        IPolicy<Integer> policy = getPolicy(this.current);
        Integer noOp = getMdp().getActionSpace().noOp();
        double d = 0.0d;
        double d2 = 0.0d;
        while (!getMdp().isDone() && this.experienceHandler.getTrainingBatchSize() != i) {
            if (!observation2.isSkipped()) {
                noOp = policy.nextAction(observation2);
            }
            StepReply<Observation> step = getLegacyMDPWrapper().step(noOp);
            d2 += step.getReward() * getConfiguration().getRewardFactor();
            if (!observation2.isSkipped()) {
                this.experienceHandler.addExperience(observation2, noOp, d2, step.isDone());
                d2 = 0.0d;
                incrementSteps();
            }
            observation2 = (Observation) step.getObservation();
            d += step.getReward();
        }
        boolean z = getMdp().isDone() || getConfiguration().getMaxEpochStep() == this.currentEpisodeStepCount;
        if (z && this.experienceHandler.getTrainingBatchSize() != i) {
            this.experienceHandler.setFinalObservation(observation2);
        }
        int trainingBatchSize = this.experienceHandler.getTrainingBatchSize();
        getAsyncGlobal().applyGradient(this.updateAlgorithm.computeGradients(this.current, this.experienceHandler.generateTrainingBatch()), trainingBatchSize);
        return new AsyncThread.SubEpochReturn(trainingBatchSize, observation2, d, this.current.getLatestScore(), z);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public NN getCurrent() {
        return this.current;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setUpdateAlgorithm(UpdateAlgorithm<NN> updateAlgorithm) {
        this.updateAlgorithm = updateAlgorithm;
    }

    protected void setExperienceHandler(ExperienceHandler experienceHandler) {
        this.experienceHandler = experienceHandler;
    }

    public ExperienceHandler getExperienceHandler() {
        return this.experienceHandler;
    }
}
