package org.datavec.api.transform.transform.integer;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.IntegerMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"inputSchema", "columnIdx", "stateNames", "statesMap"})
/* loaded from: input_file:org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.class */
public class IntegerToOneHotTransform extends BaseTransform {
    private String columnName;
    private int minValue;
    private int maxValue;
    private int columnIdx = -1;

    public IntegerToOneHotTransform(@JsonProperty("columnName") String str, @JsonProperty("minValue") int i, @JsonProperty("maxValue") int i2) {
        this.columnName = str;
        this.minValue = i;
        this.maxValue = i2;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform, org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        super.setInputSchema(schema);
        this.columnIdx = schema.getIndexOfColumn(this.columnName);
        ColumnMetaData metaData = schema.getMetaData(this.columnName);
        if (!(metaData instanceof IntegerMetaData)) {
            throw new IllegalStateException("Cannot convert column \"" + this.columnName + "\" from integer to one-hot: column is not integer (is: " + metaData.getColumnType() + ")");
        }
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "CategoricalToOneHotTransform(columnName=\"" + this.columnName + "\")";
    }

    @Override // org.datavec.api.transform.Operation
    public Schema transform(Schema schema) {
        List<String> columnNames = schema.getColumnNames();
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        int i = 0;
        Iterator<ColumnMetaData> it = columnMetaData.iterator();
        ArrayList arrayList = new ArrayList(schema.numColumns());
        for (String str : columnNames) {
            ColumnMetaData next = it.next();
            int i2 = i;
            i++;
            if (i2 == this.columnIdx) {
                for (int i3 = this.minValue; i3 <= this.maxValue; i3++) {
                    arrayList.add(new IntegerMetaData(str + "[" + i3 + "]", 0, 1));
                }
            } else {
                arrayList.add(next);
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString());
        }
        int columnIdx = getColumnIdx();
        ArrayList arrayList = new ArrayList(list.size() + (this.maxValue - this.minValue) + 1);
        int i = 0;
        for (Writable writable : list) {
            int i2 = i;
            i++;
            if (i2 == columnIdx) {
                int i3 = writable.toInt();
                if (i3 < this.minValue || i3 > this.maxValue) {
                    throw new IllegalStateException("Invalid value: integer value (" + i3 + ") is outside of valid range: must be between " + this.minValue + " and " + this.maxValue + " inclusive");
                }
                for (int i4 = this.minValue; i4 <= this.maxValue; i4++) {
                    if (i4 == i3) {
                        arrayList.add(new IntWritable(1));
                    } else {
                        arrayList.add(new IntWritable(0));
                    }
                }
            } else {
                arrayList.add(writable);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        int intValue = ((Number) obj).intValue();
        if (intValue < this.minValue || intValue > this.maxValue) {
            throw new IllegalStateException("Invalid value: integer value (" + intValue + ") is outside of valid range: must be between " + this.minValue + " and " + this.maxValue + " inclusive");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = this.minValue; i <= this.maxValue; i++) {
            if (i == intValue) {
                arrayList.add(1);
            } else {
                arrayList.add(0);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        ArrayList arrayList = new ArrayList();
        Iterator it = ((List) obj).iterator();
        while (it.hasNext()) {
            arrayList.add((List) map(it.next()));
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        throw new UnsupportedOperationException("Output column name will be more than 1");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        List<String> columnNames = transform(this.inputSchema).getColumnNames();
        return (String[]) columnNames.toArray(new String[columnNames.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{this.columnName};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return this.columnName;
    }

    public String getColumnName() {
        return this.columnName;
    }

    public int getMinValue() {
        return this.minValue;
    }

    public int getMaxValue() {
        return this.maxValue;
    }

    public int getColumnIdx() {
        return this.columnIdx;
    }

    public void setColumnName(String str) {
        this.columnName = str;
    }

    public void setMinValue(int i) {
        this.minValue = i;
    }

    public void setMaxValue(int i) {
        this.maxValue = i;
    }

    public void setColumnIdx(int i) {
        this.columnIdx = i;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof IntegerToOneHotTransform)) {
            return false;
        }
        IntegerToOneHotTransform integerToOneHotTransform = (IntegerToOneHotTransform) obj;
        if (!integerToOneHotTransform.canEqual(this)) {
            return false;
        }
        String columnName = getColumnName();
        String columnName2 = integerToOneHotTransform.getColumnName();
        if (columnName == null) {
            if (columnName2 != null) {
                return false;
            }
        } else if (!columnName.equals(columnName2)) {
            return false;
        }
        return getMinValue() == integerToOneHotTransform.getMinValue() && getMaxValue() == integerToOneHotTransform.getMaxValue();
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    protected boolean canEqual(Object obj) {
        return obj instanceof IntegerToOneHotTransform;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public int hashCode() {
        String columnName = getColumnName();
        return (((((1 * 59) + (columnName == null ? 43 : columnName.hashCode())) * 59) + getMinValue()) * 59) + getMaxValue();
    }
}
