package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/ConjugateGradient.class */
public class ConjugateGradient extends BaseOptimizer {
    private static final long serialVersionUID = -1269296013474864091L;
    private static final Logger logger = LoggerFactory.getLogger(ConjugateGradient.class);

    public ConjugateGradient(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<TrainingListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
        INDArray iNDArray = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        INDArray iNDArray2 = (INDArray) this.searchState.get(BaseOptimizer.SEARCH_DIR);
        if (iNDArray2 == null) {
            this.searchState.put(BaseOptimizer.SEARCH_DIR, iNDArray);
        } else {
            iNDArray2.assign(iNDArray);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
        INDArray iNDArray2 = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        INDArray iNDArray3 = (INDArray) this.searchState.get(BaseOptimizer.SEARCH_DIR);
        double dot = Nd4j.getBlasWrapper().dot(iNDArray.sub(iNDArray2), iNDArray);
        double dot2 = Nd4j.getBlasWrapper().dot(iNDArray2, iNDArray2);
        double max = Math.max(dot / dot2, EvaluationBinary.DEFAULT_EDGE_VALUE);
        if (dot <= EvaluationBinary.DEFAULT_EDGE_VALUE) {
            logger.debug("Polak-Ribiere gamma <= 0.0; using gamma=0.0 -> SGD line search. dgg={}, gg={}", Double.valueOf(dot), Double.valueOf(dot2));
        }
        INDArray addi = iNDArray3.muli(Double.valueOf(max)).addi(iNDArray);
        this.searchState.put(BaseOptimizer.GRADIENT_KEY, iNDArray);
        this.searchState.put(BaseOptimizer.SEARCH_DIR, addi);
    }
}
