package org.nd4j.linalg.cpu.nativecpu.ops;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.compression.impl.AbstractCompressor;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.GradientOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.ShapeOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.accum.MatchCondition;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCpu;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.class */
public class NativeOpExecutioner extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(NativeOpExecutioner.class);
    private static final String DEBUG_ENABLED = "ND4J_DEBUG";
    private static final String VERBOSE = "ND4J_VERBOSE";
    private NativeOps loop = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    private CpuTADManager tadManager = new CpuTADManager();
    private ThreadLocal<Map<Integer, PointerPointer>> inputShapes = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> inputBuffers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> outputShapes = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> outputBuffers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, DoublePointer>> tArgsPointer = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, ShortPointer>> halfArgsPointer = new ThreadLocal<>();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, Pointer>> batchPointers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, AggregateMemoryBlock>> memoryBlocks = new ThreadLocal<>();

    /* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner$AggregateMemoryBlock.class */
    private static class AggregateMemoryBlock {
        private List<IntPointer> intArrays;
        private IntPointer indexingPointer;
        private Pointer realArgumentsPointer;
        private PointerPointer shapesPointer;
        private PointerPointer argumentsPointer;
        private PointerPointer arraysPointer;
        private final int opNum;

        private AggregateMemoryBlock(@NonNull Aggregate aggregate) {
            this.intArrays = new ArrayList();
            if (aggregate == null) {
                throw new NullPointerException("op");
            }
            this.opNum = aggregate.opNum();
            for (int i = 0; i < aggregate.maxIntArrays(); i++) {
                this.intArrays.add(new IntPointer(aggregate.maxIntArraySize()));
            }
            this.indexingPointer = new IntPointer(aggregate.maxIndexArguments());
            this.realArgumentsPointer = Nd4j.dataType() == DataBuffer.Type.DOUBLE ? new DoublePointer(aggregate.maxRealArguments()) : new FloatPointer(aggregate.maxRealArguments());
            this.shapesPointer = new PointerPointer(aggregate.maxShapes());
            this.argumentsPointer = new PointerPointer(aggregate.maxArguments());
            this.arraysPointer = new PointerPointer(aggregate.maxIntArrays());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && this.opNum == ((AggregateMemoryBlock) obj).opNum;
        }

        public int hashCode() {
            return this.opNum;
        }

        public List<IntPointer> getIntArrays() {
            return this.intArrays;
        }

        public IntPointer getIndexingPointer() {
            return this.indexingPointer;
        }

        public Pointer getRealArgumentsPointer() {
            return this.realArgumentsPointer;
        }

        public PointerPointer getShapesPointer() {
            return this.shapesPointer;
        }

        public PointerPointer getArgumentsPointer() {
            return this.argumentsPointer;
        }

        public PointerPointer getArraysPointer() {
            return this.arraysPointer;
        }

        public int getOpNum() {
            return this.opNum;
        }

        public void setIntArrays(List<IntPointer> list) {
            this.intArrays = list;
        }

        public void setIndexingPointer(IntPointer intPointer) {
            this.indexingPointer = intPointer;
        }

        public void setRealArgumentsPointer(Pointer pointer) {
            this.realArgumentsPointer = pointer;
        }

        public void setShapesPointer(PointerPointer pointerPointer) {
            this.shapesPointer = pointerPointer;
        }

        public void setArgumentsPointer(PointerPointer pointerPointer) {
            this.argumentsPointer = pointerPointer;
        }

        public void setArraysPointer(PointerPointer pointerPointer) {
            this.arraysPointer = pointerPointer;
        }

        public String toString() {
            return "NativeOpExecutioner.AggregateMemoryBlock(intArrays=" + getIntArrays() + ", indexingPointer=" + getIndexingPointer() + ", realArgumentsPointer=" + getRealArgumentsPointer() + ", shapesPointer=" + getShapesPointer() + ", argumentsPointer=" + getArgumentsPointer() + ", arraysPointer=" + getArraysPointer() + ", opNum=" + getOpNum() + ")";
        }
    }

    public NativeOpExecutioner() {
        this.tadManager.init(this.loop, this.constantHandler);
        if (System.getenv(DEBUG_ENABLED) != null) {
            try {
                this.loop.enableDebugMode(Boolean.parseBoolean(System.getenv(DEBUG_ENABLED)));
            } catch (Exception e) {
                log.error("Can't parse {}: [{}]", DEBUG_ENABLED, System.getenv(DEBUG_ENABLED));
            }
        }
        if (System.getenv(VERBOSE) != null) {
            try {
                this.loop.enableVerboseMode(Boolean.parseBoolean(System.getenv(VERBOSE)));
            } catch (Exception e2) {
                log.error("Can't parse {}: [{}]", VERBOSE, System.getenv(VERBOSE));
            }
        }
    }

    public Op exec(Op op) {
        checkForCompression(op);
        if (op instanceof ScalarOp) {
            exec((ScalarOp) op);
        } else if (op instanceof GradientOp) {
            op.exec();
        } else if (op instanceof TransformOp) {
            exec((TransformOp) op);
        } else if (op instanceof Accumulation) {
            exec((Accumulation) op);
        } else if (op instanceof IndexAccumulation) {
            exec((IndexAccumulation) op);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            exec(broadcastOp, broadcastOp.getDimension());
        } else if (op instanceof ShapeOp) {
            exec((ShapeOp) op);
        } else if (op instanceof RandomOp) {
            exec((RandomOp) op, Nd4j.getRandom());
        }
        return op;
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        if (iArr == null || iArr.length == 0) {
            iArr = new int[]{Nd4jCpu.MAX_DIMENSION};
        }
        checkForCompression(indexAccumulation);
        validateDataType(Nd4j.dataType(), indexAccumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        int[] normalizeAxis = Shape.normalizeAxis(indexAccumulation.x().rank(), iArr);
        for (int i = 0; i < normalizeAxis.length; i++) {
            if (normalizeAxis[i] < 0) {
                int i2 = i;
                normalizeAxis[i2] = normalizeAxis[i2] + indexAccumulation.x().rank();
            }
        }
        if (normalizeAxis.length == indexAccumulation.x().rank()) {
            normalizeAxis = new int[]{Nd4jCpu.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(normalizeAxis) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), normalizeAxis);
        if (removeIndex.length == 1) {
            removeIndex = normalizeAxis[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (indexAccumulation.z() == null || indexAccumulation.x() == indexAccumulation.z()) {
            indexAccumulation.setZ(indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE ? Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()) : Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroFloat()));
        } else if (!Arrays.equals(removeIndex, indexAccumulation.z().shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + indexAccumulation + ": expected shape " + Arrays.toString(removeIndex) + ", z.shape()=" + Arrays.toString(indexAccumulation.z().shape()));
        }
        if (normalizeAxis.length == indexAccumulation.x().rank()) {
            normalizeAxis = new int[]{Nd4jCpu.MAX_DIMENSION};
        }
        IntPointer addressPointer = this.constantHandler.getConstantBuffer(normalizeAxis).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), normalizeAxis);
        Pointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{addressPointer2, dataBuffer == null ? null : dataBuffer.addressPointer()});
        long profilingHookIn = profilingHookIn(indexAccumulation, new DataBuffer[]{(DataBuffer) tADOnlyShapeInfo.getFirst()});
        DoublePointer addressPointer3 = indexAccumulation.x().data().addressPointer();
        DoublePointer addressPointer4 = indexAccumulation.z().data().addressPointer();
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (indexAccumulation.z().isScalar()) {
                int execIndexReduceScalarDouble = (int) this.loop.execIndexReduceScalarDouble(put, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation));
                indexAccumulation.setFinalResult(execIndexReduceScalarDouble);
                indexAccumulation.z().putScalar(0, execIndexReduceScalarDouble);
            } else {
                this.loop.execIndexReduceDouble(put, indexAccumulation.opNum(), addressPointer3, indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation), addressPointer4, indexAccumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer, normalizeAxis.length);
            }
        } else if (indexAccumulation.z().isScalar()) {
            int execIndexReduceScalarFloat = (int) this.loop.execIndexReduceScalarFloat(put, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation));
            indexAccumulation.setFinalResult(execIndexReduceScalarFloat);
            indexAccumulation.z().putScalar(0, execIndexReduceScalarFloat);
        } else {
            this.loop.execIndexReduceFloat(put, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation), indexAccumulation.z().data().addressPointer(), indexAccumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer, normalizeAxis.length);
        }
        profilingHookOut(indexAccumulation, profilingHookIn);
        return indexAccumulation.z();
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        INDArray valueArrayOf;
        int[] normalizeAxis = Shape.normalizeAxis(accumulation.x().rank(), iArr);
        validateDataType(Nd4j.dataType(), accumulation);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        int[] maxShape = Shape.getMaxShape(new INDArray[]{accumulation.x(), accumulation.y()});
        for (int i = 0; i < normalizeAxis.length; i++) {
            if (normalizeAxis[i] >= maxShape.length && normalizeAxis[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + accumulation.x().rank() + "]");
            }
        }
        for (int i2 = 0; i2 < normalizeAxis.length; i2++) {
            if (normalizeAxis[i2] < 0) {
                int i3 = i2;
                normalizeAxis[i3] = normalizeAxis[i3] + accumulation.x().rank();
            }
        }
        if (normalizeAxis.length == accumulation.x().rank()) {
            normalizeAxis = new int[]{Nd4jCpu.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(normalizeAxis) ? new int[]{1, 1} : ArrayUtil.removeIndex(maxShape, normalizeAxis);
        if (removeIndex.length == 1) {
            removeIndex = normalizeAxis[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex) && ArrayUtil.prodLong(removeIndex) > 1 && accumulation.y() == null) {
            return accumulation.noOp();
        }
        if (accumulation.z() == null || accumulation.z() == accumulation.x()) {
            if (accumulation.isComplexAccumulation()) {
                valueArrayOf = Nd4j.create(accumulation.x().tensorssAlongDimension(normalizeAxis), accumulation.y().tensorssAlongDimension(normalizeAxis));
            } else {
                if (accumulation.y() != null) {
                    if (accumulation.x().lengthLong() != accumulation.y().lengthLong()) {
                        long lengthLong = accumulation.x().lengthLong() / accumulation.x().tensorssAlongDimension(normalizeAxis);
                        if (lengthLong != accumulation.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + lengthLong + ", y size = " + accumulation.y().lengthLong());
                        }
                    } else if (accumulation.x().tensorssAlongDimension(normalizeAxis) != accumulation.y().tensorssAlongDimension(normalizeAxis)) {
                        throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(accumulation.x().shape()) + ", y shape = " + Arrays.toString(accumulation.y().shape()) + ", dimension = " + Arrays.toString(normalizeAxis) + ")");
                    }
                }
                valueArrayOf = accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE ? Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()) : Nd4j.valueArrayOf(removeIndex, accumulation.zeroFloat());
            }
            accumulation.setZ(valueArrayOf);
        } else {
            if (!accumulation.isComplexAccumulation() && accumulation.z().lengthLong() != ArrayUtil.prodLong(removeIndex)) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(accumulation.z().shape()) + "] doesn't match expected [" + Arrays.toString(removeIndex) + "]");
            }
            if (accumulation.isComplexAccumulation()) {
                int tensorssAlongDimension = accumulation.x().tensorssAlongDimension(normalizeAxis);
                int tensorssAlongDimension2 = accumulation.y().tensorssAlongDimension(normalizeAxis);
                if (accumulation.z().lengthLong() != tensorssAlongDimension * tensorssAlongDimension2) {
                    throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(accumulation.z().shape()) + "] doesn't match expected [" + (tensorssAlongDimension * tensorssAlongDimension2) + "]");
                }
            }
            if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                accumulation.z().assign(Double.valueOf(accumulation.zeroDouble()));
            } else {
                accumulation.z().assign(Float.valueOf(accumulation.zeroFloat()));
            }
            valueArrayOf = accumulation.z();
        }
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(accumulation.x(), normalizeAxis);
        Pair<DataBuffer, DataBuffer> pair = null;
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        Pointer addressPointer2 = dataBuffer == null ? null : dataBuffer.addressPointer();
        boolean z = false;
        if (accumulation.y() != null && accumulation.x().tensorAlongDimension(0, normalizeAxis).lengthLong() == accumulation.y().lengthLong()) {
            z = true;
        }
        if (accumulation.isComplexAccumulation()) {
            pair = this.tadManager.getTADOnlyShapeInfo(accumulation.y(), normalizeAxis);
            if (accumulation.x().tensorAlongDimension(0, normalizeAxis).lengthLong() != accumulation.y().tensorAlongDimension(0, normalizeAxis).lengthLong()) {
                throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension");
            }
        }
        PointerPointer pointerPointer = this.extraz.get();
        Pointer[] pointerArr = new Pointer[3];
        pointerArr[0] = addressPointer;
        pointerArr[1] = addressPointer2;
        pointerArr[2] = z ? addressPointer2 : null;
        PointerPointer put = pointerPointer.put(pointerArr);
        profilingHookIn(accumulation, new DataBuffer[]{(DataBuffer) tADOnlyShapeInfo.getFirst()});
        IntPointer addressPointer3 = this.constantHandler.getConstantBuffer(normalizeAxis).addressPointer();
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), true));
                } else {
                    this.loop.execSummaryStatsDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length, ((Variance) accumulation).isBiasCorrected());
                }
            } else if (accumulation.y() == null || accumulation.getOpType() != Op.Type.REDUCE3) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execReduceScalarDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
                } else {
                    this.loop.execReduceDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length);
                }
            } else if (accumulation.isComplexAccumulation()) {
                this.loop.execReduce3AllDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length, ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer()), ((DataBuffer) pair.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) pair.getSecond()).addressPointer()));
            } else if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduce3ScalarDouble(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
            } else {
                this.loop.execReduce3Double(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length);
            }
        } else if (accumulation instanceof Variance) {
            Variance variance = (Variance) accumulation;
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), variance.isBiasCorrected()));
            } else {
                this.loop.execSummaryStatsFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length, variance.isBiasCorrected());
            }
        } else if (accumulation.y() == null || accumulation.getOpType() != Op.Type.REDUCE3) {
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduceScalarFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
            } else {
                this.loop.execReduceFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length);
            }
        } else if (accumulation.isComplexAccumulation()) {
            this.loop.execReduce3AllFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length, ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer()), ((DataBuffer) pair.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) pair.getSecond()).addressPointer()));
        } else if (valueArrayOf.isScalar()) {
            valueArrayOf.putScalar(0, this.loop.execReduce3ScalarFloat(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
        } else {
            this.loop.execReduce3Float(put, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer3, normalizeAxis.length);
        }
        return valueArrayOf;
    }

    private void invoke(ScalarOp scalarOp, int[] iArr) {
        int[] normalizeAxis = Shape.normalizeAxis(scalarOp.x().rank(), iArr);
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(scalarOp.x(), normalizeAxis);
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        Pointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = this.tadManager.getTADOnlyShapeInfo(scalarOp.z(), normalizeAxis);
        Pointer addressPointer3 = ((DataBuffer) tADOnlyShapeInfo2.getFirst()).addressPointer();
        Pointer addressPointer4 = ((DataBuffer) tADOnlyShapeInfo2.getSecond()).addressPointer();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{addressPointer, addressPointer2, addressPointer3, addressPointer4});
        if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
            this.loop.execScalarFloat(put, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.y().data().addressPointer(), getPointerForExtraArgs(scalarOp), Nd4j.getConstantHandler().getConstantBuffer(normalizeAxis).addressPointer(), normalizeAxis.length);
        } else if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execScalarDouble(put, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.y().data().addressPointer(), getPointerForExtraArgs(scalarOp), Nd4j.getConstantHandler().getConstantBuffer(normalizeAxis).addressPointer(), normalizeAxis.length);
        }
    }

    private void exec(ScalarOp scalarOp) {
        if ((scalarOp.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(scalarOp);
            return;
        }
        long profilingHookIn = profilingHookIn(scalarOp);
        validateDataType(Nd4j.dataType(), scalarOp);
        if (scalarOp.x().lengthLong() != scalarOp.z().lengthLong()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: [" + Arrays.toString(scalarOp.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(scalarOp.z().shapeInfoDataBuffer().asInt()) + "]");
        }
        if (scalarOp.getDimension() != null) {
            invoke(scalarOp, scalarOp.getDimension());
            return;
        }
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
                this.loop.execScalarDouble((PointerPointer) null, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.scalar().doubleValue(), getPointerForExtraArgs(scalarOp));
            } else {
                this.loop.execScalarDouble((PointerPointer) null, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().elementWiseStride(), scalarOp.z().data().addressPointer(), scalarOp.z().elementWiseStride(), scalarOp.scalar().doubleValue(), getPointerForExtraArgs(scalarOp), scalarOp.n());
            }
        } else if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
            this.loop.execScalarFloat((PointerPointer) null, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.scalar().floatValue(), getPointerForExtraArgs(scalarOp));
        } else {
            this.loop.execScalarFloat((PointerPointer) null, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().elementWiseStride(), scalarOp.z().data().addressPointer(), scalarOp.z().elementWiseStride(), scalarOp.scalar().floatValue(), getPointerForExtraArgs(scalarOp), scalarOp.n());
        }
        profilingHookOut(scalarOp, profilingHookIn);
    }

    private Pointer getPointerForExtraArgs(Op op) {
        if (op.extraArgs() != null) {
            return op.extraArgsDataBuff().addressPointer();
        }
        return null;
    }

    private void exec(TransformOp transformOp) {
        long profilingHookIn;
        validateDataType(Nd4j.dataType(), transformOp);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer pointerPointer = this.extraz.get();
        if (transformOp.opNum() == 7 && transformOp.y() != null && transformOp.y().isScalar()) {
            transformOp.setY(Nd4j.valueArrayOf(transformOp.x().shape(), transformOp.y().getDouble(0)));
        }
        if (transformOp.opNum() != 41 || transformOp.extraArgs() == null) {
            profilingHookIn = profilingHookIn(transformOp);
        } else {
            int[] iArr = new int[((Integer) transformOp.extraArgs()[0]).intValue()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = ((Integer) transformOp.extraArgs()[i + 1]).intValue();
            }
            Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(transformOp.z(), iArr);
            Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
            DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
            Pointer addressPointer2 = dataBuffer == null ? null : dataBuffer.addressPointer();
            pointerPointer.put(0L, addressPointer);
            pointerPointer.put(1L, addressPointer2);
            profilingHookIn = profilingHookIn(transformOp, new DataBuffer[]{(DataBuffer) tADOnlyShapeInfo.getFirst()});
        }
        if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (transformOp.y() != null) {
                int elementWiseStride = transformOp.x().elementWiseStride();
                int elementWiseStride2 = transformOp.y().elementWiseStride();
                int elementWiseStride3 = transformOp.z().elementWiseStride();
                boolean isRowVector = transformOp.x().isRowVector();
                boolean isRowVector2 = transformOp.y().isRowVector();
                boolean isRowVector3 = transformOp.z().isRowVector();
                if ((elementWiseStride < 1 || elementWiseStride2 < 1 || elementWiseStride != elementWiseStride2 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) && !(elementWiseStride >= 1 && elementWiseStride2 == elementWiseStride && elementWiseStride3 == elementWiseStride && isRowVector && isRowVector2 && isRowVector3)) {
                    this.loop.execPairwiseTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.y().data().addressPointer(), transformOp.y().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
                } else {
                    this.loop.execPairwiseTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), elementWiseStride, transformOp.y().data().addressPointer(), elementWiseStride2, transformOp.z().data().addressPointer(), elementWiseStride3, getPointerForExtraArgs(transformOp), transformOp.n());
                }
            } else if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
                this.loop.execTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
            } else {
                this.loop.execTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
            }
        } else if (transformOp.y() != null) {
            int elementWiseStride4 = transformOp.x().elementWiseStride();
            int elementWiseStride5 = transformOp.y().elementWiseStride();
            int elementWiseStride6 = transformOp.z().elementWiseStride();
            boolean isRowVector4 = transformOp.x().isRowVector();
            boolean isRowVector5 = transformOp.y().isRowVector();
            boolean isRowVector6 = transformOp.z().isRowVector();
            if ((elementWiseStride4 < 1 || elementWiseStride5 < 1 || elementWiseStride4 != elementWiseStride5 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) && !(elementWiseStride4 >= 1 && elementWiseStride5 == elementWiseStride4 && elementWiseStride6 == elementWiseStride4 && isRowVector4 && isRowVector5 && isRowVector6)) {
                this.loop.execPairwiseTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.y().data().addressPointer(), transformOp.y().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
            } else {
                this.loop.execPairwiseTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), elementWiseStride4, transformOp.y().data().addressPointer(), elementWiseStride5, transformOp.z().data().addressPointer(), elementWiseStride6, getPointerForExtraArgs(transformOp), transformOp.n());
            }
        } else if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
            this.loop.execTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
        } else {
            this.loop.execTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
        }
        profilingHookOut(transformOp, profilingHookIn);
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        profilingHookIn(broadcastOp);
        if (iArr == null) {
            iArr = new int[]{Nd4jCpu.MAX_DIMENSION};
        }
        int[] normalizeAxis = Shape.normalizeAxis(broadcastOp.x().rank(), iArr);
        validateDataType(Nd4j.dataType(), broadcastOp);
        for (int i = 0; i < normalizeAxis.length; i++) {
            if (normalizeAxis[i] >= broadcastOp.x().rank() && normalizeAxis[i] != Integer.MAX_VALUE) {
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(normalizeAxis) + " contains element that higher then rank of op.X: [" + broadcastOp.x().rank() + "]");
            }
        }
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(broadcastOp.x(), normalizeAxis);
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        Pointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = this.tadManager.getTADOnlyShapeInfo(broadcastOp.z(), normalizeAxis);
        Pointer addressPointer3 = ((DataBuffer) tADOnlyShapeInfo2.getFirst()).addressPointer();
        Pointer addressPointer4 = ((DataBuffer) tADOnlyShapeInfo2.getSecond()).addressPointer();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{addressPointer, addressPointer2, addressPointer3, addressPointer4});
        IntPointer addressPointer5 = this.constantHandler.getConstantBuffer(normalizeAxis).addressPointer();
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execBroadcastDouble(put, broadcastOp.opNum(), broadcastOp.x().data().addressPointer(), broadcastOp.x().shapeInfoDataBuffer().addressPointer(), broadcastOp.y().data().addressPointer(), broadcastOp.y().shapeInfoDataBuffer().addressPointer(), broadcastOp.z().data().addressPointer(), broadcastOp.z().shapeInfoDataBuffer().addressPointer(), addressPointer5, normalizeAxis.length);
        } else {
            this.loop.execBroadcastFloat(put, broadcastOp.opNum(), broadcastOp.x().data().addressPointer(), broadcastOp.x().shapeInfoDataBuffer().addressPointer(), broadcastOp.y().data().addressPointer(), broadcastOp.y().shapeInfoDataBuffer().addressPointer(), broadcastOp.z().data().addressPointer(), broadcastOp.z().shapeInfoDataBuffer().addressPointer(), addressPointer5, normalizeAxis.length);
        }
        return broadcastOp.z();
    }

    private void exec(IndexAccumulation indexAccumulation) {
        if ((indexAccumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(indexAccumulation);
            return;
        }
        if (indexAccumulation.z() == indexAccumulation.x() || indexAccumulation.z() == null) {
            indexAccumulation.setZ(Nd4j.scalar(0.0d));
        }
        long profilingHookIn = profilingHookIn(indexAccumulation);
        validateDataType(Nd4j.dataType(), indexAccumulation);
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarDouble((PointerPointer) null, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation)));
        } else {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarFloat((PointerPointer) null, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation)));
        }
        indexAccumulation.z().assign(Integer.valueOf(indexAccumulation.getFinalResult()));
        profilingHookOut(indexAccumulation, profilingHookIn);
    }

    private void exec(Accumulation accumulation) {
        if ((accumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(accumulation);
            return;
        }
        if (accumulation.isExecSpecial()) {
            accumulation.exec();
            return;
        }
        long profilingHookIn = profilingHookIn(accumulation);
        validateDataType(Nd4j.dataType(), accumulation);
        if (accumulation.z() == accumulation.x()) {
            accumulation.setZ(Nd4j.scalar(0.0d));
        }
        if (accumulation.y() != null && accumulation.getOpType() == Op.Type.REDUCE3 && accumulation.x().lengthLong() != accumulation.y().lengthLong()) {
            throw new ND4JIllegalStateException("X and Y operands should have equall lengths. X length: " + accumulation.x().lengthLong() + "; Y length: " + accumulation.y().lengthLong());
        }
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(this.loop.execSummaryStatsScalarDouble((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), true));
                accumulation.z().putScalar(0, accumulation.getFinalResult().doubleValue());
            } else if (accumulation.y() == null || accumulation.getOpType() != Op.Type.REDUCE3) {
                accumulation.setFinalResult(this.loop.execReduceScalarDouble((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
                accumulation.z().putScalar(0, accumulation.getFinalResult().doubleValue());
            } else {
                accumulation.setFinalResult(this.loop.execReduce3ScalarDouble((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
                accumulation.z().putScalar(0, accumulation.getFinalResult().doubleValue());
            }
        } else if (accumulation instanceof Variance) {
            accumulation.setFinalResult(this.loop.execSummaryStatsScalarFloat((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), ((Variance) accumulation).isBiasCorrected()));
            accumulation.z().putScalar(0, accumulation.getFinalResult().floatValue());
        } else if (accumulation.y() == null || accumulation.getOpType() != Op.Type.REDUCE3) {
            accumulation.setFinalResult(this.loop.execReduceScalarFloat((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
            accumulation.z().putScalar(0, accumulation.getFinalResult().floatValue());
        } else {
            accumulation.setFinalResult(this.loop.execReduce3ScalarFloat((PointerPointer) null, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
            accumulation.z().putScalar(0, accumulation.getFinalResult().floatValue());
        }
        profilingHookOut(accumulation, profilingHookIn);
    }

    protected <T extends Aggregate> Pointer getPointer(Batch<T> batch) {
        if (this.batchPointers.get() == null) {
            this.batchPointers.set(new HashMap());
        }
        if (this.batchPointers.get().containsKey(Integer.valueOf(batch.opNum()))) {
            return this.batchPointers.get().get(Integer.valueOf(batch.opNum()));
        }
        Pointer intPointer = new IntPointer(batch.getSample().getRequiredBatchMemorySize() / 4);
        this.batchPointers.get().put(Integer.valueOf(batch.opNum()), intPointer);
        return intPointer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        IntPointer pointer = getPointer(batch);
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxIntArraySize = batch.getSample().maxIntArraySize();
        int batchLimit = (((((5 * Batch.getBatchLimit()) + (batch.getSample().maxIndexArguments() * Batch.getBatchLimit())) + ((maxIntArrays * maxIntArraySize) * Batch.getBatchLimit())) / (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 2 : 1)) + (batch.getSample().maxRealArguments() * Batch.getBatchLimit())) / (Nd4j.dataType() == DataBuffer.Type.DOUBLE ? 1 : 2);
        int maxArguments = batchLimit + (batch.getSample().maxArguments() * Batch.getBatchLimit());
        for (int i = 0; i < batch.getNumAggregates(); i++) {
            Aggregate aggregate = (Aggregate) batch.getAggregates().get(i);
            pointer.put(i * 5, aggregate.getArguments().size());
            pointer.put(r0 + 1, aggregate.getShapes().size());
            pointer.put(r0 + 2, aggregate.getIndexingArguments().size());
            pointer.put(r0 + 3, aggregate.getRealArguments().size());
            pointer.put(r0 + 4, aggregate.getIntArrayArguments().size());
            for (int i2 = 0; i2 < aggregate.getIndexingArguments().size(); i2++) {
                pointer.put(r0 + (i * batch.getSample().maxIndexArguments()) + i2, ((Integer) aggregate.getIndexingArguments().get(i2)).intValue());
            }
            int i3 = maxIntArrays * maxIntArraySize;
            for (int i4 = 0; i4 < aggregate.getIntArrayArguments().size(); i4++) {
                int i5 = (i * i3) + (i4 * maxIntArraySize);
                if (aggregate.getIntArrayArguments().get(i4) != null) {
                    for (int i6 = 0; i6 < ((int[]) aggregate.getIntArrayArguments().get(i4)).length; i6++) {
                        pointer.put(r0 + i5 + i6, ((int[]) aggregate.getIntArrayArguments().get(i4))[i6]);
                    }
                }
            }
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                FloatPointer floatPointer = new FloatPointer(pointer);
                for (int i7 = 0; i7 < aggregate.getRealArguments().size(); i7++) {
                    floatPointer.put(r0 + (i * aggregate.maxRealArguments()) + i7, ((Number) aggregate.getRealArguments().get(i7)).floatValue());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                DoublePointer doublePointer = new DoublePointer(pointer);
                for (int i8 = 0; i8 < aggregate.getRealArguments().size(); i8++) {
                    doublePointer.put(r0 + (i * aggregate.maxRealArguments()) + i8, ((Number) aggregate.getRealArguments().get(i8)).doubleValue());
                }
            }
            if (this.extraz.get() == null) {
                this.extraz.set(new PointerPointer(32L));
            }
            PointerPointer pointerPointer = new PointerPointer(pointer);
            for (int i9 = 0; i9 < aggregate.getArguments().size(); i9++) {
                int maxArguments2 = batchLimit + (i * batch.getSample().maxArguments());
                if (aggregate.getArguments().get(i9) != null) {
                    pointerPointer.put(maxArguments2 + i9, ((INDArray) aggregate.getArguments().get(i9)).data().addressPointer());
                }
            }
            for (int i10 = 0; i10 < aggregate.getShapes().size(); i10++) {
                int maxShapes = maxArguments + (i * batch.getSample().maxShapes());
                if (aggregate.getShapes().get(i10) != null) {
                    pointerPointer.put(maxShapes + i10, ((DataBuffer) aggregate.getShapes().get(i10)).addressPointer());
                }
            }
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.loop.execAggregateBatchFloat((PointerPointer) null, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer);
        } else {
            if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
                throw new UnsupportedOperationException("Half precision isn't supported on CPU");
            }
            this.loop.execAggregateBatchDouble((PointerPointer) null, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer);
        }
    }

    public void exec(List<Aggregate> list) {
        if (list.size() == 0) {
            return;
        }
        Iterator it = Batch.getBatches(list).iterator();
        while (it.hasNext()) {
            exec((Batch) it.next());
        }
    }

    public void exec(Aggregate aggregate) {
        if (this.memoryBlocks.get() == null) {
            this.memoryBlocks.set(new HashMap());
        }
        if (this.memoryBlocks.get().get(Integer.valueOf(aggregate.opNum())) == null) {
            this.memoryBlocks.get().put(Integer.valueOf(aggregate.opNum()), new AggregateMemoryBlock(aggregate));
        }
        AggregateMemoryBlock aggregateMemoryBlock = this.memoryBlocks.get().get(Integer.valueOf(aggregate.opNum()));
        int size = aggregate.getArguments().size();
        int size2 = aggregate.getIndexingArguments().size();
        int size3 = aggregate.getRealArguments().size();
        int size4 = aggregate.getShapes().size();
        int size5 = aggregate.getIntArrayArguments().size();
        PointerPointer argumentsPointer = aggregateMemoryBlock.getArgumentsPointer();
        ArrayList arrayList = new ArrayList();
        PointerPointer arraysPointer = aggregateMemoryBlock.getArraysPointer();
        for (int i = 0; i < size; i++) {
            argumentsPointer.put(i, aggregate.getArguments().get(i) == null ? null : ((INDArray) aggregate.getArguments().get(i)).data().addressPointer());
        }
        PointerPointer shapesPointer = aggregateMemoryBlock.getShapesPointer();
        for (int i2 = 0; i2 < size4; i2++) {
            if (((DataBuffer) aggregate.getShapes().get(i2)).dataType() != DataBuffer.Type.INT) {
                throw new RuntimeException("ShapeBuffers should have INT data opType");
            }
            shapesPointer.put(i2, aggregate.getShapes().get(i2) == null ? null : ((DataBuffer) aggregate.getShapes().get(i2)).addressPointer());
        }
        IntPointer indexingPointer = aggregateMemoryBlock.getIndexingPointer();
        for (int i3 = 0; i3 < size2; i3++) {
            indexingPointer.put(i3, ((Integer) aggregate.getIndexingArguments().get(i3)).intValue());
        }
        double[] dArr = new double[size3];
        for (int i4 = 0; i4 < size3; i4++) {
            if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                aggregateMemoryBlock.getRealArgumentsPointer().put(i4, ((Number) aggregate.getRealArguments().get(i4)).floatValue());
            } else {
                aggregateMemoryBlock.getRealArgumentsPointer().put(i4, ((Number) aggregate.getRealArguments().get(i4)).doubleValue());
            }
        }
        for (int i5 = 0; i5 < size5; i5++) {
            IntPointer intPointer = aggregateMemoryBlock.getIntArrays().get(i5);
            intPointer.put((int[]) aggregate.getIntArrayArguments().get(i5), 0, ((int[]) aggregate.getIntArrayArguments().get(i5)).length);
            arraysPointer.put(i5, intPointer);
            arrayList.add(intPointer);
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.loop.execAggregateFloat((PointerPointer) null, aggregate.opNum(), argumentsPointer, size, shapesPointer, size4, indexingPointer, size2, arraysPointer, size5, aggregateMemoryBlock.getRealArgumentsPointer(), size3);
        } else {
            if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
                throw new UnsupportedOperationException("Half precision isn't supported on CPU");
            }
            this.loop.execAggregateDouble((PointerPointer) null, aggregate.opNum(), argumentsPointer, size, shapesPointer, size4, indexingPointer, size2, arraysPointer, size5, aggregateMemoryBlock.getRealArgumentsPointer(), size3);
        }
    }

    public Properties getEnvironmentInformation() {
        Properties environmentInformation = super.getEnvironmentInformation();
        environmentInformation.put("backend", "CPU");
        environmentInformation.put("omp.threads", Integer.valueOf(this.loop.ompGetMaxThreads()));
        environmentInformation.put("blas.threads", Integer.valueOf(Nd4j.factory().blas().getMaxThreads()));
        environmentInformation.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
        environmentInformation.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
        return environmentInformation;
    }

    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp randomOp, Random random) {
        if (random.getStateBuffer() == null) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution");
        }
        long profilingHookIn = profilingHookIn(randomOp);
        validateDataType(Nd4j.dataType(), randomOp);
        if (randomOp.x() == null || randomOp.y() == null || randomOp.z() == null) {
            if (randomOp.x() == null || randomOp.z() == null) {
                if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                    this.loop.execRandomFloat((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
                } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                    this.loop.execRandomDouble((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
                }
            } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
                this.loop.execRandomFloat((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.x().data().addressPointer(), randomOp.x().shapeInfoDataBuffer().addressPointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
            } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
                this.loop.execRandomDouble((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.x().data().addressPointer(), randomOp.x().shapeInfoDataBuffer().addressPointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.loop.execRandomFloat((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.x().data().addressPointer(), randomOp.x().shapeInfoDataBuffer().addressPointer(), randomOp.y().data().addressPointer(), randomOp.y().shapeInfoDataBuffer().addressPointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execRandomDouble((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), randomOp.x().data().addressPointer(), randomOp.x().shapeInfoDataBuffer().addressPointer(), randomOp.y().data().addressPointer(), randomOp.y().shapeInfoDataBuffer().addressPointer(), randomOp.z().data().addressPointer(), randomOp.z().shapeInfoDataBuffer().addressPointer(), randomOp.extraArgsDataBuff().addressPointer());
        }
        profilingHookOut(randomOp, profilingHookIn);
        return randomOp.z();
    }

    public TADManager getTADManager() {
        return this.tadManager;
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d) {
        return thresholdEncode(iNDArray, d, null);
    }

    public INDArray thresholdEncode(INDArray iNDArray, double d, Integer num) {
        int i = Nd4j.getExecutioner().exec(new MatchCondition(iNDArray, Conditions.absGreaterThanOrEqual(Double.valueOf(d))), new int[]{Nd4jCpu.MAX_DIMENSION}).getInt(new int[]{0});
        if (i < 2) {
            return null;
        }
        if (num != null) {
            i = Math.min(i, num.intValue());
        }
        DataBuffer data = iNDArray.data();
        long length = data.length() * Nd4j.sizeOfDataType(data.dataType());
        int i2 = i + 4;
        DataBuffer createInt = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(4 + i, false) : Nd4j.getDataBufferFactory().createInt(4 + i, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        createInt.put(0L, i);
        createInt.put(1L, (int) data.length());
        createInt.put(2L, Float.floatToIntBits((float) d));
        createInt.put(3L, 0);
        CompressionDescriptor compressionDescriptor = new CompressionDescriptor();
        compressionDescriptor.setCompressedLength(i2 * 4);
        compressionDescriptor.setOriginalLength(length);
        compressionDescriptor.setOriginalElementSize(Nd4j.sizeOfDataType(data.dataType()));
        compressionDescriptor.setNumberOfElements(data.length());
        compressionDescriptor.setCompressionAlgorithm("THRESHOLD");
        compressionDescriptor.setCompressionType(CompressionType.LOSSLESS);
        Nd4j.getNDArrayFactory().convertDataEx(AbstractCompressor.getBufferTypeEx(data), data.addressPointer(), DataBuffer.TypeEx.THRESHOLD, createInt.addressPointer(), data.length());
        Nd4j.getAffinityManager().tagLocation(data, AffinityManager.Location.HOST);
        return Nd4j.createArrayFromShapeBuffer(createInt, iNDArray.shapeInfoDataBuffer());
    }

    public INDArray thresholdDecode(INDArray iNDArray, INDArray iNDArray2) {
        DataBuffer data = iNDArray.data();
        if (data.dataType() != DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("thresholdEncoded array should have dataType of INT");
        }
        data.getInt(0L);
        long j = data.getInt(1L);
        data.getInt(2L);
        if (iNDArray2.lengthLong() != j) {
            throw new ND4JIllegalStateException("originalLength [" + j + "] stored in encoded array doesn't match target length [" + iNDArray2.lengthLong() + "]");
        }
        this.loop.convertTypes((PointerPointer) null, DataBuffer.TypeEx.THRESHOLD.ordinal(), data.addressPointer(), iNDArray2.length(), AbstractCompressor.getBufferTypeEx(iNDArray2.data()).ordinal(), iNDArray2.data().addressPointer());
        return iNDArray2;
    }

    public long bitmapEncode(INDArray iNDArray, INDArray iNDArray2, double d) {
        long encodeBitmapDouble;
        long lengthLong = iNDArray.lengthLong();
        if (iNDArray2.data().length() != (lengthLong / 16) + 5) {
            throw new ND4JIllegalStateException("Length of target array should be " + ((lengthLong / 16) + 5));
        }
        if (iNDArray2.data().dataType() != DataBuffer.Type.INT) {
            throw new ND4JIllegalStateException("Target array should have INT dataType");
        }
        DataBuffer data = iNDArray2.data();
        data.put(0L, (int) lengthLong);
        data.put(1L, (int) lengthLong);
        data.put(2L, Float.floatToIntBits((float) d));
        data.put(3L, 1);
        if (iNDArray.data().dataType() == DataBuffer.Type.FLOAT) {
            encodeBitmapDouble = this.loop.encodeBitmapFloat((PointerPointer) null, iNDArray.data().addressPointer(), lengthLong, data.addressPointer(), (float) d);
        } else {
            if (iNDArray.data().dataType() != DataBuffer.Type.DOUBLE) {
                throw new UnsupportedOperationException("HALF precision isn't supported on CPU yet");
            }
            encodeBitmapDouble = this.loop.encodeBitmapDouble((PointerPointer) null, iNDArray.data().addressPointer(), lengthLong, data.addressPointer(), (float) d);
        }
        return encodeBitmapDouble;
    }

    public INDArray bitmapDecode(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray2.data().dataType() == DataBuffer.Type.FLOAT) {
            this.loop.decodeBitmapFloat((PointerPointer) null, iNDArray.data().addressPointer(), iNDArray2.length(), iNDArray2.data().addressPointer());
        }
        return iNDArray2;
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String allCustomOps = this.loop.getAllCustomOps();
            if (allCustomOps == null || allCustomOps.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap hashMap = new HashMap();
            for (String str : allCustomOps.split(";")) {
                if (str != null && !str.isEmpty()) {
                    String[] split = str.split(":");
                    hashMap.put(split[0], CustomOpDescriptor.builder().hash(Long.valueOf(split[1]).longValue()).numInputs(Integer.valueOf(split[2]).intValue()).numOutputs(Integer.valueOf(split[3]).intValue()).allowsInplace(Integer.valueOf(split[4]).intValue() == 1).numTArgs(Integer.valueOf(split[5]).intValue()).numIArgs(Integer.valueOf(split[6]).intValue()).build());
                }
            }
            this.customOps = Collections.unmodifiableMap(hashMap);
        }
        return this.customOps;
    }

    private PointerPointer getPointerPointerFrom(ThreadLocal<Map<Integer, PointerPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new PointerPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        PointerPointer pointerPointer = new PointerPointer(i);
        threadLocal.get().put(Integer.valueOf(i), pointerPointer);
        return pointerPointer;
    }

    private ShortPointer getShortPointerFrom(ThreadLocal<Map<Integer, ShortPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new ShortPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        ShortPointer shortPointer = new ShortPointer(i);
        threadLocal.get().put(Integer.valueOf(i), shortPointer);
        return shortPointer;
    }

    private DoublePointer getDoublePointerFrom(ThreadLocal<Map<Integer, DoublePointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new DoublePointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        DoublePointer doublePointer = new DoublePointer(i);
        threadLocal.get().put(Integer.valueOf(i), doublePointer);
        return doublePointer;
    }

    private PointerPointer getInputShapes(int i) {
        return getPointerPointerFrom(this.inputShapes, i);
    }

    private PointerPointer getInputBuffers(int i) {
        return getPointerPointerFrom(this.inputBuffers, i);
    }

    private PointerPointer getOutputShapes(int i) {
        return getPointerPointerFrom(this.outputShapes, i);
    }

    private PointerPointer getOutputBuffers(int i) {
        return getPointerPointerFrom(this.outputBuffers, i);
    }

    public void exec(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op");
        }
        if (customOp.numOutputArguments() == 0 && !customOp.isInplaceCall()) {
            throw new ND4JIllegalStateException("Op name " + customOp.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
        }
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        PointerPointer inputShapes = getInputShapes(customOp.numInputArguments());
        PointerPointer inputBuffers = getInputBuffers(customOp.numInputArguments());
        int i = 0;
        for (INDArray iNDArray : customOp.inputArguments()) {
            if (iNDArray == null) {
                throw new NullPointerException("Input argument is null");
            }
            inputBuffers.put(i, iNDArray.data().addressPointer());
            int i2 = i;
            i++;
            inputShapes.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        INDArray[] outputArguments = customOp.outputArguments();
        for (INDArray iNDArray2 : outputArguments) {
            if (iNDArray2 == null) {
                throw new ND4JIllegalStateException("Op output arguments must not be null!");
            }
        }
        PointerPointer outputShapes = getOutputShapes(customOp.numOutputArguments());
        PointerPointer outputBuffers = getOutputBuffers(customOp.numOutputArguments());
        int i3 = 0;
        for (INDArray iNDArray3 : outputArguments) {
            outputBuffers.put(i3, iNDArray3.data().addressPointer());
            int i4 = i3;
            i3++;
            outputShapes.put(i4, iNDArray3.shapeInfoDataBuffer().addressPointer());
        }
        IntPointer intPointer = customOp.numIArguments() > 0 ? new IntPointer(customOp.numIArguments()) : null;
        int i5 = 0;
        for (int i6 : customOp.iArgs()) {
            int i7 = i5;
            i5++;
            intPointer.put(i7, i6);
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            FloatPointer floatPointer = customOp.numTArguments() > 0 ? new FloatPointer(customOp.numTArguments()) : null;
            int i8 = 0;
            for (double d : customOp.tArgs()) {
                int i9 = i8;
                i8++;
                floatPointer.put(i9, (float) d);
            }
            OpStatus byNumber = OpStatus.byNumber(this.loop.execCustomOpFloat((PointerPointer) null, opHash, inputBuffers, inputShapes, customOp.numInputArguments(), outputBuffers, outputShapes, customOp.numOutputArguments(), floatPointer, customOp.numTArguments(), intPointer, customOp.numIArguments(), customOp.isInplaceCall()));
            if (byNumber != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
            }
            return;
        }
        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
            if (Nd4j.dataType() == DataBuffer.Type.HALF) {
                ShortPointer shortPointerFrom = customOp.numTArguments() > 0 ? getShortPointerFrom(this.halfArgsPointer, customOp.numTArguments()) : null;
                int i10 = 0;
                for (double d2 : customOp.tArgs()) {
                    int i11 = i10;
                    i10++;
                    shortPointerFrom.put(i11, ArrayUtil.toHalf(d2));
                }
                OpStatus byNumber2 = OpStatus.byNumber(this.loop.execCustomOpHalf((PointerPointer) null, opHash, inputBuffers, inputShapes, customOp.numInputArguments(), outputBuffers, outputShapes, customOp.numOutputArguments(), shortPointerFrom, customOp.numTArguments(), intPointer, customOp.numIArguments(), customOp.isInplaceCall()));
                if (byNumber2 != OpStatus.ND4J_STATUS_OK) {
                    throw new ND4JIllegalStateException("Op execution failed: " + byNumber2);
                }
                return;
            }
            return;
        }
        DoublePointer doublePointerFrom = customOp.numTArguments() > 0 ? getDoublePointerFrom(this.tArgsPointer, customOp.numTArguments()) : null;
        int i12 = 0;
        for (double d3 : customOp.tArgs()) {
            int i13 = i12;
            i12++;
            doublePointerFrom.put(i13, d3);
        }
        customOp.numInputArguments();
        OpStatus opStatus = OpStatus.ND4J_STATUS_OK;
        try {
            OpStatus.byNumber(this.loop.execCustomOpDouble((PointerPointer) null, opHash, inputBuffers, inputShapes, customOp.numInputArguments(), outputBuffers, outputShapes, customOp.numOutputArguments(), doublePointerFrom, customOp.numTArguments(), intPointer, customOp.numIArguments(), customOp.isInplaceCall()));
        } catch (Exception e) {
            log.error("Failed to execute. Please see above message (printed out from c++) for a possible cause of error.");
            throw e;
        }
    }

    protected int[] getShapeFromPointer(IntPointer intPointer) {
        int i = intPointer.get(0L);
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = intPointer.get(i2 + 1);
        }
        return iArr;
    }

    public List<int[]> calculateOutputShape(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op");
        }
        String lowerCase = customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        ArrayList arrayList = new ArrayList();
        if (customOp.numInputArguments() < 1) {
            return Collections.emptyList();
        }
        PointerPointer pointerPointer = new PointerPointer(customOp.numInputArguments());
        PointerPointer pointerPointer2 = new PointerPointer(customOp.numInputArguments());
        int i = 0;
        for (INDArray iNDArray : customOp.inputArguments()) {
            pointerPointer.put(i, iNDArray.data().addressPointer());
            int i2 = i;
            i++;
            pointerPointer2.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        IntPointer intPointer = customOp.numIArguments() > 0 ? new IntPointer(customOp.numIArguments()) : null;
        int i3 = 0;
        for (int i4 : customOp.iArgs()) {
            int i5 = i3;
            i3++;
            intPointer.put(i5, i4);
        }
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            FloatPointer floatPointer = customOp.numTArguments() > 0 ? new FloatPointer(customOp.numTArguments()) : null;
            int i6 = 0;
            for (double d : customOp.tArgs()) {
                int i7 = i6;
                i6++;
                floatPointer.put(i7, (float) d);
            }
            Nd4jCpu.ShapeList shapeList = (Nd4jCpu.ShapeList) this.loop.calculateOutputShapesFloat((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.numInputArguments(), floatPointer, customOp.numTArguments(), intPointer, customOp.numIArguments());
            if (shapeList == null) {
                throw new RuntimeException();
            }
            for (int i8 = 0; i8 < shapeList.size(); i8++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList.at(i8)).asIntPointer()));
            }
            this.loop.deleteShapeList(shapeList);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            DoublePointer doublePointer = customOp.numTArguments() > 0 ? new DoublePointer(customOp.numTArguments()) : null;
            int i9 = 0;
            for (double d2 : customOp.tArgs()) {
                int i10 = i9;
                i9++;
                doublePointer.put(i10, d2);
            }
            Nd4jCpu.ShapeList shapeList2 = (Nd4jCpu.ShapeList) this.loop.calculateOutputShapesDouble((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.numInputArguments(), doublePointer, customOp.numTArguments(), intPointer, customOp.numIArguments());
            if (shapeList2 == null) {
                throw new RuntimeException();
            }
            for (int i11 = 0; i11 < shapeList2.size(); i11++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList2.at(i11)).asIntPointer()));
            }
            this.loop.deleteShapeList(shapeList2);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            ShortPointer shortPointer = customOp.numTArguments() > 0 ? new ShortPointer(customOp.numTArguments()) : null;
            int i12 = 0;
            for (double d3 : customOp.tArgs()) {
                int i13 = i12;
                i12++;
                shortPointer.put(i13, ArrayUtil.toHalf(d3));
            }
            Nd4jCpu.ShapeList shapeList3 = (Nd4jCpu.ShapeList) this.loop.calculateOutputShapesHalf((PointerPointer) null, opHash, pointerPointer, pointerPointer2, customOp.numInputArguments(), shortPointer, customOp.numTArguments(), intPointer, customOp.numIArguments());
            if (shapeList3 == null) {
                throw new RuntimeException();
            }
            getCustomOperations().get(lowerCase).getNumOutputs();
            for (int i14 = 0; i14 < shapeList3.size(); i14++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(shapeList3.at(i14)).asIntPointer()));
            }
            this.loop.deleteShapeList(shapeList3);
        }
        return arrayList;
    }

    public void enableDebugMode(boolean z) {
        this.loop.enableDebugMode(z);
    }

    public void enableVerboseMode(boolean z) {
        this.loop.enableVerboseMode(z);
    }

    public void registerGraph(long j, Pointer pointer) {
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.loop.registerGraphFloat((PointerPointer) null, j, pointer);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.registerGraphDouble((PointerPointer) null, j, pointer);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            this.loop.registerGraphHalf((PointerPointer) null, j, pointer);
        }
    }

    public Map<String, INDArray> executeGraph(long j, Map<String, INDArray> map) {
        PointerPointer pointerPointer = new PointerPointer(map.size());
        PointerPointer pointerPointer2 = new PointerPointer(map.size());
        IntPointer intPointer = new IntPointer(map.size());
        int i = 0;
        ArrayList arrayList = new ArrayList(map.keySet());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            INDArray iNDArray = map.get((String) it.next());
            pointerPointer.put(i, iNDArray.data().addressPointer());
            pointerPointer2.put(i, iNDArray.shapeInfoDataBuffer().addressPointer());
            intPointer.put(i, i);
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            Nd4jCpu.FloatVariablesSet floatVariablesSet = (Nd4jCpu.FloatVariablesSet) this.loop.executeStoredGraphFloat((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber = OpStatus.byNumber(floatVariablesSet.status());
            if (byNumber != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
            }
            for (int i2 = 0; i2 < floatVariablesSet.size(); i2++) {
                Nd4jCpu.FloatVariable at = floatVariablesSet.at(i2);
                int id = at.id();
                at.index();
                IntPointer shapeInfo = at.getNDArray().shapeInfo();
                FloatPointer buffer = at.getNDArray().buffer();
                int[] iArr = new int[(shapeInfo.get(0L) * 2) + 4];
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    iArr[i3] = shapeInfo.get(i3);
                }
                INDArray create = Nd4j.create(Shape.shapeOf(iArr), Shape.stridesOf(iArr), 0L, Shape.order(iArr));
                Pointer.memcpy(create.data().addressPointer(), buffer, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                linkedHashMap.put(arrayList.get(id), create);
            }
            this.loop.deleteVariablesSetFloat(floatVariablesSet);
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            Nd4jCpu.DoubleVariablesSet doubleVariablesSet = (Nd4jCpu.DoubleVariablesSet) this.loop.executeStoredGraphDouble((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber2 = OpStatus.byNumber(doubleVariablesSet.status());
            if (byNumber2 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber2);
            }
            for (int i4 = 0; i4 < doubleVariablesSet.size(); i4++) {
                Nd4jCpu.DoubleVariable at2 = doubleVariablesSet.at(i4);
                int id2 = at2.id();
                at2.index();
                IntPointer shapeInfo2 = at2.getNDArray().shapeInfo();
                DoublePointer buffer2 = at2.getNDArray().buffer();
                int[] iArr2 = new int[(shapeInfo2.get(0L) * 2) + 4];
                for (int i5 = 0; i5 < iArr2.length; i5++) {
                    iArr2[i5] = shapeInfo2.get(i5);
                }
                INDArray create2 = Nd4j.create(Shape.shapeOf(iArr2), Shape.stridesOf(iArr2), 0L, Shape.order(iArr2));
                Pointer.memcpy(create2.data().addressPointer(), buffer2, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                linkedHashMap.put(arrayList.get(id2), create2);
            }
            this.loop.deleteVariablesSetDouble(doubleVariablesSet);
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            Nd4jCpu.DoubleVariablesSet doubleVariablesSet2 = (Nd4jCpu.DoubleVariablesSet) this.loop.executeStoredGraphHalf((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
            OpStatus byNumber3 = OpStatus.byNumber(doubleVariablesSet2.status());
            if (byNumber3 != OpStatus.ND4J_STATUS_OK) {
                throw new ND4JIllegalStateException("Op execution failed: " + byNumber3);
            }
            for (int i6 = 0; i6 < doubleVariablesSet2.size(); i6++) {
                Nd4jCpu.DoubleVariable at3 = doubleVariablesSet2.at(i6);
                int id3 = at3.id();
                at3.index();
                IntPointer shapeInfo3 = at3.getNDArray().shapeInfo();
                DoublePointer buffer3 = at3.getNDArray().buffer();
                int[] iArr3 = new int[(shapeInfo3.get(0L) * 2) + 4];
                for (int i7 = 0; i7 < iArr3.length; i7++) {
                    iArr3[i7] = shapeInfo3.get(i7);
                }
                INDArray create3 = Nd4j.create(Shape.shapeOf(iArr3), Shape.stridesOf(iArr3), 0L, Shape.order(iArr3));
                Pointer.memcpy(create3.data().addressPointer(), buffer3, ArrayUtil.prod(r0) * Nd4j.sizeOfDataType());
                linkedHashMap.put(arrayList.get(id3), create3);
            }
            this.loop.deleteVariablesSetHalf(doubleVariablesSet2);
        }
        return linkedHashMap;
    }

    public void forgetGraph(long j) {
        this.loop.unregisterGraph((PointerPointer) null, j);
    }

    public void setElementsThreshold(int i) {
        this.loop.setElementThreshold(i);
    }

    public void setTadThreshold(int i) {
        this.loop.setTADThreshold(i);
    }
}
