package org.nd4j.autodiff.samediff.internal;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.function.Predicate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession.class */
public abstract class AbstractSession<T, O> {
    private static final Logger log = LoggerFactory.getLogger(AbstractSession.class);
    public static final String OUTER_FRAME = "main";
    protected final SameDiff sameDiff;
    protected final Map<VarId, T> nodeOutputs = new HashMap();
    protected final Map<VarId, List<T>> tensorArrays = new HashMap();
    protected final DependencyTracker<ExecStep, ExecStep> dt = new DependencyTracker<>();
    protected final Set<String> subgraph = new HashSet();
    protected final Set<String> subgraphOps = new HashSet();
    protected final Set<String> zeroInputOpsInSubgraph = new HashSet();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$ExecStep.class */
    public static class ExecStep {
        protected final ExecType type;
        protected final String name;
        protected final FrameIter frameIter;

        protected ExecStep(@NonNull ExecType execType, @NonNull String str, FrameIter frameIter) {
            if (execType == null) {
                throw new NullPointerException("execType is marked @NonNull but is null");
            }
            if (str == null) {
                throw new NullPointerException("name is marked @NonNull but is null");
            }
            this.type = execType;
            this.name = str;
            this.frameIter = frameIter;
        }

        protected VarId toVarId() {
            return new VarId(this.name, this.frameIter.getFrame(), this.frameIter.getIteration(), this.frameIter.getParentFrame());
        }

        public String toString() {
            return "ExecStep(" + this.type + ",name=\"" + this.name + "\"," + this.frameIter + ")";
        }

        public ExecType getType() {
            return this.type;
        }

        public String getName() {
            return this.name;
        }

        public FrameIter getFrameIter() {
            return this.frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ExecStep)) {
                return false;
            }
            ExecStep execStep = (ExecStep) obj;
            if (!execStep.canEqual(this)) {
                return false;
            }
            ExecType type = getType();
            ExecType type2 = execStep.getType();
            if (type == null) {
                if (type2 != null) {
                    return false;
                }
            } else if (!type.equals(type2)) {
                return false;
            }
            String name = getName();
            String name2 = execStep.getName();
            if (name == null) {
                if (name2 != null) {
                    return false;
                }
            } else if (!name.equals(name2)) {
                return false;
            }
            FrameIter frameIter = getFrameIter();
            FrameIter frameIter2 = execStep.getFrameIter();
            return frameIter == null ? frameIter2 == null : frameIter.equals(frameIter2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ExecStep;
        }

        public int hashCode() {
            ExecType type = getType();
            int hashCode = (1 * 59) + (type == null ? 43 : type.hashCode());
            String name = getName();
            int hashCode2 = (hashCode * 59) + (name == null ? 43 : name.hashCode());
            FrameIter frameIter = getFrameIter();
            return (hashCode2 * 59) + (frameIter == null ? 43 : frameIter.hashCode());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$ExecStepPredicate.class */
    public class ExecStepPredicate implements Predicate<ExecStep> {
        protected String currentFrame;
        protected int currentFrameIter;
        protected FrameIter currParentFrame;

        public boolean test(ExecStep execStep) {
            return this.currentFrame.equals(execStep.getFrameIter().getFrame()) && this.currentFrameIter == execStep.getFrameIter().getIteration() && ((this.currParentFrame == null && execStep.getFrameIter().getParentFrame() == null) || this.currParentFrame.equals(execStep.getFrameIter().getParentFrame()));
        }

        public String getCurrentFrame() {
            return this.currentFrame;
        }

        public int getCurrentFrameIter() {
            return this.currentFrameIter;
        }

        public FrameIter getCurrParentFrame() {
            return this.currParentFrame;
        }

        public void setCurrentFrame(String str) {
            this.currentFrame = str;
        }

        public void setCurrentFrameIter(int i) {
            this.currentFrameIter = i;
        }

        public void setCurrParentFrame(FrameIter frameIter) {
            this.currParentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ExecStepPredicate)) {
                return false;
            }
            ExecStepPredicate execStepPredicate = (ExecStepPredicate) obj;
            if (!execStepPredicate.canEqual(this)) {
                return false;
            }
            String currentFrame = getCurrentFrame();
            String currentFrame2 = execStepPredicate.getCurrentFrame();
            if (currentFrame == null) {
                if (currentFrame2 != null) {
                    return false;
                }
            } else if (!currentFrame.equals(currentFrame2)) {
                return false;
            }
            if (getCurrentFrameIter() != execStepPredicate.getCurrentFrameIter()) {
                return false;
            }
            FrameIter currParentFrame = getCurrParentFrame();
            FrameIter currParentFrame2 = execStepPredicate.getCurrParentFrame();
            return currParentFrame == null ? currParentFrame2 == null : currParentFrame.equals(currParentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ExecStepPredicate;
        }

        public int hashCode() {
            String currentFrame = getCurrentFrame();
            int hashCode = (((1 * 59) + (currentFrame == null ? 43 : currentFrame.hashCode())) * 59) + getCurrentFrameIter();
            FrameIter currParentFrame = getCurrParentFrame();
            return (hashCode * 59) + (currParentFrame == null ? 43 : currParentFrame.hashCode());
        }

        public String toString() {
            return "AbstractSession.ExecStepPredicate(currentFrame=" + getCurrentFrame() + ", currentFrameIter=" + getCurrentFrameIter() + ", currParentFrame=" + getCurrParentFrame() + ")";
        }

        public ExecStepPredicate(String str, int i, FrameIter frameIter) {
            this.currentFrame = str;
            this.currentFrameIter = i;
            this.currParentFrame = frameIter;
        }

        public ExecStepPredicate() {
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$ExecType.class */
    public enum ExecType {
        OP,
        VARIABLE,
        CONSTANT,
        PLACEHOLDER,
        SWITCH_L,
        SWITCH_R,
        EXEC_START,
        CONTROL_DEP
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$FrameIter.class */
    public static class FrameIter {
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public String toString() {
            return "(\"" + this.frame + "\"," + this.iteration + (this.parentFrame == null ? "" : ",parent=" + this.parentFrame.toString()) + ")";
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public FrameIter m1466clone() {
            return new FrameIter(this.frame, this.iteration, this.parentFrame == null ? null : this.parentFrame.m1466clone());
        }

        public VarId toVarId(String str) {
            return new VarId(str, this.frame, this.iteration, this.parentFrame);
        }

        public String getFrame() {
            return this.frame;
        }

        public int getIteration() {
            return this.iteration;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setIteration(int i) {
            this.iteration = i;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof FrameIter)) {
                return false;
            }
            FrameIter frameIter = (FrameIter) obj;
            if (!frameIter.canEqual(this)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = frameIter.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            if (getIteration() != frameIter.getIteration()) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = frameIter.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof FrameIter;
        }

        public int hashCode() {
            String frame = getFrame();
            int hashCode = (((1 * 59) + (frame == null ? 43 : frame.hashCode())) * 59) + getIteration();
            FrameIter parentFrame = getParentFrame();
            return (hashCode * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public FrameIter(String str, int i, FrameIter frameIter) {
            this.frame = str;
            this.iteration = i;
            this.parentFrame = frameIter;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$VarId.class */
    public static class VarId {
        private String variable;
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public String toString() {
            return "VarId(\"" + this.variable + "\",\"" + this.frame + "\"," + this.iteration + ",parent=" + this.parentFrame + ")";
        }

        public FrameIter toFrameIter() {
            return new FrameIter(this.frame, this.iteration, this.parentFrame);
        }

        public String getVariable() {
            return this.variable;
        }

        public String getFrame() {
            return this.frame;
        }

        public int getIteration() {
            return this.iteration;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setVariable(String str) {
            this.variable = str;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setIteration(int i) {
            this.iteration = i;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VarId)) {
                return false;
            }
            VarId varId = (VarId) obj;
            if (!varId.canEqual(this)) {
                return false;
            }
            String variable = getVariable();
            String variable2 = varId.getVariable();
            if (variable == null) {
                if (variable2 != null) {
                    return false;
                }
            } else if (!variable.equals(variable2)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = varId.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            if (getIteration() != varId.getIteration()) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = varId.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VarId;
        }

        public int hashCode() {
            String variable = getVariable();
            int hashCode = (1 * 59) + (variable == null ? 43 : variable.hashCode());
            String frame = getFrame();
            int hashCode2 = (((hashCode * 59) + (frame == null ? 43 : frame.hashCode())) * 59) + getIteration();
            FrameIter parentFrame = getParentFrame();
            return (hashCode2 * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public VarId(String str, String str2, int i, FrameIter frameIter) {
            this.variable = str;
            this.frame = str2;
            this.iteration = i;
            this.parentFrame = frameIter;
        }
    }

    public AbstractSession(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked @NonNull but is null");
        }
        this.sameDiff = sameDiff;
    }

    public boolean contains(String str, String str2, int i, FrameIter frameIter) {
        return this.nodeOutputs.containsKey(new VarId(str, str2, i, frameIter));
    }

    public T get(String str, String str2, int i, FrameIter frameIter) {
        return get(str, str2, i, frameIter, true);
    }

    public T get(String str, String str2, int i, FrameIter frameIter, boolean z) {
        T t = this.nodeOutputs.get(new VarId(str, str2, i, frameIter));
        if (z) {
            Preconditions.checkNotNull(t, "No output found for variable %s (frame %s, iteration %s)", str, str2, Integer.valueOf(i));
        }
        return t;
    }

    public Map<String, T> output(@NonNull List<String> list, Map<String, T> map, MultiDataSet multiDataSet, Collection<String> collection, List<Listener> list2, At at) {
        FrameIter frameIter;
        if (list == null) {
            throw new NullPointerException("variables is marked @NonNull but is null");
        }
        Preconditions.checkState((list.isEmpty() && collection.isEmpty()) ? false : true, "Variables to perform forward pass for must not be empty");
        if (collection == null) {
            collection = Collections.emptyList();
        }
        if (at == null) {
            at = At.defaultAt();
        }
        for (String str : list) {
            Preconditions.checkState(this.sameDiff.variableMap().containsKey(str), "Requested output variable %s does not exist in SameDiff instance", str);
        }
        Set<String> hashSet = new HashSet<>(list);
        Map<String, T> preprocessPlaceholders = preprocessPlaceholders(map, at);
        this.dt.clear();
        this.subgraph.clear();
        this.subgraphOps.clear();
        this.nodeOutputs.clear();
        this.tensorArrays.clear();
        Set<String> hashSet2 = new HashSet<>(list);
        Set<String> hashSet3 = new HashSet<>(collection);
        hashSet3.addAll(list);
        initSubgraph(hashSet3);
        List<String> inputs = this.sameDiff.inputs();
        if (preprocessPlaceholders == null || !preprocessPlaceholders.keySet().containsAll(inputs)) {
            for (String str2 : inputs) {
                boolean z = list.contains(str2);
                if (!z) {
                    Variable variable = this.sameDiff.getVariables().get(str2);
                    if (variable.getInputsForOp() != null) {
                        Iterator<String> it = variable.getInputsForOp().iterator();
                        while (true) {
                            if (!it.hasNext()) {
                                break;
                            }
                            if (this.subgraph.contains(it.next())) {
                                z = true;
                                break;
                            }
                        }
                    }
                }
                if (z && (preprocessPlaceholders == null || !preprocessPlaceholders.containsKey(str2))) {
                    throw new IllegalStateException("An input placeholder \"" + str2 + "\" is required to calculate the requested outputs, but a placeholder value was not provided");
                }
            }
        }
        ExecStep execStep = new ExecStep(ExecType.EXEC_START, "", null);
        for (SDVariable sDVariable : this.sameDiff.variables()) {
            VariableType variableType = sDVariable.getVariableType();
            if (variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT) {
                ExecStep execStep2 = new ExecStep(variableType == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT, sDVariable.name(), new FrameIter(OUTER_FRAME, 0, null));
                this.dt.addDependency(execStep2, execStep);
                Variable variable2 = this.sameDiff.getVariables().get(sDVariable.name());
                if (variable2.getControlDeps() != null) {
                    addVarControlDeps(execStep2, variable2);
                }
            }
        }
        for (String str3 : inputs) {
            ExecStep execStep3 = new ExecStep(ExecType.PLACEHOLDER, str3, new FrameIter(OUTER_FRAME, 0, null));
            this.dt.addDependency(execStep3, execStep);
            Variable variable3 = this.sameDiff.getVariables().get(str3);
            if (variable3.getControlDeps() != null) {
                addVarControlDeps(execStep3, variable3);
            }
        }
        Iterator<String> it2 = this.zeroInputOpsInSubgraph.iterator();
        while (it2.hasNext()) {
            this.dt.addDependency(new ExecStep(ExecType.OP, it2.next(), new FrameIter(OUTER_FRAME, 0, null)), execStep);
        }
        this.dt.markSatisfied(execStep, true);
        Map<String, T> hashMap = new HashMap<>();
        int i = 0;
        String str4 = OUTER_FRAME;
        int i2 = 0;
        FrameIter frameIter2 = null;
        ExecStepPredicate execStepPredicate = new ExecStepPredicate();
        while (hashMap.size() < hashSet2.size()) {
            if (!this.dt.hasNewAllSatisfied()) {
                execFailed(hashSet2, hashMap, i);
            }
            execStepPredicate.setCurrentFrame(str4);
            execStepPredicate.setCurrentFrameIter(i2);
            execStepPredicate.setCurrParentFrame(frameIter2);
            ExecStep firstNewAllSatisfiedMatching = this.dt.getFirstNewAllSatisfiedMatching(execStepPredicate);
            if (firstNewAllSatisfiedMatching == null) {
                firstNewAllSatisfiedMatching = this.dt.getNewAllSatisfied();
            }
            str4 = firstNewAllSatisfiedMatching.getFrameIter().getFrame();
            i2 = firstNewAllSatisfiedMatching.getFrameIter().getIteration();
            frameIter2 = firstNewAllSatisfiedMatching.getFrameIter().getParentFrame();
            log.trace("Beginning execution step {}: {}", Integer.valueOf(i), firstNewAllSatisfiedMatching);
            boolean z2 = false;
            boolean z3 = false;
            if (firstNewAllSatisfiedMatching.getType() == ExecType.CONSTANT || firstNewAllSatisfiedMatching.getType() == ExecType.VARIABLE) {
                VarId varId = new VarId(firstNewAllSatisfiedMatching.getName(), OUTER_FRAME, 0, null);
                T constantOrVariable = getConstantOrVariable(firstNewAllSatisfiedMatching.getName());
                Preconditions.checkNotNull(constantOrVariable, "Encountered null placeholder array for constant: %s", varId);
                this.nodeOutputs.put(varId, constantOrVariable);
                frameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (hashSet3.contains(firstNewAllSatisfiedMatching.getName())) {
                    hashMap.put(firstNewAllSatisfiedMatching.getName(), constantOrVariable);
                }
            } else if (firstNewAllSatisfiedMatching.getType() == ExecType.PLACEHOLDER) {
                this.nodeOutputs.put(new VarId(firstNewAllSatisfiedMatching.getName(), OUTER_FRAME, 0, null), preprocessPlaceholders.get(firstNewAllSatisfiedMatching.getName()));
                frameIter = new FrameIter(OUTER_FRAME, 0, null);
                if (hashSet3.contains(firstNewAllSatisfiedMatching.getName())) {
                    hashMap.put(firstNewAllSatisfiedMatching.getName(), preprocessPlaceholders.get(firstNewAllSatisfiedMatching.getName()));
                }
            } else {
                if (firstNewAllSatisfiedMatching.getType() != ExecType.OP) {
                    throw new RuntimeException("Unknown ExecStep: " + firstNewAllSatisfiedMatching);
                }
                String name = firstNewAllSatisfiedMatching.getName();
                SameDiffOp sameDiffOp = this.sameDiff.getOps().get(name);
                DifferentialFunction op = sameDiffOp.getOp();
                if (op instanceof Enter) {
                    frameIter = new FrameIter(((Enter) op).getFrameName(), 0, firstNewAllSatisfiedMatching.getFrameIter());
                } else if (op instanceof Exit) {
                    frameIter = new FrameIter(firstNewAllSatisfiedMatching.getFrameIter().getParentFrame().getFrame(), firstNewAllSatisfiedMatching.getFrameIter().getParentFrame().getIteration(), firstNewAllSatisfiedMatching.getFrameIter().getParentFrame().getParentFrame());
                } else if (op instanceof NextIteration) {
                    frameIter = firstNewAllSatisfiedMatching.getFrameIter().m1466clone();
                    frameIter.setIteration(frameIter.getIteration());
                } else {
                    frameIter = firstNewAllSatisfiedMatching.getFrameIter();
                }
                Set<VarId> set = null;
                Set<VarId> set2 = null;
                Set<String> set3 = null;
                DependencyList<ExecStep, ExecStep> dependencies = this.dt.getDependencies(firstNewAllSatisfiedMatching);
                List<String> inputsToOp = sameDiffOp.getInputsToOp();
                if (inputsToOp != null && !inputsToOp.isEmpty()) {
                    set = new HashSet<>();
                    set2 = new HashSet<>();
                    set3 = new HashSet<>();
                    List<ExecStep> dependencies2 = dependencies.getDependencies();
                    if (dependencies2 != null && !dependencies2.isEmpty()) {
                        for (ExecStep execStep4 : dependencies2) {
                            switch (execStep4.getType()) {
                                case OP:
                                case SWITCH_L:
                                case SWITCH_R:
                                    List<String> inputsToOp2 = this.sameDiff.getOps().get(firstNewAllSatisfiedMatching.getName()).getInputsToOp();
                                    List<String> outputsOfOp = this.sameDiff.getOps().get(execStep4.getName()).getOutputsOfOp();
                                    for (String str5 : inputsToOp2) {
                                        if (outputsOfOp.contains(str5)) {
                                            set.add(new VarId(str5, execStep4.getFrameIter().getFrame(), execStep4.getFrameIter().getIteration(), execStep4.getFrameIter().getParentFrame()));
                                        }
                                    }
                                    break;
                                case VARIABLE:
                                    set.add(new VarId(execStep4.getName(), OUTER_FRAME, 0, null));
                                    break;
                                case CONSTANT:
                                case PLACEHOLDER:
                                    set3.add(execStep4.getName());
                                    break;
                                default:
                                    throw new UnsupportedOperationException("Not yet implemented: " + execStep4.getType());
                            }
                        }
                    }
                }
                O andParameterizeOp = getAndParameterizeOp(name, frameIter, set, set2, set3, preprocessPlaceholders, hashSet);
                T[] outputs = getOutputs(andParameterizeOp, frameIter, set, set2, set3, list2, at, multiDataSet, hashSet);
                List<String> outputsOfOp2 = sameDiffOp.getOutputsOfOp();
                Preconditions.checkState(outputs.length == outputsOfOp2.size(), "Unexpected number of outputs from executed op %s: got %s outputs when %s outputs were expected (%s)", andParameterizeOp.getClass().getSimpleName(), Integer.valueOf(outputs.length), Integer.valueOf(outputsOfOp2.size()), outputsOfOp2);
                for (int i3 = 0; i3 < outputs.length; i3++) {
                    if (outputs[i3] != null || !(sameDiffOp.getOp() instanceof Switch)) {
                        String str6 = outputsOfOp2.get(i3);
                        this.nodeOutputs.put(new VarId(str6, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()), outputs[i3]);
                        if (hashSet3.contains(str6)) {
                            hashMap.put(str6, outputs[i3]);
                        }
                    }
                }
                if (op instanceof Switch) {
                    z2 = true;
                    z3 = true;
                    int i4 = (outputs[0] == null ? 1 : 0) + (outputs[1] == null ? 1 : 0);
                    Preconditions.checkState(i4 == 1, "Expected exactly one output to be present for switch ops, got %s", i4);
                    ExecStep execStep5 = outputs[0] != null ? new ExecStep(ExecType.SWITCH_L, firstNewAllSatisfiedMatching.getName(), firstNewAllSatisfiedMatching.getFrameIter()) : new ExecStep(ExecType.SWITCH_R, firstNewAllSatisfiedMatching.getName(), firstNewAllSatisfiedMatching.getFrameIter());
                    updateDescendantDeps(execStep5, frameIter);
                    this.dt.markSatisfied(execStep5, true);
                } else if (op instanceof Enter) {
                    z2 = true;
                    z3 = true;
                    FrameIter frameIter3 = new FrameIter(((Enter) op).getFrameName(), 0, firstNewAllSatisfiedMatching.getFrameIter());
                    ExecStep execStep6 = new ExecStep(ExecType.OP, firstNewAllSatisfiedMatching.getName(), frameIter3);
                    updateDescendantDeps(execStep6, frameIter3);
                    this.dt.markSatisfied(execStep6, true);
                } else if (op instanceof Exit) {
                    z2 = true;
                    z3 = true;
                    FrameIter parentFrame = firstNewAllSatisfiedMatching.getFrameIter().getParentFrame();
                    ExecStep execStep7 = new ExecStep(ExecType.OP, firstNewAllSatisfiedMatching.getName(), parentFrame);
                    updateDescendantDeps(execStep7, parentFrame);
                    this.dt.markSatisfied(execStep7, true);
                }
                if (sameDiffOp.getControlDepFor() != null) {
                    ExecStep execStep8 = new ExecStep(ExecType.CONTROL_DEP, name, null);
                    if (!this.dt.isSatisfied(execStep8)) {
                        this.dt.markSatisfied(execStep8, true);
                    }
                }
            }
            if (!z2) {
                updateDescendantDeps(firstNewAllSatisfiedMatching, frameIter);
            }
            if (!z3) {
                this.dt.markSatisfied(firstNewAllSatisfiedMatching, true);
            }
            i++;
        }
        return postProcessOutput(hashMap);
    }

    protected void addVarControlDeps(ExecStep execStep, Variable variable) {
        List<String> controlDeps = variable.getControlDeps();
        if (controlDeps != null) {
            Iterator<String> it = controlDeps.iterator();
            while (it.hasNext()) {
                this.dt.addDependency(execStep, new ExecStep(ExecType.CONTROL_DEP, it.next(), null));
            }
        }
    }

    protected void execFailed(Set<String> set, Map<String, T> map, int i) {
        int size = set.size() - map.size();
        StringBuilder sb = new StringBuilder();
        sb.append("No variable are available for execution at step ").append(i).append(": ").append(size).append(" values remaining");
        HashSet hashSet = new HashSet();
        for (String str : set) {
            if (!map.containsKey(str)) {
                hashSet.add(str);
            }
        }
        if (size <= 10) {
            sb.append(". Missing variables: ");
            sb.append(hashSet);
        } else {
            sb.append(". First 10 missing variables: ");
            Iterator it = hashSet.iterator();
            for (int i2 = 0; i2 < 10 && it.hasNext(); i2++) {
                if (i2 > 0) {
                    sb.append(",");
                }
                sb.append((String) it.next());
            }
        }
        throw new IllegalStateException(sb.toString());
    }

    protected void updateDescendantDeps(ExecStep execStep, FrameIter frameIter) {
        ExecType type = execStep.getType();
        String name = execStep.getName();
        if (execStep.getType() == ExecType.OP) {
            Iterator<String> it = this.sameDiff.getOps().get(name).getOutputsOfOp().iterator();
            while (it.hasNext()) {
                Variable variable = this.sameDiff.getVariables().get(it.next());
                List<String> inputsForOp = variable.getInputsForOp();
                if (inputsForOp != null) {
                    for (String str : inputsForOp) {
                        if (this.subgraphOps.contains(str)) {
                            addDependenciesForOp(str, frameIter);
                        }
                    }
                }
                List<String> controlDepsForOp = variable.getControlDepsForOp();
                if (controlDepsForOp != null) {
                    for (String str2 : controlDepsForOp) {
                        if (this.subgraphOps.contains(str2)) {
                            addDependenciesForOp(str2, frameIter);
                        }
                    }
                }
            }
            return;
        }
        if (type == ExecType.VARIABLE || type == ExecType.CONSTANT || type == ExecType.PLACEHOLDER) {
            List<String> inputsForOp2 = this.sameDiff.getVariables().get(name).getInputsForOp();
            if (inputsForOp2 != null) {
                for (String str3 : inputsForOp2) {
                    if (this.subgraphOps.contains(str3)) {
                        addDependenciesForOp(str3, frameIter);
                    }
                }
                return;
            }
            return;
        }
        if (execStep.getType() != ExecType.SWITCH_L && execStep.getType() != ExecType.SWITCH_R) {
            throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + execStep);
        }
        List<String> outputsOfOp = this.sameDiff.getOps().get(name).getOutputsOfOp();
        List<String> inputsForOp3 = this.sameDiff.getVariables().get(execStep.getType() == ExecType.SWITCH_L ? outputsOfOp.get(0) : outputsOfOp.get(1)).getInputsForOp();
        if (inputsForOp3 != null) {
            for (String str4 : inputsForOp3) {
                if (this.subgraphOps.contains(str4)) {
                    addDependenciesForOp(str4, frameIter);
                }
            }
        }
    }

    protected void addDependenciesForOp(String str, FrameIter frameIter) {
        SameDiffOp sameDiffOp = this.sameDiff.getOps().get(str);
        List<String> inputsToOp = sameDiffOp.getInputsToOp();
        List<String> controlDeps = sameDiffOp.getControlDeps();
        List<String> varControlDeps = sameDiffOp.getVarControlDeps();
        ExecStep execStep = new ExecStep(ExecType.OP, str, frameIter);
        if ((sameDiffOp.getOp() instanceof NextIteration) || !this.dt.hasDependency(execStep)) {
            if (sameDiffOp.getOp() instanceof Merge) {
                Variable variable = this.sameDiff.getVariables().get(inputsToOp.get(0));
                Variable variable2 = this.sameDiff.getVariables().get(inputsToOp.get(1));
                this.dt.addOrDependency(execStep, getExecStepForVar(variable.getName(), frameIter), getExecStepForVar(variable2.getName(), frameIter));
            } else if (sameDiffOp.getOp() instanceof NextIteration) {
                FrameIter m1466clone = frameIter.m1466clone();
                m1466clone.setIteration(m1466clone.getIteration() + 1);
                execStep = new ExecStep(ExecType.OP, str, m1466clone);
                Iterator<String> it = inputsToOp.iterator();
                while (it.hasNext()) {
                    this.dt.addDependency(execStep, getExecStepForVar(it.next(), frameIter));
                }
            } else {
                Iterator<String> it2 = inputsToOp.iterator();
                while (it2.hasNext()) {
                    this.dt.addDependency(execStep, getExecStepForVar(it2.next(), frameIter));
                }
            }
            if (controlDeps != null) {
                Iterator<String> it3 = controlDeps.iterator();
                while (it3.hasNext()) {
                    this.dt.addDependency(execStep, getExecStepForVar(it3.next(), frameIter));
                }
            }
            if (varControlDeps != null) {
                for (String str2 : varControlDeps) {
                }
            }
        }
    }

    protected ExecStep getExecStepForVar(String str, FrameIter frameIter) {
        Variable variable = this.sameDiff.getVariables().get(str);
        VariableType variableType = variable.getVariable().getVariableType();
        if (variableType == VariableType.VARIABLE) {
            return new ExecStep(ExecType.VARIABLE, variable.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        if (variableType == VariableType.PLACEHOLDER) {
            return new ExecStep(ExecType.PLACEHOLDER, variable.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        if (variableType == VariableType.CONSTANT) {
            return new ExecStep(ExecType.CONSTANT, variable.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
        }
        String outputOfOp = variable.getOutputOfOp();
        SameDiffOp sameDiffOp = this.sameDiff.getOps().get(outputOfOp);
        if (sameDiffOp.getOp() instanceof Switch) {
            List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
            int indexOf = outputsOfOp.indexOf(variable.getName());
            if (indexOf == 0) {
                return new ExecStep(ExecType.SWITCH_L, outputOfOp, frameIter);
            }
            if (indexOf == 1) {
                return new ExecStep(ExecType.SWITCH_R, outputOfOp, frameIter);
            }
            throw new IllegalStateException("Expected variable \"" + variable.getName() + "\" to be an output of operation \"" + outputOfOp + "\", but op output variables are: " + outputsOfOp);
        }
        if (!(sameDiffOp.getOp() instanceof Enter) || !((Enter) sameDiffOp.getOp()).isConstant()) {
            return new ExecStep(ExecType.OP, outputOfOp, frameIter);
        }
        FrameIter m1466clone = frameIter.m1466clone();
        m1466clone.setIteration(0);
        String str2 = sameDiffOp.getInputsToOp().get(0);
        FrameIter parentFrame = m1466clone.getParentFrame();
        while (parentFrame != null) {
            Variable variable2 = this.sameDiff.getVariables().get(str2);
            if (variable2.getOutputOfOp() == null) {
                break;
            }
            SameDiffOp sameDiffOp2 = this.sameDiff.getOps().get(variable2.getOutputOfOp());
            if (!(sameDiffOp2.getOp() instanceof Enter) || !((Enter) sameDiffOp.getOp()).isConstant()) {
                break;
            }
            parentFrame.setIteration(0);
            parentFrame = parentFrame.getParentFrame();
            str2 = sameDiffOp2.getInputsToOp().get(0);
        }
        return new ExecStep(ExecType.OP, outputOfOp, m1466clone);
    }

    protected void initSubgraph(Set<String> set) {
        LinkedList linkedList = new LinkedList(set);
        while (!linkedList.isEmpty()) {
            String str = (String) linkedList.remove();
            String ownName = this.sameDiff.getVariableOutputOp(str) == null ? null : this.sameDiff.getVariableOutputOp(str).getOwnName();
            if (!this.subgraph.contains(str)) {
                String[] inputsForOp = ownName == null ? null : this.sameDiff.getInputsForOp(this.sameDiff.getOpById(ownName));
                List<String> controlDeps = this.sameDiff.getVariables().get(str).getControlDeps();
                int length = inputsForOp == null ? 0 : inputsForOp.length;
                if (controlDeps != null) {
                    length += controlDeps.size();
                }
                if (length == 0 && ownName != null) {
                    this.zeroInputOpsInSubgraph.add(ownName);
                }
                this.subgraph.add(str);
                if (ownName != null) {
                    this.subgraphOps.add(ownName);
                }
                if (controlDeps != null) {
                    for (String str2 : controlDeps) {
                        if (!this.subgraph.contains(str2)) {
                            linkedList.add(str2);
                        }
                    }
                }
            }
            if (ownName != null) {
                for (String str3 : this.sameDiff.getInputsForOp(this.sameDiff.getOpById(ownName))) {
                    if (!this.subgraph.contains(str3)) {
                        linkedList.add(str3);
                    }
                }
                List<String> controlDeps2 = this.sameDiff.getOps().get(ownName).getControlDeps();
                if (controlDeps2 != null) {
                    for (String str4 : controlDeps2) {
                        if (!this.subgraph.contains(str4)) {
                            linkedList.add(str4);
                        }
                    }
                }
            }
        }
    }

    protected Map<String, T> preprocessPlaceholders(Map<String, T> map, At at) {
        return map;
    }

    protected Map<String, T> postProcessOutput(Map<String, T> map) {
        return map;
    }

    public abstract T getConstantOrVariable(String str);

    public abstract O getAndParameterizeOp(String str, FrameIter frameIter, Set<VarId> set, Set<VarId> set2, Set<String> set3, Map<String, T> map, Set<String> set4);

    public abstract T[] getOutputs(O o, FrameIter frameIter, Set<VarId> set, Set<VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet, Set<String> set4);

    /* JADX INFO: Access modifiers changed from: protected */
    public static VarId lookup(String str, Collection<VarId> collection, Collection<VarId> collection2, boolean z) {
        VarId lookup = collection == null ? null : lookup(str, collection, false);
        if (lookup == null && collection2 != null) {
            lookup = lookup(str, collection2, false);
        }
        if (lookup == null && z) {
            throw new RuntimeException("Could not find VarId for input \"" + str + "\"");
        }
        return lookup;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static VarId lookup(String str, Collection<VarId> collection, boolean z) {
        for (VarId varId : collection) {
            if (varId.getVariable().equals(str)) {
                return varId;
            }
        }
        if (z) {
            throw new RuntimeException("Could not find VarId to input " + str);
        }
        return null;
    }

    public Map<VarId, T> getNodeOutputs() {
        return this.nodeOutputs;
    }

    public Map<VarId, List<T>> getTensorArrays() {
        return this.tensorArrays;
    }
}
