package org.datavec.api.io.filters;

import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.writable.Writable;

/* loaded from: input_file:org/datavec/api/io/filters/BalancedPathFilter.class */
public class BalancedPathFilter extends RandomPathFilter {
    protected PathLabelGenerator labelGenerator;
    protected long maxLabels;
    protected long minPathsPerLabel;
    protected long maxPathsPerLabel;
    protected String[] labels;

    public BalancedPathFilter(Random random, String[] strArr, PathLabelGenerator pathLabelGenerator) {
        this(random, strArr, pathLabelGenerator, 0L, 0L, 0L, 0L, new String[0]);
    }

    public BalancedPathFilter(Random random, PathLabelGenerator pathLabelGenerator, long j) {
        this(random, null, pathLabelGenerator, 0L, 0L, 0L, j, new String[0]);
    }

    public BalancedPathFilter(Random random, String[] strArr, PathLabelGenerator pathLabelGenerator, long j) {
        this(random, strArr, pathLabelGenerator, 0L, 0L, 0L, j, new String[0]);
    }

    public BalancedPathFilter(Random random, PathLabelGenerator pathLabelGenerator, long j, long j2, long j3) {
        this(random, null, pathLabelGenerator, j, j2, 0L, j3, new String[0]);
    }

    public BalancedPathFilter(Random random, String[] strArr, PathLabelGenerator pathLabelGenerator, long j, long j2) {
        this(random, strArr, pathLabelGenerator, 0L, j, 0L, j2, new String[0]);
    }

    public BalancedPathFilter(Random random, String[] strArr, PathLabelGenerator pathLabelGenerator, long j, long j2, long j3, long j4, String... strArr2) {
        super(random, strArr, j);
        this.maxLabels = 0L;
        this.minPathsPerLabel = 0L;
        this.maxPathsPerLabel = 0L;
        this.labels = null;
        this.labelGenerator = pathLabelGenerator;
        this.maxLabels = j2;
        this.minPathsPerLabel = j3;
        this.maxPathsPerLabel = j4;
        this.labels = strArr2;
    }

    protected boolean acceptLabel(String str) {
        if (this.labels == null || this.labels.length == 0) {
            return true;
        }
        for (String str2 : this.labels) {
            if (str.equals(str2)) {
                return true;
            }
        }
        return false;
    }

    @Override // org.datavec.api.io.filters.RandomPathFilter, org.datavec.api.io.filters.PathFilter
    public URI[] filter(URI[] uriArr) {
        URI[] filter = super.filter(uriArr);
        if (this.labelGenerator == null) {
            this.labelGenerator = new ParentPathLabelGenerator();
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (URI uri : filter) {
            Writable labelForPath = this.labelGenerator.getLabelForPath(uri);
            if (acceptLabel(labelForPath.toString())) {
                List list = (List) linkedHashMap.get(labelForPath);
                if (list == null) {
                    if (this.maxLabels <= 0 || linkedHashMap.size() < this.maxLabels) {
                        ArrayList arrayList = new ArrayList();
                        list = arrayList;
                        linkedHashMap.put(labelForPath, arrayList);
                    }
                }
                list.add(uri);
            }
        }
        int min = this.maxPathsPerLabel > 0 ? (int) Math.min(this.maxPathsPerLabel, 2147483647L) : Integer.MAX_VALUE;
        for (List list2 : linkedHashMap.values()) {
            if (min > list2.size()) {
                min = list2.size();
            }
        }
        if (min < this.minPathsPerLabel) {
            min = (int) Math.min(this.minPathsPerLabel, 2147483647L);
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < min; i++) {
            for (List list3 : linkedHashMap.values()) {
                if (i < list3.size()) {
                    arrayList2.add(list3.get(i));
                }
            }
        }
        return (URI[]) arrayList2.toArray(new URI[arrayList2.size()]);
    }
}
