package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.class */
public class KerasSequentialModel extends KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasSequentialModel.class);

    public KerasSequentialModel(KerasModelBuilder kerasModelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(kerasModelBuilder.getModelJson(), kerasModelBuilder.getModelYaml(), kerasModelBuilder.getWeightsArchive(), kerasModelBuilder.getWeightsRoot(), kerasModelBuilder.getTrainingJson(), kerasModelBuilder.getTrainingArchive(), kerasModelBuilder.isEnforceTrainingConfig());
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v54, types: [org.deeplearning4j.nn.modelimport.keras.KerasLayer] */
    public KerasSequentialModel(String str, String str2, Hdf5Archive hdf5Archive, String str3, String str4, Hdf5Archive hdf5Archive2, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasInput kerasInput;
        Map<String, Object> parseModelConfig = KerasModelUtils.parseModelConfig(str, str2);
        this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(parseModelConfig, config);
        this.kerasBackend = KerasModelUtils.determineKerasBackend(parseModelConfig, config);
        this.enforceTrainingConfig = z;
        if (!parseModelConfig.containsKey(config.getFieldClassName())) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
        }
        this.className = (String) parseModelConfig.get(config.getFieldClassName());
        if (!this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Model class name must be " + config.getFieldClassNameSequential() + " (found " + this.className + ")");
        }
        if (!parseModelConfig.containsKey(config.getModelFieldConfig())) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no " + config.getModelFieldConfig() + " field found)");
        }
        prepareLayers((List) parseModelConfig.get(config.getModelFieldConfig()));
        if (this.layersOrdered.get(0) instanceof KerasInput) {
            kerasInput = this.layersOrdered.get(0);
        } else {
            kerasInput = new KerasInput("input1", this.layersOrdered.get(0).getInputShape());
            kerasInput.setDimOrder(this.layersOrdered.get(0).getDimOrder());
            this.layers.put(kerasInput.getLayerName(), kerasInput);
            this.layersOrdered.add(0, kerasInput);
        }
        this.inputLayerNames = new ArrayList<>(Collections.singletonList(kerasInput.getLayerName()));
        this.outputLayerNames = new ArrayList<>(Collections.singletonList(this.layersOrdered.get(this.layersOrdered.size() - 1).getLayerName()));
        KerasLayer kerasLayer = null;
        for (KerasLayer kerasLayer2 : this.layersOrdered) {
            if (kerasLayer != null) {
                kerasLayer2.setInboundLayerNames(Collections.singletonList(kerasLayer.getLayerName()));
            }
            kerasLayer = kerasLayer2;
        }
        if (str4 != null && z) {
            importTrainingConfiguration(str4);
        }
        inferOutputTypes();
        if (hdf5Archive != null) {
            KerasModelUtils.importWeights(hdf5Archive, str3, this.layers, this.kerasMajorVersion, this.kerasBackend);
        }
    }

    public KerasSequentialModel() {
    }

    public MultiLayerConfiguration getMultiLayerConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        InputPreProcessor inputPreprocessor;
        if (!this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
        }
        if (this.inputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
        }
        if (this.outputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
        }
        NeuralNetConfiguration.ListBuilder list = new NeuralNetConfiguration.Builder().list();
        KerasLayer kerasLayer = null;
        int i = 0;
        for (KerasLayer kerasLayer2 : this.layersOrdered) {
            if (kerasLayer2.isLayer()) {
                int size = kerasLayer2.getInboundLayerNames().size();
                if (size != 1) {
                    throw new InvalidKerasConfigurationException("Layers in MultiLayerConfiguration must have exactly one inbound layer (found " + size + " for layer " + kerasLayer2.getLayerName() + ")");
                }
                if (kerasLayer != null) {
                    InputType[] inputTypeArr = new InputType[1];
                    if (kerasLayer.isInputPreProcessor()) {
                        inputTypeArr[0] = this.outputTypes.get(kerasLayer.getInboundLayerNames().get(0));
                        inputPreprocessor = kerasLayer.getInputPreprocessor(inputTypeArr);
                    } else {
                        inputTypeArr[0] = this.outputTypes.get(kerasLayer.getLayerName());
                        inputPreprocessor = kerasLayer2.getInputPreprocessor(inputTypeArr);
                    }
                    if (inputPreprocessor != null) {
                        list.inputPreProcessor(Integer.valueOf(i), inputPreprocessor);
                    }
                }
                int i2 = i;
                i++;
                list.layer(i2, kerasLayer2.getLayer());
                if (this.outputLayerNames.contains(kerasLayer2.getLayerName()) && (kerasLayer2.getLayer() instanceof IOutputLayer)) {
                }
            } else if (kerasLayer2.getVertex() != null) {
                throw new InvalidKerasConfigurationException("Cannot add vertex to MultiLayerConfiguration (class name " + kerasLayer2.getClassName() + ", layer name " + kerasLayer2.getLayerName() + ")");
            }
            kerasLayer = kerasLayer2;
        }
        InputType outputType = this.layersOrdered.get(0).getOutputType(new InputType[0]);
        if (outputType != null) {
            list.setInputType(outputType);
        }
        if (!this.useTruncatedBPTT || this.truncatedBPTT <= 0) {
            list.backpropType(BackpropType.Standard);
        } else {
            list.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return list.build();
    }

    public MultiLayerNetwork getMultiLayerNetwork() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getMultiLayerNetwork(true);
    }

    public MultiLayerNetwork getMultiLayerNetwork(boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(getMultiLayerConfiguration());
        multiLayerNetwork.init();
        if (z) {
            multiLayerNetwork = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(multiLayerNetwork, this.layers);
        }
        return multiLayerNetwork;
    }
}
