package org.datavec.api.transform.sequence.expansion;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties({"inputSchema"})
/* loaded from: input_file:org/datavec/api/transform/sequence/expansion/BaseSequenceExpansionTransform.class */
public abstract class BaseSequenceExpansionTransform implements Transform {
    protected List<String> requiredColumns;
    protected List<String> expandedColumnNames;
    protected Schema inputSchema;

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseSequenceExpansionTransform(@NonNull List<String> list, @NonNull List<String> list2) {
        if (list == null) {
            throw new NullPointerException("requiredColumns is marked @NonNull but is null");
        }
        if (list2 == null) {
            throw new NullPointerException("expandedColumnNames is marked @NonNull but is null");
        }
        if (list.size() == 0) {
            throw new IllegalArgumentException("No columns have values to be expanded. Must have requiredColumns.size() > 0");
        }
        this.requiredColumns = list;
        this.expandedColumnNames = list2;
    }

    protected abstract List<ColumnMetaData> expandedColumnMetaDatas(List<ColumnMetaData> list, List<String> list2);

    protected abstract List<List<Writable>> expandTimeStep(List<Writable> list);

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        ArrayList arrayList = new ArrayList(schema.numColumns());
        ArrayList arrayList2 = new ArrayList();
        Iterator<String> it = this.requiredColumns.iterator();
        while (it.hasNext()) {
            arrayList2.add(schema.getMetaData(it.next()));
        }
        List<ColumnMetaData> expandedColumnMetaDatas = expandedColumnMetaDatas(arrayList2, this.expandedColumnNames);
        int i = 0;
        for (ColumnMetaData columnMetaData : schema.getColumnMetaData()) {
            if (this.requiredColumns.contains(columnMetaData.getName())) {
                int i2 = i;
                i++;
                arrayList.add(expandedColumnMetaDatas.get(i2));
            } else {
                arrayList.add(columnMetaData);
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return this.expandedColumnNames.get(0);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return (String[]) this.expandedColumnNames.toArray(new String[this.requiredColumns.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return (String[]) this.requiredColumns.toArray(new String[this.requiredColumns.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return columnNames()[0];
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        throw new UnsupportedOperationException("Cannot perform sequence expansion on non-sequence data");
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        int numColumns = this.inputSchema.numColumns();
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        int[] iArr = new int[this.requiredColumns.size()];
        int i = 0;
        Iterator<String> it = this.requiredColumns.iterator();
        while (it.hasNext()) {
            int indexOfColumn = this.inputSchema.getIndexOfColumn(it.next());
            hashMap.put(Integer.valueOf(indexOfColumn), Integer.valueOf(i));
            int i2 = i;
            i++;
            iArr[i2] = indexOfColumn;
        }
        ArrayList arrayList2 = new ArrayList(this.requiredColumns.size());
        for (List<Writable> list2 : list) {
            arrayList2.clear();
            for (int i3 : iArr) {
                arrayList2.add(list2.get(i3));
            }
            List<List<Writable>> expandTimeStep = expandTimeStep(arrayList2);
            int size = expandTimeStep.size();
            for (int i4 = 0; i4 < size; i4++) {
                ArrayList arrayList3 = new ArrayList(numColumns);
                for (int i5 = 0; i5 < numColumns; i5++) {
                    if (hashMap.containsKey(Integer.valueOf(i5))) {
                        arrayList3.add(expandTimeStep.get(i4).get(((Integer) hashMap.get(Integer.valueOf(i5))).intValue()));
                    } else {
                        arrayList3.add(list2.get(i5));
                    }
                }
                arrayList.add(arrayList3);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BaseSequenceExpansionTransform)) {
            return false;
        }
        BaseSequenceExpansionTransform baseSequenceExpansionTransform = (BaseSequenceExpansionTransform) obj;
        if (!baseSequenceExpansionTransform.canEqual(this)) {
            return false;
        }
        List<String> requiredColumns = getRequiredColumns();
        List<String> requiredColumns2 = baseSequenceExpansionTransform.getRequiredColumns();
        if (requiredColumns == null) {
            if (requiredColumns2 != null) {
                return false;
            }
        } else if (!requiredColumns.equals(requiredColumns2)) {
            return false;
        }
        List<String> expandedColumnNames = getExpandedColumnNames();
        List<String> expandedColumnNames2 = baseSequenceExpansionTransform.getExpandedColumnNames();
        return expandedColumnNames == null ? expandedColumnNames2 == null : expandedColumnNames.equals(expandedColumnNames2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BaseSequenceExpansionTransform;
    }

    public int hashCode() {
        List<String> requiredColumns = getRequiredColumns();
        int hashCode = (1 * 59) + (requiredColumns == null ? 43 : requiredColumns.hashCode());
        List<String> expandedColumnNames = getExpandedColumnNames();
        return (hashCode * 59) + (expandedColumnNames == null ? 43 : expandedColumnNames.hashCode());
    }

    public List<String> getRequiredColumns() {
        return this.requiredColumns;
    }

    public List<String> getExpandedColumnNames() {
        return this.expandedColumnNames;
    }

    public void setRequiredColumns(List<String> list) {
        this.requiredColumns = list;
    }

    public void setExpandedColumnNames(List<String> list) {
        this.expandedColumnNames = list;
    }

    public String toString() {
        return "BaseSequenceExpansionTransform(requiredColumns=" + getRequiredColumns() + ", expandedColumnNames=" + getExpandedColumnNames() + ", inputSchema=" + getInputSchema() + ")";
    }

    @Override // org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        this.inputSchema = schema;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema getInputSchema() {
        return this.inputSchema;
    }
}
