package org.deeplearning4j.nn.layers.feedforward.autoencoder;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.class */
public class AutoEncoder extends BasePretrainNetwork<org.deeplearning4j.nn.conf.layers.AutoEncoder> {
    public AutoEncoder(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        super(neuralNetConfiguration, dataType);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        setInput(iNDArray, LayerWorkspaceMgr.noWorkspaces());
        INDArray encode = encode(iNDArray, true, LayerWorkspaceMgr.noWorkspaces());
        return new Pair<>(encode, encode);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray decode = decode(iNDArray, LayerWorkspaceMgr.noWorkspaces());
        return new Pair<>(decode, decode);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray encode(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray paramWithNoise = getParamWithNoise("W", z, layerWorkspaceMgr);
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getActivationFn().getActivation(iNDArray.castTo(paramWithNoise.dataType()).mmuli(paramWithNoise, layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, paramWithNoise.dataType(), new long[]{iNDArray.size(0), paramWithNoise.size(1)})).addiRowVector(getParamWithNoise("b", z, layerWorkspaceMgr)), z));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray decode(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        return layerWorkspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getActivationFn().getActivation(iNDArray.mmul(paramWithNoise.transpose()).addiRowVector(getParamWithNoise(PretrainParamInitializer.VISIBLE_BIAS_KEY, true, layerWorkspaceMgr)), true));
    }

    @Override // org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return encode(iNDArray, z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return true;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return encode(this.input, z, layerWorkspaceMgr);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.layers.AbstractLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray paramWithNoise = getParamWithNoise("W", true, layerWorkspaceMgr);
        INDArray castTo = this.input.castTo(this.dataType);
        double corruptionLevel = ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getCorruptionLevel();
        INDArray corruptedInput = corruptionLevel > EvaluationBinary.DEFAULT_EDGE_VALUE ? getCorruptedInput(castTo, corruptionLevel) : castTo;
        setInput(corruptedInput, layerWorkspaceMgr);
        INDArray encode = encode(corruptedInput, true, layerWorkspaceMgr);
        INDArray decode = decode(encode, layerWorkspaceMgr);
        INDArray sub = castTo.sub(decode);
        INDArray muli = ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getSparsity() == EvaluationBinary.DEFAULT_EDGE_VALUE ? sub.mmul(paramWithNoise).muli(encode).muli(encode.rsub(1)) : sub.mmul(paramWithNoise).muli(encode).muli(encode.add(Double.valueOf(-((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getSparsity())));
        this.gradient = createGradient(corruptedInput.transpose().mmul(muli).addi(sub.transpose().mmul(encode)), sub.sum(new int[]{0}), muli.sum(new int[]{0}));
        setScoreWithZ(decode);
    }
}
