package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/* loaded from: input_file:org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.class */
public class VAEReconProbScoreCalculator extends BaseScoreCalculator<Model> {
    protected final int reconstructionProbNumSamples;
    protected final boolean logProb;
    protected final boolean average;

    public VAEReconProbScoreCalculator(DataSetIterator dataSetIterator, int i, boolean z) {
        this(dataSetIterator, i, z, true);
    }

    public VAEReconProbScoreCalculator(DataSetIterator dataSetIterator, int i, boolean z, boolean z2) {
        super(dataSetIterator);
        this.reconstructionProbNumSamples = i;
        this.logProb = z;
        this.average = z2;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected void reset() {
        this.scoreSum = EvaluationBinary.DEFAULT_EDGE_VALUE;
        this.minibatchCount = 0;
        this.exampleCount = 0;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected INDArray output(Model model, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        return null;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected INDArray[] output(Model model, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3) {
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    public double scoreMinibatch(Model model, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        Layer layer = model instanceof MultiLayerNetwork ? ((MultiLayerNetwork) model).getLayer(0) : ((ComputationGraph) model).getLayer(0);
        if (!(layer instanceof VariationalAutoencoder)) {
            throw new UnsupportedOperationException("Can only score networks with VariationalAutoencoder layers as first layer - got " + layer.getClass().getSimpleName());
        }
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) layer;
        return this.logProb ? -variationalAutoencoder.reconstructionLogProbability(iNDArray, this.reconstructionProbNumSamples).sumNumber().doubleValue() : variationalAutoencoder.reconstructionProbability(iNDArray, this.reconstructionProbNumSamples).sumNumber().doubleValue();
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected double scoreMinibatch(Model model, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4, INDArray[] iNDArrayArr5) {
        return EvaluationBinary.DEFAULT_EDGE_VALUE;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator
    protected double finalScore(double d, int i, int i2) {
        return this.average ? d / i2 : d;
    }

    @Override // org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator
    public boolean minimizeScore() {
        return false;
    }
}
