package org.deeplearning4j.nn.layers.samediff;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/samediff/SameDiffLayer.class */
public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
    public static final String INPUT_KEY = "input";
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m157clone() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        if (this.sameDiff == null) {
            doInit();
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.sameDiff.associateArrayWithVariable(this.input.dup(), this.sameDiff.getVariable("input"));
                for (String str : this.paramTable.keySet()) {
                    this.sameDiff.associateArrayWithVariable(this.paramTable.get(str), str);
                }
                INDArray dup = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, (INDArray) this.sameDiff.exec((Map) null, new String[]{this.outputKey}).get(this.outputKey));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return dup;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        DefaultGradient defaultGradient = new DefaultGradient();
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.sameDiff.associateArrayWithVariable(this.input.dup(), this.sameDiff.getVariable("input"));
                this.fn.updateVariable(this.outputVar.getVarName(), iNDArray.dup());
                for (String str : this.paramTable.keySet()) {
                    this.sameDiff.associateArrayWithVariable(this.paramTable.get(str), str);
                }
                this.sameDiff.execBackwards(Collections.emptyMap());
                for (String str2 : this.paramTable.keySet()) {
                    INDArray arr = this.sameDiff.grad(str2).getArr();
                    INDArray iNDArray2 = this.gradTable.get(str2);
                    iNDArray2.assign(arr);
                    defaultGradient.gradientForVariable().put(str2, iNDArray2);
                }
                INDArray arr2 = this.sameDiff.grad("input").getArr();
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return new Pair<>(defaultGradient, layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, arr2));
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return this.params;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return this.paramTable.get(str);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public long numParams() {
        if (this.params == null) {
            return 0L;
        }
        return (int) this.params.length();
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        if (!this.paramTable.containsKey(str)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + str);
        }
        INDArray iNDArray2 = this.paramTable.get(str);
        if (!Arrays.equals(iNDArray2.shape(), iNDArray.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + str + "\", invalid shape: parameter array has shape " + Arrays.toString(iNDArray2.shape()) + ", trying to set parameter of shape " + Arrays.toString(iNDArray.shape()));
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (iNDArray != null) {
            throw new UnsupportedOperationException("Not supported");
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer
    protected void setParams(INDArray iNDArray, char c) {
        setParams(iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        this.params = iNDArray;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        this.gradients = iNDArray;
        this.gradTable = layerConf().initializer().getGradientsFromFlattened(conf(), iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        if (this.paramTable == null) {
            this.paramTable = map;
            return;
        }
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            setParam(entry.getKey(), entry.getValue());
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        return this.paramTable;
    }

    protected void doInit() {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer sameDiffLayer = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
            this.sameDiff = SameDiff.create();
            Map<String, INDArray> paramTable = paramTable();
            SDVariable var = this.sameDiff.var("input", (long[]) this.input.shape().clone());
            Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (String str : paramShapes.keySet()) {
                linkedHashMap.put(str, this.sameDiff.var(str, paramShapes.get(str)));
            }
            SDVariable defineLayer = sameDiffLayer.defineLayer(this.sameDiff, var, linkedHashMap);
            Preconditions.checkNotNull(defineLayer, "Invalid output: layer output is null");
            this.outputVar = defineLayer;
            for (Map.Entry<String, INDArray> entry : paramTable.entrySet()) {
                this.sameDiff.associateArrayWithVariable(entry.getValue(), this.sameDiff.getVariable(entry.getKey()));
            }
            this.fn = this.sameDiff.f().externalErrors(new SDVariable[]{defineLayer});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.getVarName();
            if (scopeOutOfWorkspaces != null) {
                if (0 == 0) {
                    scopeOutOfWorkspaces.close();
                    return;
                }
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }
}
