package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.OrthogonalDistribution;
import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/utils/KerasInitilizationUtils.class */
public class KerasInitilizationUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasInitilizationUtils.class);

    public static Pair<WeightInit, Distribution> mapWeightInitialization(String str, KerasLayerConfiguration kerasLayerConfiguration, Map<String, Object> map, int i) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        double intValue;
        WeightInit weightInit = null;
        TruncatedNormalDistribution truncatedNormalDistribution = null;
        if (str != null) {
            if (str.equals(kerasLayerConfiguration.getINIT_GLOROT_NORMAL())) {
                weightInit = WeightInit.XAVIER;
            } else if (str.equals(kerasLayerConfiguration.getINIT_GLOROT_UNIFORM())) {
                weightInit = WeightInit.XAVIER_UNIFORM;
            } else if (str.equals(kerasLayerConfiguration.getINIT_LECUN_NORMAL())) {
                weightInit = WeightInit.LECUN_NORMAL;
            } else if (str.equals(kerasLayerConfiguration.getINIT_LECUN_UNIFORM())) {
                weightInit = WeightInit.LECUN_UNIFORM;
            } else if (str.equals(kerasLayerConfiguration.getINIT_HE_NORMAL())) {
                weightInit = WeightInit.RELU;
            } else if (str.equals(kerasLayerConfiguration.getINIT_HE_UNIFORM())) {
                weightInit = WeightInit.RELU_UNIFORM;
            } else if (str.equals(kerasLayerConfiguration.getINIT_ONE()) || str.equals(kerasLayerConfiguration.getINIT_ONES()) || str.equals(kerasLayerConfiguration.getINIT_ONES_ALIAS())) {
                weightInit = WeightInit.ONES;
            } else if (str.equals(kerasLayerConfiguration.getINIT_ZERO()) || str.equals(kerasLayerConfiguration.getINIT_ZEROS()) || str.equals(kerasLayerConfiguration.getINIT_ZEROS_ALIAS())) {
                weightInit = WeightInit.ZERO;
            } else if (str.equals(kerasLayerConfiguration.getINIT_UNIFORM()) || str.equals(kerasLayerConfiguration.getINIT_RANDOM_UNIFORM()) || str.equals(kerasLayerConfiguration.getINIT_RANDOM_UNIFORM_ALIAS())) {
                if (i == 2) {
                    truncatedNormalDistribution = new UniformDistribution(((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_MINVAL())).doubleValue(), ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_MAXVAL())).doubleValue());
                } else {
                    double d = 0.05d;
                    if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())) {
                        d = ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).doubleValue();
                    }
                    truncatedNormalDistribution = new UniformDistribution(-d, d);
                }
                weightInit = WeightInit.DISTRIBUTION;
            } else if (str.equals(kerasLayerConfiguration.getINIT_NORMAL()) || str.equals(kerasLayerConfiguration.getINIT_RANDOM_NORMAL()) || str.equals(kerasLayerConfiguration.getINIT_RANDOM_NORMAL_ALIAS())) {
                if (i == 2) {
                    truncatedNormalDistribution = new NormalDistribution(((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_MEAN())).doubleValue(), ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_STDDEV())).doubleValue());
                } else {
                    double d2 = 0.05d;
                    if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())) {
                        d2 = ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).doubleValue();
                    }
                    truncatedNormalDistribution = new NormalDistribution(0.0d, d2);
                }
                weightInit = WeightInit.DISTRIBUTION;
            } else if (str.equals(kerasLayerConfiguration.getINIT_CONSTANT()) || str.equals(kerasLayerConfiguration.getINIT_CONSTANT_ALIAS())) {
                truncatedNormalDistribution = new ConstantDistribution(((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_VALUE())).doubleValue());
                weightInit = WeightInit.DISTRIBUTION;
            } else if (str.equals(kerasLayerConfiguration.getINIT_ORTHOGONAL()) || str.equals(kerasLayerConfiguration.getINIT_ORTHOGONAL_ALIAS())) {
                if (i == 2) {
                    truncatedNormalDistribution = new OrthogonalDistribution(((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_GAIN())).doubleValue());
                } else {
                    double d3 = 1.1d;
                    if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())) {
                        d3 = ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).doubleValue();
                    }
                    truncatedNormalDistribution = new OrthogonalDistribution(d3);
                }
                weightInit = WeightInit.DISTRIBUTION;
            } else if (str.equals(kerasLayerConfiguration.getINIT_TRUNCATED_NORMAL()) || str.equals(kerasLayerConfiguration.getINIT_TRUNCATED_NORMAL_ALIAS())) {
                truncatedNormalDistribution = new TruncatedNormalDistribution(((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_MEAN())).doubleValue(), ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_STDDEV())).doubleValue());
                weightInit = WeightInit.DISTRIBUTION;
            } else if (str.equals(kerasLayerConfiguration.getINIT_IDENTITY()) || str.equals(kerasLayerConfiguration.getINIT_IDENTITY_ALIAS())) {
                if (i != 2) {
                    double d4 = 1.0d;
                    if (map.containsKey(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())) {
                        d4 = ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).doubleValue();
                    }
                    if (d4 != 1.0d) {
                        log.warn("Scaled identity weight init not supported, setting scale=1");
                    }
                } else if (((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_GAIN())).doubleValue() != 1.0d) {
                    log.warn("Scaled identity weight init not supported, setting gain=1");
                }
                weightInit = WeightInit.IDENTITY;
            } else {
                if (!str.equals(kerasLayerConfiguration.getINIT_VARIANCE_SCALING())) {
                    throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + str);
                }
                try {
                    intValue = ((Double) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).doubleValue();
                } catch (Exception e) {
                    intValue = ((Integer) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_SCALE())).intValue();
                }
                if (intValue != 1.0d) {
                    log.warn("Scaled identity weight init not supported, setting scale=1");
                }
                String str2 = (String) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_MODE());
                String str3 = (String) map.get(kerasLayerConfiguration.getLAYER_FIELD_INIT_DISTRIBUTION());
                boolean z = -1;
                switch (str2.hashCode()) {
                    case -1281840687:
                        if (str2.equals("fan_in")) {
                            z = false;
                            break;
                        }
                        break;
                    case -1082362970:
                        if (str2.equals("fan_avg")) {
                            z = 2;
                            break;
                        }
                        break;
                    case -1082349534:
                        if (str2.equals("fan_out")) {
                            z = true;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        if (!str3.equals("normal")) {
                            weightInit = WeightInit.VAR_SCALING_UNIFORM_FAN_IN;
                            break;
                        } else {
                            weightInit = WeightInit.VAR_SCALING_NORMAL_FAN_IN;
                            break;
                        }
                    case true:
                        if (!str3.equals("normal")) {
                            weightInit = WeightInit.VAR_SCALING_UNIFORM_FAN_OUT;
                            break;
                        } else {
                            weightInit = WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
                            break;
                        }
                    case true:
                        if (!str3.equals("normal")) {
                            weightInit = WeightInit.VAR_SCALING_UNIFORM_FAN_AVG;
                            break;
                        } else {
                            weightInit = WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
                            break;
                        }
                    default:
                        throw new InvalidKerasConfigurationException("Initialization argument 'mode' has to be either fan_in, fan_out or fan_avg");
                }
            }
        }
        return new Pair<>(weightInit, truncatedNormalDistribution);
    }

    public static Pair<WeightInit, Distribution> getWeightInitFromConfig(Map<String, Object> map, String str, boolean z, KerasLayerConfiguration kerasLayerConfiguration, int i) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> map2;
        String str2;
        Pair<WeightInit, Distribution> pair;
        Map<String, Object> innerLayerConfigFromConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(map, kerasLayerConfiguration);
        if (!innerLayerConfigFromConfig.containsKey(str)) {
            throw new InvalidKerasConfigurationException("Keras layer is missing " + str + " field");
        }
        if (i != 2) {
            str2 = (String) innerLayerConfigFromConfig.get(str);
            map2 = innerLayerConfigFromConfig;
        } else {
            HashMap hashMap = (HashMap) innerLayerConfigFromConfig.get(str);
            map2 = (HashMap) hashMap.get("config");
            if (!hashMap.containsKey("class_name")) {
                throw new UnsupportedKerasConfigurationException("Incomplete initialization class");
            }
            str2 = (String) hashMap.get("class_name");
        }
        try {
            pair = mapWeightInitialization(str2, kerasLayerConfiguration, map2, i);
        } catch (UnsupportedKerasConfigurationException e) {
            if (z) {
                throw e;
            }
            pair = new Pair<>(WeightInit.XAVIER, (Object) null);
            log.warn("Unknown weight initializer " + str2 + " (Using XAVIER instead).");
        }
        return pair;
    }
}
