package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.ArrayList;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.class */
public class KerasLoss extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLoss.class);
    private final String KERAS_CLASS_NAME_LOSS = "Loss";
    private LossFunctions.LossFunction loss;

    public KerasLoss(String str, String str2, String str3) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        this(str, str2, str3, true);
    }

    public KerasLoss(String str, String str2, String str3, boolean z) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        this.KERAS_CLASS_NAME_LOSS = "Loss";
        this.className = "Loss";
        this.layerName = str;
        this.inputShape = null;
        this.dimOrder = KerasLayer.DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.inboundLayerNames.add(str2);
        try {
            this.loss = KerasLossUtils.mapLossFunction(str3, this.conf);
        } catch (UnsupportedKerasConfigurationException e) {
            if (z) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            this.loss = LossFunctions.LossFunction.SQUARED_LOSS;
        }
    }

    public FeedForwardLayer getLossLayer(InputType inputType) throws UnsupportedKerasConfigurationException {
        if (inputType instanceof InputType.InputTypeFeedForward) {
            this.layer = new LossLayer.Builder(this.loss).name(this.layerName).activation(Activation.IDENTITY).build();
        } else if (inputType instanceof InputType.InputTypeRecurrent) {
            this.layer = new RnnLossLayer.Builder(this.loss).name(this.layerName).activation(Activation.IDENTITY).build();
        } else {
            if (!(inputType instanceof InputType.InputTypeConvolutional)) {
                throw new UnsupportedKerasConfigurationException("Unsupported output layer typegot : " + inputType.toString());
            }
            this.layer = new CnnLossLayer.Builder(this.loss).name(this.layerName).activation(Activation.IDENTITY).build();
        }
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Loss layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return getLossLayer(inputTypeArr[0]).getOutputType(-1, inputTypeArr[0]);
    }

    public String getKERAS_CLASS_NAME_LOSS() {
        getClass();
        return "Loss";
    }

    public LossFunctions.LossFunction getLoss() {
        return this.loss;
    }

    public void setLoss(LossFunctions.LossFunction lossFunction) {
        this.loss = lossFunction;
    }

    public String toString() {
        return "KerasLoss(KERAS_CLASS_NAME_LOSS=" + getKERAS_CLASS_NAME_LOSS() + ", loss=" + getLoss() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KerasLoss)) {
            return false;
        }
        KerasLoss kerasLoss = (KerasLoss) obj;
        if (!kerasLoss.canEqual(this)) {
            return false;
        }
        String keras_class_name_loss = getKERAS_CLASS_NAME_LOSS();
        String keras_class_name_loss2 = kerasLoss.getKERAS_CLASS_NAME_LOSS();
        if (keras_class_name_loss == null) {
            if (keras_class_name_loss2 != null) {
                return false;
            }
        } else if (!keras_class_name_loss.equals(keras_class_name_loss2)) {
            return false;
        }
        LossFunctions.LossFunction loss = getLoss();
        LossFunctions.LossFunction loss2 = kerasLoss.getLoss();
        return loss == null ? loss2 == null : loss.equals(loss2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof KerasLoss;
    }

    public int hashCode() {
        String keras_class_name_loss = getKERAS_CLASS_NAME_LOSS();
        int hashCode = (1 * 59) + (keras_class_name_loss == null ? 43 : keras_class_name_loss.hashCode());
        LossFunctions.LossFunction loss = getLoss();
        return (hashCode * 59) + (loss == null ? 43 : loss.hashCode());
    }
}
