package org.deeplearning4j.rl4j.network.dqn;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/dqn/DQN.class */
public class DQN<NN extends DQN> implements IDQN<NN> {
    protected final MultiLayerNetwork mln;
    int i = 0;

    public DQN(MultiLayerNetwork multiLayerNetwork) {
        this.mln = multiLayerNetwork;
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.mln};
    }

    public static DQN load(String str) throws IOException {
        return new DQN(ModelSerializer.restoreMultiLayerNetwork(str));
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public boolean isRecurrent() {
        return false;
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public void reset() {
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        this.mln.fit(iNDArray, iNDArray2);
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public void fit(INDArray iNDArray, INDArray[] iNDArrayArr) {
        fit(iNDArray, iNDArrayArr[0]);
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN
    public INDArray output(INDArray iNDArray) {
        return this.mln.output(iNDArray);
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN
    public INDArray output(Observation observation) {
        return output(observation.getData());
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public INDArray[] outputAll(INDArray iNDArray) {
        return new INDArray[]{output(iNDArray)};
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NN m27clone() {
        NN nn = (NN) new DQN(this.mln.clone());
        nn.mln.setListeners(this.mln.getListeners());
        return nn;
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public void copy(NN nn) {
        this.mln.setParams(nn.mln.params());
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN
    public Gradient[] gradient(INDArray iNDArray, INDArray iNDArray2) {
        this.mln.setInput(iNDArray);
        this.mln.setLabels(iNDArray2);
        this.mln.computeGradientAndScore();
        Collection listeners = this.mln.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).onGradientCalculation(this.mln);
            }
        }
        return new Gradient[]{this.mln.gradient()};
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public Gradient[] gradient(INDArray iNDArray, INDArray[] iNDArrayArr) {
        return gradient(iNDArray, iNDArrayArr[0]);
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public void applyGradient(Gradient[] gradientArr, int i) {
        MultiLayerConfiguration layerWiseConfigurations = this.mln.getLayerWiseConfigurations();
        int iterationCount = layerWiseConfigurations.getIterationCount();
        int epochCount = layerWiseConfigurations.getEpochCount();
        this.mln.getUpdater().update(this.mln, gradientArr[0], iterationCount, epochCount, i, LayerWorkspaceMgr.noWorkspaces());
        this.mln.params().subi(gradientArr[0].gradient());
        Collection listeners = this.mln.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).iterationDone(this.mln, iterationCount, epochCount);
            }
        }
        layerWiseConfigurations.setIterationCount(iterationCount + 1);
    }

    @Override // org.deeplearning4j.rl4j.network.dqn.IDQN, org.deeplearning4j.rl4j.network.NeuralNet
    public double getLatestScore() {
        return this.mln.score();
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(OutputStream outputStream) throws IOException {
        ModelSerializer.writeModel(this.mln, outputStream, true);
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(String str) throws IOException {
        ModelSerializer.writeModel(this.mln, str, true);
    }
}
