package org.deeplearning4j.nn.conf.dropout;

import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"mask", "helper"})
/* loaded from: input_file:org/deeplearning4j/nn/conf/dropout/Dropout.class */
public class Dropout implements IDropout {
    private static final Logger log = LoggerFactory.getLogger(Dropout.class);
    private double p;
    private ISchedule pSchedule;
    private transient INDArray mask;
    private transient DropoutHelper helper;

    public Dropout(double d) {
        this(d, null);
        if (d < EvaluationBinary.DEFAULT_EDGE_VALUE) {
            throw new IllegalArgumentException("Activation retain probability must be > 0. Got: " + d);
        }
        if (d == EvaluationBinary.DEFAULT_EDGE_VALUE) {
            throw new IllegalArgumentException("Invalid probability value: Dropout with 0.0 probability of retaining activations is not supported");
        }
    }

    public Dropout(ISchedule iSchedule) {
        this(Double.NaN, iSchedule);
    }

    protected Dropout(@JsonProperty("p") double d, @JsonProperty("pSchedule") ISchedule iSchedule) {
        this.p = d;
        this.pSchedule = iSchedule;
        initializeHelper();
    }

    protected void initializeHelper() {
        if ("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))) {
            try {
                this.helper = (DropoutHelper) Class.forName("org.deeplearning4j.nn.layers.dropout.CudnnDropoutHelper").asSubclass(DropoutHelper.class).newInstance();
                log.debug("CudnnDropoutHelper successfully initialized");
                if (!this.helper.checkSupported()) {
                    this.helper = null;
                }
            } catch (Throwable th) {
                if (th instanceof ClassNotFoundException) {
                    return;
                }
                log.warn("Could not initialize CudnnDropoutHelper", th);
            }
        }
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray applyDropout(INDArray iNDArray, INDArray iNDArray2, int i, int i2, LayerWorkspaceMgr layerWorkspaceMgr) {
        Preconditions.checkState(iNDArray2.dataType().isFPType(), "Output array must be a floating point type, got %s for array of shape %ndShape", iNDArray2.dataType(), iNDArray2);
        double valueAt = this.pSchedule != null ? this.pSchedule.valueAt(i, i2) : this.p;
        if (this.helper != null) {
            this.helper.applyDropout(iNDArray, iNDArray2, this.p);
            return iNDArray2;
        }
        INDArray iNDArray3 = iNDArray;
        if (iNDArray3 != iNDArray2 && iNDArray3.dataType() != iNDArray2.dataType()) {
            iNDArray3 = iNDArray3.castTo(iNDArray2.dataType());
        }
        this.mask = layerWorkspaceMgr.createUninitialized(ArrayType.INPUT, iNDArray2.dataType(), iNDArray2.shape(), iNDArray2.ordering()).assign(Double.valueOf(1.0d));
        Nd4j.getExecutioner().exec(new DropOutInverted(this.mask, this.mask, valueAt));
        Nd4j.getExecutioner().exec(new OldMulOp(iNDArray3, this.mask, iNDArray2));
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public INDArray backprop(INDArray iNDArray, INDArray iNDArray2, int i, int i2) {
        if (this.helper != null) {
            this.helper.backprop(iNDArray, iNDArray2);
            return iNDArray2;
        }
        Preconditions.checkState(this.mask != null, "Cannot perform backprop: Dropout mask array is absent (already cleared?)");
        INDArray iNDArray3 = this.mask;
        if (iNDArray3.dataType() != iNDArray2.dataType()) {
            iNDArray3 = iNDArray3.castTo(iNDArray2.dataType());
        }
        Nd4j.getExecutioner().exec(new OldMulOp(iNDArray, iNDArray3, iNDArray2));
        this.mask = null;
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    public void clear() {
        this.mask = null;
    }

    @Override // org.deeplearning4j.nn.conf.dropout.IDropout
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Dropout m43clone() {
        return new Dropout(this.p, this.pSchedule == null ? null : this.pSchedule.clone());
    }

    public double getP() {
        return this.p;
    }

    public ISchedule getPSchedule() {
        return this.pSchedule;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public DropoutHelper getHelper() {
        return this.helper;
    }

    public void setP(double d) {
        this.p = d;
    }

    public void setPSchedule(ISchedule iSchedule) {
        this.pSchedule = iSchedule;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    public void setHelper(DropoutHelper dropoutHelper) {
        this.helper = dropoutHelper;
    }

    public String toString() {
        return "Dropout(p=" + getP() + ", pSchedule=" + getPSchedule() + ", mask=" + getMask() + ", helper=" + getHelper() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Dropout)) {
            return false;
        }
        Dropout dropout = (Dropout) obj;
        if (!dropout.canEqual(this) || Double.compare(getP(), dropout.getP()) != 0) {
            return false;
        }
        ISchedule pSchedule = getPSchedule();
        ISchedule pSchedule2 = dropout.getPSchedule();
        return pSchedule == null ? pSchedule2 == null : pSchedule.equals(pSchedule2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Dropout;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getP());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        ISchedule pSchedule = getPSchedule();
        return (i * 59) + (pSchedule == null ? 43 : pSchedule.hashCode());
    }
}
