package org.datavec.api.util.ndarray;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.BooleanWritable;
import org.datavec.api.writable.ByteWritable;
import org.datavec.api.writable.BytesWritable;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.guava.base.Preconditions;

/* loaded from: input_file:org/datavec/api/util/ndarray/RecordConverter.class */
public class RecordConverter {
    private RecordConverter() {
    }

    @Deprecated
    public static INDArray toArray(Collection<Writable> collection, int i) {
        return toArray(collection);
    }

    public static List<List<Writable>> toRecords(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArray.rows(); i++) {
            arrayList.add(toRecord(iNDArray.getRow(i)));
        }
        return arrayList;
    }

    public static INDArray toTensor(List<List<List<Writable>>> list) {
        return (INDArray) TimeSeriesWritableUtils.convertWritablesSequence(list).getFirst();
    }

    public static INDArray toMatrix(List<List<Writable>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Writable>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(toArray(it.next()));
        }
        return Nd4j.vstack(arrayList);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v57, types: [java.util.List] */
    public static INDArray toArray(Collection<? extends Writable> collection) {
        ArrayList arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        if (arrayList.size() == 1 && (arrayList.get(0) instanceof NDArrayWritable)) {
            return ((NDArrayWritable) arrayList.get(0)).get();
        }
        int i = 0;
        for (Writable writable : collection) {
            if (writable instanceof NDArrayWritable) {
                INDArray iNDArray = ((NDArrayWritable) writable).get();
                if (!iNDArray.isRowVector()) {
                    throw new UnsupportedOperationException("Multiple writables present but NDArrayWritable is not a row vector. Can only concat row vectors with other writables. Shape: " + Arrays.toString(iNDArray.shape()));
                }
                i = (int) (i + iNDArray.length());
            } else {
                i++;
            }
        }
        INDArray create = Nd4j.create(1, i);
        int i2 = 0;
        for (Writable writable2 : collection) {
            if (writable2 instanceof NDArrayWritable) {
                INDArray iNDArray2 = ((NDArrayWritable) writable2).get();
                create.put(new INDArrayIndex[]{NDArrayIndex.point(0L), NDArrayIndex.interval(i2, i2 + iNDArray2.length())}, iNDArray2);
                i2 = (int) (i2 + iNDArray2.length());
            } else {
                create.putScalar(0L, i2, writable2.toDouble());
                i2++;
            }
        }
        return create;
    }

    public static INDArray toMinibatchArray(@NonNull List<? extends Writable> list) {
        if (list == null) {
            throw new NullPointerException("l is marked @NonNull but is null");
        }
        Preconditions.checkArgument(list.size() > 0, "Cannot convert empty list");
        if (list.size() == 1 && (list.get(0) instanceof NDArrayWritable)) {
            return ((NDArrayWritable) list.get(0)).get();
        }
        ArrayList arrayList = null;
        DoubleArrayList doubleArrayList = null;
        for (Writable writable : list) {
            if (writable instanceof NDArrayWritable) {
                INDArray iNDArray = ((NDArrayWritable) writable).get();
                if (iNDArray.size(0) != 1) {
                    throw new UnsupportedOperationException("NDArrayWritable must have leading dimension 1 for this method. Received array with shape: " + Arrays.toString(iNDArray.shape()));
                }
                if (arrayList == null) {
                    arrayList = new ArrayList();
                }
                arrayList.add(iNDArray);
            } else {
                if (doubleArrayList == null) {
                    doubleArrayList = new DoubleArrayList();
                }
                doubleArrayList.add(writable.toDouble());
            }
        }
        if (arrayList == null || doubleArrayList == null) {
            return arrayList != null ? Nd4j.concat(0, (INDArray[]) arrayList.toArray(new INDArray[arrayList.size()])) : Nd4j.create(doubleArrayList.toArray(new double[doubleArrayList.size()]), new long[]{doubleArrayList.size(), 1}, DataType.FLOAT);
        }
        throw new IllegalStateException("Error converting writables: found both NDArrayWritable and single value (DoubleWritable etc) in the one list. All writables must be NDArrayWritables or single value writables only for this method");
    }

    public static List<Writable> toRecord(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new NDArrayWritable(iNDArray));
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v30, types: [org.datavec.api.writable.NDArrayWritable] */
    /* JADX WARN: Type inference failed for: r0v31, types: [org.datavec.api.writable.BytesWritable] */
    /* JADX WARN: Type inference failed for: r0v32, types: [org.datavec.api.writable.NullWritable] */
    /* JADX WARN: Type inference failed for: r0v33, types: [org.datavec.api.writable.LongWritable] */
    /* JADX WARN: Type inference failed for: r0v34, types: [org.datavec.api.writable.BooleanWritable] */
    /* JADX WARN: Type inference failed for: r0v35, types: [org.datavec.api.writable.ByteWritable] */
    /* JADX WARN: Type inference failed for: r0v36, types: [org.datavec.api.writable.IntWritable] */
    /* JADX WARN: Type inference failed for: r0v37, types: [org.datavec.api.writable.DoubleWritable] */
    /* JADX WARN: Type inference failed for: r0v40, types: [org.datavec.api.writable.FloatWritable] */
    public static List<Writable> toRecord(Schema schema, List<Object> list) {
        Text text;
        ArrayList arrayList = new ArrayList(list.size());
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        if (columnMetaData.size() != list.size()) {
            throw new IllegalArgumentException("Schema and source list don't have the same length!");
        }
        for (int i = 0; i < columnMetaData.size(); i++) {
            ColumnMetaData columnMetaData2 = columnMetaData.get(i);
            Object obj = list.get(i);
            if (!columnMetaData2.isValid(obj)) {
                throw new IllegalArgumentException("Element " + i + ": " + obj + " is not valid for Column \"" + columnMetaData2.getName() + "\" (" + columnMetaData2.getColumnType() + ")");
            }
            try {
                switch (columnMetaData2.getColumnType().getWritableType()) {
                    case Float:
                        text = new FloatWritable(((Float) obj).floatValue());
                        break;
                    case Double:
                        text = new DoubleWritable(((Double) obj).doubleValue());
                        break;
                    case Int:
                        text = new IntWritable(((Integer) obj).intValue());
                        break;
                    case Byte:
                        text = new ByteWritable(((Byte) obj).byteValue());
                        break;
                    case Boolean:
                        text = new BooleanWritable(((Boolean) obj).booleanValue());
                        break;
                    case Long:
                        text = new LongWritable(((Long) obj).longValue());
                        break;
                    case Null:
                        text = new NullWritable();
                        break;
                    case Bytes:
                        text = new BytesWritable((byte[]) obj);
                        break;
                    case NDArray:
                        text = new NDArrayWritable((INDArray) obj);
                        break;
                    case Text:
                        if (!(obj instanceof String)) {
                            if (!(obj instanceof Text)) {
                                if (!(obj instanceof byte[])) {
                                    throw new IllegalArgumentException("Element " + i + ": " + obj + " is not usable for Column \"" + columnMetaData2.getName() + "\" (" + columnMetaData2.getColumnType() + ")");
                                }
                                text = new Text((byte[]) obj);
                                break;
                            } else {
                                text = new Text((Text) obj);
                                break;
                            }
                        } else {
                            text = new Text((String) obj);
                            break;
                        }
                    default:
                        throw new IllegalArgumentException("Element " + i + ": " + obj + " is not usable for Column \"" + columnMetaData2.getName() + "\" (" + columnMetaData2.getColumnType() + ")");
                }
                arrayList.add(text);
            } catch (ClassCastException e) {
                throw new IllegalArgumentException("Element " + i + ": " + obj + " is not usable for Column \"" + columnMetaData2.getName() + "\" (" + columnMetaData2.getColumnType() + ")", e);
            }
        }
        return arrayList;
    }

    public static List<List<Writable>> toRecords(DataSet dataSet) {
        return isClassificationDataSet(dataSet) ? getClassificationWritableMatrix(dataSet) : getRegressionWritableMatrix(dataSet);
    }

    private static boolean isClassificationDataSet(DataSet dataSet) {
        INDArray labels = dataSet.getLabels();
        return labels.sum(new int[]{0, 1}).getInt(new int[]{0}) == dataSet.numExamples() && labels.shape()[1] > 1;
    }

    private static List<List<Writable>> getClassificationWritableMatrix(DataSet dataSet) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dataSet.numExamples(); i++) {
            List<Writable> record = toRecord(dataSet.getFeatures().getRow(i, true));
            record.add(new IntWritable(Nd4j.argMax(dataSet.getLabels().getRow(i), new int[0]).getInt(new int[]{0})));
            arrayList.add(record);
        }
        return arrayList;
    }

    private static List<List<Writable>> getRegressionWritableMatrix(DataSet dataSet) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dataSet.numExamples(); i++) {
            List<Writable> record = toRecord(dataSet.getFeatures().getRow(i));
            INDArray row = dataSet.getLabels().getRow(i);
            for (int i2 = 0; i2 < row.shape()[1]; i2++) {
                record.add(new DoubleWritable(row.getDouble(i2)));
            }
            arrayList.add(record);
        }
        return arrayList;
    }
}
