package org.nd4j.linalg.api.ops;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/linalg/api/ops/BaseIndexAccumulation.class */
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
    protected int finalResult;

    public BaseIndexAccumulation() {
    }

    public BaseIndexAccumulation(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        init();
    }

    public BaseIndexAccumulation(INDArray iNDArray, INDArray iNDArray2, long j) {
        this(iNDArray, iNDArray2, iNDArray, j);
    }

    public BaseIndexAccumulation(INDArray iNDArray) {
        this(iNDArray, null, iNDArray, iNDArray.lengthLong());
    }

    public BaseIndexAccumulation(INDArray iNDArray, INDArray iNDArray2) {
        this(iNDArray, iNDArray2, iNDArray, iNDArray.lengthLong());
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public double zeroDouble() {
        return 0.0d;
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public float zeroFloat() {
        return 0.0f;
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public Pair<Double, Integer> zeroPair() {
        return new Pair<>(Double.valueOf(zeroDouble()), -1);
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public IComplexNumber zeroComplex() {
        return Nd4j.createComplexNumber(Double.valueOf(0.0d), Double.valueOf(0.0d));
    }

    private void init() {
        init(this.x, this.y, this.x, this.x.lengthLong());
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void init(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super.init(iNDArray, iNDArray2, iNDArray3, j);
        if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            this.extraArgs = new Object[]{Double.valueOf(zeroDouble())};
        } else if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            this.extraArgs = new Object[]{Float.valueOf(zeroFloat())};
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            this.extraArgs = new Object[]{Float.valueOf(zeroHalf())};
        }
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public int combineSubResults(double d, int i, double d2, int i2) {
        return update(d, i, d2, i2);
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public int combineSubResults(float f, int i, float f2, int i2) {
        return update(f, i, f2, i2);
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public Pair<Double, Integer> combineSubResults(Pair<Double, Integer> pair, Pair<Double, Integer> pair2) {
        int intValue = ((Integer) pair.getSecond()).intValue();
        return update(((Double) pair.getFirst()).doubleValue(), intValue, ((Double) pair2.getFirst()).doubleValue(), ((Integer) pair2.getSecond()).intValue()) == intValue ? pair : pair2;
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public void setFinalResult(int i) {
        this.finalResult = i;
    }

    @Override // org.nd4j.linalg.api.ops.IndexAccumulation
    public int getFinalResult() {
        return this.finalResult;
    }
}
