package org.deeplearning4j.nn.graph;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.FwdPassType;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.OutputAdapter;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
import org.deeplearning4j.nn.graph.util.GraphIndices;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.util.CrashReportingUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.NetworkUtils;
import org.deeplearning4j.util.OutputLayerUtil;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.memory.abstracts.DummyWorkspace;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.workspace.ND4JWorkspaceException;
import org.nd4j.linalg.workspace.WorkspaceUtils;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/graph/ComputationGraph.class */
public class ComputationGraph implements Serializable, Model, NeuralNetwork {
    protected ComputationGraphConfiguration configuration;
    protected transient Solver solver;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected Gradient gradient;
    protected double score;
    protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM";
    protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT";
    protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM";
    protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM";
    protected final WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG;
    protected final WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG;
    protected GraphVertex[] vertices;
    protected Map<String, GraphVertex> verticesMap;
    protected int[] topologicalOrder;
    protected GraphIndices graphIndices;
    protected Layer[] layers;
    private int numInputArrays;
    private int numOutputArrays;
    private transient INDArray[] inputs;
    private transient INDArray[] labels;
    private transient INDArray[] inputMaskArrays;
    private transient INDArray[] labelMaskArrays;
    private transient int[] outputLayerIdxs;
    private NeuralNetConfiguration defaultConfiguration;
    private static final Logger log = LoggerFactory.getLogger(ComputationGraph.class);
    protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.05d).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.05d).policyReset(ResetPolicy.BLOCK_LEFT).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build();
    protected boolean initCalled = false;
    private boolean initDone = false;
    protected boolean clearTbpttState = true;
    protected transient Map<String, Pointer> helperWorkspaces = new HashMap();
    private final transient AtomicLong occupiedBy = new AtomicLong(-1);
    protected transient ThreadLocal<Long> lastEtlTime = new ThreadLocal<>();
    private Collection<TrainingListener> trainingListeners = new ArrayList();

    public ComputationGraph(ComputationGraphConfiguration computationGraphConfiguration) {
        this.configuration = computationGraphConfiguration;
        this.numInputArrays = computationGraphConfiguration.getNetworkInputs().size();
        this.numOutputArrays = computationGraphConfiguration.getNetworkOutputs().size();
        this.inputs = new INDArray[this.numInputArrays];
        this.labels = new INDArray[this.numOutputArrays];
        this.defaultConfiguration = computationGraphConfiguration.getDefaultConfiguration();
        this.WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02d).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(2 * computationGraphConfiguration.getVertices().size()).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
        this.WS_LAYER_ACT_X_CONFIG = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.02d).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(computationGraphConfiguration.getVertices().size()).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    }

    public void setLastEtlTime(long j) {
        this.lastEtlTime.set(Long.valueOf(j));
    }

    public long getLastEtlTime() {
        Long l = this.lastEtlTime.get();
        if (l == null) {
            return 0L;
        }
        return l.longValue();
    }

    public void setCacheMode(CacheMode cacheMode) {
        if (cacheMode == null) {
            cacheMode = CacheMode.NONE;
        }
        for (Layer layer : this.layers) {
            layer.setCacheMode(cacheMode);
        }
    }

    public ComputationGraphConfiguration getConfiguration() {
        return this.configuration;
    }

    public int getNumLayers() {
        if (this.layers != null) {
            return this.layers.length;
        }
        return 0;
    }

    public Layer getLayer(int i) {
        return this.layers[i];
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(String str) {
        Preconditions.checkState(this.verticesMap.containsKey(str), "Layer with name %s does not exist in the network", str);
        return this.verticesMap.get(str).getLayer();
    }

    public GraphVertex[] getVertices() {
        return this.vertices;
    }

    public GraphVertex getVertex(String str) {
        return this.verticesMap.get(str);
    }

    public int getNumInputArrays() {
        return this.numInputArrays;
    }

    public int getNumOutputArrays() {
        return this.numOutputArrays;
    }

    public void setInput(int i, INDArray iNDArray) {
        if (this.inputs == null) {
            this.inputs = new INDArray[this.numInputArrays];
        }
        this.inputs[i] = iNDArray;
    }

    public void setInputs(INDArray... iNDArrayArr) {
        if (iNDArrayArr != null && iNDArrayArr.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid input array: network has " + this.numInputArrays + " inputs, but array is of length " + iNDArrayArr.length);
        }
        this.inputs = iNDArrayArr;
    }

    public INDArray getInput(int i) {
        if (this.inputs == null) {
            return null;
        }
        return this.inputs[i];
    }

    public INDArray[] getInputs() {
        return this.inputs;
    }

    public INDArray[] getInputMaskArrays() {
        return this.inputMaskArrays;
    }

    public INDArray[] getLabelMaskArrays() {
        return this.labelMaskArrays;
    }

    public void setLabel(int i, INDArray iNDArray) {
        this.labels[i] = iNDArray;
    }

    public void setLabels(INDArray... iNDArrayArr) {
        if (iNDArrayArr != null && iNDArrayArr.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid output array: network has " + this.numOutputArrays + " outputs, but array is of length " + iNDArrayArr.length);
        }
        this.labels = iNDArrayArr;
    }

    public void setGradientsAccumulator(GradientsAccumulator gradientsAccumulator) {
        if (!this.initCalled) {
            init();
        }
        this.solver.getOptimizer().setGradientsAccumulator(gradientsAccumulator);
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public void init() {
        init(null, false);
    }

    public void init(INDArray iNDArray, boolean z) {
        boolean z2;
        if (this.initCalled) {
            return;
        }
        if (this.configuration.getTrainingWorkspaceMode() == null) {
            this.configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.configuration.getInferenceWorkspaceMode() == null) {
            this.configuration.setInferenceWorkspaceMode(WorkspaceMode.NONE);
        }
        if (this.configuration.getCacheMode() == null) {
            this.configuration.setCacheMode(CacheMode.NONE);
        }
        OneTimeLogger.info(log, "Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", new Object[]{this.configuration.getTrainingWorkspaceMode(), this.configuration.getInferenceWorkspaceMode(), this.configuration.getCacheMode()});
        GraphIndices calculateIndices = calculateIndices();
        this.topologicalOrder = calculateIndices.getTopologicalSortOrder();
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.configuration.getVertices();
        List<String> networkInputs = this.configuration.getNetworkInputs();
        Map<String, List<String>> vertexInputs = this.configuration.getVertexInputs();
        this.vertices = new GraphVertex[networkInputs.size() + this.configuration.getVertices().size()];
        HashMap hashMap = new HashMap();
        int i = 0;
        for (String str : networkInputs) {
            InputVertex inputVertex = new InputVertex(this, str, i, null);
            hashMap.put(str, Integer.valueOf(i));
            int i2 = i;
            i++;
            this.vertices[i2] = inputVertex;
        }
        long j = 0;
        long[] jArr = new long[this.topologicalOrder.length];
        int i3 = 0;
        while (i3 < this.configuration.getNetworkInputs().size()) {
            jArr[i3] = 0;
            i3++;
        }
        while (i3 < this.topologicalOrder.length) {
            jArr[i3] = vertices.get(calculateIndices.getIdxToName().get(Integer.valueOf(i3))).numParams(true);
            j += jArr[i3];
            i3++;
        }
        if (iNDArray != null) {
            if (!iNDArray.isRowVectorOrScalar()) {
                throw new IllegalArgumentException("Invalid parameters: should be a row vector");
            }
            if (iNDArray.length() != j) {
                throw new IllegalArgumentException("Invalid parameters: expected length " + j + ", got length " + iNDArray.length());
            }
            if (z) {
                this.flattenedParams = iNDArray.dup();
            } else {
                this.flattenedParams = iNDArray;
            }
            z2 = false;
        } else if (j > 0) {
            this.flattenedParams = Nd4j.create(new long[]{1, j});
            z2 = true;
        } else {
            this.flattenedParams = null;
            z2 = false;
        }
        if (z2) {
            Nd4j.getRandom().setSeed(conf().getSeed());
        }
        INDArray[] iNDArrayArr = new INDArray[this.topologicalOrder.length];
        long j2 = 0;
        int i4 = 0;
        for (int i5 : this.topologicalOrder) {
            long j3 = jArr[i5];
            if (j3 != 0) {
                iNDArrayArr[i5] = this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(j2, j2 + j3)});
            }
            i4++;
            j2 += j3;
        }
        int i6 = 0;
        ArrayList arrayList = new ArrayList();
        this.defaultConfiguration.clearVariables();
        List<String> variables = this.defaultConfiguration.variables(false);
        for (int size = this.configuration.getNetworkInputs().size(); size < this.topologicalOrder.length; size++) {
            String str2 = calculateIndices.getIdxToName().get(Integer.valueOf(size));
            org.deeplearning4j.nn.conf.graph.GraphVertex graphVertex = vertices.get(str2);
            GraphVertex instantiate = graphVertex.instantiate(this, str2, i, iNDArrayArr[i], z2);
            if (instantiate == null) {
                throw new IllegalStateException("Encountered null layer/vertex during initialization for layer \"" + str2 + "\": " + graphVertex.getClass().getSimpleName() + " initialization returned null layer/vertex?");
            }
            if (instantiate.hasLayer()) {
                i6++;
                Layer layer = instantiate.getLayer();
                arrayList.add(layer);
                List<String> variables2 = layer.conf().variables();
                if (variables2 != null) {
                    Iterator<String> it = variables2.iterator();
                    while (it.hasNext()) {
                        variables.add(instantiate.getVertexName() + "_" + it.next());
                    }
                }
            }
            hashMap.put(str2, Integer.valueOf(i));
            int i7 = i;
            i++;
            this.vertices[i7] = instantiate;
        }
        this.layers = (Layer[]) arrayList.toArray(new Layer[i6]);
        this.verticesMap = new HashMap();
        for (GraphVertex graphVertex2 : this.vertices) {
            this.verticesMap.put(graphVertex2.getVertexName(), graphVertex2);
        }
        HashMap hashMap2 = new HashMap();
        for (GraphVertex graphVertex3 : this.vertices) {
            String vertexName = graphVertex3.getVertexName();
            List<String> list = vertexInputs.get(vertexName);
            if (list != null) {
                for (String str3 : list) {
                    List list2 = (List) hashMap2.get(str3);
                    if (list2 == null) {
                        list2 = new ArrayList();
                        hashMap2.put(str3, list2);
                    }
                    list2.add(vertexName);
                }
            }
        }
        for (GraphVertex graphVertex4 : this.vertices) {
            String vertexName2 = graphVertex4.getVertexName();
            int vertexIndex = graphVertex4.getVertexIndex();
            List<String> list3 = vertexInputs.get(vertexName2);
            if (list3 != null) {
                VertexIndices[] vertexIndicesArr = new VertexIndices[list3.size()];
                for (int i8 = 0; i8 < list3.size(); i8++) {
                    String str4 = list3.get(i8);
                    int intValue = ((Integer) hashMap.get(str4)).intValue();
                    int indexOf = vertexInputs.get(vertexName2).indexOf(str4);
                    if (indexOf == -1) {
                        throw new IllegalStateException("Could not find vertex " + vertexIndex + " in the list of inputs for vertex " + graphVertex4.getVertexName() + "; error in graph structure?");
                    }
                    vertexIndicesArr[i8] = new VertexIndices(intValue, indexOf);
                }
                graphVertex4.setInputVertices(vertexIndicesArr);
            }
        }
        for (GraphVertex graphVertex5 : this.vertices) {
            String vertexName3 = graphVertex5.getVertexName();
            List<String> list4 = (List) hashMap2.get(vertexName3);
            if (list4 != null && !list4.isEmpty()) {
                VertexIndices[] vertexIndicesArr2 = new VertexIndices[list4.size()];
                int i9 = 0;
                for (String str5 : list4) {
                    int i10 = i9;
                    i9++;
                    vertexIndicesArr2[i10] = new VertexIndices(((Integer) hashMap.get(str5)).intValue(), vertexInputs.get(str5).indexOf(vertexName3));
                }
                graphVertex5.setOutputVertices(vertexIndicesArr2);
            }
        }
        Iterator<String> it2 = this.configuration.getNetworkOutputs().iterator();
        while (it2.hasNext()) {
            this.verticesMap.get(it2.next()).setOutputVertex(true);
        }
        if (this.solver == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    this.solver.initOptimizer();
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        HashMap hashMap3 = new HashMap();
        for (Map.Entry<String, List<String>> entry : this.configuration.getVertexInputs().entrySet()) {
            for (String str6 : entry.getValue()) {
                if (!hashMap3.containsKey(str6)) {
                    hashMap3.put(str6, new ArrayList());
                }
                ((List) hashMap3.get(str6)).add(entry.getKey());
            }
        }
        for (Layer layer2 : this.layers) {
            String layerName = layer2.conf().getLayer().getLayerName();
            String str7 = this.configuration.getVertexInputs().get(layerName).get(0);
            if (!this.configuration.getNetworkInputs().contains(str7)) {
                List list5 = (List) hashMap3.get(str7);
                if (list5.size() == 1) {
                    layer2.allowInputModification(true);
                } else {
                    int indexOf2 = ArrayUtils.indexOf(calculateIndices.getTopologicalSortOrder(), calculateIndices.getNameToIdx().get(layerName).intValue());
                    int i11 = -1;
                    Iterator it3 = list5.iterator();
                    while (it3.hasNext()) {
                        i11 = Math.max(i11, ArrayUtils.indexOf(calculateIndices.getTopologicalSortOrder(), calculateIndices.getNameToIdx().get((String) it3.next()).intValue()));
                    }
                    if (indexOf2 == i11) {
                        layer2.allowInputModification(true);
                    }
                }
            }
        }
        synchronizeIterEpochCounts();
        this.initCalled = true;
    }

    public void initGradientsView() {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            if (!this.initCalled) {
                init();
            }
            GraphIndices calculateIndices = calculateIndices();
            long j = 0;
            long[] jArr = new long[this.topologicalOrder.length];
            int i = 0;
            while (i < this.configuration.getNetworkInputs().size()) {
                jArr[i] = 0;
                i++;
            }
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.configuration.getVertices();
            while (i < this.topologicalOrder.length) {
                jArr[i] = vertices.get(calculateIndices.getIdxToName().get(Integer.valueOf(i))).numParams(true);
                j += jArr[i];
                i++;
            }
            if (j > 0) {
                this.flattenedGradients = Nd4j.create(new long[]{1, j});
            }
            long j2 = 0;
            int i2 = 0;
            for (int i3 : this.topologicalOrder) {
                long j3 = jArr[i3];
                if (j3 != 0) {
                    this.vertices[i3].setBackpropGradientsViewArray(this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(j2, j2 + j3)}));
                }
                i2++;
                j2 += j3;
            }
            if (scopeOutOfWorkspaces != null) {
                if (0 == 0) {
                    scopeOutOfWorkspaces.close();
                    return;
                }
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    protected int[] getOutputLayerIndices() {
        if (this.outputLayerIdxs == null) {
            this.outputLayerIdxs = new int[this.numOutputArrays];
            int i = 0;
            Iterator<String> it = this.configuration.getNetworkOutputs().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                this.outputLayerIdxs[i2] = this.verticesMap.get(it.next()).getVertexIndex();
            }
        }
        return this.outputLayerIdxs;
    }

    public void pretrain(DataSetIterator dataSetIterator) {
        pretrain(dataSetIterator, 1);
    }

    public void pretrain(DataSetIterator dataSetIterator, int i) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        pretrain(ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator), i);
    }

    public void pretrain(MultiDataSetIterator multiDataSetIterator) {
        pretrain(multiDataSetIterator, 1);
    }

    public void pretrain(MultiDataSetIterator multiDataSetIterator, int i) {
        try {
            pretrainHelper(multiDataSetIterator, i);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainHelper(MultiDataSetIterator multiDataSetIterator, int i) {
        if (this.configuration.isPretrain()) {
            if (this.flattenedGradients == null) {
                initGradientsView();
            }
            for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
                if (this.vertices[i2].hasLayer() && !(this.vertices[i2].getLayer() instanceof IOutputLayer) && this.vertices[i2].getLayer().isPretrainLayer()) {
                    pretrainLayerHelper(this.vertices[i2].getVertexName(), multiDataSetIterator, i);
                }
            }
        }
    }

    public void pretrainLayer(String str, DataSetIterator dataSetIterator) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs using a DataSetIterator");
        }
        pretrainLayer(str, ComputationGraphUtil.toMultiDataSetIterator(dataSetIterator));
    }

    public void pretrainLayer(String str, MultiDataSetIterator multiDataSetIterator) {
        try {
            pretrainLayerHelper(str, multiDataSetIterator, 1);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private void pretrainLayerHelper(String str, MultiDataSetIterator multiDataSetIterator, int i) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (!this.verticesMap.containsKey(str)) {
            throw new IllegalStateException("Invalid vertex name: " + str + " - all vertex names: " + this.verticesMap.keySet());
        }
        if (this.verticesMap.get(str).hasLayer()) {
            GraphVertex graphVertex = this.verticesMap.get(str);
            int vertexIndex = graphVertex.getVertexIndex();
            LayerWorkspaceMgr noWorkspaces = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).build();
            noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
            if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
                multiDataSetIterator.reset();
            }
            MultiDataSetIterator asyncMultiDataSetIterator = multiDataSetIterator.asyncSupported() ? new AsyncMultiDataSetIterator(multiDataSetIterator) : multiDataSetIterator;
            while (asyncMultiDataSetIterator.hasNext()) {
                MultiDataSet multiDataSet = (MultiDataSet) asyncMultiDataSetIterator.next();
                MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.ACTIVATIONS);
                Throwable th = null;
                try {
                    try {
                        Map<String, INDArray> ffToLayerActivationsInWS = ffToLayerActivationsInWS(false, vertexIndex, new int[]{vertexIndex}, FwdPassType.STANDARD, false, multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays(), true);
                        for (VertexIndices vertexIndices : graphVertex.getInputVertices()) {
                            graphVertex.setInput(vertexIndices.getVertexEdgeNumber(), ffToLayerActivationsInWS.get(this.vertices[vertexIndices.getVertexIndex()].getVertexName()), noWorkspaces);
                        }
                        Layer layer = graphVertex.getLayer();
                        layer.getConfig().setPretrain(true);
                        layer.fit(layer.input(), noWorkspaces);
                        layer.getConfig().setPretrain(false);
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (notifyScopeEntered != null) {
                        if (th != null) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    throw th3;
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void fit(DataSet dataSet) {
        if (this.numInputArrays != 1 || this.numOutputArrays != 1) {
            throw new UnsupportedOperationException("Cannot train ComputationGraph network with  multiple inputs or outputs using a DataSet");
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()}, dataSet.getFeaturesMaskArray() != null ? new INDArray[]{dataSet.getFeaturesMaskArray()} : null, dataSet.getLabelsMaskArray() != null ? new INDArray[]{dataSet.getLabelsMaskArray()} : null);
        } else {
            fit(new INDArray[]{dataSet.getFeatures()}, new INDArray[]{dataSet.getLabels()});
        }
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
        clearLayersStates();
    }

    public void fit(@NonNull DataSetIterator dataSetIterator, int i) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        Preconditions.checkArgument(i > 0, "Number of epochs much be > 0. Got numEpochs = %s", i);
        Preconditions.checkArgument(i == 1 || dataSetIterator.resetSupported(), "Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i2 = 0; i2 < i; i2++) {
            fit(dataSetIterator);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void fit(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        fit(new MultiDataSetIteratorAdapter(dataSetIterator));
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void fit(MultiDataSet multiDataSet) {
        fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        if (multiDataSet.hasMaskArrays()) {
            clearLayerMaskArrays();
        }
    }

    public void fit(@NonNull MultiDataSetIterator multiDataSetIterator, int i) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        Preconditions.checkArgument(i > 0, "Number of epochs much be > 0. Got numEpochs = %s", i);
        Preconditions.checkArgument(i == 1 || multiDataSetIterator.resetSupported(), "Cannot perform multiple epochs training usingiterator thas does not support resetting (iterator.resetSupported() returned false)");
        for (int i2 = 0; i2 < i; i2++) {
            fit(multiDataSetIterator);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public synchronized void fit(MultiDataSetIterator multiDataSetIterator) {
        MultiDataSetIterator multiDataSetIterator2;
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        Iterator<TrainingListener> it = this.trainingListeners.iterator();
        while (it.hasNext()) {
            it.next().onEpochStart(this);
        }
        boolean z = false;
        if (multiDataSetIterator.asyncSupported()) {
            multiDataSetIterator2 = new AsyncMultiDataSetIterator(multiDataSetIterator, Math.max(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE);
            z = true;
        } else {
            multiDataSetIterator2 = multiDataSetIterator;
        }
        long currentTimeMillis = System.currentTimeMillis();
        while (true) {
            long j = currentTimeMillis;
            if (!multiDataSetIterator2.hasNext()) {
                break;
            }
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator2.next();
            this.lastEtlTime.set(Long.valueOf(System.currentTimeMillis() - j));
            fit(multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
            currentTimeMillis = System.currentTimeMillis();
        }
        if (z) {
            ((AsyncMultiDataSetIterator) multiDataSetIterator2).shutdown();
        }
        Iterator<TrainingListener> it2 = this.trainingListeners.iterator();
        while (it2.hasNext()) {
            it2.next().onEpochEnd(this);
        }
        incrementEpochCount();
    }

    public void fit(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        fit(iNDArrayArr, iNDArrayArr2, null, null);
    }

    public void fit(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        try {
            fitHelper(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private synchronized void fitHelper(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        if (numParams() == 0) {
            return;
        }
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        setInputs(iNDArrayArr);
        setLabels(iNDArrayArr2);
        setLayerMaskArrays(iNDArrayArr3, iNDArrayArr4);
        update(TaskUtils.buildTask(iNDArrayArr, iNDArrayArr2));
        LayerWorkspaceMgr noWorkspaces = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        if (!this.configuration.isBackprop()) {
            throw new IllegalStateException("Network configuration is set to backprop(false). Use the pretrain and pretrainLayer methods to perform training for unsupervised layerwise training of neural networks");
        }
        if (this.configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
            doTruncatedBPTT(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4, noWorkspaces);
        } else {
            if (this.solver == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            this.solver.optimize(noWorkspaces);
        }
        if (iNDArrayArr3 != null || iNDArrayArr4 != null) {
            clearLayerMaskArrays();
        }
        clearLayersStates();
        synchronizeIterEpochCounts();
    }

    public int[] topologicalSortOrder() {
        return calculateIndices().getTopologicalSortOrder();
    }

    public GraphIndices calculateIndices() {
        if (this.graphIndices != null) {
            return this.graphIndices;
        }
        if (this.configuration.getTopologicalOrder() != null && this.configuration.getTopologicalOrderStr() != null) {
            int[] topologicalOrder = this.configuration.getTopologicalOrder();
            List<String> topologicalOrderStr = this.configuration.getTopologicalOrderStr();
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (int i = 0; i < topologicalOrder.length; i++) {
                hashMap.put(topologicalOrderStr.get(i), Integer.valueOf(topologicalOrder[i]));
                hashMap2.put(Integer.valueOf(topologicalOrder[i]), topologicalOrderStr.get(i));
            }
            this.graphIndices = GraphIndices.builder().topologicalSortOrder(topologicalOrder).nameToIdx(hashMap).idxToName(hashMap2).build();
            return this.graphIndices;
        }
        Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.configuration.getVertices();
        int[] iArr = new int[this.configuration.getNetworkInputs().size() + this.configuration.getVertices().size()];
        int i2 = 0;
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        int i3 = 0;
        for (String str : this.configuration.getNetworkInputs()) {
            hashMap3.put(Integer.valueOf(i3), str);
            hashMap4.put(str, Integer.valueOf(i3));
            i3++;
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it = vertices.entrySet().iterator();
        while (it.hasNext()) {
            String key = it.next().getKey();
            hashMap3.put(Integer.valueOf(i3), key);
            hashMap4.put(key, Integer.valueOf(i3));
            i3++;
        }
        HashMap hashMap5 = new HashMap();
        HashMap hashMap6 = new HashMap();
        Iterator<String> it2 = this.configuration.getNetworkInputs().iterator();
        while (it2.hasNext()) {
            hashMap5.put(Integer.valueOf(((Integer) hashMap4.get(it2.next())).intValue()), null);
        }
        Iterator<Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex>> it3 = vertices.entrySet().iterator();
        while (it3.hasNext()) {
            String key2 = it3.next().getKey();
            int intValue = ((Integer) hashMap4.get(key2)).intValue();
            List<String> list = this.configuration.getVertexInputs().get(key2);
            if (list == null || list.isEmpty()) {
                hashMap5.put(Integer.valueOf(intValue), null);
            } else {
                HashSet hashSet = new HashSet();
                Iterator<String> it4 = list.iterator();
                while (it4.hasNext()) {
                    Integer num = (Integer) hashMap4.get(it4.next());
                    hashSet.add(num);
                    Set set = (Set) hashMap6.get(num);
                    if (set == null) {
                        set = new HashSet();
                        hashMap6.put(num, set);
                    }
                    set.add(Integer.valueOf(intValue));
                }
                hashMap5.put(Integer.valueOf(intValue), hashSet);
            }
        }
        LinkedList linkedList = new LinkedList();
        for (Map.Entry entry : hashMap5.entrySet()) {
            Set set2 = (Set) entry.getValue();
            if (set2 == null || set2.isEmpty()) {
                linkedList.add(entry.getKey());
            }
        }
        while (!linkedList.isEmpty()) {
            int intValue2 = ((Integer) linkedList.removeFirst()).intValue();
            int i4 = i2;
            i2++;
            iArr[i4] = intValue2;
            Set<Integer> set3 = (Set) hashMap6.get(Integer.valueOf(intValue2));
            if (set3 != null) {
                for (Integer num2 : set3) {
                    Set set4 = (Set) hashMap5.get(num2);
                    set4.remove(Integer.valueOf(intValue2));
                    if (set4.isEmpty()) {
                        linkedList.add(num2);
                    }
                }
            }
        }
        for (Map.Entry entry2 : hashMap5.entrySet()) {
            Set set5 = (Set) entry2.getValue();
            if (set5 != null && !set5.isEmpty()) {
                throw new IllegalStateException("Invalid configuration: cycle detected in graph. Cannot calculate topological ordering with graph cycle (cycle includes vertex \"" + ((String) hashMap3.get(entry2.getKey())) + "\")");
            }
        }
        ArrayList arrayList = new ArrayList(iArr.length);
        for (int i5 : iArr) {
            arrayList.add(hashMap3.get(Integer.valueOf(i5)));
        }
        this.configuration.setTopologicalOrder(iArr);
        this.configuration.setTopologicalOrderStr(arrayList);
        this.graphIndices = GraphIndices.builder().topologicalSortOrder(iArr).nameToIdx(hashMap4).idxToName(hashMap3).build();
        return this.graphIndices;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        computeGradientAndScore();
    }

    public void computeGradientAndScore() {
        MemoryWorkspace scopeOutOfWorkspaces;
        synchronizeIterEpochCounts();
        LayerWorkspaceMgr noWorkspaces = this.configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        boolean z = this.configuration.getBackpropType() == BackpropType.TruncatedBPTT;
        FwdPassType fwdPassType = z ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD;
        synchronizeIterEpochCounts();
        MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            Map<String, INDArray> ffToLayerActivationsInWS = ffToLayerActivationsInWS(true, -1, getOutputLayerIndices(), fwdPassType, z, this.inputs, this.inputMaskArrays, this.labelMaskArrays, false);
            if (!this.trainingListeners.isEmpty()) {
                scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th2 = null;
                try {
                    try {
                        Iterator<TrainingListener> it = this.trainingListeners.iterator();
                        while (it.hasNext()) {
                            it.next().onForwardPass(this, ffToLayerActivationsInWS);
                        }
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            }
            calcBackpropGradients(false, false, new INDArray[0]);
            noWorkspaces.assertCurrentWorkspace(ArrayType.ACTIVATIONS, null);
            double calcL1 = calcL1();
            double calcL2 = calcL2();
            this.score = EvaluationBinary.DEFAULT_EDGE_VALUE;
            int i = 0;
            Iterator<String> it2 = this.configuration.getNetworkOutputs().iterator();
            while (it2.hasNext()) {
                GraphVertex graphVertex = this.verticesMap.get(it2.next());
                if (graphVertex instanceof LayerVertex) {
                    LayerVertex layerVertex = (LayerVertex) graphVertex;
                    if (!layerVertex.isSetLayerInput()) {
                        layerVertex.applyPreprocessorAndSetInput(noWorkspaces);
                    }
                }
                Layer layer = graphVertex.getLayer();
                if (layer instanceof FrozenLayerWithBackprop) {
                    layer = ((FrozenLayerWithBackprop) layer).getInsideLayer();
                }
                layer.setMaskArray(this.labelMaskArrays == null ? null : this.labelMaskArrays[i]);
                MemoryWorkspace notifyScopeEntered2 = noWorkspaces.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th4 = null;
                try {
                    try {
                        this.score += ((IOutputLayer) layer).computeScore(calcL1, calcL2, true, noWorkspaces);
                        if (notifyScopeEntered2 != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered2.close();
                                } catch (Throwable th5) {
                                    th4.addSuppressed(th5);
                                }
                            } else {
                                notifyScopeEntered2.close();
                            }
                        }
                        calcL1 = 0.0d;
                        calcL2 = 0.0d;
                        i++;
                    } finally {
                    }
                } catch (Throwable th6) {
                    if (notifyScopeEntered2 != null) {
                        if (th4 != null) {
                            try {
                                notifyScopeEntered2.close();
                            } catch (Throwable th7) {
                                th4.addSuppressed(th7);
                            }
                        } else {
                            notifyScopeEntered2.close();
                        }
                    }
                    throw th6;
                }
            }
            if (!this.trainingListeners.isEmpty()) {
                scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th8 = null;
                try {
                    try {
                        Iterator<TrainingListener> it3 = this.trainingListeners.iterator();
                        while (it3.hasNext()) {
                            it3.next().onBackwardPass(this);
                        }
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th9) {
                                    th8.addSuppressed(th9);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            }
            for (GraphVertex graphVertex2 : this.vertices) {
                graphVertex2.clear();
            }
        } finally {
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th10) {
                        th.addSuppressed(th10);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
    }

    public Map<String, INDArray> feedForward(INDArray iNDArray, int i, boolean z) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        setInput(0, iNDArray);
        return feedForward(z, i);
    }

    public Map<String, INDArray> feedForward(INDArray[] iNDArrayArr, int i, boolean z, boolean z2) {
        setInputs(iNDArrayArr);
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, i, null, iNDArrayArr, this.inputMaskArrays, this.labelMaskArrays, z2);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(INDArray[] iNDArrayArr, int i, boolean z) {
        setInputs(iNDArrayArr);
        return feedForward(z, i);
    }

    public Map<String, INDArray> feedForward(boolean z, int i) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.layers[i].getIndex(), null, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(INDArray iNDArray, boolean z) {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot feedForward with single input for graph network with " + this.numInputArrays + " expected inputs");
        }
        setInput(0, iNDArray);
        return feedForward(z);
    }

    public Map<String, INDArray> feedForward(INDArray[] iNDArrayArr, boolean z) {
        return feedForward(iNDArrayArr, z, true);
    }

    public Map<String, INDArray> feedForward(INDArray[] iNDArrayArr, boolean z, boolean z2) {
        setInputs(iNDArrayArr);
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.vertices.length - 1, null, iNDArrayArr, this.inputMaskArrays, this.labelMaskArrays, z2);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward() {
        return feedForward(false);
    }

    public Map<String, INDArray> feedForward(boolean z) {
        try {
            return ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.vertices.length - 1, null, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public Map<String, INDArray> feedForward(boolean z, boolean z2, boolean z3) {
        Map<String, INDArray> ffToLayerActivationsDetached = ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.vertices.length - 1, z2 ? getOutputLayerIndices() : null, this.inputs, this.inputMaskArrays, this.labelMaskArrays, true);
        if (z3) {
            return ffToLayerActivationsDetached;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, INDArray> entry : ffToLayerActivationsDetached.entrySet()) {
            GraphVertex graphVertex = this.verticesMap.get(entry.getKey());
            if ((graphVertex instanceof LayerVertex) || (graphVertex instanceof InputVertex)) {
                hashMap.put(entry.getKey(), entry.getValue());
            }
        }
        return hashMap;
    }

    public INDArray[] output(INDArray... iNDArrayArr) {
        return output(false, iNDArrayArr, this.inputMaskArrays, this.labelMaskArrays);
    }

    public INDArray outputSingle(INDArray... iNDArrayArr) {
        return outputSingle(false, iNDArrayArr);
    }

    public INDArray[] output(boolean z, INDArray... iNDArrayArr) {
        return output(z, (MemoryWorkspace) null, iNDArrayArr);
    }

    public INDArray[] output(boolean z, MemoryWorkspace memoryWorkspace, INDArray... iNDArrayArr) {
        return output(z, iNDArrayArr, this.inputMaskArrays, this.labelMaskArrays, memoryWorkspace);
    }

    public INDArray[] output(boolean z, @NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        return output(z, iNDArrayArr, iNDArrayArr2, (INDArray[]) null);
    }

    public INDArray[] output(boolean z, @NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        return output(z, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, null);
    }

    public synchronized <T> T output(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, @NonNull OutputAdapter<T> outputAdapter) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("inputs is marked @NonNull but is null");
        }
        if (outputAdapter == null) {
            throw new NullPointerException("outputAdapter is marked @NonNull but is null");
        }
        MemoryWorkspace andActivateWorkspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM);
        Throwable th = null;
        try {
            try {
                T apply = outputAdapter.apply(output(false, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, andActivateWorkspace));
                if (andActivateWorkspace != null) {
                    if (0 != 0) {
                        try {
                            andActivateWorkspace.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        andActivateWorkspace.close();
                    }
                }
                return apply;
            } finally {
            }
        } catch (Throwable th3) {
            if (andActivateWorkspace != null) {
                if (th != null) {
                    try {
                        andActivateWorkspace.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    andActivateWorkspace.close();
                }
            }
            throw th3;
        }
    }

    public synchronized INDArray[] output(boolean z, @NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, MemoryWorkspace memoryWorkspace) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        try {
            setLayerMaskArrays(iNDArrayArr2, iNDArrayArr3);
            INDArray[] outputOfLayersDetached = outputOfLayersDetached(z, FwdPassType.STANDARD, getOutputLayerIndices(), iNDArrayArr, iNDArrayArr2, iNDArrayArr3, true, false, memoryWorkspace);
            clearLayerMaskArrays();
            clearLayersStates();
            return outputOfLayersDetached;
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public INDArray outputSingle(boolean z, INDArray... iNDArrayArr) {
        return outputSingle(z, true, iNDArrayArr);
    }

    public INDArray outputSingle(boolean z, boolean z2, INDArray... iNDArrayArr) {
        if (this.numOutputArrays != 1) {
            throw new IllegalStateException("Cannot use outputSingle with ComputationGraph that does not have exactly 1 output. nOutputs: " + this.numOutputArrays);
        }
        return output(z, z2, iNDArrayArr)[0];
    }

    public synchronized INDArray[] output(boolean z, boolean z2, INDArray... iNDArrayArr) {
        try {
            return outputOfLayersDetached(z, FwdPassType.STANDARD, getOutputLayerIndices(), iNDArrayArr, null, null, z2, !z2, null);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public INDArray[] output(DataSetIterator dataSetIterator) {
        return output(new MultiDataSetIteratorAdapter(dataSetIterator));
    }

    public INDArray[] output(MultiDataSetIterator multiDataSetIterator) {
        ArrayList arrayList = new ArrayList();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            arrayList.add(output(false, multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays()));
        }
        return (INDArray[]) DataSetUtil.mergeFeatures((INDArray[][]) arrayList.toArray(new INDArray[arrayList.size()][0]), (INDArray[][]) null).getFirst();
    }

    public INDArray outputSingle(DataSetIterator dataSetIterator) {
        Preconditions.checkArgument(this.numOutputArrays == 1, "Cannot use this method with nets that have more than 1 output array. This network has %s outputs", this.numOutputArrays);
        return output(dataSetIterator)[0];
    }

    public INDArray outputSingle(MultiDataSetIterator multiDataSetIterator) {
        Preconditions.checkArgument(this.numOutputArrays == 1, "Cannot use this method with nets that have more than 1 output array. This network has %s outputs", this.numOutputArrays);
        return output(multiDataSetIterator)[0];
    }

    public INDArray[] output(List<String> list, boolean z, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        Preconditions.checkState(list != null && list.size() > 0, "Layers must not be null: got later names %s", list);
        int[] iArr = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            Preconditions.checkState(this.verticesMap.containsKey(str), "Layer with name %s not found in network", str);
            iArr[i] = this.verticesMap.get(str).getVertexIndex();
        }
        return outputOfLayersDetached(z, FwdPassType.STANDARD, iArr, iNDArrayArr, iNDArrayArr2, null, true, false, null);
    }

    protected void validateArrayWorkspaces(LayerWorkspaceMgr layerWorkspaceMgr, INDArray iNDArray, ArrayType arrayType, String str, boolean z, String str2) {
        try {
            layerWorkspaceMgr.validateArrayLocation(arrayType, iNDArray, false, z);
        } catch (ND4JWorkspaceException e) {
            GraphVertex graphVertex = this.verticesMap.get(str);
            throw new IllegalStateException(str2 + ": array (" + arrayType + ") workspace validation failed (vertex " + str + " - class: " + (graphVertex instanceof LayerVertex ? graphVertex.getLayer().getClass().getSimpleName() : graphVertex.getClass().getSimpleName()) + ") - array is defined in incorrect workspace", e);
        }
    }

    protected synchronized Map<String, INDArray> ffToLayerActivationsDetached(boolean z, @NonNull FwdPassType fwdPassType, boolean z2, int i, int[] iArr, @NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, boolean z3) {
        LayerWorkspaceMgr build;
        INDArray doForward;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked @NonNull but is null");
        }
        if (iNDArrayArr == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        if (i < 0 || i >= this.topologicalOrder.length) {
            throw new IllegalArgumentException("Invalid layer index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + i);
        }
        setInputs(iNDArrayArr);
        setLayerMaskArrays(iNDArrayArr2, iNDArrayArr3);
        WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to ffToLayerActivationsDetached");
        if ((z ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode()) == WorkspaceMode.NONE) {
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            build = LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (iNDArrayArr[0].isAttached()) {
                build.setNoLeverageOverride(iNDArrayArr[0].data().getParentWorkspace().getId());
            }
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            hashMap.put(this.configuration.getNetworkInputs().get(i2), iNDArrayArr[i2]);
        }
        for (int i3 = 0; i3 <= i; i3++) {
            GraphVertex graphVertex = this.vertices[this.topologicalOrder[i3]];
            String vertexName = graphVertex.getVertexName();
            int vertexIndex = graphVertex.getVertexIndex();
            if (iArr == null || !ArrayUtils.contains(iArr, vertexIndex)) {
                MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th = null;
                try {
                    try {
                        VertexIndices[] outputVertices = graphVertex.getOutputVertices();
                        if (graphVertex.isInputVertex()) {
                            doForward = this.inputs[vertexIndex];
                        } else {
                            if (fwdPassType == FwdPassType.STANDARD) {
                                doForward = graphVertex.doForward(z, build);
                            } else if (fwdPassType != FwdPassType.RNN_TIMESTEP) {
                                if (fwdPassType != FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                                    throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
                                }
                                if (graphVertex.hasLayer()) {
                                    Layer layer = graphVertex.getLayer();
                                    if (layer instanceof RecurrentLayer) {
                                        doForward = ((RecurrentLayer) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2, build);
                                    } else if (layer instanceof MultiLayerNetwork) {
                                        List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2);
                                        doForward = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
                                    } else {
                                        doForward = graphVertex.doForward(z, build);
                                    }
                                } else {
                                    doForward = graphVertex.doForward(z, build);
                                }
                            } else if (graphVertex.hasLayer()) {
                                INDArray iNDArray = graphVertex.getInputs()[0];
                                Layer layer2 = graphVertex.getLayer();
                                doForward = layer2 instanceof RecurrentLayer ? ((RecurrentLayer) layer2).rnnTimeStep(reshapeTimeStepInput(iNDArray), build) : layer2 instanceof MultiLayerNetwork ? ((MultiLayerNetwork) layer2).rnnTimeStep(reshapeTimeStepInput(iNDArray)) : graphVertex.doForward(z, build);
                            } else {
                                doForward = graphVertex.doForward(z, build);
                            }
                            validateArrayWorkspaces(build, doForward, ArrayType.ACTIVATIONS, vertexName, false, "Feed forward (inference)");
                        }
                        hashMap.put(graphVertex.getVertexName(), doForward);
                        if (outputVertices != null) {
                            for (VertexIndices vertexIndices : outputVertices) {
                                this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), doForward, build);
                            }
                        }
                        if (z3) {
                            graphVertex.clear();
                        }
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (notifyScopeEntered != null) {
                        if (th != null) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    throw th3;
                }
            }
        }
        return hashMap;
    }

    protected synchronized Map<String, INDArray> ffToLayerActivationsInWS(boolean z, int i, int[] iArr, FwdPassType fwdPassType, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, boolean z3) {
        LayerWorkspaceMgr build;
        INDArray doForward;
        if (i != -1 && (i < 0 || i >= this.topologicalOrder.length)) {
            throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + i);
        }
        setInputs(iNDArrayArr);
        setLayerMaskArrays(iNDArrayArr2, iNDArrayArr3);
        if ((z ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode()) == WorkspaceMode.NONE) {
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsDetached");
            build = LayerWorkspaceMgr.noWorkspaces();
        } else {
            WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open");
            build = LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
            if (iNDArrayArr[0].isAttached()) {
                build.setNoLeverageOverride(iNDArrayArr[0].data().getParentWorkspace().getId());
            }
            if (this.configuration.getCacheMode() != CacheMode.NONE) {
                build.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG);
            }
        }
        build.setHelperWorkspacePointers(this.helperWorkspaces);
        HashMap hashMap = new HashMap();
        int indexOf = i > 0 ? ArrayUtils.indexOf(this.topologicalOrder, i) : this.topologicalOrder.length - 1;
        for (int i2 = 0; i2 <= indexOf; i2++) {
            GraphVertex graphVertex = this.vertices[this.topologicalOrder[i2]];
            String vertexName = graphVertex.getVertexName();
            int vertexIndex = graphVertex.getVertexIndex();
            if (iArr == null || !ArrayUtils.contains(iArr, vertexIndex)) {
                MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th = null;
                try {
                    VertexIndices[] outputVertices = graphVertex.getOutputVertices();
                    if (graphVertex.isInputVertex()) {
                        doForward = this.inputs[vertexIndex];
                    } else {
                        if (fwdPassType == FwdPassType.STANDARD) {
                            doForward = graphVertex.doForward(z, build);
                        } else {
                            if (fwdPassType != FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) {
                                throw new IllegalStateException("FwdPassType not supported for this method: " + fwdPassType);
                            }
                            if (graphVertex.hasLayer()) {
                                Layer layer = graphVertex.getLayer();
                                if (layer instanceof RecurrentLayer) {
                                    doForward = ((RecurrentLayer) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2, build);
                                } else if (layer instanceof MultiLayerNetwork) {
                                    List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) layer).rnnActivateUsingStoredState(graphVertex.getInputs()[0], z, z2);
                                    doForward = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
                                } else {
                                    doForward = graphVertex.doForward(z, build);
                                }
                            } else {
                                doForward = graphVertex.doForward(z, build);
                            }
                        }
                        validateArrayWorkspaces(build, doForward, ArrayType.ACTIVATIONS, vertexName, false, "Feed forward (inference)");
                    }
                    hashMap.put(graphVertex.getVertexName(), doForward);
                    if (outputVertices != null) {
                        for (VertexIndices vertexIndices : outputVertices) {
                            this.vertices[vertexIndices.getVertexIndex()].setInput(vertexIndices.getVertexEdgeNumber(), doForward, build);
                        }
                    }
                    if (z3) {
                        graphVertex.clear();
                    }
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                } catch (Throwable th3) {
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    throw th3;
                }
            }
        }
        return hashMap;
    }

    protected INDArray[] outputOfLayersDetached(boolean z, @NonNull FwdPassType fwdPassType, @NonNull int[] iArr, @NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, boolean z2, boolean z3, MemoryWorkspace memoryWorkspace) {
        boolean z4;
        LayerWorkspaceMgr build;
        INDArray doForward;
        if (fwdPassType == null) {
            throw new NullPointerException("fwdPassType is marked @NonNull but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("layerIndexes is marked @NonNull but is null");
        }
        if (iNDArrayArr == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        if (iNDArrayArr.length != this.numInputArrays) {
            throw new IllegalArgumentException("Invalid number of input arrays: network has " + this.numInputArrays + " inputs, got " + iNDArrayArr.length + " input arrays");
        }
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0 || iArr[i] >= this.topologicalOrder.length) {
                throw new IllegalArgumentException("Invalid input index - index must be >= 0 and < " + this.topologicalOrder.length + ", got index " + iArr[i]);
            }
        }
        setInputs(iNDArrayArr);
        setLayerMaskArrays(iNDArrayArr2, iNDArrayArr3);
        MemoryWorkspace memoryWorkspace2 = null;
        if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
            WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active before call to outputOfLayersDetached");
        } else {
            Preconditions.checkState(memoryWorkspace.isScopeActive(), "Workspace \"" + memoryWorkspace.getId() + "\" was provided for the network/layer outputs. When provided, this workspace must be opened before calling the output method; furthermore, closing the workspace is the responsibility of the user");
            memoryWorkspace2 = memoryWorkspace.getParentWorkspace();
        }
        int[] iArr2 = new int[this.topologicalOrder.length];
        for (GraphVertex graphVertex : this.vertices) {
            int vertexIndex = graphVertex.getVertexIndex();
            int i2 = -1;
            VertexIndices[] outputVertices = graphVertex.getOutputVertices();
            if (outputVertices != null) {
                for (VertexIndices vertexIndices : outputVertices) {
                    int indexOf = ArrayUtils.indexOf(this.topologicalOrder, vertexIndices.getVertexIndex());
                    if (indexOf == -1) {
                        throw new IllegalStateException("Did not find vertex " + vertexIndices.getVertexIndex() + " in topological sort array");
                    }
                    i2 = Math.max(i2, indexOf);
                }
            } else {
                i2 = this.topologicalOrder.length - 1;
            }
            iArr2[vertexIndex] = i2;
        }
        INDArray[] iNDArrayArr4 = new INDArray[iArr.length];
        int i3 = -1;
        for (int i4 : iArr) {
            i3 = Math.max(i3, ArrayUtils.indexOf(this.topologicalOrder, i4));
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        IdentityHashMap identityHashMap = new IdentityHashMap();
        boolean z5 = (z ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode()) == WorkspaceMode.NONE;
        LayerWorkspaceMgr noWorkspaces = z5 ? LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces) : null;
        List[] listArr = new List[this.topologicalOrder.length];
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        for (int i5 = 0; i5 <= i3; i5++) {
            try {
                GraphVertex graphVertex2 = this.vertices[this.topologicalOrder[i5]];
                String vertexName = graphVertex2.getVertexName();
                int vertexIndex2 = graphVertex2.getVertexIndex();
                if (z5) {
                    build = noWorkspaces;
                } else if (arrayList2.size() > 0) {
                    build = (LayerWorkspaceMgr) arrayList2.remove(arrayList2.size() - 1);
                } else {
                    String str = "WS_LAYER_ACT_" + arrayList.size();
                    build = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, str, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.ACTIVATIONS, str, this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
                    if (z3) {
                        build.setScopedOutFor(ArrayType.INPUT);
                        build.setScopedOutFor(ArrayType.ACTIVATIONS);
                    } else if (iNDArrayArr[0].isAttached()) {
                        build.setNoLeverageOverride(iNDArrayArr[0].data().getParentWorkspace().getId());
                    }
                    arrayList.add(build);
                }
                build.setHelperWorkspacePointers(this.helperWorkspaces);
                boolean z6 = false;
                String str2 = null;
                WorkspaceConfiguration workspaceConfiguration = null;
                if (ArrayUtils.contains(iArr, vertexIndex2)) {
                    z6 = true;
                    if (memoryWorkspace != null && !(memoryWorkspace instanceof DummyWorkspace)) {
                        str2 = build.getWorkspaceName(ArrayType.ACTIVATIONS);
                        workspaceConfiguration = build.getConfiguration(ArrayType.ACTIVATIONS);
                        build.setWorkspace(ArrayType.ACTIVATIONS, memoryWorkspace.getId(), memoryWorkspace.getWorkspaceConfiguration());
                    } else if (!build.isScopedOut(ArrayType.ACTIVATIONS)) {
                        str2 = build.getWorkspaceName(ArrayType.ACTIVATIONS);
                        workspaceConfiguration = build.getConfiguration(ArrayType.ACTIVATIONS);
                        build.setScopedOutFor(ArrayType.ACTIVATIONS);
                    }
                }
                MemoryWorkspace memoryWorkspace3 = null;
                if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace) || !z6) {
                    memoryWorkspace3 = build.notifyScopeEntered(ArrayType.ACTIVATIONS);
                    identityHashMap.put(memoryWorkspace3, build);
                }
                if (memoryWorkspace3 != null) {
                    memoryWorkspace3.setPreviousWorkspace(currentWorkspace);
                }
                int i6 = iArr2[vertexIndex2];
                if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace) || (memoryWorkspace3 != null && !memoryWorkspace.getId().equals(memoryWorkspace3.getId()))) {
                    if (listArr[i6] == null) {
                        listArr[i6] = new ArrayList();
                    }
                    listArr[i6].add(memoryWorkspace3);
                }
                MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th = null;
                try {
                    try {
                        VertexIndices[] outputVertices2 = graphVertex2.getOutputVertices();
                        if (graphVertex2.isInputVertex()) {
                            doForward = iNDArrayArr[vertexIndex2];
                        } else {
                            if (fwdPassType == FwdPassType.STANDARD) {
                                doForward = graphVertex2.doForward(z, build);
                            } else {
                                if (fwdPassType != FwdPassType.RNN_TIMESTEP) {
                                    throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType);
                                }
                                if (graphVertex2.hasLayer()) {
                                    INDArray iNDArray = graphVertex2.getInputs()[0];
                                    Layer layer = graphVertex2.getLayer();
                                    doForward = layer instanceof RecurrentLayer ? ((RecurrentLayer) layer).rnnTimeStep(reshapeTimeStepInput(iNDArray), build) : layer instanceof MultiLayerNetwork ? ((MultiLayerNetwork) layer).rnnTimeStep(reshapeTimeStepInput(iNDArray)) : graphVertex2.doForward(z, build);
                                } else {
                                    doForward = graphVertex2.doForward(z, build);
                                }
                            }
                            validateArrayWorkspaces(build, doForward, ArrayType.ACTIVATIONS, vertexName, false, "Feed forward (inference)");
                        }
                        if (outputVertices2 != null) {
                            for (VertexIndices vertexIndices2 : outputVertices2) {
                                this.vertices[vertexIndices2.getVertexIndex()].setInput(vertexIndices2.getVertexEdgeNumber(), doForward, build);
                            }
                        }
                        if (z2) {
                            graphVertex2.clear();
                        }
                        if (z6) {
                            iNDArrayArr4[ArrayUtils.indexOf(iArr, vertexIndex2)] = doForward;
                            if (str2 != null) {
                                build.setWorkspace(ArrayType.ACTIVATIONS, str2, workspaceConfiguration);
                            }
                        }
                        if (notifyScopeEntered != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                        if (listArr[i5] != null) {
                            for (MemoryWorkspace memoryWorkspace4 : listArr[i5]) {
                                memoryWorkspace4.close();
                                arrayList2.add((LayerWorkspaceMgr) identityHashMap.remove(memoryWorkspace4));
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } finally {
                for (MemoryWorkspace memoryWorkspace5 : identityHashMap.keySet()) {
                    while (memoryWorkspace5.isScopeActive()) {
                        memoryWorkspace5.close();
                    }
                }
                Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
                if (memoryWorkspace == null || (memoryWorkspace instanceof DummyWorkspace)) {
                    WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached");
                } else {
                    Preconditions.checkState(memoryWorkspace.isScopeActive(), "Expected output workspace to still be openat end of outputOfLayerDetached, but ");
                    memoryWorkspace.setPreviousWorkspace(memoryWorkspace2);
                }
            }
        }
        if (memoryWorkspace != null) {
            if (!z4) {
                return iNDArrayArr4;
            }
        }
        return iNDArrayArr4;
    }

    private INDArray reshapeTimeStepInput(INDArray iNDArray) {
        if (iNDArray.rank() == 2) {
            long[] shape = iNDArray.shape();
            iNDArray = iNDArray.reshape(new long[]{shape[0], shape[1], 1});
        }
        return iNDArray;
    }

    public Gradient backpropGradient(INDArray... iNDArrayArr) {
        if (iNDArrayArr == null || iNDArrayArr.length != this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid input: must have epsilons length equal to number of output arrays");
        }
        try {
            calcBackpropGradients(true, this.configuration.getBackpropType() == BackpropType.TruncatedBPTT, iNDArrayArr);
            return this.gradient;
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    protected void calcBackpropGradients(boolean z, boolean z2, INDArray... iNDArrayArr) {
        LayerWorkspaceMgr build;
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (iNDArrayArr == null || (iNDArrayArr.length == 0 && this.configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE)) {
            WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "Expected workspace WS_ALL_LAYERS_ACT to be active and open in calcBackpropGradients when workspace mode is not set to NONE");
        }
        if (iNDArrayArr != null && iNDArrayArr.length > 0) {
            for (String str : this.configuration.getNetworkOutputs()) {
                GraphVertex vertex = getVertex(str);
                if ((vertex instanceof LayerVertex) && (((LayerVertex) vertex).getLayer() instanceof IOutputLayer)) {
                    throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer: output layers cannot use external errors for backprop. Layer name: " + str);
                }
            }
        }
        int[] iArr = new int[this.topologicalOrder.length];
        for (GraphVertex graphVertex : this.vertices) {
            int vertexIndex = graphVertex.getVertexIndex();
            int i = Integer.MAX_VALUE;
            VertexIndices[] inputVertices = graphVertex.getInputVertices();
            if (inputVertices != null) {
                for (VertexIndices vertexIndices : inputVertices) {
                    int indexOf = ArrayUtils.indexOf(this.topologicalOrder, vertexIndices.getVertexIndex());
                    if (indexOf == -1) {
                        throw new IllegalStateException("Did not find vertex " + vertexIndices.getVertexIndex() + " in topological sort array");
                    }
                    i = Math.min(i, indexOf);
                }
            }
            if (i == Integer.MAX_VALUE) {
                iArr[vertexIndex] = 0;
            } else {
                iArr[vertexIndex] = i;
            }
        }
        boolean z3 = this.configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE;
        LayerWorkspaceMgr noWorkspaces = z3 ? LayerWorkspaceMgr.noWorkspaces(this.helperWorkspaces) : null;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        IdentityHashMap identityHashMap = new IdentityHashMap();
        List[] listArr = new List[this.topologicalOrder.length];
        LinkedList linkedList = new LinkedList();
        boolean[] zArr = new boolean[this.topologicalOrder.length];
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        try {
            for (int length = this.topologicalOrder.length - 1; length >= 0; length--) {
                GraphVertex graphVertex2 = this.vertices[this.topologicalOrder[length]];
                int vertexIndex2 = graphVertex2.getVertexIndex();
                String vertexName = graphVertex2.getVertexName();
                boolean z4 = (graphVertex2.hasLayer() && (graphVertex2.getLayer() instanceof FrozenLayer)) || (graphVertex2 instanceof FrozenVertex);
                if (graphVertex2.isInputVertex() || z4) {
                    if (listArr[length] != null) {
                        for (MemoryWorkspace memoryWorkspace : listArr[length]) {
                            memoryWorkspace.close();
                            arrayList2.add((LayerWorkspaceMgr) identityHashMap.remove(memoryWorkspace));
                        }
                    }
                    listArr[length] = null;
                } else {
                    if (z3) {
                        build = noWorkspaces;
                    } else if (arrayList2.size() > 0) {
                        build = (LayerWorkspaceMgr) arrayList2.remove(arrayList2.size() - 1);
                    } else {
                        build = LayerWorkspaceMgr.builder().with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.ACTIVATION_GRAD, "WS_LAYER_ACT_" + arrayList.size(), this.WS_LAYER_ACT_X_CONFIG).with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
                        arrayList.add(build);
                    }
                    build.setHelperWorkspacePointers(this.helperWorkspaces);
                    if (graphVertex2.isOutputVertex()) {
                        int indexOf2 = this.configuration.getNetworkOutputs().indexOf(graphVertex2.getVertexName());
                        Layer layer = graphVertex2.getLayer();
                        if (layer instanceof FrozenLayerWithBackprop) {
                            layer = ((FrozenLayerWithBackprop) layer).getInsideLayer();
                        }
                        if (layer instanceof IOutputLayer) {
                            ((IOutputLayer) layer).setLabels(this.labels[indexOf2]);
                        } else {
                            if ((iNDArrayArr == null || iNDArrayArr.length == 0) && this.labels[indexOf2] != null) {
                                throw new DL4JException("Layer \"" + graphVertex2.getVertexName() + "\" of type " + graphVertex2.getLayer().getClass().getSimpleName() + " is set as network output (but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with a labels array. ");
                            }
                            graphVertex2.setEpsilon(iNDArrayArr[indexOf2]);
                            zArr[this.topologicalOrder[length]] = true;
                        }
                    }
                    MemoryWorkspace notifyScopeEntered = build.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
                    identityHashMap.put(notifyScopeEntered, build);
                    notifyScopeEntered.setPreviousWorkspace(currentWorkspace);
                    int i2 = iArr[vertexIndex2];
                    if (i2 >= 0) {
                        if (listArr[i2] == null) {
                            listArr[i2] = new ArrayList();
                        }
                        listArr[i2].add(notifyScopeEntered);
                    }
                    MemoryWorkspace notifyScopeEntered2 = build.notifyScopeEntered(ArrayType.BP_WORKING_MEM);
                    Throwable th = null;
                    try {
                        try {
                            Pair<Gradient, INDArray[]> doBackward = graphVertex2.doBackward(z2, build);
                            INDArray[] iNDArrayArr2 = (INDArray[]) doBackward.getSecond();
                            for (INDArray iNDArray : iNDArrayArr2) {
                                if (iNDArray != null) {
                                    validateArrayWorkspaces(build, iNDArray, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
                                }
                            }
                            if (notifyScopeEntered2 != null) {
                                if (0 != 0) {
                                    try {
                                        notifyScopeEntered2.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    notifyScopeEntered2.close();
                                }
                            }
                            VertexIndices[] inputVertices2 = graphVertex2.getInputVertices();
                            if (inputVertices2 != null) {
                                int i3 = 0;
                                for (VertexIndices vertexIndices2 : inputVertices2) {
                                    GraphVertex graphVertex3 = this.vertices[vertexIndices2.getVertexIndex()];
                                    if (zArr[graphVertex3.getVertexIndex()]) {
                                        int i4 = i3;
                                        i3++;
                                        graphVertex3.setEpsilon(graphVertex3.getEpsilon().addi(iNDArrayArr2[i4]));
                                    } else {
                                        int i5 = i3;
                                        i3++;
                                        graphVertex3.setEpsilon(iNDArrayArr2[i5]);
                                    }
                                    zArr[graphVertex3.getVertexIndex()] = true;
                                }
                            }
                            if (doBackward.getFirst() != null) {
                                Gradient gradient = (Gradient) doBackward.getFirst();
                                Map<String, INDArray> gradientForVariable = gradient.gradientForVariable();
                                LinkedList linkedList2 = new LinkedList();
                                for (Map.Entry<String, INDArray> entry : gradientForVariable.entrySet()) {
                                    String key = entry.getKey();
                                    linkedList2.addFirst(new Triple(graphVertex2.getVertexName() + "_" + key, entry.getValue(), gradient.flatteningOrderForVariable(key)));
                                }
                                Iterator it = linkedList2.iterator();
                                while (it.hasNext()) {
                                    linkedList.addFirst((Triple) it.next());
                                }
                            }
                            if (listArr[length] != null) {
                                for (MemoryWorkspace memoryWorkspace2 : listArr[length]) {
                                    memoryWorkspace2.close();
                                    arrayList2.add((LayerWorkspaceMgr) identityHashMap.remove(memoryWorkspace2));
                                }
                                listArr[length] = null;
                            }
                        } catch (Throwable th3) {
                            th = th3;
                            throw th3;
                        }
                    } finally {
                    }
                }
            }
            DefaultGradient defaultGradient = new DefaultGradient(this.flattenedGradients);
            Iterator it2 = linkedList.iterator();
            while (it2.hasNext()) {
                Triple triple = (Triple) it2.next();
                defaultGradient.setGradientFor((String) triple.getFirst(), (INDArray) triple.getSecond(), (Character) triple.getThird());
            }
            this.gradient = defaultGradient;
            if (z2 && this.clearTbpttState) {
                rnnClearPreviousState();
            }
            if (z) {
                for (GraphVertex graphVertex4 : this.vertices) {
                    graphVertex4.clear();
                }
            }
        } finally {
            Iterator it3 = identityHashMap.keySet().iterator();
            while (it3.hasNext()) {
                ((MemoryWorkspace) it3.next()).close();
            }
            Nd4j.getMemoryManager().setCurrentWorkspace(currentWorkspace);
        }
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ComputationGraph m112clone() {
        INDArray stateViewArray;
        ComputationGraph computationGraph = new ComputationGraph(this.configuration.m26clone());
        computationGraph.init(params().dup(), false);
        if (this.solver != null && (stateViewArray = getUpdater().getStateViewArray()) != null) {
            computationGraph.getUpdater().setStateViewArray(stateViewArray.dup());
        }
        computationGraph.trainingListeners = this.trainingListeners;
        for (int i = 0; i < this.topologicalOrder.length; i++) {
            if (this.vertices[this.topologicalOrder[i]].hasLayer()) {
                String vertexName = this.vertices[this.topologicalOrder[i]].getVertexName();
                if (getLayer(vertexName) instanceof FrozenLayer) {
                    computationGraph.getVertex(vertexName).setLayerAsFrozen();
                }
            }
        }
        return computationGraph;
    }

    public double calcL2() {
        double d = 0.0d;
        for (Layer layer : this.layers) {
            d += layer.calcL2(true);
        }
        return d;
    }

    public double calcL1() {
        double d = 0.0d;
        for (Layer layer : this.layers) {
            d += layer.calcL1(true);
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<TrainingListener> collection) {
        if (this.layers == null) {
            init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(collection);
        }
        if (this.solver != null) {
            this.solver.setListeners(collection);
        }
        this.trainingListeners.clear();
        if (collection != null) {
            this.trainingListeners.addAll(collection);
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setListeners(TrainingListener... trainingListenerArr) {
        ArrayList arrayList = new ArrayList();
        if (trainingListenerArr != null && trainingListenerArr.length > 0) {
            for (TrainingListener trainingListener : trainingListenerArr) {
                if (trainingListener != null) {
                    arrayList.add(trainingListener);
                }
            }
        }
        setListeners(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void addListeners(TrainingListener... trainingListenerArr) {
        if (this.trainingListeners == null) {
            setListeners(trainingListenerArr);
            return;
        }
        ArrayList arrayList = new ArrayList(this.trainingListeners);
        Collections.addAll(arrayList, trainingListenerArr);
        setListeners(arrayList);
        if (this.solver != null) {
            this.solver.setListeners(this.trainingListeners);
        }
    }

    public Collection<TrainingListener> getListeners() {
        return this.trainingListeners;
    }

    public ComputationGraphUpdater getUpdater() {
        return getUpdater(true);
    }

    public ComputationGraphUpdater getUpdater(boolean z) {
        if (this.solver == null && z) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this));
        }
        if (this.solver != null) {
            return this.solver.getOptimizer().getComputationGraphUpdater();
        }
        return null;
    }

    public void setUpdater(ComputationGraphUpdater computationGraphUpdater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdaterComputationGraph(computationGraphUpdater);
    }

    public Layer getOutputLayer(int i) {
        if (i >= this.numOutputArrays) {
            throw new IllegalArgumentException("Invalid index: cannot get output layer " + i + ", total number of network outputs = " + this.numOutputArrays);
        }
        return getLayer(this.configuration.getNetworkOutputs().get(i));
    }

    public INDArray params(boolean z) {
        INDArray params;
        if (z) {
            return this.flattenedParams;
        }
        ArrayList arrayList = new ArrayList(this.layers.length);
        for (int i = 0; i < this.topologicalOrder.length; i++) {
            if (this.vertices[this.topologicalOrder[i]].hasLayer() && (params = this.vertices[this.topologicalOrder[i]].getLayer().params()) != null) {
                arrayList.add(params);
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    public double score(DataSet dataSet) {
        return score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean z) {
        if (this.numInputArrays == 1 && this.numOutputArrays == 1) {
            return score(ComputationGraphUtil.toMultiDataSet(dataSet), z);
        }
        throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
    }

    public double score(MultiDataSet multiDataSet) {
        return score(multiDataSet, false);
    }

    public double score(MultiDataSet multiDataSet, boolean z) {
        try {
            return scoreHelper(multiDataSet, z);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private double scoreHelper(MultiDataSet multiDataSet, boolean z) {
        LayerWorkspaceMgr noWorkspaces = (z ? this.configuration.getTrainingWorkspaceMode() : this.configuration.getInferenceWorkspaceMode()) == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().noWorkspaceFor(ArrayType.ACTIVATIONS).noWorkspaceFor(ArrayType.INPUT).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        if (multiDataSet.hasMaskArrays()) {
            setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        }
        double d = 0.0d;
        setInputs(multiDataSet.getFeatures());
        MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            ffToLayerActivationsDetached(z, FwdPassType.STANDARD, false, this.vertices.length - 1, getOutputLayerIndices(), multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays(), false);
            INDArray[] labels = multiDataSet.getLabels();
            setLabels(labels);
            double calcL1 = calcL1();
            double calcL2 = calcL2();
            int i = 0;
            for (String str : this.configuration.getNetworkOutputs()) {
                GraphVertex graphVertex = this.verticesMap.get(str);
                Layer layer = graphVertex.getLayer();
                if (layer == null || !(layer instanceof IOutputLayer)) {
                    log.warn("Cannot calculate score: vertex \"" + str + "\" is not an output layer");
                    if (notifyScopeEntered != null) {
                        if (0 != 0) {
                            try {
                                notifyScopeEntered.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            notifyScopeEntered.close();
                        }
                    }
                    return EvaluationBinary.DEFAULT_EDGE_VALUE;
                }
                int i2 = i;
                i++;
                ((IOutputLayer) layer).setLabels(labels[i2]);
                d += ((LayerVertex) graphVertex).computeScore(calcL1, calcL2, z, noWorkspaces);
                calcL1 = 0.0d;
                calcL2 = 0.0d;
            }
            clearLayersStates();
            return d;
        } finally {
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
    }

    public INDArray scoreExamples(DataSet dataSet, boolean z) {
        if (this.numInputArrays == 1 && this.numOutputArrays == 1) {
            return scoreExamples(ComputationGraphUtil.toMultiDataSet(dataSet), z);
        }
        throw new UnsupportedOperationException("Cannot score ComputationGraph network with  DataSet: network does not have 1 input and 1 output arrays");
    }

    public INDArray scoreExamples(MultiDataSet multiDataSet, boolean z) {
        try {
            return scoreExamplesHelper(multiDataSet, z);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray scoreExamplesHelper(MultiDataSet multiDataSet, boolean z) {
        LayerWorkspaceMgr noWorkspaces = this.configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE ? LayerWorkspaceMgr.noWorkspaces() : LayerWorkspaceMgr.builder().with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG).with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, this.WS_LAYER_WORKING_MEM_CONFIG).with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG).build();
        noWorkspaces.setHelperWorkspacePointers(this.helperWorkspaces);
        if (multiDataSet.hasMaskArrays()) {
            setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        }
        INDArray iNDArray = null;
        setInputs(multiDataSet.getFeatures());
        MemoryWorkspace notifyScopeEntered = noWorkspaces.notifyScopeEntered(ArrayType.ACTIVATIONS);
        Throwable th = null;
        try {
            ffToLayerActivationsInWS(false, this.vertices.length - 1, getOutputLayerIndices(), FwdPassType.STANDARD, false, multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays(), false);
            INDArray[] labels = multiDataSet.getLabels();
            setLabels(labels);
            double calcL1 = z ? calcL1() : EvaluationBinary.DEFAULT_EDGE_VALUE;
            double calcL2 = z ? calcL2() : EvaluationBinary.DEFAULT_EDGE_VALUE;
            int i = 0;
            for (String str : this.configuration.getNetworkOutputs()) {
                GraphVertex graphVertex = this.verticesMap.get(str);
                Layer layer = graphVertex.getLayer();
                if (layer == null || !(layer instanceof IOutputLayer)) {
                    throw new UnsupportedOperationException("Cannot calculate score: vertex \"" + str + "\" is not an output layer");
                }
                int i2 = i;
                i++;
                ((IOutputLayer) layer).setLabels(labels[i2]);
                MemoryWorkspace notifyScopeEntered2 = noWorkspaces.notifyScopeEntered(ArrayType.FF_WORKING_MEM);
                Throwable th2 = null;
                try {
                    try {
                        INDArray computeScoreForExamples = ((LayerVertex) graphVertex).computeScoreForExamples(calcL1, calcL2, noWorkspaces);
                        if (notifyScopeEntered2 != null) {
                            if (0 != 0) {
                                try {
                                    notifyScopeEntered2.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                notifyScopeEntered2.close();
                            }
                        }
                        if (iNDArray == null) {
                            iNDArray = computeScoreForExamples.detach();
                        } else {
                            iNDArray.addi(computeScoreForExamples);
                        }
                        calcL1 = 0.0d;
                        calcL2 = 0.0d;
                    } finally {
                    }
                } catch (Throwable th4) {
                    if (notifyScopeEntered2 != null) {
                        if (th2 != null) {
                            try {
                                notifyScopeEntered2.close();
                            } catch (Throwable th5) {
                                th2.addSuppressed(th5);
                            }
                        } else {
                            notifyScopeEntered2.close();
                        }
                    }
                    throw th4;
                }
            }
            if (multiDataSet.hasMaskArrays()) {
                clearLayerMaskArrays();
            }
            clearLayersStates();
            return iNDArray;
        } finally {
            if (notifyScopeEntered != null) {
                if (0 != 0) {
                    try {
                        notifyScopeEntered.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    notifyScopeEntered.close();
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        fit(this.inputs, this.labels, this.inputMaskArrays, this.labelMaskArrays);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        if (gradient.gradient().length() != numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true));
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            INDArray value = entry.getValue();
            int lastIndexOf = key.lastIndexOf(95);
            if (lastIndexOf == -1) {
                throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\"");
            }
            String substring = key.substring(0, lastIndexOf);
            String str = key.split("_")[1];
            this.gradient.gradientForVariable().put(key, value);
            getLayer(substring).update(value, str);
        }
        setBackpropGradientsViewArray(gradient.gradient());
    }

    private void update(Task task) {
        if (this.initDone) {
            return;
        }
        this.initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        Task taskByModel = ModelSerializer.taskByModel(this);
        heartbeat.reportEvent(Event.STANDALONE, EnvironmentUtils.buildEnvironment(), taskByModel);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    public void setScore(double d) {
        this.score = d;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return params(true);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray updaterState() {
        if (getUpdater() != null) {
            return getUpdater().getUpdaterStateViewArray();
        }
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams() {
        return numParams(true);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams(boolean z) {
        int i = 0;
        for (Layer layer : this.layers) {
            i = (int) (i + layer.numParams(z));
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (iNDArray == this.flattenedParams) {
            return;
        }
        if (this.flattenedParams != null && this.flattenedParams.length() == iNDArray.length()) {
            this.flattenedParams.assign(iNDArray);
            return;
        }
        int i = 0;
        for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
            if (this.vertices[this.topologicalOrder[i2]].hasLayer()) {
                Layer layer = this.vertices[this.topologicalOrder[i2]].getLayer();
                long numParams = layer.numParams();
                if (numParams > 0) {
                    layer.setParams(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i, numParams + i)}));
                    i = (int) (i + numParams);
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.flattenedGradients;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        int i = 0;
        for (int i2 = 0; i2 < this.topologicalOrder.length; i2++) {
            if (this.vertices[this.topologicalOrder[i2]].hasLayer()) {
                Layer layer = this.vertices[this.topologicalOrder[i2]].getLayer();
                long numParams = layer.numParams();
                if (numParams > 0) {
                    layer.setBackpropGradientsViewArray(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i, i + numParams)}));
                    i = (int) (i + numParams);
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        throw new UnsupportedOperationException("Cannot pretrain ComputationGraph with single INDArray");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return (int) this.inputs[0].size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        if (this.numInputArrays != 1) {
            throw new UnsupportedOperationException("Cannot return single input: ComputationGraph  has multiple inputs");
        }
        if (this.inputs != null) {
            return this.inputs[0];
        }
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public ConvexOptimizer getOptimizer() {
        return this.solver.getOptimizer();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        int lastIndexOf = str.lastIndexOf(95);
        if (lastIndexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        return getLayer(str.substring(0, lastIndexOf)).getParam(str.substring(lastIndexOf + 1));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (GraphVertex graphVertex : this.vertices) {
            for (Map.Entry<String, INDArray> entry : graphVertex.paramTable(z).entrySet()) {
                linkedHashMap.put(graphVertex.getVertexName() + "_" + entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(@NonNull Map<String, INDArray> map) {
        if (map == null) {
            throw new NullPointerException("paramTable is marked @NonNull but is null");
        }
        Preconditions.checkArgument(map.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal");
        Map<String, INDArray> paramTable = paramTable();
        for (String str : paramTable.keySet()) {
            INDArray iNDArray = paramTable.get(str);
            INDArray iNDArray2 = map.get(str);
            long[] shape = iNDArray.shape();
            Preconditions.checkState(Arrays.equals(shape, iNDArray2.shape()), "Cannot set parameters: shape array for parameter \"%s\" does not match existing shape: parameter shape = %s, new param shape = %s", str, shape, iNDArray2);
        }
        for (String str2 : paramTable.keySet()) {
            paramTable.get(str2).assign(map.get(str2));
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        int lastIndexOf = str.lastIndexOf(95);
        if (lastIndexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        getLayer(str.substring(0, lastIndexOf)).setParam(str.substring(lastIndexOf + 1), iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        this.inputs = null;
        this.labels = null;
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyConstraints(int i, int i2) {
        for (Layer layer : this.layers) {
            layer.applyConstraints(i, i2);
        }
    }

    public INDArray[] rnnTimeStep(INDArray... iNDArrayArr) {
        return rnnTimeStepHelper(null, iNDArrayArr);
    }

    public INDArray[] rnnTimeStep(MemoryWorkspace memoryWorkspace, INDArray... iNDArrayArr) {
        try {
            return rnnTimeStepHelper(memoryWorkspace, iNDArrayArr);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    private INDArray[] rnnTimeStepHelper(MemoryWorkspace memoryWorkspace, INDArray... iNDArrayArr) {
        boolean z = true;
        int length = iNDArrayArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (iNDArrayArr[i].rank() != 2) {
                z = false;
                break;
            }
            i++;
        }
        INDArray[] outputOfLayersDetached = outputOfLayersDetached(false, FwdPassType.RNN_TIMESTEP, getOutputLayerIndices(), iNDArrayArr, null, null, true, false, memoryWorkspace);
        if (z) {
            for (int i2 = 0; i2 < outputOfLayersDetached.length; i2++) {
                if (outputOfLayersDetached[i2].rank() == 3 && outputOfLayersDetached[i2].size(2) == 1) {
                    outputOfLayersDetached[i2] = outputOfLayersDetached[i2].tensorAlongDimension(0, new int[]{1, 0});
                }
            }
        }
        this.inputs = null;
        return outputOfLayersDetached;
    }

    public Map<String, INDArray> rnnGetPreviousState(int i) {
        return rnnGetPreviousState(this.layers[i].conf().getLayer().getLayerName());
    }

    public Map<String, INDArray> rnnGetPreviousState(String str) {
        Layer layer = this.verticesMap.get(str).getLayer();
        if (layer == null || !(layer instanceof RecurrentLayer)) {
            return null;
        }
        return ((RecurrentLayer) layer).rnnGetPreviousState();
    }

    public Map<String, Map<String, INDArray>> rnnGetPreviousStates() {
        HashMap hashMap = new HashMap();
        for (Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                hashMap.put(layer.conf().getLayer().getLayerName(), ((RecurrentLayer) layer).rnnGetPreviousState());
            }
        }
        return hashMap;
    }

    public void rnnSetPreviousState(int i, Map<String, INDArray> map) {
        rnnSetPreviousState(this.layers[i].conf().getLayer().getLayerName(), map);
    }

    public void rnnSetPreviousState(String str, Map<String, INDArray> map) {
        Layer layer = this.verticesMap.get(str).getLayer();
        if (layer == null || !(layer instanceof RecurrentLayer)) {
            throw new UnsupportedOperationException("Layer \"" + str + "\" is not a recurrent layer. Cannot set state");
        }
        ((RecurrentLayer) layer).rnnSetPreviousState(map);
    }

    public void rnnSetPreviousStates(Map<String, Map<String, INDArray>> map) {
        for (Map.Entry<String, Map<String, INDArray>> entry : map.entrySet()) {
            rnnSetPreviousState(entry.getKey(), entry.getValue());
        }
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (Layer layer : this.layers) {
            if (layer instanceof RecurrentLayer) {
                ((RecurrentLayer) layer).rnnClearPreviousState();
            } else if (layer instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) layer).rnnClearPreviousState();
            }
        }
    }

    protected void doTruncatedBPTT(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        long j = -1;
        for (INDArray iNDArray : iNDArrayArr) {
            if (iNDArray.rank() == 3) {
                if (j == -1) {
                    j = iNDArray.size(2);
                } else if (j != iNDArray.size(2)) {
                    log.warn("Cannot do TBPTT with time series of different lengths");
                    return;
                }
            }
        }
        for (INDArray iNDArray2 : iNDArrayArr2) {
            if (iNDArray2.rank() == 3) {
                if (j == -1) {
                    j = iNDArray2.size(2);
                } else if (j != iNDArray2.size(2)) {
                    log.warn("Cannot do TBPTT with time series of different lengths");
                    return;
                }
            }
        }
        long tbpttFwdLength = this.configuration.getTbpttFwdLength();
        long j2 = j / tbpttFwdLength;
        if (j % tbpttFwdLength != 0) {
            j2++;
        }
        rnnClearPreviousState();
        for (int i = 0; i < j2; i++) {
            long j3 = i * tbpttFwdLength;
            long j4 = j3 + tbpttFwdLength;
            if (j4 > j) {
                j4 = j;
            }
            List<INDArray[]> subsetsForTbptt = getSubsetsForTbptt((int) j3, j4, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
            setInputs(subsetsForTbptt.get(0));
            setLabels(subsetsForTbptt.get(1));
            setLayerMaskArrays(subsetsForTbptt.get(2), subsetsForTbptt.get(3));
            if (this.solver == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (th != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            this.solver.optimize(layerWorkspaceMgr);
            rnnUpdateStateWithTBPTTState();
        }
        if (this.clearTbpttState) {
            rnnClearPreviousState();
        }
        clearLayerMaskArrays();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<INDArray[]> getSubsetsForTbptt(int i, long j, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        INDArray[] iNDArrayArr5 = new INDArray[iNDArrayArr.length];
        INDArray[] iNDArrayArr6 = new INDArray[iNDArrayArr2.length];
        INDArray[] iNDArrayArr7 = iNDArrayArr3 != null ? new INDArray[iNDArrayArr3.length] : null;
        INDArray[] iNDArrayArr8 = iNDArrayArr4 != null ? new INDArray[iNDArrayArr4.length] : null;
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (iNDArrayArr[i2].rank() != 3) {
                iNDArrayArr5[i2] = iNDArrayArr[i2];
            } else {
                iNDArrayArr5[i2] = iNDArrayArr[i2].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, j)});
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
            if (iNDArrayArr2[i3].rank() != 3) {
                iNDArrayArr6[i3] = iNDArrayArr2[i3];
            } else {
                iNDArrayArr6[i3] = iNDArrayArr2[i3].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i, j)});
            }
        }
        if (iNDArrayArr3 != null) {
            for (int i4 = 0; i4 < iNDArrayArr3.length; i4++) {
                if (iNDArrayArr3[i4] != null) {
                    iNDArrayArr7[i4] = iNDArrayArr3[i4].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i, j)});
                }
            }
        }
        if (iNDArrayArr4 != null) {
            for (int i5 = 0; i5 < iNDArrayArr4.length; i5++) {
                if (iNDArrayArr4[i5] != null) {
                    iNDArrayArr8[i5] = iNDArrayArr4[i5].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i, j)});
                }
            }
        }
        return Arrays.asList(iNDArrayArr5, iNDArrayArr6, iNDArrayArr7, iNDArrayArr8);
    }

    public Map<String, INDArray> rnnActivateUsingStoredState(INDArray[] iNDArrayArr, boolean z, boolean z2) {
        return ffToLayerActivationsDetached(z, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, z2, this.vertices.length - 1, null, iNDArrayArr, this.inputMaskArrays, this.labelMaskArrays, true);
    }

    public void setLayerMaskArrays(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        clearLayerMaskArrays();
        this.inputMaskArrays = iNDArrayArr;
        this.labelMaskArrays = iNDArrayArr2;
        if (iNDArrayArr != null) {
            if (iNDArrayArr.length != this.numInputArrays) {
                throw new IllegalArgumentException("Invalid number of feature mask arrays");
            }
            long j = -1;
            for (INDArray iNDArray : iNDArrayArr) {
                if (iNDArray != null) {
                    j = iNDArray.size(0);
                }
            }
            HashMap hashMap = new HashMap();
            for (int i = 0; i < this.topologicalOrder.length; i++) {
                GraphVertex graphVertex = this.vertices[this.topologicalOrder[i]];
                if (graphVertex.isInputVertex()) {
                    hashMap.put(Integer.valueOf(graphVertex.getVertexIndex()), new Pair(iNDArrayArr[graphVertex.getVertexIndex()], MaskState.Active));
                } else {
                    VertexIndices[] inputVertices = graphVertex.getInputVertices();
                    INDArray[] iNDArrayArr3 = null;
                    MaskState maskState = null;
                    for (int i2 = 0; i2 < inputVertices.length; i2++) {
                        Pair pair = (Pair) hashMap.get(Integer.valueOf(inputVertices[i2].getVertexIndex()));
                        if (pair != null) {
                            if (iNDArrayArr3 == null) {
                                iNDArrayArr3 = new INDArray[inputVertices.length];
                            }
                            iNDArrayArr3[i2] = (INDArray) pair.getFirst();
                            if (maskState == null || maskState == MaskState.Passthrough) {
                                maskState = (MaskState) pair.getSecond();
                            }
                        }
                    }
                    hashMap.put(Integer.valueOf(this.topologicalOrder[i]), graphVertex.feedForwardMaskArrays(iNDArrayArr3, maskState, (int) j));
                }
            }
        }
        if (iNDArrayArr2 != null) {
            if (iNDArrayArr2.length != this.numOutputArrays) {
                throw new IllegalArgumentException("Invalid number of label mask arrays");
            }
            for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
                if (iNDArrayArr2[i3] != null) {
                    this.verticesMap.get(this.configuration.getNetworkOutputs().get(i3)).getLayer().setMaskArray(iNDArrayArr2[i3]);
                }
            }
        }
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
        this.inputMaskArrays = null;
        this.labelMaskArrays = null;
    }

    protected void rnnUpdateStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof RecurrentLayer) {
                RecurrentLayer recurrentLayer = (RecurrentLayer) this.layers[i];
                recurrentLayer.rnnSetPreviousState(recurrentLayer.rnnGetTBPTTState());
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).updateRnnStateWithTBPTTState();
            }
        }
    }

    public <T extends Evaluation> T evaluate(DataSetIterator dataSetIterator) {
        return (T) evaluate(dataSetIterator, (List<String>) null);
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator multiDataSetIterator) {
        return (T) evaluate(multiDataSetIterator, (List<String>) null);
    }

    public <T extends Evaluation> T evaluate(DataSetIterator dataSetIterator, List<String> list) {
        return (T) evaluate(dataSetIterator, list, 1);
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator multiDataSetIterator, List<String> list) {
        return (T) evaluate(multiDataSetIterator, list, 1);
    }

    public <T extends Evaluation> T evaluate(DataSetIterator dataSetIterator, List<String> list, int i) {
        if (list == null) {
            list = dataSetIterator.getLabels();
        }
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        return ((org.deeplearning4j.eval.Evaluation[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.Evaluation(list, i)))[0];
    }

    public <T extends Evaluation> T evaluate(MultiDataSetIterator multiDataSetIterator, List<String> list, int i) {
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class);
        }
        return ((org.deeplearning4j.eval.Evaluation[]) doEvaluation(multiDataSetIterator, new org.deeplearning4j.eval.Evaluation(list, i)))[0];
    }

    public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator dataSetIterator) {
        return (T) evaluateRegression(dataSetIterator, (List<String>) null);
    }

    public <T extends RegressionEvaluation> T evaluateRegression(MultiDataSetIterator multiDataSetIterator) {
        return (T) evaluateRegression(multiDataSetIterator, (List<String>) null);
    }

    public <T extends RegressionEvaluation> T evaluateRegression(DataSetIterator dataSetIterator, List<String> list) {
        return ((org.deeplearning4j.eval.RegressionEvaluation[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.RegressionEvaluation(list)))[0];
    }

    public <T extends RegressionEvaluation> T evaluateRegression(MultiDataSetIterator multiDataSetIterator, List<String> list) {
        return ((org.deeplearning4j.eval.RegressionEvaluation[]) doEvaluation(multiDataSetIterator, new org.deeplearning4j.eval.RegressionEvaluation(list)))[0];
    }

    public <T extends ROC> T evaluateROC(DataSetIterator dataSetIterator) {
        return (T) evaluateROC(dataSetIterator, 0);
    }

    public <T extends ROC> T evaluateROC(DataSetIterator dataSetIterator, int i) {
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
        }
        return ((org.deeplearning4j.eval.ROC[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.ROC(i)))[0];
    }

    public <T extends ROC> T evaluateROC(MultiDataSetIterator multiDataSetIterator) {
        return (T) evaluateROC(multiDataSetIterator, 0);
    }

    public <T extends ROC> T evaluateROC(MultiDataSetIterator multiDataSetIterator, int i) {
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class);
        }
        return ((org.deeplearning4j.eval.ROC[]) doEvaluation(multiDataSetIterator, new org.deeplearning4j.eval.ROC(i)))[0];
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator dataSetIterator) {
        return (T) evaluateROCMultiClass(dataSetIterator, 0);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(DataSetIterator dataSetIterator, int i) {
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return ((org.deeplearning4j.eval.ROCMultiClass[]) doEvaluation(dataSetIterator, new org.deeplearning4j.eval.ROCMultiClass(i)))[0];
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(MultiDataSetIterator multiDataSetIterator, int i) {
        Layer outputLayer = getOutputLayer(0);
        if (getConfiguration().isValidateOutputLayerConfig()) {
            OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class);
        }
        return ((org.deeplearning4j.eval.ROCMultiClass[]) doEvaluation(multiDataSetIterator, new org.deeplearning4j.eval.ROCMultiClass(i)))[0];
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public <T extends IEvaluation> T[] doEvaluation(DataSetIterator dataSetIterator, T... tArr) {
        return (T[]) doEvaluation(new MultiDataSetIteratorAdapter(dataSetIterator), tArr);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public <T extends IEvaluation> T[] doEvaluation(MultiDataSetIterator multiDataSetIterator, T... tArr) {
        try {
            return (T[]) doEvaluationHelper(multiDataSetIterator, tArr);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    public <T extends IEvaluation> Map<Integer, T[]> evaluate(DataSetIterator dataSetIterator, Map<Integer, T[]> map) {
        return evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), map);
    }

    public <T extends IEvaluation> Map<Integer, T[]> evaluate(MultiDataSetIterator multiDataSetIterator, Map<Integer, T[]> map) {
        try {
            return doEvaluationHelper(multiDataSetIterator, map);
        } catch (OutOfMemoryError e) {
            CrashReportingUtil.writeMemoryCrashDump(this, e);
            throw e;
        }
    }

    @SafeVarargs
    private final <T extends IEvaluation> T[] doEvaluationHelper(MultiDataSetIterator multiDataSetIterator, T... tArr) {
        return doEvaluationHelper(multiDataSetIterator, Collections.singletonMap(0, tArr)).get(0);
    }

    private <T extends IEvaluation> Map<Integer, T[]> doEvaluationHelper(MultiDataSetIterator multiDataSetIterator, Map<Integer, T[]> map) {
        MemoryWorkspace notifyScopeEntered;
        MemoryWorkspace scopeOutOfWorkspaces;
        if (this.layers == null || !(getOutputLayer(0) instanceof IOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        WorkspaceUtils.assertNoWorkspacesOpen("Expected no external workspaces open at start of evaluation (doEvaluationHelper)");
        if (multiDataSetIterator.resetSupported() && !multiDataSetIterator.hasNext()) {
            multiDataSetIterator.reset();
        }
        MultiDataSetIterator asyncMultiDataSetIterator = multiDataSetIterator.asyncSupported() ? new AsyncMultiDataSetIterator(multiDataSetIterator, 2, true) : multiDataSetIterator;
        WorkspaceMode trainingWorkspaceMode = this.configuration.getTrainingWorkspaceMode();
        this.configuration.setTrainingWorkspaceMode(this.configuration.getInferenceWorkspaceMode());
        boolean z = this.configuration.getBackpropType() == BackpropType.TruncatedBPTT;
        MemoryWorkspace workspaceForCurrentThread = getConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED ? Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM) : new DummyWorkspace();
        while (asyncMultiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) asyncMultiDataSetIterator.next();
            if (multiDataSet.getFeatures() != null && multiDataSet.getLabels() != null) {
                if (z) {
                    rnnClearPreviousState();
                    int tbpttFwdLength = this.configuration.getTbpttFwdLength();
                    long j = -1;
                    long length = multiDataSet.getFeatures().length;
                    for (int i = 0; i < length; i++) {
                        if (multiDataSet.getFeatures(i).rank() == 3) {
                            j = multiDataSet.getFeatures(i).size(2);
                        }
                    }
                    if (j < 0) {
                        throw new IllegalStateException("Invalid configuration: detected TBPTT backprop type without time series features");
                    }
                    long j2 = j / tbpttFwdLength;
                    if (j % tbpttFwdLength != 0) {
                        j2++;
                    }
                    for (int i2 = 0; i2 < j2; i2++) {
                        List<INDArray[]> subsetsForTbptt = getSubsetsForTbptt(i2 * tbpttFwdLength, Math.min(r0 + tbpttFwdLength, j), multiDataSet.getFeatures(), multiDataSet.getLabels(), multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
                        setLayerMaskArrays(subsetsForTbptt.get(2), subsetsForTbptt.get(3));
                        notifyScopeEntered = workspaceForCurrentThread.notifyScopeEntered();
                        Throwable th = null;
                        try {
                            try {
                                INDArray[] rnnTimeStep = rnnTimeStep(notifyScopeEntered, subsetsForTbptt.get(0));
                                for (Integer num : map.keySet()) {
                                    T[] tArr = map.get(num);
                                    if (tArr != null) {
                                        INDArray iNDArray = subsetsForTbptt.get(1) == null ? null : subsetsForTbptt.get(1)[num.intValue()];
                                        INDArray iNDArray2 = subsetsForTbptt.get(3) == null ? null : subsetsForTbptt.get(3)[num.intValue()];
                                        INDArray iNDArray3 = rnnTimeStep[num.intValue()];
                                        scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                                        Throwable th2 = null;
                                        try {
                                            try {
                                                for (T t : tArr) {
                                                    t.eval(iNDArray, iNDArray3, iNDArray2);
                                                }
                                                if (scopeOutOfWorkspaces != null) {
                                                    if (0 != 0) {
                                                        try {
                                                            scopeOutOfWorkspaces.close();
                                                        } catch (Throwable th3) {
                                                            th2.addSuppressed(th3);
                                                        }
                                                    } else {
                                                        scopeOutOfWorkspaces.close();
                                                    }
                                                }
                                            } finally {
                                            }
                                        } finally {
                                            if (scopeOutOfWorkspaces != null) {
                                                if (th2 != null) {
                                                    try {
                                                        scopeOutOfWorkspaces.close();
                                                    } catch (Throwable th4) {
                                                        th2.addSuppressed(th4);
                                                    }
                                                } else {
                                                    scopeOutOfWorkspaces.close();
                                                }
                                            }
                                        }
                                    }
                                }
                                if (notifyScopeEntered != null) {
                                    if (0 != 0) {
                                        try {
                                            notifyScopeEntered.close();
                                        } catch (Throwable th5) {
                                            th.addSuppressed(th5);
                                        }
                                    } else {
                                        notifyScopeEntered.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                        }
                    }
                    rnnClearPreviousState();
                } else {
                    INDArray[] features = multiDataSet.getFeatures();
                    INDArray[] featuresMaskArrays = multiDataSet.getFeaturesMaskArrays();
                    INDArray[] labels = multiDataSet.getLabels();
                    INDArray[] labelsMaskArrays = multiDataSet.getLabelsMaskArrays();
                    notifyScopeEntered = workspaceForCurrentThread.notifyScopeEntered();
                    Throwable th6 = null;
                    try {
                        try {
                            INDArray[] outputOfLayersDetached = outputOfLayersDetached(false, FwdPassType.STANDARD, getOutputLayerIndices(), features, featuresMaskArrays, labelsMaskArrays, true, false, notifyScopeEntered);
                            for (Integer num2 : map.keySet()) {
                                Preconditions.checkState(num2.intValue() >= 0 && num2.intValue() < labels.length, "Invalid output index: evaluation/output indices must be between 0 and numOutputs-1 (%s), got index %s", this.numOutputArrays, num2.intValue());
                                T[] tArr2 = map.get(num2);
                                if (tArr2 != null) {
                                    Preconditions.checkState(num2.intValue() >= 0 && num2.intValue() < getNumOutputArrays(), "Invalid output index: indices for outputs must be between 0 and %s inclusive - found index %s", this.numOutputArrays, num2.intValue());
                                    INDArray iNDArray4 = outputOfLayersDetached[num2.intValue()];
                                    INDArray iNDArray5 = labels[num2.intValue()];
                                    scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                                    Throwable th7 = null;
                                    try {
                                        try {
                                            for (T t2 : tArr2) {
                                                t2.eval(iNDArray5, iNDArray4, multiDataSet.getLabelsMaskArray(num2.intValue()));
                                            }
                                            if (scopeOutOfWorkspaces != null) {
                                                if (0 != 0) {
                                                    try {
                                                        scopeOutOfWorkspaces.close();
                                                    } catch (Throwable th8) {
                                                        th7.addSuppressed(th8);
                                                    }
                                                } else {
                                                    scopeOutOfWorkspaces.close();
                                                }
                                            }
                                        } finally {
                                        }
                                    } finally {
                                    }
                                }
                            }
                            if (notifyScopeEntered != null) {
                                if (0 != 0) {
                                    try {
                                        notifyScopeEntered.close();
                                    } catch (Throwable th9) {
                                        th6.addSuppressed(th9);
                                    }
                                } else {
                                    notifyScopeEntered.close();
                                }
                            }
                        } finally {
                        }
                    } finally {
                    }
                }
                clearLayersStates();
            }
        }
        if (multiDataSetIterator.asyncSupported()) {
            ((AsyncMultiDataSetIterator) asyncMultiDataSetIterator).shutdown();
        }
        this.configuration.setTrainingWorkspaceMode(trainingWorkspaceMode);
        return map;
    }

    public String summary() {
        return summary((InputType[]) null);
    }

    public String summary(InputType... inputTypeArr) {
        VertexIndices[] inputVertices;
        StringBuilder sb = new StringBuilder();
        sb.append("\n");
        int i = 0;
        HashMap hashMap = new HashMap();
        int i2 = -1;
        ArrayList arrayList = new ArrayList();
        if (inputTypeArr == null) {
            arrayList.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs"});
        } else {
            arrayList.add(new String[]{"VertexName (VertexType)", "nIn,nOut", "TotalParams", "ParamsShape", "Vertex Inputs", "InputShape", "OutputShape"});
        }
        int[] iArr = new int[(inputTypeArr == null || inputTypeArr.length == 0) ? 5 : 7];
        String[] strArr = (String[]) arrayList.get(0);
        for (int i3 = 0; i3 < strArr.length; i3++) {
            iArr[i3] = strArr[i3].length();
        }
        for (int i4 : this.topologicalOrder) {
            GraphVertex graphVertex = this.vertices[i4];
            String vertexName = graphVertex.getVertexName();
            String[] split = graphVertex.getClass().toString().split("\\.");
            String str = split[split.length - 1];
            String str2 = "-";
            String str3 = "-";
            String str4 = "-";
            String str5 = "-";
            String str6 = "-";
            String str7 = "-";
            String str8 = "-";
            if (!graphVertex.isInputVertex()) {
                str2 = this.configuration.getVertexInputs().get(vertexName).toString();
                ArrayList arrayList2 = new ArrayList();
                if (graphVertex.hasLayer()) {
                    Layer layer = ((LayerVertex) graphVertex).getLayer();
                    String[] split2 = layer.getClass().getName().split("\\.");
                    str = split2[split2.length - 1];
                    str5 = String.valueOf(layer.numParams());
                    if (layer.numParams() > 0) {
                        String str9 = "";
                        if (layer instanceof BidirectionalLayer) {
                            BidirectionalLayer bidirectionalLayer = (BidirectionalLayer) layer;
                            str6 = String.valueOf(((Bidirectional) bidirectionalLayer.conf().getLayer()).getNIn());
                            str7 = String.valueOf(((Bidirectional) bidirectionalLayer.conf().getLayer()).getNOut());
                        } else {
                            try {
                                str6 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNIn());
                                str7 = String.valueOf(((FeedForwardLayer) layer.conf().getLayer()).getNOut());
                            } catch (Exception e) {
                            }
                        }
                        for (String str10 : layer.conf().variables()) {
                            str9 = str9 + str10 + ":" + ArrayUtils.toString(layer.paramTable().get(str10).shape()) + ", ";
                        }
                        str8 = str9.subSequence(0, str9.lastIndexOf(",")).toString();
                    }
                    if (layer instanceof FrozenLayer) {
                        i = (int) (i + layer.numParams());
                        String[] split3 = ((FrozenLayer) layer).getInsideLayer().getClass().getName().split("\\.");
                        str = "Frozen " + split3[split3.length - 1];
                    }
                    if (inputTypeArr != null) {
                        InputType inputType = (InputType) hashMap.get(this.vertices[graphVertex.getInputVertices()[0].getVertexIndex()].getVertexName());
                        str3 = inputType.toString();
                        arrayList2.add(inputType);
                        InputPreProcessor preProcessor = ((org.deeplearning4j.nn.conf.graph.LayerVertex) this.configuration.getVertices().get(vertexName)).getPreProcessor();
                        if (preProcessor != null) {
                            str3 = str3 + "-->" + preProcessor.getOutputType(inputType);
                        }
                    }
                    i2++;
                } else if (inputTypeArr != null && (inputVertices = graphVertex.getInputVertices()) != null) {
                    for (VertexIndices vertexIndices : inputVertices) {
                        arrayList2.add(hashMap.get(this.vertices[vertexIndices.getVertexIndex()].getVertexName()));
                    }
                }
                if (inputTypeArr != null) {
                    InputType outputType = this.configuration.getVertices().get(vertexName).getOutputType(i2, (InputType[]) arrayList2.toArray(new InputType[arrayList2.size()]));
                    str4 = outputType.toString();
                    hashMap.put(vertexName, outputType);
                }
            } else if (inputTypeArr != null) {
                hashMap.put(vertexName, inputTypeArr[this.configuration.getNetworkInputs().indexOf(vertexName)]);
            }
            String[] strArr2 = inputTypeArr == null ? new String[]{vertexName + " (" + str + ")", str6 + "," + str7, str5, str8, str2} : new String[]{vertexName + " (" + str + ")", str6 + "," + str7, str5, str8, str2, str3, str4};
            for (int i5 = 0; i5 < strArr2.length; i5++) {
                iArr[i5] = Math.max(iArr[i5], strArr2[i5] == null ? 0 : strArr2[i5].length());
            }
            arrayList.add(strArr2);
        }
        StringBuilder sb2 = new StringBuilder();
        int i6 = 0;
        int i7 = 0;
        for (int i8 : iArr) {
            int i9 = i7;
            i7++;
            int i10 = i9 == iArr.length - 1 ? i8 : i8 + 3;
            sb2.append("%-").append(i10).append("s");
            i6 += i10;
        }
        sb2.append("\n");
        String sb3 = sb2.toString();
        sb.append(StringUtils.repeat("=", i6)).append("\n");
        boolean z = true;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            sb.append(String.format(sb3, (String[]) it.next()));
            if (z) {
                sb.append(StringUtils.repeat("=", i6)).append("\n");
                z = false;
            }
        }
        sb.append(StringUtils.repeat("-", i6)).append(String.format("\n%30s %d", "Total Parameters: ", Long.valueOf(params().length()))).append(String.format("\n%30s %d", "Trainable Parameters: ", Long.valueOf(params().length() - i))).append(String.format("\n%30s %d", "Frozen Parameters: ", Integer.valueOf(i))).append("\n").append(StringUtils.repeat("=", i6)).append("\n");
        return sb.toString();
    }

    public String memoryInfo(int i, InputType... inputTypeArr) {
        return CrashReportingUtil.generateMemoryStatus(this, i, inputTypeArr);
    }

    public void clearLayersStates() {
        for (Layer layer : this.layers) {
            layer.clear();
            layer.clearNoiseWeightParams();
        }
        for (GraphVertex graphVertex : this.vertices) {
            graphVertex.clearVertex();
        }
    }

    public void incrementEpochCount() {
        this.configuration.setEpochCount(this.configuration.getEpochCount() + 1);
        synchronizeIterEpochCounts();
    }

    protected void synchronizeIterEpochCounts() {
        int iterationCount = getConfiguration().getIterationCount();
        int epochCount = getConfiguration().getEpochCount();
        for (Layer layer : this.layers) {
            layer.setIterationCount(iterationCount);
            layer.setEpochCount(epochCount);
        }
    }

    public int getIterationCount() {
        return this.configuration.getIterationCount();
    }

    public int getEpochCount() {
        return this.configuration.getEpochCount();
    }

    public void save(File file) throws IOException {
        save(file, true);
    }

    public void save(File file, boolean z) throws IOException {
        ModelSerializer.writeModel(this, file, z);
    }

    public static ComputationGraph load(File file, boolean z) throws IOException {
        return ModelSerializer.restoreComputationGraph(file, z);
    }

    public void setLearningRate(double d) {
        NetworkUtils.setLearningRate(this, d);
    }

    public void setLearningRate(ISchedule iSchedule) {
        NetworkUtils.setLearningRate(this, iSchedule);
    }

    public void setLearningRate(String str, double d) {
        NetworkUtils.setLearningRate(this, str, d);
    }

    public void setLearningRate(String str, ISchedule iSchedule) {
        NetworkUtils.setLearningRate(this, str, iSchedule);
    }

    public Double getLearningRate(String str) {
        return NetworkUtils.getLearningRate(this, str);
    }

    public int layerSize(int i) {
        if (i < 0 || i > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + i + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        return layerSize(this.layers[i].conf().getLayer().getLayerName());
    }

    public int layerInputSize(int i) {
        if (i < 0 || i > this.layers.length) {
            throw new IllegalArgumentException("Invalid layer index: " + i + ". Layer index must be between 0 and " + (this.layers.length - 1) + " inclusive");
        }
        return layerInputSize(this.layers[i].conf().getLayer().getLayerName());
    }

    public int layerSize(String str) {
        Layer layer = getLayer(str);
        if (layer == null) {
            throw new IllegalArgumentException("No layer with name \"" + str + "\" exists");
        }
        org.deeplearning4j.nn.conf.layers.Layer layer2 = layer.conf().getLayer();
        if (layer2 == null || !(layer2 instanceof FeedForwardLayer)) {
            return 0;
        }
        return (int) ((FeedForwardLayer) layer2).getNOut();
    }

    public int layerInputSize(String str) {
        Layer layer = getLayer(str);
        if (layer == null) {
            throw new IllegalArgumentException("No layer with name \"" + str + "\" exists");
        }
        org.deeplearning4j.nn.conf.layers.Layer layer2 = layer.conf().getLayer();
        if (layer2 == null || !(layer2 instanceof FeedForwardLayer)) {
            return 0;
        }
        return (int) ((FeedForwardLayer) layer2).getNIn();
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof ComputationGraph)) {
            return false;
        }
        ComputationGraph computationGraph = (ComputationGraph) obj;
        return computationGraph.params().equals(params()) && getConfiguration().equals(computationGraph.getConfiguration()) && getUpdater().equals(computationGraph.getUpdater());
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        ModelSerializer.writeModel((Model) this, (OutputStream) objectOutputStream, true);
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        ComputationGraph restoreComputationGraph = ModelSerializer.restoreComputationGraph((InputStream) objectInputStream, true);
        this.defaultConfiguration = restoreComputationGraph.defaultConfiguration.m33clone();
        this.configuration = restoreComputationGraph.configuration.m26clone();
        init();
        this.flattenedParams.assign(restoreComputationGraph.flattenedParams);
        if (restoreComputationGraph.getUpdater() == null || restoreComputationGraph.getUpdater(false).getStateViewArray() == null) {
            return;
        }
        getUpdater(true).getStateViewArray().assign(restoreComputationGraph.getUpdater(false).getStateViewArray());
    }

    public INDArray getFlattenedGradients() {
        return this.flattenedGradients;
    }

    public void setInitDone(boolean z) {
        this.initDone = z;
    }

    public Map<String, Pointer> getHelperWorkspaces() {
        return this.helperWorkspaces;
    }
}
