package org.nd4j.linalg.primitives;

import java.io.Serializable;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:org/nd4j/linalg/primitives/CounterMap.class */
public class CounterMap<F, S> implements Serializable {
    private static final long serialVersionUID = 119;
    protected Map<F, Counter<S>> maps = new ConcurrentHashMap();

    public boolean isEmpty() {
        return this.maps.isEmpty();
    }

    public boolean isEmpty(F f) {
        Counter<S> counter;
        if (isEmpty() || (counter = this.maps.get(f)) == null) {
            return true;
        }
        return counter.isEmpty();
    }

    public void incrementAll(CounterMap<F, S> counterMap) {
        for (Map.Entry<F, Counter<S>> entry : counterMap.maps.entrySet()) {
            F key = entry.getKey();
            for (Map.Entry<S, AtomicDouble> entry2 : entry.getValue().entrySet()) {
                incrementCount(key, entry2.getKey(), entry2.getValue().get());
            }
        }
    }

    public void incrementCount(F f, S s, double d) {
        Counter<S> counter = this.maps.get(f);
        if (counter == null) {
            counter = new Counter<>();
            this.maps.put(f, counter);
        }
        counter.incrementCount(s, d);
    }

    public double getCount(F f, S s) {
        Counter<S> counter = this.maps.get(f);
        if (counter == null) {
            return 0.0d;
        }
        return counter.getCount(s);
    }

    public double setCount(F f, S s, double d) {
        Counter<S> counter = this.maps.get(f);
        if (counter == null) {
            counter = new Counter<>();
            this.maps.put(f, counter);
        }
        return counter.setCount(s, d);
    }

    public Pair<F, S> argMax() {
        Double valueOf = Double.valueOf(-1.7976931348623157E308d);
        Pair<F, S> pair = null;
        for (Map.Entry<F, Counter<S>> entry : this.maps.entrySet()) {
            Counter<S> value = entry.getValue();
            S argMax = value.argMax();
            if (value.getCount(argMax) > valueOf.doubleValue() || pair == null) {
                pair = new Pair<>(entry.getKey(), argMax);
                valueOf = Double.valueOf(value.getCount(argMax));
            }
        }
        return pair;
    }

    public void clear() {
        this.maps.clear();
    }

    public void clear(F f) {
        Counter<S> counter = this.maps.get(f);
        if (counter != null) {
            counter.clear();
        }
    }

    public Set<F> keySet() {
        return this.maps.keySet();
    }

    public Counter<S> getCounter(F f) {
        return this.maps.get(f);
    }

    public Iterator<Pair<F, S>> getIterator() {
        return new Iterator<Pair<F, S>>() { // from class: org.nd4j.linalg.primitives.CounterMap.1
            Iterator<F> outerIt;
            Iterator<S> innerIt;
            F curKey;

            {
                this.outerIt = CounterMap.this.keySet().iterator();
            }

            private boolean hasInside() {
                if (this.innerIt != null && this.innerIt.hasNext()) {
                    return true;
                }
                if (!this.outerIt.hasNext()) {
                    return false;
                }
                this.curKey = this.outerIt.next();
                this.innerIt = CounterMap.this.getCounter(this.curKey).keySet().iterator();
                return true;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return hasInside();
            }

            @Override // java.util.Iterator
            public Pair<F, S> next() {
                hasInside();
                if (this.curKey == null) {
                    throw new RuntimeException("Outer element can't be null");
                }
                return Pair.makePair(this.curKey, this.innerIt.next());
            }

            @Override // java.util.Iterator
            public void remove() {
            }
        };
    }

    public int size() {
        return this.maps.size();
    }

    public int totalSize() {
        int i = 0;
        Iterator<F> it = keySet().iterator();
        while (it.hasNext()) {
            i += getCounter(it.next()).size();
        }
        return i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CounterMap)) {
            return false;
        }
        CounterMap counterMap = (CounterMap) obj;
        if (!counterMap.canEqual(this)) {
            return false;
        }
        Map<F, Counter<S>> map = this.maps;
        Map<F, Counter<S>> map2 = counterMap.maps;
        return map == null ? map2 == null : map.equals(map2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CounterMap;
    }

    public int hashCode() {
        Map<F, Counter<S>> map = this.maps;
        return (1 * 59) + (map == null ? 43 : map.hashCode());
    }
}
