package org.datavec.api.transform.transform.string;

import java.io.File;
import java.io.IOException;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"inputSchema", "map", "columnIdx"})
/* loaded from: input_file:org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.class */
public class StringListToCountsNDArrayTransform extends BaseTransform {
    protected final String columnName;
    protected final String newColumnName;
    protected final List<String> vocabulary;
    protected final String delimiter;
    protected final boolean binary;
    protected final boolean ignoreUnknown;
    protected final Map<String, Integer> map;
    protected int columnIdx;

    public StringListToCountsNDArrayTransform(String str, List<String> list, String str2, boolean z, boolean z2) {
        this(str, str + "[BOW]", list, str2, z, z2);
    }

    public StringListToCountsNDArrayTransform(@JsonProperty("columnName") String str, @JsonProperty("newColumnName") String str2, @JsonProperty("vocabulary") List<String> list, @JsonProperty("delimiter") String str3, @JsonProperty("binary") boolean z, @JsonProperty("ignoreUnknown") boolean z2) {
        this.columnIdx = -1;
        this.columnName = str;
        this.newColumnName = str2;
        this.vocabulary = list;
        this.delimiter = str3;
        this.binary = z;
        this.ignoreUnknown = z2;
        this.map = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            this.map.put(list.get(i), Integer.valueOf(i));
        }
    }

    public static List<String> readVocabFromFile(String str) throws IOException {
        return FileUtils.readLines(new File(str), "utf-8");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        int indexOfColumn = schema.getIndexOfColumn(this.columnName);
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        ArrayList arrayList = new ArrayList();
        List<String> columnNames = schema.getColumnNames();
        Iterator<String> it = columnNames.iterator();
        int i = 0;
        for (ColumnMetaData columnMetaData2 : columnMetaData) {
            it.next();
            int i2 = i;
            i++;
            if (i2 != indexOfColumn) {
                arrayList.add(columnMetaData2);
            } else {
                if (columnMetaData2.getColumnType() != ColumnType.String) {
                    throw new IllegalStateException("Cannot convert non-string type");
                }
                arrayList.add(new NDArrayMetaData(this.newColumnName, new long[]{this.vocabulary.size()}));
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform, org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        this.inputSchema = schema;
        this.columnIdx = schema.getIndexOfColumn(this.columnName);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "StringListToCountsTransform(columnName=" + this.columnName + ",vocabularySize=" + this.vocabulary.size() + ",delimiter=\"" + this.delimiter + "\")";
    }

    protected Collection<Integer> getIndices(String str) {
        AbstractCollection hashSet = this.binary ? new HashSet() : new ArrayList();
        if (str != null && !str.isEmpty()) {
            for (String str2 : str.split(this.delimiter)) {
                Integer num = this.map.get(str2);
                if (num == null && !this.ignoreUnknown) {
                    throw new IllegalStateException("Encountered unknown String: \"" + str2 + "\"");
                }
                if (num != null) {
                    hashSet.add(num);
                }
            }
        }
        return hashSet;
    }

    protected INDArray makeBOWNDArray(Collection<Integer> collection) {
        INDArray zeros = Nd4j.zeros(this.vocabulary.size());
        for (Integer num : collection) {
            zeros.putScalar(num.intValue(), zeros.getDouble(num.intValue()) + 1.0d);
        }
        Nd4j.getExecutioner().commit();
        return zeros;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString());
        }
        ArrayList arrayList = new ArrayList(list.size());
        int i = 0;
        for (Writable writable : list) {
            int i2 = i;
            i++;
            if (i2 == this.columnIdx) {
                arrayList.add(new NDArrayWritable(makeBOWNDArray(getIndices(writable.toString()))));
            } else {
                arrayList.add(writable);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        return null;
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        return null;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        throw new UnsupportedOperationException("New column names is always more than 1 in length");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return (String[]) this.vocabulary.toArray(new String[this.vocabulary.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return columnName();
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof StringListToCountsNDArrayTransform)) {
            return false;
        }
        StringListToCountsNDArrayTransform stringListToCountsNDArrayTransform = (StringListToCountsNDArrayTransform) obj;
        if (!stringListToCountsNDArrayTransform.canEqual(this)) {
            return false;
        }
        String columnName = getColumnName();
        String columnName2 = stringListToCountsNDArrayTransform.getColumnName();
        if (columnName == null) {
            if (columnName2 != null) {
                return false;
            }
        } else if (!columnName.equals(columnName2)) {
            return false;
        }
        String newColumnName = getNewColumnName();
        String newColumnName2 = stringListToCountsNDArrayTransform.getNewColumnName();
        if (newColumnName == null) {
            if (newColumnName2 != null) {
                return false;
            }
        } else if (!newColumnName.equals(newColumnName2)) {
            return false;
        }
        List<String> vocabulary = getVocabulary();
        List<String> vocabulary2 = stringListToCountsNDArrayTransform.getVocabulary();
        if (vocabulary == null) {
            if (vocabulary2 != null) {
                return false;
            }
        } else if (!vocabulary.equals(vocabulary2)) {
            return false;
        }
        String delimiter = getDelimiter();
        String delimiter2 = stringListToCountsNDArrayTransform.getDelimiter();
        if (delimiter == null) {
            if (delimiter2 != null) {
                return false;
            }
        } else if (!delimiter.equals(delimiter2)) {
            return false;
        }
        if (isBinary() != stringListToCountsNDArrayTransform.isBinary() || isIgnoreUnknown() != stringListToCountsNDArrayTransform.isIgnoreUnknown()) {
            return false;
        }
        Map<String, Integer> map = getMap();
        Map<String, Integer> map2 = stringListToCountsNDArrayTransform.getMap();
        return map == null ? map2 == null : map.equals(map2);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    protected boolean canEqual(Object obj) {
        return obj instanceof StringListToCountsNDArrayTransform;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public int hashCode() {
        String columnName = getColumnName();
        int hashCode = (1 * 59) + (columnName == null ? 43 : columnName.hashCode());
        String newColumnName = getNewColumnName();
        int hashCode2 = (hashCode * 59) + (newColumnName == null ? 43 : newColumnName.hashCode());
        List<String> vocabulary = getVocabulary();
        int hashCode3 = (hashCode2 * 59) + (vocabulary == null ? 43 : vocabulary.hashCode());
        String delimiter = getDelimiter();
        int hashCode4 = (((((hashCode3 * 59) + (delimiter == null ? 43 : delimiter.hashCode())) * 59) + (isBinary() ? 79 : 97)) * 59) + (isIgnoreUnknown() ? 79 : 97);
        Map<String, Integer> map = getMap();
        return (hashCode4 * 59) + (map == null ? 43 : map.hashCode());
    }

    public String getColumnName() {
        return this.columnName;
    }

    public String getNewColumnName() {
        return this.newColumnName;
    }

    public List<String> getVocabulary() {
        return this.vocabulary;
    }

    public String getDelimiter() {
        return this.delimiter;
    }

    public boolean isBinary() {
        return this.binary;
    }

    public boolean isIgnoreUnknown() {
        return this.ignoreUnknown;
    }

    public Map<String, Integer> getMap() {
        return this.map;
    }

    public int getColumnIdx() {
        return this.columnIdx;
    }

    public void setColumnIdx(int i) {
        this.columnIdx = i;
    }
}
