package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasLeakyReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasPReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.KerasThresholdedReLU;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasAtrousConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolution3D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasCropping3D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDeconvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasDepthwiseConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSeparableConvolution2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasUpsampling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding1D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding2D;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasZeroPadding3D;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasActivation;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDense;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasFlatten;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasLambda;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasMasking;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasMerge;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasPermute;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasRepeatVector;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasReshape;
import org.deeplearning4j.nn.modelimport.keras.layers.core.KerasSpatialDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
import org.deeplearning4j.nn.modelimport.keras.layers.normalization.KerasBatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasGlobalPooling;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling1D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling2D;
import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling3D;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.class */
public class KerasLayerUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasLayerUtils.class);

    public static void checkForUnsupportedConfigurations(Map<String, Object> map, boolean z, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        getBiasL1RegularizationFromConfig(map, z, kerasLayerConfiguration);
        getBiasL2RegularizationFromConfig(map, z, kerasLayerConfiguration);
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_W_REGULARIZER())) {
            checkForUnknownRegularizer((Map) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_W_REGULARIZER()), z, kerasLayerConfiguration);
        }
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER())) {
            checkForUnknownRegularizer((Map) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER()), z, kerasLayerConfiguration);
        }
    }

    public static double getBiasL1RegularizationFromConfig(Map<String, Object> map, boolean z, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map map2;
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER()) && (map2 = (Map) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER())) != null && map2.containsKey(kerasLayerConfiguration.getREGULARIZATION_TYPE_L1())) {
            throw new UnsupportedKerasConfigurationException("L1 regularization for bias parameter not supported");
        }
        return 0.0d;
    }

    private static double getBiasL2RegularizationFromConfig(Map<String, Object> map, boolean z, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        Map map2;
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER()) && (map2 = (Map) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_B_REGULARIZER())) != null && map2.containsKey(kerasLayerConfiguration.getREGULARIZATION_TYPE_L2())) {
            throw new UnsupportedKerasConfigurationException("L2 regularization for bias parameter not supported");
        }
        return 0.0d;
    }

    private static void checkForUnknownRegularizer(Map<String, Object> map, boolean z, KerasLayerConfiguration kerasLayerConfiguration) throws UnsupportedKerasConfigurationException {
        if (map != null) {
            for (String str : map.keySet()) {
                if (!str.equals(kerasLayerConfiguration.getREGULARIZATION_TYPE_L1()) && !str.equals(kerasLayerConfiguration.getREGULARIZATION_TYPE_L2()) && !str.equals(kerasLayerConfiguration.getLAYER_FIELD_NAME()) && !str.equals(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME()) && !str.equals(kerasLayerConfiguration.getLAYER_FIELD_CONFIG())) {
                    if (z) {
                        throw new UnsupportedKerasConfigurationException("Unknown regularization field " + str);
                    }
                    log.warn("Ignoring unknown regularization field " + str);
                }
            }
        }
    }

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration, Map<String, Class<? extends KerasLayer>> map2, Map<String, SameDiffLambdaLayer> map3, Map<String, ? extends KerasLayer> map4) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getKerasLayerFromConfig(map, false, kerasLayerConfiguration, map2, map3, map4);
    }

    public static KerasLayer getKerasLayerFromConfig(Map<String, Object> map, boolean z, KerasLayerConfiguration kerasLayerConfiguration, Map<String, Class<? extends KerasLayer>> map2, Map<String, SameDiffLambdaLayer> map3, Map<String, ? extends KerasLayer> map4) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasLayer kerasSpatialDropout;
        String classNameFromConfig = getClassNameFromConfig(map, kerasLayerConfiguration);
        if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_TIME_DISTRIBUTED())) {
            map = getTimeDistributedLayerConfig(map, kerasLayerConfiguration);
            classNameFromConfig = getClassNameFromConfig(map, kerasLayerConfiguration);
        }
        if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ACTIVATION())) {
            kerasSpatialDropout = new KerasActivation(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_LEAKY_RELU())) {
            kerasSpatialDropout = new KerasLeakyReLU(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MASKING())) {
            kerasSpatialDropout = new KerasMasking(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_THRESHOLDED_RELU())) {
            kerasSpatialDropout = new KerasThresholdedReLU(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_PRELU())) {
            kerasSpatialDropout = new KerasPReLU(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_DROPOUT())) {
            kerasSpatialDropout = new KerasDropout(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_1D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_2D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SPATIAL_DROPOUT_3D())) {
            kerasSpatialDropout = new KerasSpatialDropout(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ALPHA_DROPOUT())) {
            kerasSpatialDropout = new KerasAlphaDropout(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GAUSSIAN_DROPOUT())) {
            kerasSpatialDropout = new KerasGaussianDropout(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GAUSSIAN_NOISE())) {
            kerasSpatialDropout = new KerasGaussianNoise(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_DENSE()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE())) {
            kerasSpatialDropout = new KerasDense(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_BIDIRECTIONAL())) {
            kerasSpatialDropout = new KerasBidirectional(map, z, map4);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_LSTM())) {
            kerasSpatialDropout = new KerasLSTM(map, z, map4);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SIMPLE_RNN())) {
            kerasSpatialDropout = new KerasSimpleRnn(map, z, map4);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CONVOLUTION_3D())) {
            kerasSpatialDropout = new KerasConvolution3D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CONVOLUTION_2D())) {
            kerasSpatialDropout = new KerasConvolution2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_DECONVOLUTION_2D())) {
            kerasSpatialDropout = new KerasDeconvolution2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CONVOLUTION_1D())) {
            kerasSpatialDropout = new KerasConvolution1D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_2D())) {
            kerasSpatialDropout = new KerasAtrousConvolution2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ATROUS_CONVOLUTION_1D())) {
            kerasSpatialDropout = new KerasAtrousConvolution1D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_DEPTHWISE_CONVOLUTION_2D())) {
            kerasSpatialDropout = new KerasDepthwiseConvolution2D(map, map4, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SEPARABLE_CONVOLUTION_2D())) {
            kerasSpatialDropout = new KerasSeparableConvolution2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MAX_POOLING_3D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_AVERAGE_POOLING_3D())) {
            kerasSpatialDropout = new KerasPooling3D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MAX_POOLING_2D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_AVERAGE_POOLING_2D())) {
            kerasSpatialDropout = new KerasPooling2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MAX_POOLING_1D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_AVERAGE_POOLING_1D())) {
            kerasSpatialDropout = new KerasPooling1D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_1D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_2D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_AVERAGE_POOLING_3D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_1D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_2D()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_GLOBAL_MAX_POOLING_3D())) {
            kerasSpatialDropout = new KerasGlobalPooling(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_BATCHNORMALIZATION())) {
            kerasSpatialDropout = new KerasBatchNormalization(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_EMBEDDING())) {
            kerasSpatialDropout = new KerasEmbedding(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_INPUT())) {
            kerasSpatialDropout = new KerasInput(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_REPEAT())) {
            kerasSpatialDropout = new KerasRepeatVector(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_PERMUTE())) {
            kerasSpatialDropout = new KerasPermute(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MERGE())) {
            kerasSpatialDropout = new KerasMerge(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ADD()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ADD())) {
            kerasSpatialDropout = new KerasMerge(map, ElementWiseVertex.Op.Add, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_SUBTRACT()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_FUNCTIONAL_SUBTRACT())) {
            kerasSpatialDropout = new KerasMerge(map, ElementWiseVertex.Op.Subtract, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_AVERAGE()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_FUNCTIONAL_AVERAGE())) {
            kerasSpatialDropout = new KerasMerge(map, ElementWiseVertex.Op.Average, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_MULTIPLY()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_FUNCTIONAL_MULTIPLY())) {
            kerasSpatialDropout = new KerasMerge(map, ElementWiseVertex.Op.Product, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CONCATENATE()) || classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_FUNCTIONAL_CONCATENATE())) {
            kerasSpatialDropout = new KerasMerge(map, null, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_FLATTEN())) {
            kerasSpatialDropout = new KerasFlatten(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_RESHAPE())) {
            kerasSpatialDropout = new KerasReshape(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ZERO_PADDING_1D())) {
            kerasSpatialDropout = new KerasZeroPadding1D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ZERO_PADDING_2D())) {
            kerasSpatialDropout = new KerasZeroPadding2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_ZERO_PADDING_3D())) {
            kerasSpatialDropout = new KerasZeroPadding3D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_UPSAMPLING_1D())) {
            kerasSpatialDropout = new KerasUpsampling1D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_UPSAMPLING_2D())) {
            kerasSpatialDropout = new KerasUpsampling2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CROPPING_3D())) {
            kerasSpatialDropout = new KerasCropping3D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CROPPING_2D())) {
            kerasSpatialDropout = new KerasCropping2D(map, z);
        } else if (classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_CROPPING_1D())) {
            kerasSpatialDropout = new KerasCropping1D(map, z);
        } else if (!classNameFromConfig.equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_LAMBDA()) || map3.isEmpty()) {
            Class<? extends KerasLayer> cls = map2.get(classNameFromConfig);
            if (cls == null) {
                throw new UnsupportedKerasConfigurationException("Unsupported keras layer type " + classNameFromConfig);
            }
            try {
                kerasSpatialDropout = cls.getConstructor(Map.class).newInstance(map);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } else {
            String layerNameFromConfig = getLayerNameFromConfig(map, kerasLayerConfiguration);
            if (!map3.containsKey(layerNameFromConfig)) {
                throw new UnsupportedKerasConfigurationException("No SameDiff Lambda layer found for Lambdalayer " + layerNameFromConfig);
            }
            kerasSpatialDropout = new KerasLambda(map, z, map3.get(layerNameFromConfig));
        }
        return kerasSpatialDropout;
    }

    public static String getClassNameFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME())) {
            return (String) map.get(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME());
        }
        throw new InvalidKerasConfigurationException("Field " + kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
    }

    public static Map<String, Object> getTimeDistributedLayerConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        if (!map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME())) {
            throw new InvalidKerasConfigurationException("Field " + kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME() + " missing from layer config");
        }
        if (!map.get(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME()).equals(kerasLayerConfiguration.getLAYER_CLASS_NAME_TIME_DISTRIBUTED())) {
            throw new InvalidKerasConfigurationException("Expected " + kerasLayerConfiguration.getLAYER_CLASS_NAME_TIME_DISTRIBUTED() + " layer, found " + map.get(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME()));
        }
        if (!map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_CONFIG())) {
            throw new InvalidKerasConfigurationException("Field " + kerasLayerConfiguration.getLAYER_FIELD_CONFIG() + " missing from layer config");
        }
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        Map map2 = (Map) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_LAYER());
        map.put(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME(), map2.get(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME()));
        map.put(kerasLayerConfiguration.getLAYER_FIELD_NAME(), map2.get(kerasLayerConfiguration.getLAYER_FIELD_CLASS_NAME()));
        innerLayerConfigFromConfig.putAll(getInnerLayerConfigFromConfig(map2, kerasLayerConfiguration));
        innerLayerConfigFromConfig.remove(kerasLayerConfiguration.getLAYER_FIELD_LAYER());
        return map;
    }

    public static Map<String, Object> getInnerLayerConfigFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_CONFIG())) {
            return (Map) map.get(kerasLayerConfiguration.getLAYER_FIELD_CONFIG());
        }
        throw new InvalidKerasConfigurationException("Field " + kerasLayerConfiguration.getLAYER_FIELD_CONFIG() + " missing from layer config");
    }

    public static String getLayerNameFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_NAME())) {
            return (String) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_NAME());
        }
        throw new InvalidKerasConfigurationException("Field " + kerasLayerConfiguration.getLAYER_FIELD_NAME() + " missing from layer config");
    }

    public static int[] getInputShapeFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (!innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_BATCH_INPUT_SHAPE())) {
            return null;
        }
        List list = (List) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_BATCH_INPUT_SHAPE());
        int[] iArr = new int[list.size() - 1];
        for (int i = 1; i < list.size(); i++) {
            iArr[i - 1] = list.get(i) != null ? ((Integer) list.get(i)).intValue() : 0;
        }
        return iArr;
    }

    public static KerasLayer.DimOrder getDimOrderFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        KerasLayer.DimOrder dimOrder = KerasLayer.DimOrder.NONE;
        if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_BACKEND())) {
            String str = (String) map.get(kerasLayerConfiguration.getLAYER_FIELD_BACKEND());
            if (str.equals("tensorflow") || str.equals("cntk")) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (str.equals("theano")) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            }
        }
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_DIM_ORDERING())) {
            String str2 = (String) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_DIM_ORDERING());
            if (str2.equals(kerasLayerConfiguration.getDIM_ORDERING_TENSORFLOW())) {
                dimOrder = KerasLayer.DimOrder.TENSORFLOW;
            } else if (str2.equals(kerasLayerConfiguration.getDIM_ORDERING_THEANO())) {
                dimOrder = KerasLayer.DimOrder.THEANO;
            } else {
                log.warn("Keras layer has unknown Keras dimension order: " + dimOrder);
            }
        }
        return dimOrder;
    }

    public static List<String> getInboundLayerNamesFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) {
        ArrayList arrayList = new ArrayList();
        if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_INBOUND_NODES())) {
            List list = (List) map.get(kerasLayerConfiguration.getLAYER_FIELD_INBOUND_NODES());
            if (!list.isEmpty()) {
                Iterator it = ((List) list.get(0)).iterator();
                while (it.hasNext()) {
                    arrayList.add((String) ((List) it.next()).get(0));
                }
            }
        }
        return arrayList;
    }

    public static int getNOutFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        int intValue;
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_OUTPUT_DIM())) {
            intValue = ((Integer) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_OUTPUT_DIM())).intValue();
        } else if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM())) {
            intValue = ((Integer) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_EMBEDDING_OUTPUT_DIM())).intValue();
        } else {
            if (!innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_NB_FILTER())) {
                throw new InvalidKerasConfigurationException("Could not determine number of outputs for layer: no " + kerasLayerConfiguration.getLAYER_FIELD_OUTPUT_DIM() + " or " + kerasLayerConfiguration.getLAYER_FIELD_NB_FILTER() + " field found");
            }
            intValue = ((Integer) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_NB_FILTER())).intValue();
        }
        return intValue;
    }

    public static double getDropoutFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        double d = 1.0d;
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT())) {
            try {
                d = 1.0d - ((Double) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT())).doubleValue();
            } catch (Exception e) {
                d = 1.0d - ((Integer) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT())).intValue();
            }
        } else if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT_W())) {
            try {
                d = 1.0d - ((Double) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT_W())).doubleValue();
            } catch (Exception e2) {
                d = 1.0d - ((Integer) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_DROPOUT_W())).intValue();
            }
        }
        return d;
    }

    public static boolean getHasBiasFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        boolean z = true;
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_USE_BIAS())) {
            z = ((Boolean) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_USE_BIAS())).booleanValue();
        }
        return z;
    }

    public static boolean getZeroMaskingFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        boolean z = true;
        if (innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_MASK_ZERO())) {
            z = ((Boolean) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_MASK_ZERO())).booleanValue();
        }
        return z;
    }

    public static double getMaskingValueFromConfig(Map<String, Object> map, KerasLayerConfiguration kerasLayerConfiguration) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        double d = 0.0d;
        if (!innerLayerConfigFromConfig.containsKey(kerasLayerConfiguration.getLAYER_FIELD_MASK_VALUE())) {
            throw new InvalidKerasConfigurationException("No mask value found, field " + kerasLayerConfiguration.getLAYER_FIELD_MASK_VALUE());
        }
        try {
            d = ((Double) innerLayerConfigFromConfig.get(kerasLayerConfiguration.getLAYER_FIELD_MASK_VALUE())).doubleValue();
        } catch (Exception e) {
            log.warn("Couldn't read masking value, default to 0.0");
        }
        return d;
    }

    public static void removeDefaultWeights(Map<String, INDArray> map, KerasLayerConfiguration kerasLayerConfiguration) {
        if (map.size() > 2) {
            Set<String> keySet = map.keySet();
            keySet.remove(kerasLayerConfiguration.getKERAS_PARAM_NAME_W());
            keySet.remove(kerasLayerConfiguration.getKERAS_PARAM_NAME_B());
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }

    public static Pair<Boolean, Double> getMaskingConfiguration(List<String> list, Map<String, ? extends KerasLayer> map) {
        Boolean bool = false;
        Double valueOf = Double.valueOf(0.0d);
        for (String str : list) {
            if (map.containsKey(str)) {
                KerasLayer kerasLayer = map.get(str);
                if ((kerasLayer instanceof KerasEmbedding) && ((KerasEmbedding) kerasLayer).isZeroMasking()) {
                    bool = true;
                } else if (kerasLayer instanceof KerasMasking) {
                    bool = true;
                    valueOf = Double.valueOf(((KerasMasking) kerasLayer).getMaskingValue());
                }
            }
        }
        return new Pair<>(bool, valueOf);
    }
}
