package org.deeplearning4j.rl4j.learning.async.nstep.discrete;

import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.class */
public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IDQN> {
    protected final AsyncQLearningConfiguration configuration;
    protected final IAsyncGlobal<IDQN> asyncGlobal;
    protected final int threadNumber;
    private final Random rnd;

    public AsyncNStepQLearningThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> iAsyncGlobal, AsyncQLearningConfiguration asyncQLearningConfiguration, TrainingListenerList trainingListenerList, int i, int i2) {
        super(iAsyncGlobal, mdp, trainingListenerList, i, i2);
        this.configuration = asyncQLearningConfiguration;
        this.asyncGlobal = iAsyncGlobal;
        this.threadNumber = i;
        this.rnd = Nd4j.getRandom();
        Long seed = asyncQLearningConfiguration.getSeed();
        if (seed != null) {
            this.rnd.setSeed(seed.longValue() + i);
        }
        setUpdateAlgorithm(buildUpdateAlgorithm());
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public Policy<Integer> getPolicy(IDQN idqn) {
        return new EpsGreedy(new DQNPolicy(idqn), getMdp(), this.configuration.getUpdateStart(), this.configuration.getEpsilonNbStep(), this.rnd, this.configuration.getMinEpsilon(), this);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete
    protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
        return new QLearningUpdateAlgorithm(getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(), getMdp().getActionSpace().getSize(), this.configuration.getGamma());
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncQLearningConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public IAsyncGlobal<IDQN> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public int getThreadNumber() {
        return this.threadNumber;
    }
}
