package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction;
import org.deeplearning4j.optimize.stepfunctions.NegativeGradientStepFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/BaseOptimizer.class */
public abstract class BaseOptimizer implements ConvexOptimizer {
    protected NeuralNetConfiguration conf;
    protected static final Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    protected Collection<TrainingListener> trainingListeners;
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected Updater updater;
    protected ComputationGraphUpdater computationGraphUpdater;
    protected double step;
    private int batchSize;
    protected double score;
    protected double oldScore;
    public static final String GRADIENT_KEY = "g";
    public static final String SCORE_KEY = "score";
    public static final String PARAMS_KEY = "params";
    public static final String SEARCH_DIR = "searchDirection";
    protected GradientsAccumulator accumulator;
    protected double stepMax = Double.MAX_VALUE;
    protected Map<String, Object> searchState = new ConcurrentHashMap();

    public BaseOptimizer(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<TrainingListener> collection, Model model) {
        this.trainingListeners = new ArrayList();
        this.conf = neuralNetConfiguration;
        this.stepFunction = stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(getClass());
        this.trainingListeners = collection != null ? collection : new ArrayList<>();
        this.model = model;
        this.lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
        this.lineMaximizer.setStepMax(this.stepMax);
        this.lineMaximizer.setMaxIterations(neuralNetConfiguration.getMaxNumLineSearchIterations());
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setGradientsAccumulator(GradientsAccumulator gradientsAccumulator) {
        this.accumulator = gradientsAccumulator;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public GradientsAccumulator getGradientsAccumulator() {
        return this.accumulator;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public double score() {
        throw new UnsupportedOperationException("Not yet reimplemented");
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Updater getUpdater() {
        if (this.updater == null) {
            this.updater = UpdaterCreator.getUpdater(this.model);
        }
        return this.updater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setUpdater(Updater updater) {
        this.updater = updater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public ComputationGraphUpdater getComputationGraphUpdater() {
        if (this.computationGraphUpdater == null && (this.model instanceof ComputationGraph)) {
            this.computationGraphUpdater = new ComputationGraphUpdater((ComputationGraph) this.model);
        }
        return this.computationGraphUpdater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setUpdaterComputationGraph(ComputationGraphUpdater computationGraphUpdater) {
        this.computationGraphUpdater = computationGraphUpdater;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setListeners(Collection<TrainingListener> collection) {
        if (collection == null) {
            this.trainingListeners = Collections.emptyList();
        } else {
            this.trainingListeners = collection;
        }
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public NeuralNetConfiguration getConf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public Pair<Gradient, Double> gradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        this.oldScore = this.score;
        this.model.computeGradientAndScore(layerWorkspaceMgr);
        if (this.trainingListeners != null && !this.trainingListeners.isEmpty()) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    Iterator<TrainingListener> it = this.trainingListeners.iterator();
                    while (it.hasNext()) {
                        it.next().onGradientCalculation(this.model);
                    }
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        Pair<Gradient, Double> gradientAndScore = this.model.gradientAndScore();
        this.score = ((Double) gradientAndScore.getSecond()).doubleValue();
        updateGradientAccordingToParams((Gradient) gradientAndScore.getFirst(), this.model, this.model.batchSize(), layerWorkspaceMgr);
        return gradientAndScore;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize(LayerWorkspaceMgr layerWorkspaceMgr) {
        MemoryWorkspace scopeOutOfWorkspaces;
        int iterationCount;
        int epochCount;
        MemoryWorkspace scopeOutOfWorkspaces2;
        Throwable th;
        Pair<Gradient, Double> gradientAndScore = gradientAndScore(layerWorkspaceMgr);
        if (this.searchState.isEmpty()) {
            this.searchState.put("g", ((Gradient) gradientAndScore.getFirst()).gradient());
            scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            Throwable th2 = null;
            try {
                try {
                    setupSearchState(gradientAndScore);
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } else {
            this.searchState.put("g", ((Gradient) gradientAndScore.getFirst()).gradient());
        }
        scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th4 = null;
        try {
            try {
                preProcessLine();
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th5) {
                            th4.addSuppressed(th5);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                INDArray iNDArray = (INDArray) this.searchState.get("g");
                INDArray iNDArray2 = (INDArray) this.searchState.get(SEARCH_DIR);
                INDArray iNDArray3 = (INDArray) this.searchState.get(PARAMS_KEY);
                try {
                    this.step = this.lineMaximizer.optimize(iNDArray3, iNDArray, iNDArray2, layerWorkspaceMgr);
                } catch (InvalidStepException e) {
                    log.warn("Invalid step...continuing another iteration: {}", e.getMessage());
                    this.step = EvaluationBinary.DEFAULT_EDGE_VALUE;
                }
                if (this.step != EvaluationBinary.DEFAULT_EDGE_VALUE) {
                    this.stepFunction.step(iNDArray3, iNDArray2, this.step);
                    this.model.setParams(iNDArray3);
                } else {
                    log.debug("Step size returned by line search is 0.0.");
                }
                Pair<Gradient, Double> gradientAndScore2 = gradientAndScore(layerWorkspaceMgr);
                scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                Throwable th6 = null;
                try {
                    try {
                        postStep(((Gradient) gradientAndScore2.getFirst()).gradient());
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th7) {
                                    th6.addSuppressed(th7);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        iterationCount = getIterationCount(this.model);
                        epochCount = getEpochCount(this.model);
                        scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                        th = null;
                    } finally {
                    }
                } finally {
                    if (scopeOutOfWorkspaces != null) {
                        if (th6 != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th8) {
                                th6.addSuppressed(th8);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                }
            } finally {
            }
            try {
                try {
                    Iterator<TrainingListener> it = this.trainingListeners.iterator();
                    while (it.hasNext()) {
                        it.next().iterationDone(this.model, iterationCount, epochCount);
                    }
                    if (scopeOutOfWorkspaces2 != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces2.close();
                            } catch (Throwable th9) {
                                th.addSuppressed(th9);
                            }
                        } else {
                            scopeOutOfWorkspaces2.close();
                        }
                    }
                    incrementIterationCount(this.model, 1);
                    applyConstraints(this.model);
                    return true;
                } finally {
                }
            } catch (Throwable th10) {
                if (scopeOutOfWorkspaces2 != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th11) {
                            th.addSuppressed(th11);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
                throw th10;
            }
        } finally {
        }
    }

    protected void postFirstStep(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public int batchSize() {
        return this.batchSize;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setBatchSize(int i) {
        this.batchSize = i;
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void updateGradientAccordingToParams(Gradient gradient, Model model, int i, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) model;
            if (this.computationGraphUpdater == null) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    this.computationGraphUpdater = new ComputationGraphUpdater(computationGraph);
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            }
            this.computationGraphUpdater.update(gradient, getIterationCount(model), getEpochCount(model), i, layerWorkspaceMgr);
            return;
        }
        if (this.updater == null) {
            MemoryWorkspace scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th5 = null;
            try {
                this.updater = UpdaterCreator.getUpdater(model);
                if (scopeOutOfWorkspaces2 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
            } catch (Throwable th7) {
                if (scopeOutOfWorkspaces2 != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces2.close();
                        } catch (Throwable th8) {
                            th5.addSuppressed(th8);
                        }
                    } else {
                        scopeOutOfWorkspaces2.close();
                    }
                }
                throw th7;
            }
        }
        this.updater.update((Layer) model, gradient, getIterationCount(model), getEpochCount(model), i, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        INDArray gradient = ((Gradient) pair.getFirst()).gradient(this.conf.variables());
        INDArray dup = this.model.params().dup();
        this.searchState.put("g", gradient);
        this.searchState.put(SCORE_KEY, pair.getSecond());
        this.searchState.put(PARAMS_KEY, dup);
    }

    public static StepFunction getDefaultStepFunctionForOptimizer(Class<? extends ConvexOptimizer> cls) {
        return cls == StochasticGradientDescent.class ? new NegativeGradientStepFunction() : new NegativeDefaultStepFunction();
    }

    public static int getIterationCount(Model model) {
        return model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount() : model instanceof ComputationGraph ? ((ComputationGraph) model).getConfiguration().getIterationCount() : model.conf().getIterationCount();
    }

    public static void incrementIterationCount(Model model, int i) {
        if (model instanceof MultiLayerNetwork) {
            MultiLayerConfiguration layerWiseConfigurations = ((MultiLayerNetwork) model).getLayerWiseConfigurations();
            layerWiseConfigurations.setIterationCount(layerWiseConfigurations.getIterationCount() + i);
        } else if (!(model instanceof ComputationGraph)) {
            model.conf().setIterationCount(model.conf().getIterationCount() + i);
        } else {
            ComputationGraphConfiguration configuration = ((ComputationGraph) model).getConfiguration();
            configuration.setIterationCount(configuration.getIterationCount() + i);
        }
    }

    public static int getEpochCount(Model model) {
        return model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount() : model instanceof ComputationGraph ? ((ComputationGraph) model).getConfiguration().getEpochCount() : model.conf().getEpochCount();
    }

    public static void applyConstraints(Model model) {
        model.applyConstraints(getIterationCount(model), getEpochCount(model));
    }

    @Override // org.deeplearning4j.optimize.api.ConvexOptimizer
    public StepFunction getStepFunction() {
        return this.stepFunction;
    }
}
