package org.deeplearning4j.nn.layers.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/samediff/SameDiffLayer.class */
public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
    public static final String INPUT_KEY = "input";
    public static final String MASK_KEY = "mask";
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffLayer(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m163clone() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                if (this.sameDiff == null) {
                    doInit();
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf()).validateInput(this.input);
                HashMap hashMap = new HashMap();
                hashMap.put("input", this.input);
                if (this.maskArray != null) {
                    hashMap.put(MASK_KEY, this.maskArray);
                } else {
                    hashMap.put(MASK_KEY, layerConf().onesMaskForInput(this.input));
                }
                String workspaceName = layerWorkspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
                String workspaceName2 = layerWorkspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
                WorkspaceConfiguration configuration = layerWorkspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
                WorkspaceConfiguration configuration2 = layerWorkspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
                boolean isScopedOut = layerWorkspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
                Preconditions.checkState(isScopedOut || workspaceName2 != null, "Activations must have a workspace or must be scoped out");
                DL4JSameDiffMemoryMgr dL4JSameDiffMemoryMgr = new DL4JSameDiffMemoryMgr(workspaceName, workspaceName2, configuration, configuration2);
                InferenceSession inferenceSession = (InferenceSession) this.sameDiff.getSessions().get(Long.valueOf(Thread.currentThread().getId()));
                if (inferenceSession == null) {
                    inferenceSession = new InferenceSession(this.sameDiff);
                    this.sameDiff.getSessions().put(Long.valueOf(Thread.currentThread().getId()), inferenceSession);
                }
                inferenceSession.setMmgr(dL4JSameDiffMemoryMgr);
                INDArray iNDArray = (INDArray) this.sameDiff.output(hashMap, new String[]{this.outputKey}).get(this.outputKey);
                if (!isScopedOut && !iNDArray.data().getParentWorkspace().getId().equals(workspaceName2)) {
                    iNDArray = layerWorkspaceMgr.dup(ArrayType.ACTIVATIONS, iNDArray);
                } else if (isScopedOut && iNDArray.isAttached()) {
                    iNDArray = iNDArray.detach();
                }
                this.sameDiff.clearPlaceholders(true);
                this.sameDiff.clearOpInputs();
                return iNDArray;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(true);
        DefaultGradient defaultGradient = new DefaultGradient();
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            if (this.sameDiff == null) {
                doInit();
            }
            if (!this.sameDiff.hasGradientFunction()) {
                this.sameDiff.createGradFunction(new String[]{"input"});
            }
            Map sessions = this.sameDiff.getFunction("grad").getSessions();
            if (!sessions.containsKey(Long.valueOf(Thread.currentThread().getId()))) {
                sessions.put(Long.valueOf(Thread.currentThread().getId()), new InferenceSession(this.sameDiff.getFunction("grad")));
            }
            String workspaceName = layerWorkspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
            String workspaceName2 = layerWorkspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
            WorkspaceConfiguration configuration = layerWorkspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
            WorkspaceConfiguration configuration2 = layerWorkspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
            Preconditions.checkState(layerWorkspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD) || workspaceName2 != null, "Activation gradients must have a workspace or be scoped out");
            ((InferenceSession) sessions.get(Long.valueOf(Thread.currentThread().getId()))).setMmgr(new DL4JSameDiffMemoryMgr(workspaceName, workspaceName2, configuration, configuration2));
            ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf()).validateInput(this.input);
            HashMap hashMap = new HashMap();
            hashMap.put("input", this.input);
            hashMap.put(this.fn.getGradPlaceholderName(), iNDArray);
            if (this.maskArray != null) {
                hashMap.put(MASK_KEY, this.maskArray);
            } else {
                hashMap.put(MASK_KEY, layerConf().onesMaskForInput(this.input));
            }
            ArrayList arrayList = new ArrayList(this.paramTable.size() + 1);
            arrayList.add("input");
            arrayList.addAll(this.paramTable.keySet());
            Map calculateGradients = this.sameDiff.calculateGradients(hashMap, arrayList);
            for (String str : this.paramTable.keySet()) {
                INDArray iNDArray2 = (INDArray) calculateGradients.get(str);
                INDArray iNDArray3 = this.gradTable.get(str);
                iNDArray3.assign(iNDArray2);
                defaultGradient.gradientForVariable().put(str, iNDArray3);
            }
            INDArray iNDArray4 = (INDArray) calculateGradients.get("input");
            this.sameDiff.clearPlaceholders(true);
            this.sameDiff.clearOpInputs();
            return new Pair<>(defaultGradient, layerWorkspaceMgr.dup(ArrayType.ACTIVATION_GRAD, iNDArray4));
        } finally {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return this.params;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return this.paramTable.get(str);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public long numParams() {
        if (this.params == null) {
            return 0L;
        }
        return (int) this.params.length();
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        if (!this.paramTable.containsKey(str)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + str);
        }
        INDArray iNDArray2 = this.paramTable.get(str);
        if (!Arrays.equals(iNDArray2.shape(), iNDArray.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + str + "\", invalid shape: parameter array has shape " + Arrays.toString(iNDArray2.shape()) + ", trying to set parameter of shape " + Arrays.toString(iNDArray.shape()));
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (this.params == null && iNDArray == null) {
            return;
        }
        if (this.params == null) {
            throw new IllegalStateException("Cannot set parameters of length " + iNDArray.length() + " to a layer with no parameters");
        }
        if (iNDArray == null) {
            throw new IllegalStateException("Cannot set null parameters");
        }
        Preconditions.checkState(this.params.length() == iNDArray.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters", iNDArray.length(), this.params.length());
        this.params.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer
    protected void setParams(INDArray iNDArray, char c) {
        setParams(iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        this.params = iNDArray;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        this.gradients = iNDArray;
        this.gradTable = layerConf().initializer().getGradientsFromFlattened(conf(), iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        if (this.paramTable == null) {
            this.paramTable = map;
            return;
        }
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            setParam(entry.getKey(), entry.getValue());
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return paramTable(false);
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        return this.paramTable;
    }

    protected void doInit() {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer sameDiffLayer = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
            this.sameDiff = SameDiff.create();
            this.sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
            Map<String, INDArray> paramTable = paramTable();
            long[] jArr = (long[]) this.input.shape().clone();
            jArr[0] = -1;
            SDVariable placeHolder = this.sameDiff.placeHolder("input", this.dataType, jArr);
            Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (String str : paramShapes.keySet()) {
                linkedHashMap.put(str, this.sameDiff.var(str, this.dataType, paramShapes.get(str)));
            }
            SDVariable defineLayer = sameDiffLayer.defineLayer(this.sameDiff, placeHolder, linkedHashMap, this.sameDiff.placeHolder(MASK_KEY, this.dataType, ArrayUtil.nTimes(jArr.length, -1L)));
            Preconditions.checkNotNull(defineLayer, "Invalid output: layer output is null");
            this.outputVar = defineLayer;
            for (Map.Entry<String, INDArray> entry : paramTable.entrySet()) {
                this.sameDiff.associateArrayWithVariable(entry.getValue(), this.sameDiff.getVariable(entry.getKey()));
            }
            this.fn = SameDiffUtils.externalErrors(this.sameDiff, (Map) null, new SDVariable[]{defineLayer});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.name();
            if (scopeOutOfWorkspaces != null) {
                if (0 == 0) {
                    scopeOutOfWorkspaces.close();
                    return;
                }
                try {
                    scopeOutOfWorkspaces.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (0 != 0) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer sameDiffLayer = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
        this.maskArray = iNDArray;
        this.maskState = maskState;
        return sameDiffLayer.feedForwardMaskArray(iNDArray, maskState, i);
    }
}
