package org.deeplearning4j.rl4j.learning.async.a3c.discrete;

import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.class */
public class A3CThreadDiscrete<OBSERVATION extends Encodable> extends AsyncThreadDiscrete<OBSERVATION, IActorCritic> {
    protected final A3CLearningConfiguration configuration;
    protected final IAsyncGlobal<IActorCritic> asyncGlobal;
    protected final int threadNumber;
    private final Random rnd;

    public A3CThreadDiscrete(MDP<OBSERVATION, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> iAsyncGlobal, A3CLearningConfiguration a3CLearningConfiguration, int i, TrainingListenerList trainingListenerList, int i2) {
        super(iAsyncGlobal, mdp, trainingListenerList, i2, i);
        this.configuration = a3CLearningConfiguration;
        this.asyncGlobal = iAsyncGlobal;
        this.threadNumber = i2;
        Long seed = this.configuration.getSeed();
        this.rnd = Nd4j.getRandom();
        if (seed != null) {
            this.rnd.setSeed(seed.longValue() + i2);
        }
        setUpdateAlgorithm(buildUpdateAlgorithm());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public Policy<Integer> getPolicy(IActorCritic iActorCritic) {
        return new ACPolicy(iActorCritic, this.rnd);
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete
    protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
        return new AdvantageActorCriticUpdateAlgorithm(this.asyncGlobal.getTarget().isRecurrent(), getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(), getMdp().getActionSpace().getSize(), this.configuration.getGamma());
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public A3CLearningConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public IAsyncGlobal<IActorCritic> getAsyncGlobal() {
        return this.asyncGlobal;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public int getThreadNumber() {
        return this.threadNumber;
    }
}
