package org.deeplearning4j.nn.layers.samediff;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.params.SameDiffParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.temp.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.class */
public class SameDiffGraphVertex extends BaseGraphVertex {
    protected SameDiffVertex config;
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected Map<String, SDVariable> inputVars;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffGraphVertex(SameDiffVertex sameDiffVertex, ComputationGraph computationGraph, String str, int i, INDArray iNDArray, boolean z) {
        super(computationGraph, str, i, null, null);
        this.config = sameDiffVertex;
        SDVertexParams vertexParams = sameDiffVertex.getVertexParams();
        this.paramTable = SameDiffParamInitializer.getInstance().subsetAndReshape(vertexParams.getParameterKeys(), vertexParams.getParamShapes(), iNDArray, null, sameDiffVertex);
        if (z) {
            sameDiffVertex.initializeParameters(this.paramTable);
        }
        this.params = iNDArray;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (this.sameDiff == null) {
            doInit();
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.sameDiff.clearExecutionCache();
                for (int i = 0; i < this.inputs.length; i++) {
                    this.sameDiff.associateArrayWithVariable(this.inputs[i].dup(), this.sameDiff.getVariable(this.config.getVertexParams().getInputs().get(i)));
                }
                if (this.paramTable != null && this.paramTable.size() > 0) {
                    for (String str : this.paramTable.keySet()) {
                        this.sameDiff.associateArrayWithVariable(this.paramTable.get(str), str);
                    }
                }
                this.sameDiff.exec();
                INDArray dup = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, this.sameDiff.getArrForVarName(this.outputKey));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return dup;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        DefaultGradient defaultGradient = new DefaultGradient();
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.sameDiff.clearExecutionCache();
                for (int i = 0; i < this.inputs.length; i++) {
                    this.sameDiff.associateArrayWithVariable(this.inputs[i].dup(), this.sameDiff.getVariable(this.config.getVertexParams().getInputs().get(i)));
                }
                this.fn.updateVariable(this.outputVar.getVarName(), this.epsilon.dup());
                for (String str : this.paramTable.keySet()) {
                    this.sameDiff.associateArrayWithVariable(this.paramTable.get(str), str);
                }
                this.sameDiff.execBackwards();
                for (String str2 : this.paramTable.keySet()) {
                    INDArray arr = this.sameDiff.grad(str2).getArr();
                    INDArray iNDArray = this.gradTable.get(str2);
                    iNDArray.assign(arr);
                    defaultGradient.gradientForVariable().put(str2, iNDArray);
                }
                INDArray[] iNDArrayArr = new INDArray[this.inputs.length];
                for (int i2 = 0; i2 < this.inputs.length; i2++) {
                    iNDArrayArr[i2] = this.sameDiff.grad(this.config.getVertexParams().getInputs().get(i2)).getArr();
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
                    iNDArrayArr[i3] = layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, iNDArrayArr[i3]);
                }
                return new Pair<>(defaultGradient, iNDArrayArr);
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        SDVertexParams vertexParams = this.config.getVertexParams();
        this.gradTable = SameDiffParamInitializer.getInstance().subsetAndReshape(vertexParams.getParameterKeys(), vertexParams.getParamShapes(), iNDArray, null, this.config);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        throw new UnsupportedOperationException("Not yet supported");
    }

    protected void doInit() {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            this.sameDiff = SameDiff.create();
            this.inputVars = new LinkedHashMap();
            int i = 0;
            for (String str : this.config.getVertexParams().getInputs()) {
                int i2 = i;
                i++;
                this.inputVars.put(str, this.sameDiff.var(str, (long[]) this.inputs[i2].shape().clone()));
            }
            Map<String, long[]> paramShapes = this.config.getVertexParams().getParamShapes();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (String str2 : paramShapes.keySet()) {
                linkedHashMap.put(str2, this.sameDiff.var(str2, paramShapes.get(str2)));
            }
            SDVariable defineVertex = this.config.defineVertex(this.sameDiff, this.inputVars, linkedHashMap);
            Preconditions.checkNotNull(defineVertex, "Invalid output: layer output is null");
            this.outputVar = defineVertex;
            for (Map.Entry<String, INDArray> entry : this.paramTable.entrySet()) {
                this.sameDiff.associateArrayWithVariable(entry.getValue(), this.sameDiff.getVariable(entry.getKey()));
            }
            this.fn = this.sameDiff.f().externalErrors(new SDVariable[]{defineVertex});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.getVarName();
            if (scopeOutOfWorkspaces != null) {
                if (0 == 0) {
                    scopeOutOfWorkspaces.close();
                    return;
                }
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void clearVertex() {
        clear();
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex, org.deeplearning4j.nn.api.Trainable
    public Map<String, INDArray> paramTable(boolean z) {
        return this.paramTable;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public TrainingConfig getConfig() {
        return this.config;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public INDArray params() {
        return this.params;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }
}
