package org.deeplearning4j.nn.params;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/DefaultParamInitializer.class */
public class DefaultParamInitializer implements ParamInitializer {
    private static final DefaultParamInitializer INSTANCE = new DefaultParamInitializer();
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";
    public static final String GAIN_KEY = "g";

    public static DefaultParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) layer;
        long nIn = feedForwardLayer.getNIn();
        long nOut = feedForwardLayer.getNOut();
        return (nIn * nOut) + (hasBias(layer) ? nOut : 0L) + (hasLayerNorm(layer) ? nOut : 0L);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        ArrayList arrayList = new ArrayList(3);
        arrayList.addAll(weightKeys(layer));
        arrayList.addAll(biasKeys(layer));
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return hasLayerNorm(layer) ? Arrays.asList("W", "g") : Collections.singletonList("W");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return hasBias(layer) ? Collections.singletonList("b") : Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return "W".equals(str) || (hasLayerNorm(layer) && "g".equals(str));
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return "b".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (!(neuralNetConfiguration.getLayer() instanceof FeedForwardLayer)) {
            throw new IllegalArgumentException("unsupported layer type: " + neuralNetConfiguration.getLayer().getClass().getName());
        }
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        long numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected params view of length " + numParams + ", got length " + iNDArray.length());
        }
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        long nIn = feedForwardLayer.getNIn();
        long nOut = feedForwardLayer.getNOut();
        long j = nIn * nOut;
        synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, j)}), z));
        neuralNetConfiguration.addVariable("W");
        long j2 = j;
        if (hasBias(feedForwardLayer)) {
            synchronizedMap.put("b", createBias(neuralNetConfiguration, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + nOut)}), z));
            neuralNetConfiguration.addVariable("b");
            j2 += nOut;
        }
        if (hasLayerNorm(feedForwardLayer)) {
            synchronizedMap.put("g", createGain(neuralNetConfiguration, iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + nOut)}), z));
            neuralNetConfiguration.addVariable("g");
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        long nIn = feedForwardLayer.getNIn();
        long nOut = feedForwardLayer.getNOut();
        long j = nIn * nOut;
        INDArray reshape = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, j)}).reshape('f', new long[]{nIn, nOut});
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        long j2 = j;
        if (hasBias(feedForwardLayer)) {
            linkedHashMap.put("b", iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + nOut)}));
            j2 += nOut;
        }
        if (hasLayerNorm(feedForwardLayer)) {
            linkedHashMap.put("g", iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j2, j2 + nOut)}));
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createBias(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        return createBias(feedForwardLayer.getNOut(), feedForwardLayer.getBiasInit(), iNDArray, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createBias(long j, double d, INDArray iNDArray, boolean z) {
        if (z) {
            iNDArray.assign(Double.valueOf(d));
        }
        return iNDArray;
    }

    protected INDArray createGain(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        return createGain(feedForwardLayer.getNOut(), feedForwardLayer.getGainInit(), iNDArray, z);
    }

    protected INDArray createGain(long j, double d, INDArray iNDArray, boolean z) {
        if (z) {
            iNDArray.assign(Double.valueOf(d));
        }
        return iNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        return z ? createWeightMatrix(feedForwardLayer.getNIn(), feedForwardLayer.getNOut(), feedForwardLayer.getWeightInitFn(), iNDArray, true) : createWeightMatrix(feedForwardLayer.getNIn(), feedForwardLayer.getNOut(), null, iNDArray, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray createWeightMatrix(long j, long j2, IWeightInit iWeightInit, INDArray iNDArray, boolean z) {
        long[] jArr = {j, j2};
        return z ? iWeightInit.init(j, j2, jArr, 'f', iNDArray) : WeightInitUtil.reshapeWeights(jArr, iNDArray);
    }

    protected boolean hasBias(Layer layer) {
        if (layer instanceof BaseOutputLayer) {
            return ((BaseOutputLayer) layer).hasBias();
        }
        if (layer instanceof DenseLayer) {
            return ((DenseLayer) layer).hasBias();
        }
        if (layer instanceof EmbeddingLayer) {
            return ((EmbeddingLayer) layer).hasBias();
        }
        if (layer instanceof EmbeddingSequenceLayer) {
            return ((EmbeddingSequenceLayer) layer).hasBias();
        }
        return true;
    }

    protected boolean hasLayerNorm(Layer layer) {
        if (layer instanceof DenseLayer) {
            return ((DenseLayer) layer).hasLayerNorm();
        }
        return false;
    }
}
