package org.deeplearning4j.nn.modelimport.keras;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.DL4JFileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasModelImport.class */
public class KerasModelImport {
    private static final Logger log = LoggerFactory.getLogger(KerasModelImport.class);

    public static ComputationGraph importKerasModelAndWeights(InputStream inputStream, boolean z) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        File file = null;
        try {
            file = toTempFile(inputStream);
            ComputationGraph importKerasModelAndWeights = importKerasModelAndWeights(file.getAbsolutePath(), z);
            if (file != null) {
                file.delete();
            }
            return importKerasModelAndWeights;
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    public static ComputationGraph importKerasModelAndWeights(InputStream inputStream) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        File file = null;
        try {
            file = toTempFile(inputStream);
            ComputationGraph importKerasModelAndWeights = importKerasModelAndWeights(file.getAbsolutePath());
            if (file != null) {
                file.delete();
            }
            return importKerasModelAndWeights;
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream inputStream, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        File file = null;
        try {
            file = toTempFile(inputStream);
            MultiLayerNetwork importKerasSequentialModelAndWeights = importKerasSequentialModelAndWeights(file.getAbsolutePath(), z);
            if (file != null) {
                file.delete();
            }
            return importKerasSequentialModelAndWeights;
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream inputStream) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        File file = null;
        try {
            file = toTempFile(inputStream);
            MultiLayerNetwork importKerasSequentialModelAndWeights = importKerasSequentialModelAndWeights(file.getAbsolutePath());
            if (file != null) {
                file.delete();
            }
            return importKerasSequentialModelAndWeights;
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    public static ComputationGraph importKerasModelAndWeights(String str, int[] iArr, boolean z) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        return new KerasModel().modelBuilder.modelHdf5Filename(str).enforceTrainingConfig(z).inputShape(iArr).buildModel().getComputationGraph();
    }

    public static ComputationGraph importKerasModelAndWeights(String str, boolean z) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        return new KerasModel().modelBuilder.modelHdf5Filename(str).enforceTrainingConfig(z).buildModel().getComputationGraph();
    }

    public static ComputationGraph importKerasModelAndWeights(String str) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        return new KerasModel().modelBuilder().modelHdf5Filename(str).enforceTrainingConfig(true).buildModel().getComputationGraph();
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str, int[] iArr, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelHdf5Filename(str).enforceTrainingConfig(z).inputShape(iArr).buildSequential().getMultiLayerNetwork();
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelHdf5Filename(str).enforceTrainingConfig(z).buildSequential().getMultiLayerNetwork();
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelHdf5Filename(str).enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
    }

    public static ComputationGraph importKerasModelAndWeights(String str, String str2, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(false).weightsHdf5FilenameNoRoot(str2).enforceTrainingConfig(z).buildModel().getComputationGraph();
    }

    public static ComputationGraph importKerasModelAndWeights(String str, String str2) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(false).weightsHdf5FilenameNoRoot(str2).enforceTrainingConfig(true).buildModel().getComputationGraph();
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str, String str2, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelJsonFilename(str).weightsHdf5FilenameNoRoot(str2).enforceTrainingConfig(z).buildSequential().getMultiLayerNetwork();
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str, String str2) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelJsonFilename(str).weightsHdf5FilenameNoRoot(str2).enforceTrainingConfig(false).buildSequential().getMultiLayerNetwork();
    }

    public static ComputationGraphConfiguration importKerasModelConfiguration(String str, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(z).buildModel().getComputationGraphConfiguration();
    }

    public static ComputationGraphConfiguration importKerasModelConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
    }

    public static MultiLayerConfiguration importKerasSequentialConfiguration(String str, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(z).buildSequential().getMultiLayerConfiguration();
    }

    public static MultiLayerConfiguration importKerasSequentialConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasSequentialModel().modelBuilder().modelJsonFilename(str).enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration();
    }

    private static File toTempFile(InputStream inputStream) throws IOException {
        File createTempFile = DL4JFileUtils.createTempFile("DL4JKerasModelImport", ".bin");
        createTempFile.deleteOnExit();
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(createTempFile));
        Throwable th = null;
        try {
            try {
                IOUtils.copy(inputStream, bufferedOutputStream);
                bufferedOutputStream.flush();
                if (bufferedOutputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedOutputStream.close();
                    }
                }
                return createTempFile;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }
}
