package org.deeplearning4j.arbiter;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.arbiter.BaseNetworkSpace;
import org.deeplearning4j.arbiter.layers.LayerSpace;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/deeplearning4j/arbiter/MultiLayerSpace.class */
public class MultiLayerSpace extends BaseNetworkSpace<DL4JConfiguration> {

    @JsonProperty
    protected ParameterSpace<InputType> inputType;

    @JsonProperty
    protected ParameterSpace<Map<Integer, InputPreProcessor>> inputPreProcessors;

    @JsonProperty
    protected EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;

    @JsonProperty
    protected int numParameters;

    /* loaded from: input_file:org/deeplearning4j/arbiter/MultiLayerSpace$Builder.class */
    public static class Builder extends BaseNetworkSpace.Builder<Builder> {
        protected List<BaseNetworkSpace.LayerConf> layerSpaces = new ArrayList();
        protected ParameterSpace<InputType> inputType;
        protected ParameterSpace<Map<Integer, InputPreProcessor>> inputPreProcessors;
        protected EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration;

        public Builder setInputType(InputType inputType) {
            return setInputType((ParameterSpace<InputType>) new FixedValue(inputType));
        }

        public Builder setInputType(ParameterSpace<InputType> parameterSpace) {
            this.inputType = parameterSpace;
            return this;
        }

        public Builder addLayer(LayerSpace<?> layerSpace) {
            return addLayer(layerSpace, new FixedValue(1), true);
        }

        public Builder addLayer(LayerSpace<? extends Layer> layerSpace, ParameterSpace<Integer> parameterSpace, boolean z) {
            this.layerSpaces.add(new BaseNetworkSpace.LayerConf(layerSpace, "layer_" + this.layerSpaces.size(), null, parameterSpace, z));
            return this;
        }

        public Builder earlyStoppingConfiguration(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration) {
            this.earlyStoppingConfiguration = earlyStoppingConfiguration;
            return this;
        }

        public Builder setInputPreProcessors(Map<Integer, InputPreProcessor> map) {
            return setInputPreProcessors((ParameterSpace<Map<Integer, InputPreProcessor>>) new FixedValue(map));
        }

        public Builder setInputPreProcessors(ParameterSpace<Map<Integer, InputPreProcessor>> parameterSpace) {
            this.inputPreProcessors = parameterSpace;
            return this;
        }

        @Override // org.deeplearning4j.arbiter.BaseNetworkSpace.Builder
        public MultiLayerSpace build() {
            return new MultiLayerSpace(this);
        }
    }

    protected MultiLayerSpace(Builder builder) {
        super(builder);
        this.inputType = builder.inputType;
        this.inputPreProcessors = builder.inputPreProcessors;
        this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration;
        this.layerSpaces = builder.layerSpaces;
        Iterator it = LeafUtils.getUniqueObjects(collectLeaves()).iterator();
        while (it.hasNext()) {
            this.numParameters += ((ParameterSpace) it.next()).numParameters();
        }
    }

    protected MultiLayerSpace() {
    }

    /* renamed from: getValue, reason: merged with bridge method [inline-methods] */
    public DL4JConfiguration m2getValue(double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (BaseNetworkSpace.LayerConf layerConf : this.layerSpaces) {
            int intValue = ((Integer) layerConf.numLayers.getValue(dArr)).intValue();
            if (!layerConf.duplicateConfig) {
                throw new UnsupportedOperationException("Not yet implemented");
            }
            Layer layer = (Layer) layerConf.layerSpace.getValue(dArr);
            for (int i = 0; i < intValue; i++) {
                arrayList.add(layer.clone());
            }
        }
        NeuralNetConfiguration.ListBuilder list = randomGlobalConf(dArr).list();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            list.layer(i2, (Layer) arrayList.get(i2));
        }
        if (this.backprop != null) {
            list.backprop(((Boolean) this.backprop.getValue(dArr)).booleanValue());
        }
        if (this.pretrain != null) {
            list.pretrain(((Boolean) this.pretrain.getValue(dArr)).booleanValue());
        }
        if (this.backpropType != null) {
            list.backpropType((BackpropType) this.backpropType.getValue(dArr));
        }
        if (this.tbpttFwdLength != null) {
            list.tBPTTForwardLength(((Integer) this.tbpttFwdLength.getValue(dArr)).intValue());
        }
        if (this.tbpttBwdLength != null) {
            list.tBPTTBackwardLength(((Integer) this.tbpttBwdLength.getValue(dArr)).intValue());
        }
        if (this.inputType != null) {
            list.setInputType((InputType) this.inputType.getValue(dArr));
        }
        if (this.inputPreProcessors != null) {
            list.setInputPreProcessors((Map) this.inputPreProcessors.getValue(dArr));
        }
        return new DL4JConfiguration(list.build(), this.earlyStoppingConfiguration, Integer.valueOf(this.numEpochs));
    }

    public int numParameters() {
        return this.numParameters;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public List<ParameterSpace> collectLeaves() {
        List<ParameterSpace> collectLeaves = super.collectLeaves();
        for (BaseNetworkSpace.LayerConf layerConf : this.layerSpaces) {
            collectLeaves.addAll(layerConf.numLayers.collectLeaves());
            collectLeaves.addAll(layerConf.layerSpace.collectLeaves());
        }
        if (this.inputType != null) {
            collectLeaves.addAll(this.inputType.collectLeaves());
        }
        if (this.inputPreProcessors != null) {
            collectLeaves.addAll(this.inputPreProcessors.collectLeaves());
        }
        return collectLeaves;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public String toString() {
        StringBuilder sb = new StringBuilder(super.toString());
        int i = 0;
        for (BaseNetworkSpace.LayerConf layerConf : this.layerSpaces) {
            int i2 = i;
            i++;
            sb.append("Layer config ").append(i2).append(": (Number layers:").append(layerConf.numLayers).append(", duplicate: ").append(layerConf.duplicateConfig).append("), ").append(layerConf.layerSpace.toString()).append("\n");
        }
        if (this.inputType != null) {
            sb.append("inputType: ").append(this.inputType).append("\n");
        }
        if (this.inputPreProcessors != null) {
            sb.append("inputPreProcessors: ").append(this.inputPreProcessors).append("\n");
        }
        if (this.earlyStoppingConfiguration != null) {
            sb.append("Early stopping configuration:").append(this.earlyStoppingConfiguration.toString()).append("\n");
        } else {
            sb.append("Training # epochs:").append(this.numEpochs).append("\n");
        }
        return sb.toString();
    }

    public LayerSpace<?> getLayerSpace(int i) {
        return this.layerSpaces.get(i).getLayerSpace();
    }

    public static MultiLayerSpace fromJson(String str) {
        try {
            return (MultiLayerSpace) JsonMapper.getMapper().readValue(str, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static MultiLayerSpace fromYaml(String str) {
        try {
            return (MultiLayerSpace) YamlMapper.getMapper().readValue(str, MultiLayerSpace.class);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public ParameterSpace<InputType> getInputType() {
        return this.inputType;
    }

    public ParameterSpace<Map<Integer, InputPreProcessor>> getInputPreProcessors() {
        return this.inputPreProcessors;
    }

    public EarlyStoppingConfiguration<MultiLayerNetwork> getEarlyStoppingConfiguration() {
        return this.earlyStoppingConfiguration;
    }

    public int getNumParameters() {
        return this.numParameters;
    }

    public void setInputType(ParameterSpace<InputType> parameterSpace) {
        this.inputType = parameterSpace;
    }

    public void setInputPreProcessors(ParameterSpace<Map<Integer, InputPreProcessor>> parameterSpace) {
        this.inputPreProcessors = parameterSpace;
    }

    public void setEarlyStoppingConfiguration(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration) {
        this.earlyStoppingConfiguration = earlyStoppingConfiguration;
    }

    public void setNumParameters(int i) {
        this.numParameters = i;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiLayerSpace)) {
            return false;
        }
        MultiLayerSpace multiLayerSpace = (MultiLayerSpace) obj;
        if (!multiLayerSpace.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        ParameterSpace<InputType> inputType = getInputType();
        ParameterSpace<InputType> inputType2 = multiLayerSpace.getInputType();
        if (inputType == null) {
            if (inputType2 != null) {
                return false;
            }
        } else if (!inputType.equals(inputType2)) {
            return false;
        }
        ParameterSpace<Map<Integer, InputPreProcessor>> inputPreProcessors = getInputPreProcessors();
        ParameterSpace<Map<Integer, InputPreProcessor>> inputPreProcessors2 = multiLayerSpace.getInputPreProcessors();
        if (inputPreProcessors == null) {
            if (inputPreProcessors2 != null) {
                return false;
            }
        } else if (!inputPreProcessors.equals(inputPreProcessors2)) {
            return false;
        }
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration = getEarlyStoppingConfiguration();
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration2 = multiLayerSpace.getEarlyStoppingConfiguration();
        if (earlyStoppingConfiguration == null) {
            if (earlyStoppingConfiguration2 != null) {
                return false;
            }
        } else if (!earlyStoppingConfiguration.equals(earlyStoppingConfiguration2)) {
            return false;
        }
        return getNumParameters() == multiLayerSpace.getNumParameters();
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    protected boolean canEqual(Object obj) {
        return obj instanceof MultiLayerSpace;
    }

    @Override // org.deeplearning4j.arbiter.BaseNetworkSpace
    public int hashCode() {
        int hashCode = (1 * 59) + super.hashCode();
        ParameterSpace<InputType> inputType = getInputType();
        int hashCode2 = (hashCode * 59) + (inputType == null ? 43 : inputType.hashCode());
        ParameterSpace<Map<Integer, InputPreProcessor>> inputPreProcessors = getInputPreProcessors();
        int hashCode3 = (hashCode2 * 59) + (inputPreProcessors == null ? 43 : inputPreProcessors.hashCode());
        EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration = getEarlyStoppingConfiguration();
        return (((hashCode3 * 59) + (earlyStoppingConfiguration == null ? 43 : earlyStoppingConfiguration.hashCode())) * 59) + getNumParameters();
    }
}
