package org.datavec.api.writable;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import lombok.NonNull;
import org.datavec.api.io.WritableComparable;
import org.datavec.api.util.ndarray.DataInputWrapperStream;
import org.datavec.api.util.ndarray.DataOutputWrapperStream;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;

/* loaded from: input_file:org/datavec/api/writable/NDArrayWritable.class */
public class NDArrayWritable extends ArrayWritable implements WritableComparable {
    public static final byte NDARRAY_SER_VERSION_HEADER_NULL = 0;
    public static final byte NDARRAY_SER_VERSION_HEADER = 1;
    private INDArray array = null;
    private Integer hash = null;

    public NDArrayWritable() {
    }

    public NDArrayWritable(INDArray iNDArray) {
        set(iNDArray);
    }

    @Override // org.datavec.api.writable.Writable
    public void readFields(DataInput dataInput) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new DataInputWrapperStream(dataInput));
        byte readByte = dataInputStream.readByte();
        if (readByte != 1 && readByte != 0) {
            throw new IllegalStateException("Unexpected NDArrayWritable version header - stream corrupt?");
        }
        if (readByte == 0) {
            this.array = null;
        } else {
            this.array = Nd4j.read(dataInputStream);
            this.hash = null;
        }
    }

    @Override // org.datavec.api.writable.Writable
    public void writeType(DataOutput dataOutput) throws IOException {
        dataOutput.writeShort(WritableType.NDArray.typeIdx());
    }

    @Override // org.datavec.api.writable.Writable
    public WritableType getType() {
        return WritableType.NDArray;
    }

    @Override // org.datavec.api.writable.Writable
    public void write(DataOutput dataOutput) throws IOException {
        if (this.array == null) {
            dataOutput.write(0);
            return;
        }
        INDArray dup = this.array.isView() ? this.array.dup() : this.array;
        dataOutput.write(1);
        Nd4j.write(dup, new DataOutputStream(new DataOutputWrapperStream(dataOutput)));
    }

    public void set(INDArray iNDArray) {
        this.array = iNDArray;
        this.hash = null;
    }

    public INDArray get() {
        return this.array;
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof NDArrayWritable)) {
            return false;
        }
        INDArray iNDArray = ((NDArrayWritable) obj).get();
        if (this.array == null && iNDArray != null) {
            return false;
        }
        if (this.array != null && iNDArray == null) {
            return false;
        }
        if (this.array == null) {
            return true;
        }
        return this.array.equalsWithEps(iNDArray, 0.0d);
    }

    public int hashCode() {
        if (this.hash != null) {
            return this.hash.intValue();
        }
        if (this.array == null) {
            this.hash = 0;
            return this.hash.intValue();
        }
        int hashCode = Arrays.hashCode(this.array.shape());
        long length = this.array.length();
        NdIndexIterator ndIndexIterator = new NdIndexIterator('c', this.array.shape());
        for (int i = 0; i < length; i++) {
            hashCode ^= MathUtils.hashCode(this.array.getDouble(ndIndexIterator.next()));
        }
        this.hash = Integer.valueOf(hashCode);
        return hashCode;
    }

    @Override // java.lang.Comparable
    public int compareTo(@NonNull Object obj) {
        if (obj == null) {
            throw new NullPointerException("o is marked non-null but is null");
        }
        NDArrayWritable nDArrayWritable = (NDArrayWritable) obj;
        if (this.array == null) {
            return nDArrayWritable.array == null ? 0 : -1;
        }
        if (nDArrayWritable.array == null) {
            return 1;
        }
        if (this.array.rank() != nDArrayWritable.array.rank()) {
            return Integer.compare(this.array.rank(), nDArrayWritable.array.rank());
        }
        if (this.array.length() != nDArrayWritable.array.length()) {
            return Long.compare(this.array.length(), nDArrayWritable.array.length());
        }
        for (int i = 0; i < this.array.rank(); i++) {
            if (Long.compare(this.array.size(i), nDArrayWritable.array.size(i)) != 0) {
                return Long.compare(this.array.size(i), nDArrayWritable.array.size(i));
            }
        }
        NdIndexIterator ndIndexIterator = new NdIndexIterator('c', this.array.shape());
        while (ndIndexIterator.hasNext()) {
            long[] next = ndIndexIterator.next();
            double d = this.array.getDouble(next);
            double d2 = nDArrayWritable.array.getDouble(next);
            if (Double.compare(d, d2) != 0) {
                return Double.compare(d, d2);
            }
        }
        return 0;
    }

    public String toString() {
        return this.array.toString();
    }

    @Override // org.datavec.api.writable.ArrayWritable
    public long length() {
        return this.array.data().length();
    }

    @Override // org.datavec.api.writable.ArrayWritable
    public double getDouble(long j) {
        return this.array.data().getDouble(j);
    }

    @Override // org.datavec.api.writable.ArrayWritable
    public float getFloat(long j) {
        return this.array.data().getFloat(j);
    }

    @Override // org.datavec.api.writable.ArrayWritable
    public int getInt(long j) {
        return this.array.data().getInt(j);
    }

    @Override // org.datavec.api.writable.ArrayWritable
    public long getLong(long j) {
        return (long) this.array.data().getDouble(j);
    }
}
