package org.deeplearning4j.nn.weights.embeddings;

import lombok.NonNull;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"nonSerializableInit"})
/* loaded from: input_file:org/deeplearning4j/nn/weights/embeddings/WeightInitEmbedding.class */
public class WeightInitEmbedding implements IWeightInit {
    private EmbeddingInitializer serializableInit;
    private EmbeddingInitializer nonSerializableInit;

    public WeightInitEmbedding(@NonNull EmbeddingInitializer embeddingInitializer) {
        this(embeddingInitializer.jsonSerializable() ? embeddingInitializer : null, embeddingInitializer.jsonSerializable() ? null : embeddingInitializer);
        if (embeddingInitializer == null) {
            throw new NullPointerException("embeddingInitializer is marked non-null but is null");
        }
    }

    protected WeightInitEmbedding(@JsonProperty("serializableInit") EmbeddingInitializer embeddingInitializer, @JsonProperty("nonSerializableInit") EmbeddingInitializer embeddingInitializer2) {
        this.serializableInit = embeddingInitializer;
        this.nonSerializableInit = embeddingInitializer2;
    }

    @Override // org.deeplearning4j.nn.weights.IWeightInit
    public INDArray init(double d, double d2, long[] jArr, char c, INDArray iNDArray) {
        EmbeddingInitializer embeddingInitializer = this.serializableInit != null ? this.serializableInit : this.nonSerializableInit;
        if (embeddingInitializer == null) {
            throw new IllegalStateException("Cannot initialize embedding layer weights: no EmbeddingInitializer is available. This can occur if you save network configuration, load it, and the try to ");
        }
        Preconditions.checkState(jArr[0] == embeddingInitializer.vocabSize(), "Parameters shape[0]=%s does not match embedding initializer vocab size of %s", jArr[0], embeddingInitializer.vocabSize());
        Preconditions.checkState(jArr[1] == ((long) embeddingInitializer.vectorSize()), "Parameters shape[1]=%s does not match embedding initializer vector size of %s", jArr[1], embeddingInitializer.vectorSize());
        INDArray reshape = iNDArray.reshape('c', jArr);
        embeddingInitializer.loadWeightsInto(reshape);
        this.nonSerializableInit = null;
        return reshape;
    }

    public long[] shape() {
        if (this.serializableInit != null) {
            return new long[]{this.serializableInit.vocabSize(), this.serializableInit.vectorSize()};
        }
        if (this.nonSerializableInit != null) {
            return new long[]{this.nonSerializableInit.vocabSize(), this.nonSerializableInit.vectorSize()};
        }
        return null;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof WeightInitEmbedding)) {
            return false;
        }
        WeightInitEmbedding weightInitEmbedding = (WeightInitEmbedding) obj;
        if (!weightInitEmbedding.canEqual(this)) {
            return false;
        }
        EmbeddingInitializer embeddingInitializer = this.serializableInit;
        EmbeddingInitializer embeddingInitializer2 = weightInitEmbedding.serializableInit;
        if (embeddingInitializer == null) {
            if (embeddingInitializer2 != null) {
                return false;
            }
        } else if (!embeddingInitializer.equals(embeddingInitializer2)) {
            return false;
        }
        EmbeddingInitializer embeddingInitializer3 = this.nonSerializableInit;
        EmbeddingInitializer embeddingInitializer4 = weightInitEmbedding.nonSerializableInit;
        return embeddingInitializer3 == null ? embeddingInitializer4 == null : embeddingInitializer3.equals(embeddingInitializer4);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof WeightInitEmbedding;
    }

    public int hashCode() {
        EmbeddingInitializer embeddingInitializer = this.serializableInit;
        int hashCode = (1 * 59) + (embeddingInitializer == null ? 43 : embeddingInitializer.hashCode());
        EmbeddingInitializer embeddingInitializer2 = this.nonSerializableInit;
        return (hashCode * 59) + (embeddingInitializer2 == null ? 43 : embeddingInitializer2.hashCode());
    }
}
