package org.deeplearning4j.nn.conf.layers.variational;

import java.util.Arrays;
import java.util.Collection;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.class */
public class VariationalAutoencoder extends BasePretrainNetwork {
    private int[] encoderLayerSizes;
    private int[] decoderLayerSizes;
    private ReconstructionDistribution outputDistribution;
    private IActivation pzxActivationFn;
    private int numSamples;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder$Builder.class */
    public static class Builder extends BasePretrainNetwork.Builder<Builder> {
        private int[] encoderLayerSizes = {100};
        private int[] decoderLayerSizes = {100};
        private ReconstructionDistribution outputDistribution = new GaussianReconstructionDistribution(Activation.TANH);
        private IActivation pzxActivationFn = new ActivationIdentity();
        private int numSamples = 1;

        public Builder encoderLayerSizes(int... iArr) {
            if (iArr == null || iArr.length < 1) {
                throw new IllegalArgumentException("Encoder layer sizes array must have length > 0");
            }
            this.encoderLayerSizes = iArr;
            return this;
        }

        public void setEncoderLayerSizes(int[] iArr) {
            encoderLayerSizes(iArr);
        }

        public Builder decoderLayerSizes(int... iArr) {
            if (iArr == null || iArr.length < 1) {
                throw new IllegalArgumentException("Decoder layer sizes array must have length > 0");
            }
            this.decoderLayerSizes = iArr;
            return this;
        }

        public void setDecoderLayerSizes(int[] iArr) {
            decoderLayerSizes(iArr);
        }

        public Builder reconstructionDistribution(ReconstructionDistribution reconstructionDistribution) {
            this.outputDistribution = reconstructionDistribution;
            return this;
        }

        public Builder lossFunction(IActivation iActivation, LossFunctions.LossFunction lossFunction) {
            return lossFunction(iActivation, lossFunction.getILossFunction());
        }

        public Builder lossFunction(Activation activation, LossFunctions.LossFunction lossFunction) {
            return lossFunction(activation.getActivationFunction(), lossFunction.getILossFunction());
        }

        public Builder lossFunction(IActivation iActivation, ILossFunction iLossFunction) {
            return reconstructionDistribution(new LossFunctionWrapper(iActivation, iLossFunction));
        }

        public Builder pzxActivationFn(IActivation iActivation) {
            this.pzxActivationFn = iActivation;
            return this;
        }

        public Builder pzxActivationFunction(Activation activation) {
            return pzxActivationFn(activation.getActivationFunction());
        }

        @Override // org.deeplearning4j.nn.conf.layers.FeedForwardLayer.Builder
        public Builder nOut(int i) {
            super.nOut(i);
            return this;
        }

        public Builder numSamples(int i) {
            this.numSamples = i;
            return this;
        }

        @Override // org.deeplearning4j.nn.conf.layers.Layer.Builder
        public VariationalAutoencoder build() {
            return new VariationalAutoencoder(this);
        }

        public int[] getEncoderLayerSizes() {
            return this.encoderLayerSizes;
        }

        public int[] getDecoderLayerSizes() {
            return this.decoderLayerSizes;
        }

        public ReconstructionDistribution getOutputDistribution() {
            return this.outputDistribution;
        }

        public IActivation getPzxActivationFn() {
            return this.pzxActivationFn;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        public void setOutputDistribution(ReconstructionDistribution reconstructionDistribution) {
            this.outputDistribution = reconstructionDistribution;
        }

        public void setPzxActivationFn(IActivation iActivation) {
            this.pzxActivationFn = iActivation;
        }

        public void setNumSamples(int i) {
            this.numSamples = i;
        }
    }

    private VariationalAutoencoder(Builder builder) {
        super(builder);
        this.encoderLayerSizes = builder.encoderLayerSizes;
        this.decoderLayerSizes = builder.decoderLayerSizes;
        this.outputDistribution = builder.outputDistribution;
        this.pzxActivationFn = builder.pzxActivationFn;
        this.numSamples = builder.numSamples;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z) {
        LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), i, getNIn(), getNOut());
        org.deeplearning4j.nn.layers.variational.VariationalAutoencoder variationalAutoencoder = new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(neuralNetConfiguration);
        variationalAutoencoder.setListeners(collection);
        variationalAutoencoder.setIndex(i);
        variationalAutoencoder.setParamsViewArray(iNDArray);
        variationalAutoencoder.setParamTable(initializer().init(neuralNetConfiguration, iNDArray, z));
        variationalAutoencoder.setConf(neuralNetConfiguration);
        return variationalAutoencoder;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public ParamInitializer initializer() {
        return VariationalAutoencoderParamInitializer.getInstance();
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.Layer, org.deeplearning4j.nn.api.TrainingConfig
    public boolean isPretrainParam(String str) {
        return str.startsWith(VariationalAutoencoderParamInitializer.DECODER_PREFIX) || str.startsWith(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_PREFIX) || str.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX);
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        InputType outputType = getOutputType(-1, inputType);
        outputType.arrayElementsPerExample();
        long numParams = initializer().numParams(this);
        int stateSize = (int) getIUpdater().stateSize(numParams);
        int i = 0;
        for (int i2 = 1; i2 < this.encoderLayerSizes.length; i2++) {
            i += this.encoderLayerSizes[i2];
        }
        long sum = 2 * (i + (4 * this.nOut) + (this.numSamples * ((2 * this.nOut) + ArrayUtil.sum(getDecoderLayerSizes()))) + this.nOut);
        if (getIDropout() != null) {
            sum += inputType.arrayElementsPerExample();
        }
        return new LayerMemoryReport.Builder(this.layerName, VariationalAutoencoder.class, inputType, outputType).standardMemory(numParams, stateSize).workingMemory(0L, i, 0L, sum).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    public int[] getEncoderLayerSizes() {
        return this.encoderLayerSizes;
    }

    public int[] getDecoderLayerSizes() {
        return this.decoderLayerSizes;
    }

    public ReconstructionDistribution getOutputDistribution() {
        return this.outputDistribution;
    }

    public IActivation getPzxActivationFn() {
        return this.pzxActivationFn;
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    public void setEncoderLayerSizes(int[] iArr) {
        this.encoderLayerSizes = iArr;
    }

    public void setDecoderLayerSizes(int[] iArr) {
        this.decoderLayerSizes = iArr;
    }

    public void setOutputDistribution(ReconstructionDistribution reconstructionDistribution) {
        this.outputDistribution = reconstructionDistribution;
    }

    public void setPzxActivationFn(IActivation iActivation) {
        this.pzxActivationFn = iActivation;
    }

    public void setNumSamples(int i) {
        this.numSamples = i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public String toString() {
        return "VariationalAutoencoder(encoderLayerSizes=" + Arrays.toString(getEncoderLayerSizes()) + ", decoderLayerSizes=" + Arrays.toString(getDecoderLayerSizes()) + ", outputDistribution=" + getOutputDistribution() + ", pzxActivationFn=" + getPzxActivationFn() + ", numSamples=" + getNumSamples() + ")";
    }

    public VariationalAutoencoder() {
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof VariationalAutoencoder)) {
            return false;
        }
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) obj;
        if (!variationalAutoencoder.canEqual(this) || !super.equals(obj) || !Arrays.equals(getEncoderLayerSizes(), variationalAutoencoder.getEncoderLayerSizes()) || !Arrays.equals(getDecoderLayerSizes(), variationalAutoencoder.getDecoderLayerSizes())) {
            return false;
        }
        ReconstructionDistribution outputDistribution = getOutputDistribution();
        ReconstructionDistribution outputDistribution2 = variationalAutoencoder.getOutputDistribution();
        if (outputDistribution == null) {
            if (outputDistribution2 != null) {
                return false;
            }
        } else if (!outputDistribution.equals(outputDistribution2)) {
            return false;
        }
        IActivation pzxActivationFn = getPzxActivationFn();
        IActivation pzxActivationFn2 = variationalAutoencoder.getPzxActivationFn();
        if (pzxActivationFn == null) {
            if (pzxActivationFn2 != null) {
                return false;
            }
        } else if (!pzxActivationFn.equals(pzxActivationFn2)) {
            return false;
        }
        return getNumSamples() == variationalAutoencoder.getNumSamples();
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    protected boolean canEqual(Object obj) {
        return obj instanceof VariationalAutoencoder;
    }

    @Override // org.deeplearning4j.nn.conf.layers.BasePretrainNetwork, org.deeplearning4j.nn.conf.layers.FeedForwardLayer, org.deeplearning4j.nn.conf.layers.BaseLayer, org.deeplearning4j.nn.conf.layers.Layer
    public int hashCode() {
        int hashCode = (((super.hashCode() * 59) + Arrays.hashCode(getEncoderLayerSizes())) * 59) + Arrays.hashCode(getDecoderLayerSizes());
        ReconstructionDistribution outputDistribution = getOutputDistribution();
        int hashCode2 = (hashCode * 59) + (outputDistribution == null ? 43 : outputDistribution.hashCode());
        IActivation pzxActivationFn = getPzxActivationFn();
        return (((hashCode2 * 59) + (pzxActivationFn == null ? 43 : pzxActivationFn.hashCode())) * 59) + getNumSamples();
    }
}
