package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.config.KerasModelConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasLoss;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLstm;
import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasModel.class */
public class KerasModel {
    private static final Logger log = LoggerFactory.getLogger(KerasModel.class);
    protected static KerasModelConfiguration config = new KerasModelConfiguration();
    KerasModelBuilder modelBuilder;
    protected String className;
    protected boolean enforceTrainingConfig;
    protected Map<String, KerasLayer> layers;
    List<KerasLayer> layersOrdered;
    Map<String, InputType> outputTypes;
    ArrayList<String> inputLayerNames;
    ArrayList<String> outputLayerNames;
    boolean useTruncatedBPTT;
    int truncatedBPTT;
    int kerasMajorVersion;
    String kerasBackend;

    public KerasModel() {
        this.modelBuilder = new KerasModelBuilder(config);
        this.useTruncatedBPTT = false;
        this.truncatedBPTT = 0;
    }

    public KerasModelBuilder modelBuilder() {
        return this.modelBuilder;
    }

    public KerasModel(KerasModelBuilder kerasModelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(kerasModelBuilder.getModelJson(), kerasModelBuilder.getModelYaml(), kerasModelBuilder.getWeightsArchive(), kerasModelBuilder.getWeightsRoot(), kerasModelBuilder.getTrainingJson(), kerasModelBuilder.getTrainingArchive(), kerasModelBuilder.isEnforceTrainingConfig());
    }

    protected KerasModel(String str, String str2, Hdf5Archive hdf5Archive, String str3, String str4, Hdf5Archive hdf5Archive2, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.modelBuilder = new KerasModelBuilder(config);
        this.useTruncatedBPTT = false;
        this.truncatedBPTT = 0;
        Map<String, Object> parseModelConfig = KerasModelUtils.parseModelConfig(str, str2);
        this.kerasMajorVersion = KerasModelUtils.determineKerasMajorVersion(parseModelConfig, config);
        this.kerasBackend = KerasModelUtils.determineKerasBackend(parseModelConfig, config);
        this.enforceTrainingConfig = z;
        if (!parseModelConfig.containsKey(config.getFieldClassName())) {
            throw new InvalidKerasConfigurationException("Could not determine Keras model class (no " + config.getFieldClassName() + " field found)");
        }
        this.className = (String) parseModelConfig.get(config.getFieldClassName());
        if (!this.className.equals(config.getFieldClassNameModel())) {
            throw new InvalidKerasConfigurationException("Expected model class name " + config.getFieldClassNameModel() + " (found " + this.className + ")");
        }
        if (!parseModelConfig.containsKey(config.getModelFieldConfig())) {
            throw new InvalidKerasConfigurationException("Could not find model configuration details (no " + config.getModelFieldConfig() + " in model config)");
        }
        Map map = (Map) parseModelConfig.get(config.getModelFieldConfig());
        if (!map.containsKey(config.getModelFieldInputLayers())) {
            throw new InvalidKerasConfigurationException("Could not find list of input layers (no " + config.getModelFieldInputLayers() + " field found)");
        }
        this.inputLayerNames = new ArrayList<>();
        Iterator it = ((List) map.get(config.getModelFieldInputLayers())).iterator();
        while (it.hasNext()) {
            this.inputLayerNames.add((String) ((List) it.next()).get(0));
        }
        if (!map.containsKey(config.getModelFieldOutputLayers())) {
            throw new InvalidKerasConfigurationException("Could not find list of output layers (no " + config.getModelFieldOutputLayers() + " field found)");
        }
        this.outputLayerNames = new ArrayList<>();
        Iterator it2 = ((List) map.get(config.getModelFieldOutputLayers())).iterator();
        while (it2.hasNext()) {
            this.outputLayerNames.add((String) ((List) it2.next()).get(0));
        }
        if (!map.containsKey(config.getModelFieldLayers())) {
            throw new InvalidKerasConfigurationException("Could not find layer configurations (no " + config.getModelFieldLayers() + " field found)");
        }
        prepareLayers((List) map.get(config.getModelFieldLayers()));
        if (str4 != null && z) {
            importTrainingConfiguration(str4);
        }
        inferOutputTypes();
        if (hdf5Archive != null) {
            KerasModelUtils.importWeights(hdf5Archive, str3, this.layers, this.kerasMajorVersion, this.kerasBackend);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void prepareLayers(List<Object> list) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.layersOrdered = new ArrayList();
        this.layers = new HashMap();
        Iterator<Object> it = list.iterator();
        while (it.hasNext()) {
            Map map = (Map) it.next();
            map.put(config.getFieldKerasVersion(), Integer.valueOf(this.kerasMajorVersion));
            if (this.kerasMajorVersion == 2 && this.kerasBackend != null) {
                map.put(config.getFieldBackend(), this.kerasBackend);
            }
            KerasLayer kerasLayerFromConfig = KerasLayerUtils.getKerasLayerFromConfig(map, this.enforceTrainingConfig, new KerasLayer(Integer.valueOf(this.kerasMajorVersion)).conf, KerasLayer.customLayers, this.layers);
            this.layersOrdered.add(kerasLayerFromConfig);
            this.layers.put(kerasLayerFromConfig.getLayerName(), kerasLayerFromConfig);
            if (kerasLayerFromConfig instanceof KerasLstm) {
                this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasLstm) kerasLayerFromConfig).getUnroll();
            }
            if (kerasLayerFromConfig instanceof KerasSimpleRnn) {
                this.useTruncatedBPTT = this.useTruncatedBPTT || ((KerasSimpleRnn) kerasLayerFromConfig).getUnroll();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void importTrainingConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> parseJsonString = KerasModelUtils.parseJsonString(str);
        ArrayList<KerasLayer> arrayList = new ArrayList();
        if (!parseJsonString.containsKey(config.getTrainingLoss())) {
            throw new InvalidKerasConfigurationException("Could not determine training loss function (no " + config.getTrainingLoss() + " field found in training config)");
        }
        Object obj = parseJsonString.get(config.getTrainingLoss());
        if (obj instanceof String) {
            String str2 = (String) obj;
            Iterator<String> it = this.outputLayerNames.iterator();
            while (it.hasNext()) {
                String next = it.next();
                arrayList.add(new KerasLoss(next + "_loss", next, str2));
            }
        } else if (obj instanceof Map) {
            Map map = (Map) obj;
            for (String str3 : map.keySet()) {
                Object obj2 = map.get(str3);
                if (!(obj2 instanceof String)) {
                    throw new InvalidKerasConfigurationException("Unknown Keras loss " + obj2.toString());
                }
                arrayList.add(new KerasLoss(str3 + "_loss", str3, (String) obj2));
            }
        }
        this.outputLayerNames.clear();
        for (KerasLayer kerasLayer : arrayList) {
            this.layersOrdered.add(kerasLayer);
            this.layers.put(kerasLayer.getLayerName(), kerasLayer);
            this.outputLayerNames.add(kerasLayer.getLayerName());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void inferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        InputType outputType;
        this.outputTypes = new HashMap();
        for (KerasLayer kerasLayer : this.layersOrdered) {
            if (kerasLayer instanceof KerasInput) {
                outputType = kerasLayer.getOutputType(new InputType[0]);
                this.truncatedBPTT = ((KerasInput) kerasLayer).getTruncatedBptt();
            } else {
                InputType[] inputTypeArr = new InputType[kerasLayer.getInboundLayerNames().size()];
                int i = 0;
                Iterator<String> it = kerasLayer.getInboundLayerNames().iterator();
                while (it.hasNext()) {
                    int i2 = i;
                    i++;
                    inputTypeArr[i2] = this.outputTypes.get(it.next());
                }
                outputType = kerasLayer.getOutputType(inputTypeArr);
            }
            this.outputTypes.put(kerasLayer.getLayerName(), outputType);
        }
    }

    public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(config.getFieldClassNameModel()) && !this.className.equals(config.getFieldClassNameSequential())) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with ComputationGraph");
        }
        ComputationGraphConfiguration.GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder().graphBuilder();
        String[] strArr = new String[this.inputLayerNames.size()];
        this.inputLayerNames.toArray(strArr);
        graphBuilder.addInputs(strArr);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = this.inputLayerNames.iterator();
        while (it.hasNext()) {
            arrayList.add(this.layers.get(it.next()).getOutputType(new InputType[0]));
        }
        InputType[] inputTypeArr = new InputType[arrayList.size()];
        arrayList.toArray(inputTypeArr);
        graphBuilder.setInputTypes(inputTypeArr);
        String[] strArr2 = new String[this.outputLayerNames.size()];
        this.outputLayerNames.toArray(strArr2);
        graphBuilder.setOutputs(strArr2);
        HashMap hashMap = new HashMap();
        for (KerasLayer kerasLayer : this.layersOrdered) {
            List<String> inboundLayerNames = kerasLayer.getInboundLayerNames();
            String[] strArr3 = new String[inboundLayerNames.size()];
            inboundLayerNames.toArray(strArr3);
            ArrayList arrayList2 = new ArrayList();
            Iterator<String> it2 = inboundLayerNames.iterator();
            while (it2.hasNext()) {
                arrayList2.add(this.outputTypes.get(it2.next()));
            }
            InputType[] inputTypeArr2 = new InputType[arrayList2.size()];
            arrayList2.toArray(inputTypeArr2);
            InputPreProcessor inputPreprocessor = kerasLayer.getInputPreprocessor(inputTypeArr2);
            if (kerasLayer.isLayer()) {
                if (inputPreprocessor != null) {
                    hashMap.put(kerasLayer.getLayerName(), inputPreprocessor);
                }
                graphBuilder.addLayer(kerasLayer.getLayerName(), kerasLayer.getLayer(), strArr3);
                if (this.outputLayerNames.contains(kerasLayer.getLayerName()) && !(kerasLayer.getLayer() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output layer " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (kerasLayer.isVertex()) {
                if (inputPreprocessor != null) {
                    hashMap.put(kerasLayer.getLayerName(), inputPreprocessor);
                }
                graphBuilder.addVertex(kerasLayer.getLayerName(), kerasLayer.getVertex(), strArr3);
                if (this.outputLayerNames.contains(kerasLayer.getLayerName()) && !(kerasLayer.getVertex() instanceof IOutputLayer)) {
                    log.warn("Model cannot be trained: output vertex " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
                }
            } else if (kerasLayer.isInputPreProcessor()) {
                if (inputPreprocessor == null) {
                    throw new UnsupportedKerasConfigurationException("Layer " + kerasLayer.getLayerName() + " could not be mapped to Layer, Vertex, or InputPreProcessor");
                }
                graphBuilder.addVertex(kerasLayer.getLayerName(), new PreprocessorVertex(inputPreprocessor), strArr3);
            }
            if (this.outputLayerNames.contains(kerasLayer.getLayerName())) {
                log.warn("Model cannot be trained: output " + kerasLayer.getLayerName() + " is not an IOutputLayer (no loss function specified)");
            }
        }
        graphBuilder.setInputPreProcessors(hashMap);
        if (!this.useTruncatedBPTT || this.truncatedBPTT <= 0) {
            graphBuilder.backpropType(BackpropType.Standard);
        } else {
            graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return graphBuilder.build();
    }

    public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getComputationGraph(true);
    }

    public ComputationGraph getComputationGraph(boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        ComputationGraph computationGraph = new ComputationGraph(getComputationGraphConfiguration());
        computationGraph.init();
        if (z) {
            computationGraph = (ComputationGraph) KerasModelUtils.copyWeightsToModel(computationGraph, this.layers);
        }
        return computationGraph;
    }
}
