package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Trainable;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.IUpdater;

/* loaded from: input_file:org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.class */
public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater {
    protected final T network;
    protected Map<String, Trainable> layersByName;
    protected final List<UpdaterBlock> updaterBlocks;
    protected INDArray updaterStateViewArray;
    protected boolean initializedMinibatchDivision;
    protected List<INDArray> gradientsForMinibatchDivision;

    public BaseMultiLayerUpdater(T t) {
        this(t, null);
    }

    public BaseMultiLayerUpdater(T t, INDArray iNDArray) {
        this.network = t;
        Trainable[] orderedLayers = getOrderedLayers();
        int i = 0;
        Trainable trainable = null;
        String str = null;
        UpdaterBlock updaterBlock = null;
        this.updaterBlocks = new ArrayList();
        INDArray params = t.params();
        INDArray flattenedGradientsView = getFlattenedGradientsView();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < orderedLayers.length; i4++) {
            Map<String, INDArray> paramTable = orderedLayers[i4].paramTable(false);
            if (paramTable != null) {
                ArrayList arrayList = new ArrayList(paramTable.keySet());
                for (int i5 = 0; i5 < arrayList.size(); i5++) {
                    String str2 = (String) arrayList.get(i5);
                    long length = paramTable.get(str2).length();
                    IUpdater updaterByParam = orderedLayers[i4].getConfig().getUpdaterByParam(str2);
                    Preconditions.checkNotNull(updaterByParam, "Updater for parameter %s, layer \"%s\" was null", str2, orderedLayers[i4].getConfig().getLayerName());
                    int stateSize = (int) updaterByParam.stateSize(length);
                    INDArray iNDArray2 = null;
                    INDArray iNDArray3 = null;
                    if (length > 0) {
                        iNDArray3 = params.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i2, i2 + length)});
                        iNDArray2 = flattenedGradientsView.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i2, i2 + length)});
                    }
                    if (updaterBlock == null || !UpdaterUtils.updaterConfigurationsEquals(trainable, str, orderedLayers[i4], str2)) {
                        ArrayList arrayList2 = new ArrayList();
                        arrayList2.add(new UpdaterBlock.ParamState(orderedLayers[i4], str2, i2, (int) (i2 + length), iNDArray3, iNDArray2));
                        updaterBlock = new UpdaterBlock(i2, (int) (i2 + length), i3, i3 + stateSize, arrayList2);
                        this.updaterBlocks.add(updaterBlock);
                    } else {
                        updaterBlock.setParamOffsetEnd((int) (updaterBlock.getParamOffsetEnd() + length));
                        updaterBlock.setUpdaterViewOffsetEnd(updaterBlock.getUpdaterViewOffsetEnd() + stateSize);
                        updaterBlock.getLayersAndVariablesInBlock().add(new UpdaterBlock.ParamState(orderedLayers[i4], str2, i2, (int) (i2 + length), iNDArray3, iNDArray2));
                    }
                    trainable = orderedLayers[i4];
                    str = (String) arrayList.get(i5);
                    i += stateSize;
                    i2 = (int) (i2 + length);
                    i3 += stateSize;
                }
            }
        }
        boolean z = false;
        if (iNDArray != null) {
            this.updaterStateViewArray = iNDArray;
            z = false;
        } else if (i > 0) {
            this.updaterStateViewArray = Nd4j.createUninitialized(new int[]{1, i}, Nd4j.order().charValue());
            z = true;
        }
        int i6 = 0;
        int i7 = 0;
        for (int i8 = 0; i8 < this.updaterBlocks.size(); i8++) {
            UpdaterBlock updaterBlock2 = this.updaterBlocks.get(i8);
            int updaterViewOffsetEnd = updaterBlock2.getUpdaterViewOffsetEnd() - updaterBlock2.getUpdaterViewOffsetStart();
            int paramOffsetEnd = updaterBlock2.getParamOffsetEnd() - updaterBlock2.getParamOffsetStart();
            if (updaterViewOffsetEnd > 0) {
                updaterBlock2.setUpdaterView(this.updaterStateViewArray.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i6, i6 + updaterViewOffsetEnd)}));
                updaterBlock2.setUpdaterViewRequiresInitialization(z);
            }
            if (paramOffsetEnd > 0) {
                updaterBlock2.setGradientView(flattenedGradientsView.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i7, i7 + paramOffsetEnd)}));
            }
            updaterBlock2.init();
            i6 += updaterViewOffsetEnd;
            i7 += paramOffsetEnd;
        }
    }

    protected abstract Trainable[] getOrderedLayers();

    protected abstract INDArray getFlattenedGradientsView();

    protected abstract INDArray getParams();

    protected abstract boolean isMiniBatch();

    public void setStateViewArray(INDArray iNDArray) {
        if (this.updaterStateViewArray == null) {
            if (iNDArray != null) {
                throw new IllegalStateException("Attempting to set updater state view array with null value");
            }
        } else {
            if (this.updaterStateViewArray.length() != iNDArray.length()) {
                throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.updaterStateViewArray.length() + ", got length " + iNDArray.length());
            }
            this.updaterStateViewArray.assign(iNDArray);
        }
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void setStateViewArray(Trainable trainable, INDArray iNDArray, boolean z) {
        setStateViewArray(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public INDArray getStateViewArray() {
        return this.updaterStateViewArray;
    }

    public synchronized INDArray getStateViewArrayCopy() {
        Nd4j.getExecutioner().commit();
        return this.updaterStateViewArray.dup();
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void update(Trainable trainable, Gradient gradient, int i, int i2, int i3, LayerWorkspaceMgr layerWorkspaceMgr) {
        update(gradient, i, i2, i3, layerWorkspaceMgr);
    }

    public synchronized void update(Gradient gradient, int i, int i2, int i3, LayerWorkspaceMgr layerWorkspaceMgr) {
        boolean z = gradient.gradient() != getFlattenedGradientsView();
        HashMap hashMap = new HashMap();
        Trainable[] orderedLayers = getOrderedLayers();
        if (orderedLayers.length == 1 && isSingleLayerUpdater()) {
            hashMap.put(orderedLayers[0].getConfig().getLayerName(), gradient);
        } else {
            for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
                String key = entry.getKey();
                int lastIndexOf = key.lastIndexOf(95);
                if (lastIndexOf == -1) {
                    throw new IllegalStateException("Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
                }
                String substring = key.substring(0, lastIndexOf);
                Gradient gradient2 = (Gradient) hashMap.get(substring);
                if (gradient2 == null) {
                    gradient2 = new DefaultGradient();
                    hashMap.put(substring, gradient2);
                }
                gradient2.setGradientFor(key.substring(lastIndexOf + 1), entry.getValue());
            }
        }
        if (isMiniBatch()) {
            divideByMinibatch(z, gradient, i3);
        }
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            String str = (String) ((Map.Entry) it.next()).getKey();
            preApply(this.layersByName.get(str), (Gradient) hashMap.get(str), i);
        }
        if (getClass() != LayerUpdater.class) {
            layerWorkspaceMgr.assertNotOpen(ArrayType.UPDATER_WORKING_MEM, "Updater working memory");
        }
        for (UpdaterBlock updaterBlock : this.updaterBlocks) {
            if (!updaterBlock.skipDueToPretrainConfig(this instanceof LayerUpdater)) {
                MemoryWorkspace notifyScopeEntered = layerWorkspaceMgr.notifyScopeEntered(ArrayType.UPDATER_WORKING_MEM);
                Throwable th = null;
                if (z) {
                    try {
                        try {
                            updaterBlock.updateExternalGradient(i, i2, gradient.gradient(), getParams());
                        } finally {
                        }
                    } catch (Throwable th2) {
                        if (notifyScopeEntered != null) {
                            if (th != null) {
                                try {
                                    notifyScopeEntered.close();
                                } catch (Throwable th3) {
                                    th.addSuppressed(th3);
                                }
                            } else {
                                notifyScopeEntered.close();
                            }
                        }
                        throw th2;
                    }
                } else {
                    updaterBlock.update(i, i2);
                }
                if (notifyScopeEntered != null) {
                    if (0 != 0) {
                        try {
                            notifyScopeEntered.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        notifyScopeEntered.close();
                    }
                }
            }
        }
    }

    protected void divideByMinibatch(boolean z, Gradient gradient, int i) {
        if (!this.initializedMinibatchDivision) {
            this.gradientsForMinibatchDivision = getMinibatchDivisionSubsets(getFlattenedGradientsView());
            this.initializedMinibatchDivision = true;
        }
        Iterator<INDArray> it = (z ? getMinibatchDivisionSubsets(gradient.gradient()) : this.gradientsForMinibatchDivision).iterator();
        while (it.hasNext()) {
            it.next().divi(Integer.valueOf(i));
        }
    }

    protected List<INDArray> getMinibatchDivisionSubsets(INDArray iNDArray) {
        long j;
        ArrayList arrayList = new ArrayList();
        long j2 = 0;
        long j3 = 0;
        long j4 = 0;
        for (Trainable trainable : getOrderedLayers()) {
            Set<String> keySet = trainable.paramTable(false).keySet();
            Map<String, INDArray> paramTable = trainable.paramTable(false);
            for (String str : keySet) {
                if (trainable.updaterDivideByMinibatch(str)) {
                    j = j4 + paramTable.get(str).length();
                } else {
                    if (j4 > j3) {
                        arrayList.add(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(j3, j4)}));
                    }
                    j3 = j2 + paramTable.get(str).length();
                    j = j3;
                }
                j4 = j;
                j2 += paramTable.get(str).length();
            }
        }
        if (j4 > j3 && j3 < iNDArray.length()) {
            arrayList.add(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(j3, j4)}));
        }
        return arrayList;
    }

    protected boolean isSingleLayerUpdater() {
        return false;
    }

    public void preApply(Trainable trainable, Gradient gradient, int i) {
        GradientNormalization gradientNormalization;
        if (trainable.getConfig() == null || trainable.numParams() == 0 || (gradientNormalization = trainable.getConfig().getGradientNormalization()) == null || gradientNormalization == GradientNormalization.None) {
            return;
        }
        double gradientNormalizationThreshold = trainable.getConfig().getGradientNormalizationThreshold();
        INDArray gradientsViewArray = trainable.getGradientsViewArray();
        switch (gradientNormalization) {
            case RenormalizeL2PerLayer:
                if (gradientsViewArray != null) {
                    double doubleValue = gradientsViewArray.norm2Number().doubleValue();
                    if (doubleValue == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                        doubleValue = 1.0E-5d;
                    }
                    gradientsViewArray.divi(Double.valueOf(doubleValue));
                    return;
                }
                return;
            case RenormalizeL2PerParamType:
                for (INDArray iNDArray : gradient.gradientForVariable().values()) {
                    double doubleValue2 = Nd4j.getExecutioner().execAndReturn(new Norm2(iNDArray, new int[0])).getFinalResult().doubleValue();
                    if (doubleValue2 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                        doubleValue2 = 1.0E-5d;
                    }
                    iNDArray.divi(Double.valueOf(doubleValue2));
                }
                return;
            case ClipElementWiseAbsoluteValue:
                if (gradientsViewArray != null) {
                    Nd4j.getExecutioner().exec(DynamicCustomOp.builder("clipbyvalue").addInputs(new INDArray[]{gradientsViewArray}).callInplace(true).addFloatingPointArguments(new Double[]{Double.valueOf(-gradientNormalizationThreshold), Double.valueOf(gradientNormalizationThreshold)}).build());
                    return;
                }
                return;
            case ClipL2PerLayer:
                if (gradientsViewArray != null) {
                    double doubleValue3 = gradientsViewArray.norm2Number().doubleValue();
                    if (doubleValue3 > gradientNormalizationThreshold) {
                        gradientsViewArray.muli(Double.valueOf(gradientNormalizationThreshold / doubleValue3));
                        return;
                    }
                    return;
                }
                return;
            case ClipL2PerParamType:
                for (INDArray iNDArray2 : gradient.gradientForVariable().values()) {
                    double doubleValue4 = iNDArray2.norm2Number().doubleValue();
                    if (doubleValue4 > gradientNormalizationThreshold) {
                        iNDArray2.divi(Double.valueOf(doubleValue4 / gradientNormalizationThreshold));
                    }
                }
                return;
            default:
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + gradientNormalization);
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        BaseMultiLayerUpdater baseMultiLayerUpdater = (BaseMultiLayerUpdater) obj;
        return this.updaterStateViewArray != null ? this.updaterStateViewArray.equals(baseMultiLayerUpdater.updaterStateViewArray) : baseMultiLayerUpdater.updaterStateViewArray == null;
    }

    public int hashCode() {
        return (31 * ((31 * (this.layersByName != null ? this.layersByName.hashCode() : 0)) + (this.updaterBlocks != null ? this.updaterBlocks.hashCode() : 0))) + (this.updaterStateViewArray != null ? this.updaterStateViewArray.hashCode() : 0);
    }

    public T getNetwork() {
        return this.network;
    }

    public Map<String, Trainable> getLayersByName() {
        return this.layersByName;
    }

    public List<UpdaterBlock> getUpdaterBlocks() {
        return this.updaterBlocks;
    }

    public INDArray getUpdaterStateViewArray() {
        return this.updaterStateViewArray;
    }

    public boolean isInitializedMinibatchDivision() {
        return this.initializedMinibatchDivision;
    }

    public List<INDArray> getGradientsForMinibatchDivision() {
        return this.gradientsForMinibatchDivision;
    }
}
