package org.deeplearning4j.nn.layers.variational;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.class */
public class VariationalAutoencoder implements Layer {
    protected INDArray input;
    protected INDArray paramsFlattened;
    protected INDArray gradientsFlattened;
    protected Map<String, INDArray> params;
    protected transient Map<String, INDArray> gradientViews;
    protected NeuralNetConfiguration conf;
    protected ConvexOptimizer optimizer;
    protected Gradient gradient;
    protected INDArray maskArray;
    protected Solver solver;
    protected int[] encoderLayerSizes;
    protected int[] decoderLayerSizes;
    protected ReconstructionDistribution reconstructionDistribution;
    protected IActivation pzxActivationFn;
    protected int numSamples;
    protected DataType dataType;
    protected int iterationCount;
    protected int epochCount;
    protected double score = EvaluationBinary.DEFAULT_EDGE_VALUE;
    protected Collection<TrainingListener> trainingListeners = new ArrayList();
    protected int index = 0;
    protected CacheMode cacheMode = CacheMode.NONE;
    protected boolean zeroedPretrainParamGradients = false;
    protected Map<String, INDArray> weightNoiseParams = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/variational/VariationalAutoencoder$VAEFwdHelper.class */
    public static class VAEFwdHelper {
        private INDArray[] encoderPreOuts;
        private INDArray pzxMeanPreOut;
        private INDArray[] encoderActivations;

        public VAEFwdHelper(INDArray[] iNDArrayArr, INDArray iNDArray, INDArray[] iNDArrayArr2) {
            this.encoderPreOuts = iNDArrayArr;
            this.pzxMeanPreOut = iNDArray;
            this.encoderActivations = iNDArrayArr2;
        }

        public INDArray[] getEncoderPreOuts() {
            return this.encoderPreOuts;
        }

        public INDArray getPzxMeanPreOut() {
            return this.pzxMeanPreOut;
        }

        public INDArray[] getEncoderActivations() {
            return this.encoderActivations;
        }

        public void setEncoderPreOuts(INDArray[] iNDArrayArr) {
            this.encoderPreOuts = iNDArrayArr;
        }

        public void setPzxMeanPreOut(INDArray iNDArray) {
            this.pzxMeanPreOut = iNDArray;
        }

        public void setEncoderActivations(INDArray[] iNDArrayArr) {
            this.encoderActivations = iNDArrayArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VAEFwdHelper)) {
                return false;
            }
            VAEFwdHelper vAEFwdHelper = (VAEFwdHelper) obj;
            if (!vAEFwdHelper.canEqual(this) || !Arrays.deepEquals(getEncoderPreOuts(), vAEFwdHelper.getEncoderPreOuts())) {
                return false;
            }
            INDArray pzxMeanPreOut = getPzxMeanPreOut();
            INDArray pzxMeanPreOut2 = vAEFwdHelper.getPzxMeanPreOut();
            if (pzxMeanPreOut == null) {
                if (pzxMeanPreOut2 != null) {
                    return false;
                }
            } else if (!pzxMeanPreOut.equals(pzxMeanPreOut2)) {
                return false;
            }
            return Arrays.deepEquals(getEncoderActivations(), vAEFwdHelper.getEncoderActivations());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VAEFwdHelper;
        }

        public int hashCode() {
            int deepHashCode = (1 * 59) + Arrays.deepHashCode(getEncoderPreOuts());
            INDArray pzxMeanPreOut = getPzxMeanPreOut();
            return (((deepHashCode * 59) + (pzxMeanPreOut == null ? 43 : pzxMeanPreOut.hashCode())) * 59) + Arrays.deepHashCode(getEncoderActivations());
        }

        public String toString() {
            return "VariationalAutoencoder.VAEFwdHelper(encoderPreOuts=" + Arrays.deepToString(getEncoderPreOuts()) + ", pzxMeanPreOut=" + getPzxMeanPreOut() + ", encoderActivations=" + Arrays.deepToString(getEncoderActivations()) + ")";
        }
    }

    public VariationalAutoencoder(NeuralNetConfiguration neuralNetConfiguration, DataType dataType) {
        this.conf = neuralNetConfiguration;
        this.dataType = dataType;
        this.encoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getEncoderLayerSizes();
        this.decoderLayerSizes = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getDecoderLayerSizes();
        this.reconstructionDistribution = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getOutputDistribution();
        this.pzxActivationFn = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getPzxActivationFn();
        this.numSamples = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) neuralNetConfiguration.getLayer()).getNumSamples();
    }

    protected org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder layerConf() {
        return (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf().getLayer();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setCacheMode(CacheMode cacheMode) {
        if (cacheMode == null) {
            cacheMode = CacheMode.NONE;
        }
        this.cacheMode = cacheMode;
    }

    protected String layerId() {
        String layerName = conf().getLayer().getLayerName();
        return "(layer name: " + (layerName == null ? "\"\"" : layerName) + ", layer index: " + this.index + ")";
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public void init() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
        throw new UnsupportedOperationException("Not supported " + layerId());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    protected INDArray getParamWithNoise(String str, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (layerConf().getWeightNoise() == null) {
            return getParam(str);
        }
        if (z && this.weightNoiseParams.size() > 0 && this.weightNoiseParams.containsKey(str)) {
            return this.weightNoiseParams.get(str);
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                INDArray parameter = layerConf().getWeightNoise().getParameter(this, str, getIterationCount(), getEpochCount(), z, layerWorkspaceMgr);
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                if (z) {
                    this.weightNoiseParams.put(str, parameter);
                }
                return parameter;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray iNDArray;
        VAEFwdHelper doForward = doForward(true, true, layerWorkspaceMgr);
        IActivation activationFn = layerConf().getActivationFn();
        INDArray addiRowVector = doForward.encoderActivations[doForward.encoderActivations.length - 1].mmul(getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, true, layerWorkspaceMgr)).addiRowVector(getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, true, layerWorkspaceMgr));
        INDArray dup = doForward.pzxMeanPreOut.dup();
        INDArray dup2 = addiRowVector.dup();
        this.pzxActivationFn.getActivation(dup, true);
        this.pzxActivationFn.getActivation(dup2, true);
        INDArray exp = Transforms.exp(dup2, true);
        INDArray sqrt = Transforms.sqrt(exp, true);
        long size = this.input.size(0);
        long size2 = doForward.pzxMeanPreOut.size(1);
        HashMap hashMap = new HashMap();
        double d = 1.0d / this.numSamples;
        Level1 level1 = Nd4j.getBlasWrapper().level1();
        INDArray[] iNDArrayArr = this.numSamples > 1 ? new INDArray[this.encoderLayerSizes.length] : null;
        int i = 0;
        while (i < this.numSamples) {
            double d2 = i == 0 ? EvaluationBinary.DEFAULT_EDGE_VALUE : 1.0d;
            INDArray randn = Nd4j.randn(this.dataType, new long[]{size, size2});
            INDArray addi = sqrt.mul(randn).addi(dup);
            int length = this.decoderLayerSizes.length;
            INDArray iNDArray2 = addi;
            INDArray[] iNDArrayArr2 = new INDArray[length];
            INDArray[] iNDArrayArr3 = new INDArray[length];
            for (int i2 = 0; i2 < length; i2++) {
                iNDArray2 = iNDArray2.mmul(getParamWithNoise(VariationalAutoencoderParamInitializer.DECODER_PREFIX + i2 + "W", true, layerWorkspaceMgr)).addiRowVector(getParamWithNoise(VariationalAutoencoderParamInitializer.DECODER_PREFIX + i2 + "b", true, layerWorkspaceMgr));
                iNDArrayArr2[i2] = iNDArray2.dup();
                activationFn.getActivation(iNDArray2, true);
                iNDArrayArr3[i2] = iNDArray2;
            }
            INDArray paramWithNoise = getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_W, true, layerWorkspaceMgr);
            INDArray paramWithNoise2 = getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_B, true, layerWorkspaceMgr);
            if (i == 0) {
                INDArray negi = dup.mul(dup).addi(exp).negi();
                negi.addi(dup2).addi(Double.valueOf(1.0d));
                this.score = (((-0.5d) / size) * negi.sumNumber().doubleValue()) + calcRegularizationScore(false);
            }
            INDArray addiRowVector2 = iNDArray2.mmul(paramWithNoise).addiRowVector(paramWithNoise2);
            this.score += this.reconstructionDistribution.negLogProbability(this.input, addiRowVector2, true) / this.numSamples;
            if (this.trainingListeners != null && !this.trainingListeners.isEmpty() && i == 0) {
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (int i3 = 0; i3 < doForward.encoderActivations.length; i3++) {
                    linkedHashMap.put(VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i3, doForward.encoderActivations[i3]);
                }
                linkedHashMap.put(VariationalAutoencoderParamInitializer.PZX_PREFIX, addi);
                for (int i4 = 0; i4 < iNDArrayArr3.length; i4++) {
                    linkedHashMap.put(VariationalAutoencoderParamInitializer.DECODER_PREFIX + i4, iNDArrayArr3[i4]);
                }
                linkedHashMap.put(VariationalAutoencoderParamInitializer.PXZ_PREFIX, this.reconstructionDistribution.generateAtMean(addiRowVector2));
                if (!this.trainingListeners.isEmpty()) {
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    Throwable th = null;
                    try {
                        try {
                            Iterator<TrainingListener> it = this.trainingListeners.iterator();
                            while (it.hasNext()) {
                                it.next().onForwardPass(this, linkedHashMap);
                            }
                            if (scopeOutOfWorkspaces != null) {
                                if (0 != 0) {
                                    try {
                                        scopeOutOfWorkspaces.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    scopeOutOfWorkspaces.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                }
            }
            INDArray gradient = this.reconstructionDistribution.gradient(this.input, addiRowVector2);
            INDArray iNDArray3 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_W);
            INDArray iNDArray4 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PXZ_B);
            Nd4j.gemm(iNDArrayArr3[iNDArrayArr3.length - 1], gradient, iNDArray3, true, false, d, d2);
            if (i == 0) {
                gradient.sum(iNDArray4, new int[]{0});
                if (this.numSamples > 1) {
                    iNDArray4.muli(Double.valueOf(d));
                }
            } else {
                level1.axpy(iNDArray4.length(), d, gradient.sum(new int[]{0}), iNDArray4);
            }
            hashMap.put(VariationalAutoencoderParamInitializer.PXZ_W, iNDArray3);
            hashMap.put(VariationalAutoencoderParamInitializer.PXZ_B, iNDArray4);
            INDArray transpose = paramWithNoise.mmul(gradient.transpose()).transpose();
            int i5 = length - 1;
            while (i5 >= 0) {
                String str = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i5 + "W";
                String str2 = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i5 + "b";
                INDArray iNDArray5 = (INDArray) activationFn.backprop(iNDArrayArr2[i5], transpose).getFirst();
                INDArray paramWithNoise3 = getParamWithNoise(str, true, layerWorkspaceMgr);
                INDArray iNDArray6 = this.gradientViews.get(str);
                INDArray iNDArray7 = this.gradientViews.get(str2);
                Nd4j.gemm(i5 == 0 ? addi : iNDArrayArr3[i5 - 1], iNDArray5, iNDArray6, true, false, d, d2);
                if (i == 0) {
                    iNDArray5.sum(iNDArray7, new int[]{0});
                    if (this.numSamples > 1) {
                        iNDArray7.muli(Double.valueOf(d));
                    }
                } else {
                    level1.axpy(iNDArray7.length(), d, iNDArray5.sum(new int[]{0}), iNDArray7);
                }
                hashMap.put(str, iNDArray6);
                hashMap.put(str2, iNDArray7);
                transpose = paramWithNoise3.mmul(iNDArray5.transpose()).transpose();
                i5--;
            }
            INDArray paramWithNoise4 = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, true, layerWorkspaceMgr);
            INDArray paramWithNoise5 = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, true, layerWorkspaceMgr);
            INDArray iNDArray8 = transpose;
            INDArray add = iNDArray8.add(dup);
            INDArray muli = iNDArray8.mul(randn).muli(sqrt).addi(exp).subi(1).muli(Double.valueOf(0.5d));
            INDArray iNDArray9 = (INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), add).getFirst();
            INDArray iNDArray10 = (INDArray) this.pzxActivationFn.backprop(addiRowVector.dup(), muli).getFirst();
            INDArray iNDArray11 = doForward.encoderActivations[doForward.encoderActivations.length - 1];
            INDArray iNDArray12 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
            INDArray iNDArray13 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W);
            Nd4j.gemm(iNDArray11, iNDArray9, iNDArray12, true, false, d, d2);
            Nd4j.gemm(iNDArray11, iNDArray10, iNDArray13, true, false, d, d2);
            INDArray iNDArray14 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B);
            INDArray iNDArray15 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B);
            if (i == 0) {
                iNDArray14.assign(((INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), iNDArray8.add(dup)).getFirst()).sum(new int[]{0}));
                iNDArray10.sum(iNDArray15, new int[]{0});
                if (this.numSamples > 1) {
                    iNDArray14.muli(Double.valueOf(d));
                    iNDArray15.muli(Double.valueOf(d));
                }
            } else {
                level1.axpy(iNDArray14.length(), d, ((INDArray) this.pzxActivationFn.backprop(doForward.getPzxMeanPreOut().dup(), iNDArray8.add(dup)).getFirst()).sum(new int[]{0}), iNDArray14);
                level1.axpy(iNDArray15.length(), d, iNDArray10.sum(new int[]{0}), iNDArray15);
            }
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, iNDArray12);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, iNDArray14);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, iNDArray13);
            hashMap.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, iNDArray15);
            INDArray gemm = Nd4j.gemm(iNDArray9, paramWithNoise4, false, true);
            Nd4j.gemm(iNDArray10, paramWithNoise5, gemm, false, true, 1.0d, 1.0d);
            int length2 = this.encoderLayerSizes.length - 1;
            while (length2 >= 0) {
                String str3 = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + length2 + "W";
                String str4 = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + length2 + "b";
                INDArray paramWithNoise6 = getParamWithNoise(str3, true, layerWorkspaceMgr);
                INDArray iNDArray16 = this.gradientViews.get(str3);
                INDArray iNDArray17 = this.gradientViews.get(str4);
                INDArray iNDArray18 = doForward.encoderPreOuts[length2];
                if (this.numSamples > 1) {
                    if (i == 0) {
                        iNDArrayArr[length2] = (INDArray) activationFn.backprop(doForward.encoderPreOuts[length2], Nd4j.ones(doForward.encoderPreOuts[length2].shape())).getFirst();
                    }
                    iNDArray = gemm.muli(iNDArrayArr[length2]);
                } else {
                    iNDArray = (INDArray) activationFn.backprop(iNDArray18, gemm).getFirst();
                }
                Nd4j.gemm(length2 == 0 ? this.input.castTo(iNDArray16.dataType()) : doForward.encoderActivations[length2 - 1], iNDArray, iNDArray16, true, false, d, d2);
                if (i == 0) {
                    iNDArray.sum(iNDArray17, new int[]{0});
                    if (this.numSamples > 1) {
                        iNDArray17.muli(Double.valueOf(d));
                    }
                } else {
                    level1.axpy(iNDArray17.length(), d, iNDArray.sum(new int[]{0}), iNDArray17);
                }
                hashMap.put(str3, iNDArray16);
                hashMap.put(str4, iNDArray17);
                gemm = paramWithNoise6.mmul(iNDArray.transpose()).transpose();
                length2--;
            }
            i++;
        }
        DefaultGradient defaultGradient = new DefaultGradient(this.gradientsFlattened);
        Map<String, INDArray> gradientForVariable = defaultGradient.gradientForVariable();
        for (int i6 = 0; i6 < this.encoderLayerSizes.length; i6++) {
            String str5 = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i6 + "W";
            gradientForVariable.put(str5, hashMap.get(str5));
            String str6 = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i6 + "b";
            gradientForVariable.put(str6, hashMap.get(str6));
        }
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, hashMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, hashMap.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, hashMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, hashMap.get(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B));
        for (int i7 = 0; i7 < this.decoderLayerSizes.length; i7++) {
            String str7 = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i7 + "W";
            gradientForVariable.put(str7, hashMap.get(str7));
            String str8 = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i7 + "b";
            gradientForVariable.put(str8, hashMap.get(str8));
        }
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PXZ_W, hashMap.get(VariationalAutoencoderParamInitializer.PXZ_W));
        gradientForVariable.put(VariationalAutoencoderParamInitializer.PXZ_B, hashMap.get(VariationalAutoencoderParamInitializer.PXZ_B));
        this.weightNoiseParams.clear();
        this.gradient = defaultGradient;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray params() {
        return this.paramsFlattened;
    }

    @Override // org.deeplearning4j.nn.api.Trainable
    public TrainingConfig getConfig() {
        return this.conf.getLayer();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams() {
        return numParams(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public long numParams(boolean z) {
        int i = 0;
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!z || !isPretrainParam(entry.getKey())) {
                i = (int) (i + entry.getValue().length());
            }
        }
        return i;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (iNDArray.length() != this.paramsFlattened.length()) {
            throw new IllegalArgumentException("Cannot set parameters: expected parameters vector of length " + this.paramsFlattened.length() + " but got parameters array of length " + iNDArray.length() + " " + layerId());
        }
        this.paramsFlattened.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        if (this.params != null && iNDArray.length() != numParams()) {
            throw new IllegalArgumentException("Invalid input: expect params of length " + numParams() + ", got params of length " + iNDArray.length() + " " + layerId());
        }
        this.paramsFlattened = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getGradientsViewArray() {
        return this.gradientsFlattened;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        if (this.params != null && iNDArray.length() != numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams() + ", got gradient array of length of length " + iNDArray.length() + " " + layerId());
        }
        this.gradientsFlattened = iNDArray;
        this.gradientViews = this.conf.getLayer().initializer().getGradientsFromFlattened(this.conf, iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        fit();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        if (this.input.size(0) > 2147483647L) {
            throw new ND4JArraySizeException();
        }
        return (int) this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model, org.deeplearning4j.nn.api.NeuralNetwork
    public ConvexOptimizer getOptimizer() {
        return this.optimizer;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return this.params.get(str);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        return new LinkedHashMap(this.params);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable(boolean z) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
            if (!z || !isPretrainParam(entry.getKey())) {
                linkedHashMap.put(entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Trainable
    public boolean updaterDivideByMinibatch(String str) {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        this.params = map;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        if (!paramTable().containsKey(str)) {
            throw new IllegalArgumentException("Unknown parameter: " + str + " - " + layerId());
        }
        paramTable().get(str).assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        this.input = null;
        this.maskArray = null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyConstraints(int i, int i2) {
        if (layerConf().getConstraints() != null) {
            Iterator<LayerConstraint> it = layerConf().getConstraints().iterator();
            while (it.hasNext()) {
                it.next().applyConstraint(this, i, i2);
            }
        }
    }

    public boolean isPretrainParam(String str) {
        return (str.startsWith(VariationalAutoencoderParamInitializer.ENCODER_PREFIX) || str.startsWith(VariationalAutoencoderParamInitializer.PZX_MEAN_PREFIX)) ? false : true;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcRegularizationScore(boolean z) {
        double d = 0.0d;
        for (Map.Entry<String, INDArray> entry : paramTable().entrySet()) {
            if (!z || !isPretrainParam(entry.getKey())) {
                List<Regularization> regularizationByParam = layerConf().getRegularizationByParam(entry.getKey());
                if (regularizationByParam != null && !regularizationByParam.isEmpty()) {
                    Iterator<Regularization> it = regularizationByParam.iterator();
                    while (it.hasNext()) {
                        d += it.next().score(entry.getValue(), getIterationCount(), getEpochCount());
                    }
                }
            }
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray transpose;
        assertInputSet(true);
        if (!this.zeroedPretrainParamGradients) {
            for (Map.Entry<String, INDArray> entry : this.gradientViews.entrySet()) {
                if (isPretrainParam(entry.getKey())) {
                    entry.getValue().assign(0);
                }
            }
            this.zeroedPretrainParamGradients = true;
        }
        INDArray castTo = this.input.castTo(this.dataType);
        DefaultGradient defaultGradient = new DefaultGradient();
        VAEFwdHelper doForward = doForward(true, true, layerWorkspaceMgr);
        INDArray iNDArray2 = (INDArray) this.pzxActivationFn.backprop(doForward.pzxMeanPreOut, iNDArray).getFirst();
        INDArray paramWithNoise = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, true, layerWorkspaceMgr);
        INDArray iNDArray3 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_W);
        Nd4j.gemm(doForward.encoderActivations[doForward.encoderActivations.length - 1], iNDArray2, iNDArray3, true, false, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
        INDArray iNDArray4 = this.gradientViews.get(VariationalAutoencoderParamInitializer.PZX_MEAN_B);
        iNDArray2.sum(iNDArray4, new int[]{0});
        defaultGradient.gradientForVariable().put(VariationalAutoencoderParamInitializer.PZX_MEAN_W, iNDArray3);
        defaultGradient.gradientForVariable().put(VariationalAutoencoderParamInitializer.PZX_MEAN_B, iNDArray4);
        INDArray transpose2 = paramWithNoise.mmul(iNDArray2.transpose()).transpose();
        int length = this.encoderLayerSizes.length;
        IActivation activationFn = layerConf().getActivationFn();
        int i = length - 1;
        while (i >= 0) {
            String str = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i + "W";
            String str2 = VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i + "b";
            INDArray paramWithNoise2 = getParamWithNoise(str, true, layerWorkspaceMgr);
            INDArray iNDArray5 = this.gradientViews.get(str);
            INDArray iNDArray6 = this.gradientViews.get(str2);
            INDArray iNDArray7 = (INDArray) activationFn.backprop(doForward.encoderPreOuts[i], transpose2).getFirst();
            Nd4j.gemm(i == 0 ? castTo : doForward.encoderActivations[i - 1], iNDArray7, iNDArray5, true, false, 1.0d, EvaluationBinary.DEFAULT_EDGE_VALUE);
            iNDArray7.sum(iNDArray6, new int[]{0});
            defaultGradient.gradientForVariable().put(str, iNDArray5);
            defaultGradient.gradientForVariable().put(str2, iNDArray6);
            if (i == 0) {
                INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray7.dataType(), new long[]{paramWithNoise2.size(0), iNDArray7.size(0)}, 'f');
                paramWithNoise2.mmuli(iNDArray7.transpose(), createUninitialized);
                transpose = createUninitialized.transpose();
            } else {
                transpose = paramWithNoise2.mmul(iNDArray7.transpose()).transpose();
            }
            transpose2 = transpose;
            i--;
        }
        return new Pair<>(defaultGradient, transpose2);
    }

    public INDArray preOutput(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        return doForward(z, false, layerWorkspaceMgr).pzxMeanPreOut;
    }

    private VAEFwdHelper doForward(boolean z, boolean z2, LayerWorkspaceMgr layerWorkspaceMgr) {
        assertInputSet(false);
        int length = this.encoderLayerSizes.length;
        INDArray[] iNDArrayArr = new INDArray[this.encoderLayerSizes.length];
        INDArray[] iNDArrayArr2 = new INDArray[this.encoderLayerSizes.length];
        INDArray castTo = this.input.castTo(getParam("e0W").dataType());
        for (int i = 0; i < length; i++) {
            castTo = castTo.mmul(getParamWithNoise(VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i + "W", z, layerWorkspaceMgr)).addiRowVector(getParamWithNoise(VariationalAutoencoderParamInitializer.ENCODER_PREFIX + i + "b", z, layerWorkspaceMgr));
            if (z2) {
                iNDArrayArr[i] = castTo.dup();
            }
            layerConf().getActivationFn().getActivation(castTo, z);
            iNDArrayArr2[i] = castTo;
        }
        INDArray paramWithNoise = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, z, layerWorkspaceMgr);
        return new VAEFwdHelper(iNDArrayArr, castTo.mmuli(paramWithNoise, layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, paramWithNoise.dataType(), new long[]{castTo.size(0), paramWithNoise.size(1)}, 'f')).addiRowVector(getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_B, z, layerWorkspaceMgr)), iNDArrayArr2);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray preOutput = preOutput(z, layerWorkspaceMgr);
        this.pzxActivationFn.getActivation(preOutput, z);
        return preOutput;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        setInput(iNDArray, layerWorkspaceMgr);
        return activate(z, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Collection<TrainingListener> getListeners() {
        if (this.trainingListeners == null) {
            return null;
        }
        return new ArrayList(this.trainingListeners);
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(TrainingListener... trainingListenerArr) {
        setListeners(Arrays.asList(trainingListenerArr));
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public void setListeners(Collection<TrainingListener> collection) {
        if (this.trainingListeners == null) {
            this.trainingListeners = new ArrayList();
        } else {
            this.trainingListeners.clear();
        }
        if (this.trainingListeners == null) {
            this.trainingListeners = new ArrayList();
        } else {
            this.trainingListeners.clear();
        }
        if (collection == null || collection.isEmpty()) {
            return;
        }
        this.trainingListeners.addAll(collection);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void addListeners(TrainingListener... trainingListenerArr) {
        if (this.trainingListeners == null) {
            setListeners(trainingListenerArr);
            return;
        }
        for (TrainingListener trainingListener : trainingListenerArr) {
            this.trainingListeners.add(trainingListener);
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.index = i;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.index;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.input = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInputMiniBatchSize(int i) {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getInputMiniBatchSize() {
        if (this.input.size(0) > 2147483647L) {
            throw new ND4JArraySizeException();
        }
        return (int) this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        this.maskArray = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getMaskArray() {
        return this.maskArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public boolean isPretrainLayer() {
        return true;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void clearNoiseWeightParams() {
        this.weightNoiseParams.clear();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void allowInputModification(boolean z) {
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray iNDArray, MaskState maskState, int i) {
        throw new UnsupportedOperationException("Not yet implemented " + layerId());
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public LayerHelper getHelper() {
        return null;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        if (this.input == null) {
            throw new IllegalStateException("Cannot fit layer: layer input is null (not set) " + layerId());
        }
        if (this.solver == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    this.solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build();
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        this.optimizer = this.solver.getOptimizer();
        this.solver.optimize(LayerWorkspaceMgr.noWorkspaces());
    }

    public INDArray reconstructionProbability(INDArray iNDArray, int i) {
        return Transforms.exp(reconstructionLogProbability(iNDArray, i).castTo(DataType.DOUBLE), false);
    }

    public INDArray reconstructionLogProbability(INDArray iNDArray, int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid input: numSamples must be > 0. Got: " + i + " " + layerId());
        }
        if (this.reconstructionDistribution instanceof LossFunctionWrapper) {
            throw new UnsupportedOperationException("Cannot calculate reconstruction log probability when using a LossFunction (via LossFunctionWrapper) instead of a ReconstructionDistribution: ILossFunction instances are not in general probabilistic, hence it is not possible to calculate reconstruction probability " + layerId());
        }
        INDArray castTo = iNDArray.castTo(this.dataType);
        LayerWorkspaceMgr noWorkspaces = LayerWorkspaceMgr.noWorkspaces();
        setInput(castTo, noWorkspaces);
        VAEFwdHelper doForward = doForward(true, true, noWorkspaces);
        IActivation activationFn = layerConf().getActivationFn();
        INDArray paramWithNoise = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_W, false, noWorkspaces);
        INDArray paramWithNoise2 = getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_B, false, noWorkspaces);
        INDArray iNDArray2 = doForward.pzxMeanPreOut;
        INDArray addiRowVector = doForward.encoderActivations[doForward.encoderActivations.length - 1].mmul(paramWithNoise).addiRowVector(paramWithNoise2);
        this.pzxActivationFn.getActivation(iNDArray2, false);
        this.pzxActivationFn.getActivation(addiRowVector, false);
        INDArray exp = Transforms.exp(addiRowVector, false);
        Transforms.sqrt(exp, false);
        long size = this.input.size(0);
        long size2 = doForward.pzxMeanPreOut.size(1);
        INDArray paramWithNoise3 = getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_W, false, noWorkspaces);
        INDArray paramWithNoise4 = getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_B, false, noWorkspaces);
        INDArray[] iNDArrayArr = new INDArray[this.decoderLayerSizes.length];
        INDArray[] iNDArrayArr2 = new INDArray[this.decoderLayerSizes.length];
        for (int i2 = 0; i2 < this.decoderLayerSizes.length; i2++) {
            String str = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i2 + "W";
            String str2 = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i2 + "b";
            iNDArrayArr[i2] = getParamWithNoise(str, false, noWorkspaces);
            iNDArrayArr2[i2] = getParamWithNoise(str2, false, noWorkspaces);
        }
        INDArray iNDArray3 = null;
        for (int i3 = 0; i3 < i; i3++) {
            INDArray addi = Nd4j.randn(this.dataType, new long[]{size, size2}).muli(exp).addi(iNDArray2);
            int length = this.decoderLayerSizes.length;
            INDArray iNDArray4 = addi;
            for (int i4 = 0; i4 < length; i4++) {
                iNDArray4 = iNDArray4.mmul(iNDArrayArr[i4]).addiRowVector(iNDArrayArr2[i4]);
                activationFn.getActivation(iNDArray4, false);
            }
            INDArray addiRowVector2 = iNDArray4.mmul(paramWithNoise3).addiRowVector(paramWithNoise4);
            if (i3 == 0) {
                iNDArray3 = this.reconstructionDistribution.exampleNegLogProbability(castTo, addiRowVector2);
            } else {
                iNDArray3.addi(this.reconstructionDistribution.exampleNegLogProbability(castTo, addiRowVector2));
            }
        }
        setInput(null, noWorkspaces);
        return iNDArray3.divi(Integer.valueOf(-i));
    }

    public INDArray generateAtMeanGivenZ(INDArray iNDArray) {
        return this.reconstructionDistribution.generateAtMean(decodeGivenLatentSpaceValues(iNDArray, LayerWorkspaceMgr.noWorkspaces()));
    }

    public INDArray generateRandomGivenZ(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        return this.reconstructionDistribution.generateRandom(decodeGivenLatentSpaceValues(iNDArray, layerWorkspaceMgr));
    }

    private INDArray decodeGivenLatentSpaceValues(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (iNDArray.size(1) != getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, false, layerWorkspaceMgr).size(1)) {
            throw new IllegalArgumentException("Invalid latent space values: expected size " + getParamWithNoise(VariationalAutoencoderParamInitializer.PZX_MEAN_W, false, layerWorkspaceMgr).size(1) + ", got size (dimension 1) = " + iNDArray.size(1) + " " + layerId());
        }
        int length = this.decoderLayerSizes.length;
        INDArray iNDArray2 = iNDArray;
        IActivation activationFn = layerConf().getActivationFn();
        for (int i = 0; i < length; i++) {
            String str = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i + "W";
            String str2 = VariationalAutoencoderParamInitializer.DECODER_PREFIX + i + "b";
            iNDArray2 = iNDArray2.mmul(getParamWithNoise(str, false, layerWorkspaceMgr)).addiRowVector(getParamWithNoise(str2, false, layerWorkspaceMgr));
            activationFn.getActivation(iNDArray2, false);
        }
        return iNDArray2.mmul(getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_W, false, layerWorkspaceMgr)).addiRowVector(getParamWithNoise(VariationalAutoencoderParamInitializer.PXZ_B, false, layerWorkspaceMgr));
    }

    public boolean hasLossFunction() {
        return this.reconstructionDistribution.hasLossFunction();
    }

    public INDArray reconstructionError(INDArray iNDArray) {
        if (!hasLossFunction()) {
            throw new IllegalStateException("Cannot use reconstructionError method unless the variational autoencoder is configured with a standard loss function (via LossFunctionWrapper). For VAEs utilizing a reconstruction distribution, use the reconstructionProbability or reconstructionLogProbability methods " + layerId());
        }
        INDArray generateAtMeanGivenZ = generateAtMeanGivenZ(activate(iNDArray, false, LayerWorkspaceMgr.noWorkspaces()));
        return this.reconstructionDistribution instanceof CompositeReconstructionDistribution ? ((CompositeReconstructionDistribution) this.reconstructionDistribution).computeLossFunctionScoreArray(iNDArray, generateAtMeanGivenZ) : ((LossFunctionWrapper) this.reconstructionDistribution).getLossFunction().computeScoreArray(iNDArray, generateAtMeanGivenZ, new ActivationIdentity(), (INDArray) null);
    }

    public void assertInputSet(boolean z) {
        if (this.input == null) {
            if (!z) {
                throw new IllegalStateException("Cannot perform forward pass in layer " + getClass().getSimpleName() + ": layer input field is not set");
            }
            throw new IllegalStateException("Cannot perform backprop in layer " + getClass().getSimpleName() + ": layer input field is not set");
        }
    }

    public Map<String, INDArray> getGradientViews() {
        return this.gradientViews;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIterationCount() {
        return this.iterationCount;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIterationCount(int i) {
        this.iterationCount = i;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getEpochCount() {
        return this.epochCount;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setEpochCount(int i) {
        this.epochCount = i;
    }
}
