package org.nd4j.autodiff.execution;

import java.io.File;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.execution.GraphExecutioner;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.FlatResult;
import org.nd4j.graph.FlatVariable;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.ResultWrapperAbstraction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/execution/NativeGraphExecutioner.class */
public class NativeGraphExecutioner implements GraphExecutioner {
    private static final Logger log = LoggerFactory.getLogger(NativeGraphExecutioner.class);

    /* renamed from: org.nd4j.autodiff.execution.NativeGraphExecutioner$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/autodiff/execution/NativeGraphExecutioner$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$ops$Op$Type = new int[Op.Type.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_FLOAT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_SAME.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_STRICT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_BOOL.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_FLOAT.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_BOOL.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_SAME.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.INDEXREDUCE.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.CUSTOM.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    public GraphExecutioner.Type getExecutionerType() {
        return GraphExecutioner.Type.LOCAL;
    }

    public INDArray[] executeGraph(SameDiff sameDiff) {
        return executeGraph(sameDiff, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).build());
    }

    public INDArray[] reuseGraph(SameDiff sameDiff, Map<Integer, INDArray> map) {
        throw new UnsupportedOperationException();
    }

    public ByteBuffer convertToFlatBuffers(SameDiff sameDiff, ExecutorConfiguration executorConfiguration, Map<Integer, Node> map) {
        log.info("Configuration: {}", executorConfiguration);
        return sameDiff.asFlatBuffers(executorConfiguration);
    }

    public ByteBuffer convertToFlatBuffers(SameDiff sameDiff, ExecutorConfiguration executorConfiguration) {
        return convertToFlatBuffers(sameDiff, executorConfiguration, new HashMap());
    }

    public INDArray[] executeGraph(SameDiff sameDiff, ExecutorConfiguration executorConfiguration) {
        ByteBuffer convertToFlatBuffers = convertToFlatBuffers(sameDiff, executorConfiguration, new HashMap());
        Pointer bytePointer = new BytePointer(convertToFlatBuffers);
        log.info("Buffer length: {}", Integer.valueOf(convertToFlatBuffers.limit()));
        ResultWrapperAbstraction executeFlatGraph = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, bytePointer);
        if (executeFlatGraph == null) {
            throw new ND4JIllegalStateException("Graph execution failed");
        }
        FlatResult rootAsFlatResult = FlatResult.getRootAsFlatResult(new PagedPointer(executeFlatGraph.pointer(), executeFlatGraph.size()).asBytePointer().asByteBuffer());
        log.info("VarMap: {}", sameDiff.variableMap());
        INDArray[] iNDArrayArr = new INDArray[rootAsFlatResult.variablesLength()];
        for (int i = 0; i < rootAsFlatResult.variablesLength(); i++) {
            FlatVariable variables = rootAsFlatResult.variables(i);
            INDArray createFromFlatArray = Nd4j.createFromFlatArray(variables.ndarray());
            iNDArrayArr[i] = createFromFlatArray;
            if (variables.name() != null && sameDiff.variableMap().containsKey(variables.name())) {
                sameDiff.associateArrayWithVariable(createFromFlatArray, (SDVariable) sameDiff.variableMap().get(variables.name()));
            } else if (sameDiff.variableMap().get(variables.name()) != null) {
                sameDiff.associateArrayWithVariable(createFromFlatArray, sameDiff.getVariable(variables.name()));
            } else {
                log.warn("Unknown variable received: [{}]", variables.name());
            }
        }
        NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(executeFlatGraph);
        return iNDArrayArr;
    }

    public static long getOpNum(String str, Op.Type type) {
        if (type == Op.Type.CUSTOM) {
            return ((CustomOpDescriptor) Nd4j.getExecutioner().getCustomOperations().get(str.toLowerCase())).getHash();
        }
        try {
            return DifferentialFunctionClassHolder.getInstance().getInstance(str).opNum();
        } catch (Exception e) {
            throw new RuntimeException("Could not find op number for operation: [" + str + "]", e);
        }
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[type.ordinal()]) {
            case 1:
                return (byte) 10;
            case 2:
                return (byte) 12;
            case 3:
                return (byte) 0;
            case 4:
                return (byte) 1;
            case 5:
                return (byte) 3;
            case 6:
                return (byte) 2;
            case 7:
                return (byte) 5;
            case 8:
                return (byte) 8;
            case 9:
                return (byte) 6;
            case 10:
                return (byte) 9;
            case 11:
                return (byte) 21;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + type);
        }
    }

    public INDArray[] executeGraph(int i, SDVariable... sDVariableArr) {
        return new INDArray[0];
    }

    public int registerGraph(SameDiff sameDiff) {
        return 0;
    }

    public INDArray[] importProto(File file) {
        throw new UnsupportedOperationException("Not implemented yet");
    }
}
