package org.deeplearning4j.rl4j.learning.async;

import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncGlobal.class */
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
    private static final Logger log = LoggerFactory.getLogger(AsyncGlobal.class);
    private final NN current;
    private final AsyncConfiguration a3cc;
    private NN target;
    private AtomicInteger T = new AtomicInteger(0);
    private boolean running = true;
    private final ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue = new ConcurrentLinkedQueue<>();

    public AsyncGlobal(NN nn, AsyncConfiguration asyncConfiguration) {
        this.current = nn;
        this.target = (NN) nn.m22clone();
        this.a3cc = asyncConfiguration;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public boolean isTrainingComplete() {
        return this.T.get() >= this.a3cc.getMaxStep();
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public void enqueue(Gradient[] gradientArr, Integer num) {
        if (!this.running || isTrainingComplete()) {
            return;
        }
        this.queue.add(new Pair<>(gradientArr, num));
    }

    @Override // java.lang.Thread, java.lang.Runnable
    public void run() {
        while (!isTrainingComplete() && this.running) {
            if (!this.queue.isEmpty()) {
                Pair<Gradient[], Integer> poll = this.queue.poll();
                this.T.addAndGet(((Integer) poll.getSecond()).intValue());
                Gradient[] gradientArr = (Gradient[]) poll.getFirst();
                synchronized (this) {
                    this.current.applyGradient(gradientArr, ((Integer) poll.getSecond()).intValue());
                }
                if (this.a3cc.getTargetDqnUpdateFreq() != -1 && this.T.get() / this.a3cc.getTargetDqnUpdateFreq() > (this.T.get() - ((Integer) poll.getSecond()).intValue()) / this.a3cc.getTargetDqnUpdateFreq()) {
                    log.info("TARGET UPDATE at T = " + this.T.get());
                    synchronized (this) {
                        this.target.copy(this.current);
                    }
                }
            }
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public void terminate() {
        this.running = false;
        this.queue.clear();
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public NN getCurrent() {
        return this.current;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public AtomicInteger getT() {
        return this.T;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public NN getTarget() {
        return this.target;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public boolean isRunning() {
        return this.running;
    }
}
