package org.deeplearning4j.gradientcheck;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.LossLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Consumer;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil.class */
public class GradientCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);
    private static final List<Class<? extends IActivation>> VALID_ACTIVATION_FUNCTIONS = Arrays.asList(Activation.CUBE.getActivationFunction().getClass(), Activation.ELU.getActivationFunction().getClass(), Activation.IDENTITY.getActivationFunction().getClass(), Activation.RATIONALTANH.getActivationFunction().getClass(), Activation.SIGMOID.getActivationFunction().getClass(), Activation.SOFTMAX.getActivationFunction().getClass(), Activation.SOFTPLUS.getActivationFunction().getClass(), Activation.SOFTSIGN.getActivationFunction().getClass(), Activation.TANH.getActivationFunction().getClass());

    private GradientCheckUtil() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v43, types: [org.deeplearning4j.nn.conf.layers.BaseLayer] */
    private static void configureLossFnClippingIfPresent(IOutputLayer iOutputLayer) {
        ILossFunction iLossFunction = null;
        IActivation iActivation = null;
        if (iOutputLayer instanceof BaseOutputLayer) {
            BaseOutputLayer baseOutputLayer = (BaseOutputLayer) iOutputLayer;
            iLossFunction = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer) baseOutputLayer.layerConf()).getLossFn();
            iActivation = baseOutputLayer.layerConf().getActivationFn();
        } else if (iOutputLayer instanceof LossLayer) {
            LossLayer lossLayer = (LossLayer) iOutputLayer;
            iLossFunction = lossLayer.layerConf().getLossFn();
            iActivation = lossLayer.layerConf().getActivationFn();
        }
        if ((iLossFunction instanceof LossMCXENT) && (iActivation instanceof ActivationSoftmax) && ((LossMCXENT) iLossFunction).getSoftmaxClipEps() != EvaluationBinary.DEFAULT_EDGE_VALUE) {
            log.info("Setting softmax clipping epsilon to 0.0 for " + iLossFunction.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossMCXENT) iLossFunction).setSoftmaxClipEps(EvaluationBinary.DEFAULT_EDGE_VALUE);
        } else {
            if (!(iLossFunction instanceof LossBinaryXENT) || ((LossBinaryXENT) iLossFunction).getClipEps() == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                return;
            }
            log.info("Setting clipping epsilon to 0.0 for " + iLossFunction.getClass() + " loss function to avoid spurious gradient check failures");
            ((LossBinaryXENT) iLossFunction).setClipEps(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2) {
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, (INDArray) null, (INDArray) null);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, Set<String> set) {
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, (INDArray) null, (INDArray) null, false, -1, set, (Integer) null);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, iNDArray3, iNDArray4, false, -1);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z3, int i) {
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, iNDArray3, iNDArray4, z3, i, null);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z3, int i, Set<String> set) {
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, iNDArray3, iNDArray4, z3, i, set, (Consumer<MultiLayerNetwork>) null);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z3, int i, Set<String> set, final Integer num) {
        Consumer<MultiLayerNetwork> consumer = null;
        if (num != null) {
            consumer = new Consumer<MultiLayerNetwork>() { // from class: org.deeplearning4j.gradientcheck.GradientCheckUtil.1
                public void accept(MultiLayerNetwork multiLayerNetwork2) {
                    Nd4j.getRandom().setSeed(num.intValue());
                }
            };
        }
        return checkGradients(multiLayerNetwork, d, d2, d3, z, z2, iNDArray, iNDArray2, iNDArray3, iNDArray4, z3, i, set, consumer);
    }

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z3, int i, Set<String> set, Consumer<MultiLayerNetwork> consumer) {
        HashMap hashMap;
        long j;
        if (d <= EvaluationBinary.DEFAULT_EDGE_VALUE || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= EvaluationBinary.DEFAULT_EDGE_VALUE || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (!(multiLayerNetwork.getOutputLayer() instanceof IOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType dataType = multiLayerNetwork.getLayerWiseConfigurations().getDataType();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (dataType != multiLayerNetwork.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + multiLayerNetwork.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        for (NeuralNetConfiguration neuralNetConfiguration : multiLayerNetwork.getLayerWiseConfigurations().getConfs()) {
            if (neuralNetConfiguration.getLayer() instanceof BaseLayer) {
                BaseLayer baseLayer = (BaseLayer) neuralNetConfiguration.getLayer();
                Sgd iUpdater = baseLayer.getIUpdater();
                if (iUpdater instanceof Sgd) {
                    double learningRate = iUpdater.getLearningRate();
                    if (learningRate != 1.0d) {
                        throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer 0; got " + iUpdater + " with lr=" + learningRate + " for layer \"" + neuralNetConfiguration.getLayer().getLayerName() + "\"");
                    }
                } else if (!(iUpdater instanceof NoOp)) {
                    throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer 0; got " + iUpdater);
                }
                IActivation activationFn = baseLayer.getActivationFn();
                if (activationFn != null && !VALID_ACTIVATION_FUNCTIONS.contains(activationFn.getClass())) {
                    log.warn("Layer 0 is possibly using an unsuitable activation function: " + activationFn.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                }
            }
            if (neuralNetConfiguration.getLayer().getIDropout() != null && consumer == null) {
                throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no dropout should be present during gradient checks - got dropout = " + neuralNetConfiguration.getLayer().getIDropout() + " for layer 0");
            }
        }
        for (Layer layer : multiLayerNetwork.getLayers()) {
            if (layer instanceof IOutputLayer) {
                configureLossFnClippingIfPresent((IOutputLayer) layer);
            }
        }
        multiLayerNetwork.setInput(iNDArray);
        multiLayerNetwork.setLabels(iNDArray2);
        multiLayerNetwork.setLayerMaskArrays(iNDArray3, iNDArray4);
        if (consumer != null) {
            consumer.accept(multiLayerNetwork);
        }
        multiLayerNetwork.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = multiLayerNetwork.gradientAndScore();
        UpdaterCreator.getUpdater(multiLayerNetwork).update(multiLayerNetwork, (Gradient) gradientAndScore.getFirst(), 0, 0, multiLayerNetwork.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = multiLayerNetwork.params().dup().length();
        Map<String, INDArray> paramTable = multiLayerNetwork.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        if (z3) {
            hashMap = new HashMap();
            hashMap.put(arrayList.get(0), Integer.valueOf((int) Math.max(1L, paramTable.get(arrayList.get(0)).length() / i)));
        } else {
            hashMap = null;
        }
        for (int i2 = 1; i2 < jArr.length; i2++) {
            long length2 = paramTable.get(arrayList.get(i2)).length();
            jArr[i2] = jArr[i2 - 1] + length2;
            if (z3) {
                long j2 = length2 / i;
                if (j2 == 0) {
                    j2 = length2;
                }
                hashMap.put(arrayList.get(i2), Integer.valueOf((int) j2));
            }
        }
        if (z) {
            int i3 = 0;
            for (Layer layer2 : multiLayerNetwork.getLayers()) {
                log.info("Layer " + i3 + ": " + layer2.getClass().getSimpleName() + " - params " + layer2.paramTable().keySet());
                i3++;
            }
        }
        int i4 = 0;
        double d4 = 0.0d;
        DataSet dataSet = new DataSet(iNDArray, iNDArray2, iNDArray3, iNDArray4);
        int i5 = 0;
        INDArray params = multiLayerNetwork.params();
        long j3 = 0;
        while (true) {
            long j4 = j3;
            if (j4 >= length) {
                if (z) {
                    log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i4) + " passed, " + i4 + " failed. Largest relative error = " + d4);
                }
                return i4 == 0;
            }
            if (j4 >= jArr[i5]) {
                i5++;
            }
            String str = (String) arrayList.get(i5);
            if (set == null || !set.contains(str)) {
                double d5 = params.getDouble(j4);
                params.putScalar(j4, d5 + d);
                if (consumer != null) {
                    consumer.accept(multiLayerNetwork);
                }
                double score = multiLayerNetwork.score(dataSet, true);
                params.putScalar(j4, d5 - d);
                if (consumer != null) {
                    consumer.accept(multiLayerNetwork);
                }
                double score2 = multiLayerNetwork.score(dataSet, true);
                params.putScalar(j4, d5);
                double d6 = (score - score2) / (2.0d * d);
                if (Double.isNaN(d6)) {
                    throw new IllegalStateException("Numerical gradient was NaN for parameter " + j4 + " of " + length);
                }
                double d7 = dup.getDouble(j4);
                double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
                if (d7 == EvaluationBinary.DEFAULT_EDGE_VALUE && d6 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                    abs = 0.0d;
                }
                if (abs > d4) {
                    d4 = abs;
                }
                if (abs > d2 || Double.isNaN(abs)) {
                    double abs2 = Math.abs(d7 - d6);
                    if (abs2 >= d3) {
                        if (z) {
                            log.info("Param " + j4 + " (" + str + ") FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2 + ", paramValue = " + d5);
                        }
                        if (z2) {
                            return false;
                        }
                        i4++;
                    } else if (z) {
                        log.info("Param " + j4 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                    }
                } else if (z) {
                    log.info("Param " + j4 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
                }
                if (z3) {
                    j = ((Integer) hashMap.get(str)).intValue();
                    if (j4 + j > jArr[i5] + 1) {
                        j = (jArr[i5] + 1) - j4;
                    }
                } else {
                    j = 1;
                }
                j3 = j4 + j;
            } else {
                log.info("Skipping parameters for parameter name: {}", str);
                int i6 = i5;
                i5++;
                j3 = jArr[i6];
            }
        }
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        return checkGradients(computationGraph, d, d2, d3, z, z2, iNDArrayArr, iNDArrayArr2, null, null, null);
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        return checkGradients(computationGraph, d, d2, d3, z, z2, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4, null);
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, Set<String> set) {
        return checkGradients(computationGraph, d, d2, d3, z, z2, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4, set, (Consumer<ComputationGraph>) null);
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, Set<String> set, final Integer num) {
        Consumer<ComputationGraph> consumer = null;
        if (num != null) {
            consumer = new Consumer<ComputationGraph>() { // from class: org.deeplearning4j.gradientcheck.GradientCheckUtil.2
                public void accept(ComputationGraph computationGraph2) {
                    Nd4j.getRandom().setSeed(num.intValue());
                }
            };
        }
        return checkGradients(computationGraph, d, d2, d3, z, z2, iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4, set, consumer);
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, double d3, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, Set<String> set, Consumer<ComputationGraph> consumer) {
        if (d <= EvaluationBinary.DEFAULT_EDGE_VALUE || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= EvaluationBinary.DEFAULT_EDGE_VALUE || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (computationGraph.getNumInputArrays() != iNDArrayArr.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + computationGraph.getNumInputArrays() + " inputs");
        }
        if (computationGraph.getNumOutputArrays() != iNDArrayArr2.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + computationGraph.getNumOutputArrays() + " outputs");
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        DataType dataType = computationGraph.getConfiguration().getDataType();
        if (dataType != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (is: " + dataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil");
        }
        if (dataType != computationGraph.params().dataType()) {
            throw new IllegalStateException("Parameters datatype does not match network configuration datatype (is: " + computationGraph.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE.");
        }
        for (String str : computationGraph.getConfiguration().getVertices().keySet()) {
            GraphVertex graphVertex = computationGraph.getConfiguration().getVertices().get(str);
            if (graphVertex instanceof LayerVertex) {
                LayerVertex layerVertex = (LayerVertex) graphVertex;
                if (layerVertex.getLayerConf().getLayer() instanceof BaseLayer) {
                    BaseLayer baseLayer = (BaseLayer) layerVertex.getLayerConf().getLayer();
                    Sgd iUpdater = baseLayer.getIUpdater();
                    if (iUpdater instanceof Sgd) {
                        double learningRate = iUpdater.getLearningRate();
                        if (learningRate != 1.0d) {
                            throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer 0; got " + iUpdater + " with lr=" + learningRate + " for layer \"" + layerVertex.getLayerConf().getLayer().getLayerName() + "\"");
                        }
                    } else if (!(iUpdater instanceof NoOp)) {
                        throw new IllegalStateException("Must have Updater.NONE (or SGD + lr=1.0) for layer 0; got " + iUpdater);
                    }
                    IActivation activationFn = baseLayer.getActivationFn();
                    if (activationFn != null && !VALID_ACTIVATION_FUNCTIONS.contains(activationFn.getClass())) {
                        log.warn("Layer \"" + str + "\" is possibly using an unsuitable activation function: " + activationFn.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)");
                    }
                }
                if (layerVertex.getLayerConf().getLayer().getIDropout() != null && consumer == null) {
                    throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no dropout should be present during gradient checks - got dropout = " + layerVertex.getLayerConf().getLayer().getIDropout() + " for layer 0");
                }
            }
        }
        for (Layer layer : computationGraph.getLayers()) {
            if (layer instanceof IOutputLayer) {
                configureLossFnClippingIfPresent((IOutputLayer) layer);
            }
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            computationGraph.setInput(i, iNDArrayArr[i]);
        }
        for (int i2 = 0; i2 < iNDArrayArr2.length; i2++) {
            computationGraph.setLabel(i2, iNDArrayArr2[i2]);
        }
        computationGraph.setLayerMaskArrays(iNDArrayArr3, iNDArrayArr4);
        if (consumer != null) {
            consumer.accept(computationGraph);
        }
        computationGraph.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = computationGraph.gradientAndScore();
        new ComputationGraphUpdater(computationGraph).update((Gradient) gradientAndScore.getFirst(), 0, 0, computationGraph.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = computationGraph.params().dup().length();
        Map<String, INDArray> paramTable = computationGraph.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i3 = 1; i3 < jArr.length; i3++) {
            jArr[i3] = jArr[i3 - 1] + paramTable.get(arrayList.get(i3)).length();
        }
        int i4 = 0;
        int i5 = 0;
        double d4 = 0.0d;
        MultiDataSet multiDataSet = new MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
        INDArray params = computationGraph.params();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= length) {
                if (z) {
                    log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i5) + " passed, " + i5 + " failed. Largest relative error = " + d4);
                }
                return i5 == 0;
            }
            if (j2 >= jArr[i4]) {
                i4++;
            }
            String str2 = (String) arrayList.get(i4);
            if (set == null || !set.contains(str2)) {
                double d5 = params.getDouble(j2);
                params.putScalar(j2, d5 + d);
                if (consumer != null) {
                    consumer.accept(computationGraph);
                }
                double score = computationGraph.score((org.nd4j.linalg.dataset.api.MultiDataSet) multiDataSet, true);
                params.putScalar(j2, d5 - d);
                if (consumer != null) {
                    consumer.accept(computationGraph);
                }
                double score2 = computationGraph.score((org.nd4j.linalg.dataset.api.MultiDataSet) multiDataSet, true);
                params.putScalar(j2, d5);
                double d6 = (score - score2) / (2.0d * d);
                if (Double.isNaN(d6)) {
                    throw new IllegalStateException("Numerical gradient was NaN for parameter " + j2 + " of " + length);
                }
                double d7 = dup.getDouble(j2);
                double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
                if (d7 == EvaluationBinary.DEFAULT_EDGE_VALUE && d6 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                    abs = 0.0d;
                }
                if (abs > d4) {
                    d4 = abs;
                }
                if (abs > d2 || Double.isNaN(abs)) {
                    double abs2 = Math.abs(d7 - d6);
                    if (abs2 < d3) {
                        log.info("Param " + j2 + " (" + str2 + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                    } else {
                        if (z) {
                            log.info("Param " + j2 + " (" + str2 + ") FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2 + ", paramValue = " + d5);
                        }
                        if (z2) {
                            return false;
                        }
                        i5++;
                    }
                } else if (z) {
                    log.info("Param " + j2 + " (" + str2 + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
                }
            } else {
                log.info("Skipping parameters for parameter name: {}", str2);
                int i6 = i4;
                i4++;
                j2 = jArr[i6];
            }
            j = j2 + 1;
        }
    }

    public static boolean checkGradientsPretrainLayer(Layer layer, double d, double d2, double d3, boolean z, boolean z2, INDArray iNDArray, int i) {
        LayerWorkspaceMgr noWorkspaces = LayerWorkspaceMgr.noWorkspaces();
        if (d <= EvaluationBinary.DEFAULT_EDGE_VALUE || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= EvaluationBinary.DEFAULT_EDGE_VALUE || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        DataType dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataType.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil");
        }
        layer.setInput(iNDArray, LayerWorkspaceMgr.noWorkspaces());
        Nd4j.getRandom().setSeed(i);
        layer.computeGradientAndScore(noWorkspaces);
        Pair<Gradient, Double> gradientAndScore = layer.gradientAndScore();
        UpdaterCreator.getUpdater(layer).update(layer, (Gradient) gradientAndScore.getFirst(), 0, 0, layer.batchSize(), LayerWorkspaceMgr.noWorkspaces());
        INDArray dup = ((Gradient) gradientAndScore.getFirst()).gradient().dup();
        long length = layer.params().dup().length();
        Map<String, INDArray> paramTable = layer.paramTable();
        ArrayList arrayList = new ArrayList(paramTable.keySet());
        long[] jArr = new long[arrayList.size()];
        jArr[0] = paramTable.get(arrayList.get(0)).length();
        for (int i2 = 1; i2 < jArr.length; i2++) {
            jArr[i2] = jArr[i2 - 1] + paramTable.get(arrayList.get(i2)).length();
        }
        int i3 = 0;
        double d4 = 0.0d;
        int i4 = 0;
        INDArray params = layer.params();
        for (int i5 = 0; i5 < length; i5++) {
            if (i5 >= jArr[i4]) {
                i4++;
            }
            String str = (String) arrayList.get(i4);
            double d5 = params.getDouble(i5);
            params.putScalar(i5, d5 + d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore(noWorkspaces);
            double score = layer.score();
            params.putScalar(i5, d5 - d);
            Nd4j.getRandom().setSeed(i);
            layer.computeGradientAndScore(noWorkspaces);
            double score2 = layer.score();
            params.putScalar(i5, d5);
            double d6 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d6)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i5 + " of " + length);
            }
            double d7 = dup.getDouble(i5);
            double abs = Math.abs(d7 - d6) / (Math.abs(d6) + Math.abs(d7));
            if (d7 == EvaluationBinary.DEFAULT_EDGE_VALUE && d6 == EvaluationBinary.DEFAULT_EDGE_VALUE) {
                abs = 0.0d;
            }
            if (abs > d4) {
                d4 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                double abs2 = Math.abs(d7 - d6);
                if (abs2 < d3) {
                    log.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                } else {
                    if (z) {
                        log.info("Param " + i5 + " (" + str + ") FAILED: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2 + ", paramValue = " + d5);
                    }
                    if (z2) {
                        return false;
                    }
                    i3++;
                }
            } else if (z) {
                log.info("Param " + i5 + " (" + str + ") passed: grad= " + d7 + ", numericalGrad= " + d6 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i3) + " passed, " + i3 + " failed. Largest relative error = " + d4);
        }
        return i3 == 0;
    }
}
