package org.deeplearning4j.spark.impl.paramavg;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
import java.util.Random;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.paramavg.util.ExportSupport;
import org.deeplearning4j.spark.util.serde.StorageLevelDeserializer;
import org.deeplearning4j.spark.util.serde.StorageLevelSerializer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.class */
public abstract class BaseTrainingMaster<R extends TrainingResult, W extends TrainingWorker<R>> implements TrainingMaster<R, W> {
    private static final Logger log = LoggerFactory.getLogger(BaseTrainingMaster.class);
    protected static ObjectMapper jsonMapper;
    protected static ObjectMapper yamlMapper;
    protected boolean collectTrainingStats;
    protected ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    protected String lastRDDExportPath;
    protected int batchSizePerWorker;
    protected Random rng;
    protected String trainingMasterUID;
    protected StatsStorageRouter statsStorage;
    protected Collection<IterationListener> listeners;
    protected Repartition repartition;
    protected RepartitionStrategy repartitionStrategy;

    @JsonDeserialize(using = StorageLevelDeserializer.class)
    @JsonSerialize(using = StorageLevelSerializer.class)
    protected StorageLevel storageLevel;
    protected int lastExportedRDDId = Integer.MIN_VALUE;
    protected String exportDirectory = null;

    @JsonDeserialize(using = StorageLevelDeserializer.class)
    @JsonSerialize(using = StorageLevelSerializer.class)
    protected StorageLevel storageLevelStreams = StorageLevel.MEMORY_ONLY();
    protected RDDTrainingApproach rddTrainingApproach = RDDTrainingApproach.Export;

    /* JADX INFO: Access modifiers changed from: protected */
    public static synchronized ObjectMapper getJsonMapper() {
        if (jsonMapper == null) {
            jsonMapper = getNewMapper(new JsonFactory());
        }
        return jsonMapper;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static synchronized ObjectMapper getYamlMapper() {
        if (yamlMapper == null) {
            yamlMapper = getNewMapper(new YAMLFactory());
        }
        return yamlMapper;
    }

    protected static ObjectMapper getNewMapper(JsonFactory jsonFactory) {
        ObjectMapper objectMapper = new ObjectMapper(jsonFactory);
        objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        objectMapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
        objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
        objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
        objectMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
        return objectMapper;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public JavaRDD<String> exportIfRequired(JavaSparkContext javaSparkContext, JavaRDD<DataSet> javaRDD) {
        String export;
        ExportSupport.assertExportSupported(javaSparkContext);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int id = javaRDD.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            export = export(javaRDD);
        } else if (this.lastExportedRDDId == id) {
            export = getBaseDirForRDD(javaRDD);
        } else {
            deleteTempDir(javaSparkContext, this.lastRDDExportPath);
            export = export(javaRDD);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return javaSparkContext.textFile(export + "paths/");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public JavaRDD<String> exportIfRequiredMDS(JavaSparkContext javaSparkContext, JavaRDD<MultiDataSet> javaRDD) {
        String exportMDS;
        ExportSupport.assertExportSupported(javaSparkContext);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int id = javaRDD.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            exportMDS = exportMDS(javaRDD);
        } else if (this.lastExportedRDDId == id) {
            exportMDS = getBaseDirForRDD(javaRDD);
        } else {
            deleteTempDir(javaSparkContext, this.lastRDDExportPath);
            exportMDS = exportMDS(javaRDD);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return javaSparkContext.textFile(exportMDS + "paths/");
    }

    protected String export(JavaRDD<DataSet> javaRDD) {
        String baseDirForRDD = getBaseDirForRDD(javaRDD);
        String str = baseDirForRDD + "data/";
        log.info("Initiating RDD<DataSet> export at {}", baseDirForRDD);
        javaRDD.mapPartitionsWithIndex(new BatchAndExportDataSetsFunction(this.batchSizePerWorker, str), true).saveAsTextFile(baseDirForRDD + "paths/");
        log.info("RDD<DataSet> export complete at {}", baseDirForRDD);
        this.lastExportedRDDId = javaRDD.id();
        this.lastRDDExportPath = baseDirForRDD;
        return baseDirForRDD;
    }

    protected String exportMDS(JavaRDD<MultiDataSet> javaRDD) {
        String baseDirForRDD = getBaseDirForRDD(javaRDD);
        String str = baseDirForRDD + "data/";
        log.info("Initiating RDD<MultiDataSet> export at {}", baseDirForRDD);
        javaRDD.mapPartitionsWithIndex(new BatchAndExportMultiDataSetsFunction(this.batchSizePerWorker, str), true).saveAsTextFile(baseDirForRDD + "paths/");
        log.info("RDD<MultiDataSet> export complete at {}", baseDirForRDD);
        this.lastExportedRDDId = javaRDD.id();
        this.lastRDDExportPath = baseDirForRDD;
        return baseDirForRDD;
    }

    protected String getBaseDirForRDD(JavaRDD<?> javaRDD) {
        if (this.exportDirectory == null) {
            this.exportDirectory = getDefaultExportDirectory(javaRDD.context());
        }
        return this.exportDirectory + (this.exportDirectory.endsWith("/") ? "" : "/") + this.trainingMasterUID + "/" + javaRDD.id() + "/";
    }

    protected boolean deleteTempDir(JavaSparkContext javaSparkContext, String str) {
        log.info("Attempting to delete temporary directory: {}", str);
        try {
            try {
                FileSystem.get(new URI(str), javaSparkContext.hadoopConfiguration()).delete(new Path(str), true);
                log.info("Deleted temporary directory: {}", str);
                return true;
            } catch (IOException e) {
                log.warn("Could not delete temporary directory: {}", str, e);
                return false;
            }
        } catch (IOException | URISyntaxException e2) {
            throw new RuntimeException(e2);
        }
    }

    protected String getDefaultExportDirectory(SparkContext sparkContext) {
        String str = sparkContext.hadoopConfiguration().get("hadoop.tmp.dir");
        if (!str.endsWith("/") && !str.endsWith("\\")) {
            str = str + "/";
        }
        return str + "dl4j/";
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean deleteTempFiles(JavaSparkContext javaSparkContext) {
        return this.lastRDDExportPath == null || deleteTempDir(javaSparkContext, this.lastRDDExportPath);
    }

    @Override // org.deeplearning4j.spark.api.TrainingMaster
    public boolean deleteTempFiles(SparkContext sparkContext) {
        return deleteTempFiles(new JavaSparkContext(sparkContext));
    }
}
