package org.deeplearning4j.spark.parameterserver.training;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingWorker;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.class */
public class SharedTrainingWorker extends BaseTrainingWorker<SharedTrainingResult> implements TrainingWorker<SharedTrainingResult> {
    private final Broadcast<NetBroadcastTuple> broadcastModel;
    private final Broadcast<SharedTrainingConfiguration> broadcastConfiguration;

    public SharedTrainingWorker(Broadcast<NetBroadcastTuple> broadcast, Broadcast<SharedTrainingConfiguration> broadcast2) {
        this.broadcastModel = broadcast;
        this.broadcastConfiguration = broadcast2;
    }

    public void removeHook(TrainingHook trainingHook) {
        throw new UnsupportedOperationException();
    }

    public void addHook(TrainingHook trainingHook) {
        throw new UnsupportedOperationException();
    }

    public MultiLayerNetwork getInitialModel() {
        NetBroadcastTuple netBroadcastTuple = (NetBroadcastTuple) this.broadcastModel.getValue();
        if (netBroadcastTuple.getConfiguration() == null) {
            return null;
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(netBroadcastTuple.getConfiguration());
        multiLayerNetwork.init();
        if (netBroadcastTuple.getParameters() != null) {
            multiLayerNetwork.setParams(netBroadcastTuple.getParameters());
        }
        if (netBroadcastTuple.getUpdaterState() != null) {
            multiLayerNetwork.getUpdater().getStateViewArray().assign(netBroadcastTuple.getUpdaterState());
        }
        return multiLayerNetwork;
    }

    public ComputationGraph getInitialModelGraph() {
        NetBroadcastTuple netBroadcastTuple = (NetBroadcastTuple) this.broadcastModel.getValue();
        if (netBroadcastTuple.getGraphConfiguration() == null) {
            return null;
        }
        ComputationGraph computationGraph = new ComputationGraph(netBroadcastTuple.getGraphConfiguration());
        computationGraph.init();
        if (netBroadcastTuple.getParameters() != null) {
            computationGraph.setParams(netBroadcastTuple.getParameters());
        }
        if (netBroadcastTuple.getUpdaterState() != null) {
            computationGraph.getUpdater().getUpdaterStateViewArray().assign(netBroadcastTuple.getUpdaterState());
        }
        return computationGraph;
    }

    /* renamed from: processMinibatch, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m16processMinibatch(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z) {
        throw new UnsupportedOperationException();
    }

    /* renamed from: processMinibatch, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m15processMinibatch(DataSet dataSet, ComputationGraph computationGraph, boolean z) {
        throw new UnsupportedOperationException();
    }

    /* renamed from: processMinibatch, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m14processMinibatch(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph computationGraph, boolean z) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z) {
        throw new UnsupportedOperationException();
    }

    /* renamed from: getFinalResult, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m13getFinalResult(MultiLayerNetwork multiLayerNetwork) {
        throw new UnsupportedOperationException();
    }

    /* renamed from: getFinalResult, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m12getFinalResult(ComputationGraph computationGraph) {
        throw new UnsupportedOperationException();
    }

    /* renamed from: getFinalResultNoData, reason: merged with bridge method [inline-methods] */
    public SharedTrainingResult m11getFinalResultNoData() {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultNoDataWithStats() {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork multiLayerNetwork) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph computationGraph) {
        throw new UnsupportedOperationException();
    }

    public WorkerConfiguration getDataConfiguration() {
        throw new UnsupportedOperationException();
    }

    public Broadcast<NetBroadcastTuple> getBroadcastModel() {
        return this.broadcastModel;
    }

    public Broadcast<SharedTrainingConfiguration> getBroadcastConfiguration() {
        return this.broadcastConfiguration;
    }
}
