package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.class */
public abstract class SameDiffLambdaVertex extends SameDiffVertex {
    protected transient VertexInputs inputs;

    /* loaded from: input_file:org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex$VertexInputs.class */
    public class VertexInputs {
        private SameDiff sameDiff;
        private Map<Integer, SDVariable> map = new LinkedHashMap();

        protected VertexInputs(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
        }

        public SDVariable getInput(int i) {
            Preconditions.checkArgument(i >= 0, "Input number must be >= 0.Got: %s", i);
            if (!this.map.containsKey(Integer.valueOf(i))) {
                this.map.put(Integer.valueOf(i), this.sameDiff.var("var_" + i, new int[]{1}));
            }
            return this.map.get(Integer.valueOf(i));
        }
    }

    public abstract SDVariable defineVertex(SameDiff sameDiff, VertexInputs vertexInputs);

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> map, Map<String, SDVariable> map2) {
        VertexInputs inputs = getInputs(sameDiff);
        int i = 0;
        if (inputs.map.size() == 0 && map.size() > 0) {
            Iterator<SDVariable> it = map.values().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                inputs.map.put(Integer.valueOf(i2), it.next());
            }
        }
        return defineVertex(sameDiff, getInputs(sameDiff));
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public void defineParametersAndInputs(SDVertexParams sDVertexParams) {
        SameDiff create = SameDiff.create();
        VertexInputs vertexInputs = new VertexInputs(create);
        defineVertex(create, vertexInputs);
        ArrayList arrayList = new ArrayList();
        Iterator it = vertexInputs.map.keySet().iterator();
        while (it.hasNext()) {
            arrayList.add(((SDVariable) vertexInputs.map.get((Integer) it.next())).getVarName());
        }
        sDVertexParams.defineInputs((String[]) arrayList.toArray(new String[arrayList.size()]));
    }

    @Override // org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
    public void initializeParameters(Map<String, INDArray> map) {
    }

    protected VertexInputs getInputs(SameDiff sameDiff) {
        if (this.inputs == null) {
            this.inputs = new VertexInputs(sameDiff);
        }
        return this.inputs;
    }
}
