package org.deeplearning4j.rl4j.learning.async;

import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/async/AsyncGlobal.class */
public class AsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
    private static final Logger log = LoggerFactory.getLogger(AsyncGlobal.class);
    private final NN current;
    private NN target;
    private final IAsyncLearningConfiguration configuration;
    private final Lock updateLock = new ReentrantLock();
    private int workerUpdateCount;
    private int stepCount;

    public AsyncGlobal(NN nn, IAsyncLearningConfiguration iAsyncLearningConfiguration) {
        this.current = nn;
        this.target = (NN) nn.m25clone();
        this.configuration = iAsyncLearningConfiguration;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public boolean isTrainingComplete() {
        return this.stepCount >= this.configuration.getMaxStep();
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public void applyGradient(Gradient[] gradientArr, int i) {
        if (isTrainingComplete()) {
            return;
        }
        try {
            this.updateLock.lock();
            this.current.applyGradient(gradientArr, i);
            this.stepCount += i;
            this.workerUpdateCount++;
            int learnerUpdateFrequency = this.configuration.getLearnerUpdateFrequency();
            if (learnerUpdateFrequency == -1 || this.workerUpdateCount % learnerUpdateFrequency != 0) {
                this.target.copy(this.current);
            } else {
                log.info("Updating target network at updates={} steps={}", Integer.valueOf(this.workerUpdateCount), Integer.valueOf(this.stepCount));
            }
        } finally {
            this.updateLock.unlock();
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public NN getTarget() {
        try {
            this.updateLock.lock();
            return this.target;
        } finally {
            this.updateLock.unlock();
        }
    }

    public Lock getUpdateLock() {
        return this.updateLock;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public int getWorkerUpdateCount() {
        return this.workerUpdateCount;
    }

    @Override // org.deeplearning4j.rl4j.learning.async.IAsyncGlobal
    public int getStepCount() {
        return this.stepCount;
    }
}
