package org.deeplearning4j.rl4j.network.ac;

import java.io.IOException;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.network.CommonGradientNames;
import org.deeplearning4j.rl4j.network.NeuralNetOutput;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.class */
public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph> {
    protected final ComputationGraph cg;
    protected final boolean recurrent;

    public ActorCriticCompGraph(ComputationGraph computationGraph) {
        this.cg = computationGraph;
        this.recurrent = computationGraph.getOutputLayer(0) instanceof RnnOutputLayer;
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public NeuralNetwork[] getNeuralNetworks() {
        return new NeuralNetwork[]{this.cg};
    }

    public static ActorCriticCompGraph load(String str) throws IOException {
        return new ActorCriticCompGraph(ModelSerializer.restoreComputationGraph(str));
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void fit(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.cg.fit(new INDArray[]{iNDArray}, iNDArrayArr);
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet, org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public void reset() {
        if (this.recurrent) {
            this.cg.rnnClearPreviousState();
        }
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public INDArray[] outputAll(INDArray iNDArray) {
        return this.recurrent ? this.cg.rnnTimeStep(new INDArray[]{iNDArray}) : this.cg.output(new INDArray[]{iNDArray});
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ActorCriticCompGraph m30clone() {
        ActorCriticCompGraph actorCriticCompGraph = new ActorCriticCompGraph(this.cg.clone());
        actorCriticCompGraph.cg.setListeners(this.cg.getListeners());
        return actorCriticCompGraph;
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void fit(FeaturesLabels featuresLabels) {
        this.cg.fit(new INDArray[]{featuresLabels.getFeatures()}, new INDArray[]{featuresLabels.getLabels("value"), featuresLabels.getLabels("policy")});
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public Gradients computeGradients(FeaturesLabels featuresLabels) {
        this.cg.setInput(0, featuresLabels.getFeatures());
        this.cg.setLabels(new INDArray[]{featuresLabels.getLabels("value"), featuresLabels.getLabels("policy")});
        this.cg.computeGradientAndScore();
        Collection listeners = this.cg.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).onGradientCalculation(this.cg);
            }
        }
        Gradients gradients = new Gradients(featuresLabels.getBatchSize());
        gradients.putGradient(CommonGradientNames.ActorCritic.Combined, this.cg.gradient());
        return gradients;
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void applyGradients(Gradients gradients) {
        ComputationGraphConfiguration configuration = this.cg.getConfiguration();
        int iterationCount = configuration.getIterationCount();
        int epochCount = configuration.getEpochCount();
        Gradient gradient = gradients.getGradient(CommonGradientNames.ActorCritic.Combined);
        this.cg.getUpdater().update(gradient, iterationCount, epochCount, (int) gradients.getBatchSize(), LayerWorkspaceMgr.noWorkspaces());
        this.cg.params().subi(gradient.gradient());
        Collection listeners = this.cg.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).iterationDone(this.cg, iterationCount, epochCount);
            }
        }
        configuration.setIterationCount(iterationCount + 1);
    }

    @Override // org.deeplearning4j.rl4j.network.ITrainableNeuralNet
    public void copyFrom(ActorCriticCompGraph actorCriticCompGraph) {
        this.cg.setParams(actorCriticCompGraph.cg.params());
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    @Deprecated
    public Gradient[] gradient(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.cg.setInput(0, iNDArray);
        this.cg.setLabels(iNDArrayArr);
        this.cg.computeGradientAndScore();
        Collection listeners = this.cg.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).onGradientCalculation(this.cg);
            }
        }
        return new Gradient[]{this.cg.gradient()};
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void applyGradient(Gradient[] gradientArr, int i) {
        if (this.recurrent) {
            i = 1;
        }
        ComputationGraphConfiguration configuration = this.cg.getConfiguration();
        int iterationCount = configuration.getIterationCount();
        int epochCount = configuration.getEpochCount();
        this.cg.getUpdater().update(gradientArr[0], iterationCount, epochCount, i, LayerWorkspaceMgr.noWorkspaces());
        this.cg.params().subi(gradientArr[0].gradient());
        Collection listeners = this.cg.getListeners();
        if (listeners != null && listeners.size() > 0) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).iterationDone(this.cg, iterationCount, epochCount);
            }
        }
        configuration.setIterationCount(iterationCount + 1);
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public double getLatestScore() {
        return this.cg.score();
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(OutputStream outputStream) throws IOException {
        ModelSerializer.writeModel(this.cg, outputStream, true);
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    public void save(String str) throws IOException {
        ModelSerializer.writeModel(this.cg, str, true);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic
    public void save(OutputStream outputStream, OutputStream outputStream2) throws IOException {
        throw new UnsupportedOperationException("Call save(stream)");
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic
    public void save(String str, String str2) throws IOException {
        throw new UnsupportedOperationException("Call save(path)");
    }

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public NeuralNetOutput output(Observation observation) {
        if (!isRecurrent()) {
            return output(observation.getData());
        }
        INDArray[] rnnTimeStep = this.cg.rnnTimeStep(new INDArray[]{observation.getData()});
        return packageResult(rnnTimeStep[0], rnnTimeStep[1]);
    }

    @Override // org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public NeuralNetOutput output(INDArray iNDArray) {
        INDArray[] output = this.cg.output(new INDArray[]{iNDArray});
        return packageResult(output[0], output[1]);
    }

    private NeuralNetOutput packageResult(INDArray iNDArray, INDArray iNDArray2) {
        NeuralNetOutput neuralNetOutput = new NeuralNetOutput();
        neuralNetOutput.put("value", iNDArray);
        neuralNetOutput.put("policy", iNDArray2);
        return neuralNetOutput;
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet, org.deeplearning4j.rl4j.network.IOutputNeuralNet
    public boolean isRecurrent() {
        return this.recurrent;
    }
}
