package org.deeplearning4j.rl4j.learning.async;

import java.util.Stack;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.class */
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet> extends AsyncThread<O, Integer, DiscreteSpace, NN> {
    private NN current;

    public AsyncThreadDiscrete(IAsyncGlobal<NN> iAsyncGlobal, int i, int i2) {
        super(iAsyncGlobal, i, i2);
        synchronized (iAsyncGlobal) {
            this.current = (NN) iAsyncGlobal.getCurrent().m18clone();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v83, types: [org.deeplearning4j.rl4j.space.Encodable] */
    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public AsyncThread.SubEpochReturn<O> trainSubEpoch(O o, int i) {
        INDArray[] outputAll;
        Integer nextAction;
        synchronized (getAsyncGlobal()) {
            this.current.copy(getAsyncGlobal().getCurrent());
        }
        Stack<MiniTrans<Integer>> stack = new Stack<>();
        O o2 = o;
        Policy<O, Integer> policy = getPolicy(this.current);
        Integer num = null;
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        int skipFrame = historyProcessor != null ? historyProcessor.getConf().getSkipFrame() : 1;
        double d = 0.0d;
        double d2 = 0.0d;
        int i2 = 0;
        while (!getMdp().isDone() && i2 < i * skipFrame) {
            INDArray input = Learning.getInput(getMdp(), o2);
            INDArray iNDArray = null;
            if (historyProcessor != null) {
                historyProcessor.record(input);
            }
            if (i2 % skipFrame == 0 || num == null) {
                iNDArray = processHistory(input);
                nextAction = policy.nextAction(iNDArray);
            } else {
                nextAction = num;
            }
            StepReply step = getMdp().step(nextAction);
            d2 += step.getReward() * getConf().getRewardFactor();
            if (i2 % skipFrame == 0 || num == null || step.isDone()) {
                o2 = (Encodable) step.getObservation();
                if (iNDArray == null) {
                    iNDArray = processHistory(input);
                }
                stack.add(new MiniTrans<>(iNDArray, nextAction, this.current.outputAll(iNDArray), d2));
                d2 = 0.0d;
            }
            d += step.getReward();
            i2++;
            num = nextAction;
        }
        INDArray input2 = Learning.getInput(getMdp(), o2);
        INDArray processHistory = processHistory(input2);
        if (historyProcessor != null) {
            historyProcessor.record(input2);
        }
        if (!getMdp().isDone() || i2 >= i * skipFrame) {
            if (getConf().getTargetDqnUpdateFreq() == -1) {
                outputAll = this.current.outputAll(processHistory);
            } else {
                synchronized (getAsyncGlobal()) {
                    outputAll = getAsyncGlobal().getTarget().outputAll(processHistory);
                }
            }
            stack.add(new MiniTrans<>(processHistory, null, outputAll, Nd4j.max(outputAll[0]).getDouble(0L)));
        } else {
            stack.add(new MiniTrans<>(processHistory, null, null, 0.0d));
        }
        getAsyncGlobal().enqueue(calcGradient(this.current, stack), Integer.valueOf(i2));
        return new AsyncThread.SubEpochReturn<>(i2, o2, d, this.current.getLatestScore());
    }

    protected INDArray processHistory(INDArray iNDArray) {
        INDArray[] iNDArrayArr;
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        if (historyProcessor != null) {
            historyProcessor.add(iNDArray);
            iNDArrayArr = historyProcessor.getHistory();
        } else {
            iNDArrayArr = new INDArray[]{iNDArray};
        }
        INDArray concat = Transition.concat(iNDArrayArr);
        if (historyProcessor != null) {
            concat.muli(Double.valueOf(1.0d / historyProcessor.getScale()));
        }
        if (getCurrent().isRecurrent()) {
            concat = concat.reshape(Learning.makeShape(1, ArrayUtil.toInts(concat.shape()), 1));
        } else if (concat.shape().length > 2) {
            concat = concat.reshape(Learning.makeShape(1, ArrayUtil.toInts(concat.shape())));
        }
        return concat;
    }

    public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> stack);

    @Override // org.deeplearning4j.rl4j.learning.async.AsyncThread
    public NN getCurrent() {
        return this.current;
    }
}
