package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/TrainingSession.class */
public class TrainingSession extends InferenceSession {
    private static final Logger log = LoggerFactory.getLogger(TrainingSession.class);
    protected TrainingConfig config;
    protected Map<String, String> gradVarToVarMap;
    protected Map<String, GradientUpdater> updaters;
    protected Map<String, Integer> lossVarsToLossIdx;
    protected double[] currIterLoss;
    protected Map<Class<?>, AtomicDouble> currIterRegLoss;
    protected List<Listener> listeners;

    public TrainingSession(SameDiff sameDiff) {
        super(sameDiff);
    }

    public Loss trainingIteration(TrainingConfig trainingConfig, Map<String, INDArray> map, Set<String> set, Map<String, GradientUpdater> map2, MultiDataSet multiDataSet, List<String> list, List<Listener> list2, At at) {
        List<String> list3;
        this.config = trainingConfig;
        this.updaters = map2;
        if (list2 == null) {
            this.listeners = null;
        } else {
            ArrayList arrayList = new ArrayList();
            for (Listener listener : list2) {
                if (listener.isActive(at.operation())) {
                    arrayList.add(listener);
                }
            }
            this.listeners = arrayList.isEmpty() ? null : arrayList;
        }
        ArrayList arrayList2 = new ArrayList();
        this.gradVarToVarMap = new HashMap();
        for (String str : set) {
            Preconditions.checkState(this.sameDiff.hasVariable(str), "SameDiff instance does not have a variable with name \"%s\"", str);
            SDVariable variable = this.sameDiff.getVariable(str);
            Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s", str, variable.getVariableType());
            SDVariable gradient = this.sameDiff.getVariable(str).getGradient();
            if (gradient != null) {
                arrayList2.add(gradient.name());
                this.gradVarToVarMap.put(gradient.name(), str);
            }
        }
        this.lossVarsToLossIdx = new LinkedHashMap();
        this.currIterLoss = new double[list.size()];
        this.currIterRegLoss = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            this.lossVarsToLossIdx.put(list.get(i), Integer.valueOf(i));
        }
        output(new ArrayList<>(this.gradVarToVarMap.keySet()), map, multiDataSet, arrayList2, list2, at);
        double[] dArr = new double[this.currIterLoss.length + this.currIterRegLoss.size()];
        System.arraycopy(this.currIterLoss, 0, dArr, 0, this.currIterLoss.length);
        if (this.currIterRegLoss.size() > 0) {
            list3 = new ArrayList(list.size() + this.currIterRegLoss.size());
            list3.addAll(list);
            int size = this.currIterRegLoss.size();
            for (Map.Entry<Class<?>, AtomicDouble> entry : this.currIterRegLoss.entrySet()) {
                list3.add(entry.getKey().getSimpleName());
                dArr[size] = entry.getValue().get();
            }
        } else {
            list3 = list;
        }
        Loss loss = new Loss(list3, dArr);
        if (list2 != null) {
            for (Listener listener2 : list2) {
                if (listener2.isActive(Operation.TRAINING)) {
                    listener2.iterationDone(this.sameDiff, at, multiDataSet, loss);
                }
            }
        }
        return loss;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.InferenceSession
    public INDArray[] getOutputs(SameDiffOp sameDiffOp, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet, Set<String> set4) {
        INDArray[] outputs2 = super.getOutputs2(sameDiffOp, frameIter, set, set2, set3, list, at, multiDataSet, set4);
        int i = 0;
        for (String str : sameDiffOp.getOutputsOfOp()) {
            if (this.lossVarsToLossIdx.containsKey(str)) {
                int intValue = this.lossVarsToLossIdx.get(str).intValue();
                INDArray iNDArray = outputs2[i];
                double d = iNDArray.isScalar() ? iNDArray.getDouble(0L) : iNDArray.sumNumber().doubleValue();
                double[] dArr = this.currIterLoss;
                dArr[intValue] = dArr[intValue] + d;
            }
            if (this.gradVarToVarMap.containsKey(str)) {
                String str2 = this.gradVarToVarMap.get(str);
                Variable variable = this.sameDiff.getVariables().get(str);
                if (variable.getInputsForOp() != null && variable.getInputsForOp().isEmpty()) {
                    throw new IllegalStateException("Op depends on gradient variable: " + str + " for variable " + str2);
                }
                GradientUpdater gradientUpdater = this.updaters.get(str2);
                Preconditions.checkState(gradientUpdater != null, "No updater found for variable \"%s\"", str2);
                Variable variable2 = this.sameDiff.getVariables().get(str2);
                INDArray iNDArray2 = outputs2[i];
                INDArray arr = variable2.getVariable().getArr();
                List<Regularization> regularization = this.config.getRegularization();
                if (regularization != null && regularization.size() > 0) {
                    double learningRate = this.config.getUpdater().hasLearningRate() ? this.config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0d;
                    for (Regularization regularization2 : regularization) {
                        if (regularization2.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
                            if (this.listeners != null) {
                                double score = regularization2.score(arr, at.iteration(), at.epoch());
                                if (!this.currIterRegLoss.containsKey(regularization2.getClass())) {
                                    this.currIterRegLoss.put(regularization2.getClass(), new AtomicDouble());
                                }
                                this.currIterRegLoss.get(regularization2.getClass()).addAndGet(score);
                            }
                            regularization2.apply(arr, iNDArray2, learningRate, at.iteration(), at.epoch());
                        }
                    }
                }
                gradientUpdater.applyUpdater(iNDArray2, at.iteration(), at.epoch());
                if (regularization != null && regularization.size() > 0) {
                    double learningRate2 = this.config.getUpdater().hasLearningRate() ? this.config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0d;
                    for (Regularization regularization3 : regularization) {
                        if (regularization3.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
                            if (this.listeners != null) {
                                double score2 = regularization3.score(arr, at.iteration(), at.epoch());
                                if (!this.currIterRegLoss.containsKey(regularization3.getClass())) {
                                    this.currIterRegLoss.put(regularization3.getClass(), new AtomicDouble());
                                }
                                this.currIterRegLoss.get(regularization3.getClass()).addAndGet(score2);
                            }
                            regularization3.apply(arr, iNDArray2, learningRate2, at.iteration(), at.epoch());
                        }
                    }
                }
                if (list != null) {
                    for (Listener listener : list) {
                        if (listener.isActive(at.operation())) {
                            listener.preUpdate(this.sameDiff, at, variable2, iNDArray2);
                        }
                    }
                }
                if (this.config.isMinimize()) {
                    arr.subi(iNDArray2);
                } else {
                    arr.addi(iNDArray2);
                }
                log.trace("Applied updater to gradient and updated variable: {}", str2);
            }
            i++;
        }
        return outputs2;
    }

    @Override // org.nd4j.autodiff.samediff.internal.InferenceSession, org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ INDArray[] getOutputs(SameDiffOp sameDiffOp, AbstractSession.FrameIter frameIter, Set set, Set set2, Set set3, List list, At at, MultiDataSet multiDataSet, Set set4) {
        return getOutputs(sameDiffOp, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, (List<Listener>) list, at, multiDataSet, (Set<String>) set4);
    }
}
