package org.deeplearning4j.plot;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne.class */
public class BarnesHutTsne implements Model {
    public static final String workspaceCache = "LOOP_CACHE";
    public static final String workspaceExternal = "LOOP_EXTERNAL";
    protected int maxIter;
    protected double realMin;
    protected double initialMomentum;
    protected double finalMomentum;
    protected double minGain;
    protected double momentum;
    protected int switchMomentumIteration;
    protected boolean normalize;
    protected boolean usePca;
    protected int stopLyingIteration;
    protected double tolerance;
    protected double learningRate;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad;
    protected double perplexity;
    protected INDArray Y;
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction;
    private boolean invert;
    private INDArray x;
    private int numDimensions;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;
    private int vpTreeWorkers;
    protected transient TrainingListener trainingListener;
    protected WorkspaceMode workspaceMode;
    private Initializer initializer;
    protected WorkspaceConfiguration workspaceConfigurationFeedForward;
    private static final Logger log = LoggerFactory.getLogger(BarnesHutTsne.class);
    protected static final WorkspaceConfiguration workspaceConfigurationExternal = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(0.3d).policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.BLOCK_LEFT).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
    public static final WorkspaceConfiguration workspaceConfigurationCache = WorkspaceConfiguration.builder().overallocationLimit(0.2d).policyReset(ResetPolicy.BLOCK_LEFT).cyclesBeforeInitialization(3).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.REALLOCATE).policyLearning(LearningPolicy.OVER_TIME).build();

    /* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne$Builder.class */
    public static class Builder {
        private int maxIter = 1000;
        private double realMin = 9.999999960041972E-13d;
        private double initialMomentum = 0.5d;
        private double finalMomentum = 0.800000011920929d;
        private double momentum = 0.5d;
        private int switchMomentumIteration = 100;
        private boolean normalize = true;
        private int stopLyingIteration = 100;
        private double tolerance = 9.999999747378752E-6d;
        private double learningRate = 0.10000000149011612d;
        private boolean useAdaGrad = false;
        private double perplexity = 30.0d;
        private double minGain = 0.009999999776482582d;
        private double theta = 0.5d;
        private boolean invert = true;
        private int numDim = 2;
        private String similarityFunction = Distance.EUCLIDEAN.toString();
        private int vpTreeWorkers = 1;
        protected WorkspaceMode workspaceMode = WorkspaceMode.NONE;
        private INDArray staticInput;

        public Builder vpTreeWorkers(int i) {
            this.vpTreeWorkers = i;
            return this;
        }

        public Builder staticInit(INDArray iNDArray) {
            this.staticInput = iNDArray;
            return this;
        }

        public Builder minGain(double d) {
            this.minGain = d;
            return this;
        }

        public Builder perplexity(double d) {
            this.perplexity = d;
            return this;
        }

        public Builder useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder tolerance(double d) {
            this.tolerance = d;
            return this;
        }

        public Builder stopLyingIteration(int i) {
            this.stopLyingIteration = i;
            return this;
        }

        public Builder normalize(boolean z) {
            this.normalize = z;
            return this;
        }

        public Builder setMaxIter(int i) {
            this.maxIter = i;
            return this;
        }

        public Builder setRealMin(double d) {
            this.realMin = d;
            return this;
        }

        public Builder setInitialMomentum(double d) {
            this.initialMomentum = d;
            return this;
        }

        public Builder setFinalMomentum(double d) {
            this.finalMomentum = d;
            return this;
        }

        public Builder setMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder setSwitchMomentumIteration(int i) {
            this.switchMomentumIteration = i;
            return this;
        }

        public Builder similarityFunction(String str) {
            this.similarityFunction = str;
            return this;
        }

        public Builder invertDistanceMetric(boolean z) {
            this.invert = z;
            return this;
        }

        public Builder theta(double d) {
            this.theta = d;
            return this;
        }

        public Builder numDimension(int i) {
            this.numDim = i;
            return this;
        }

        public Builder workspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public BarnesHutTsne build() {
            return new BarnesHutTsne(this.numDim, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, null, this.minGain, this.vpTreeWorkers, this.workspaceMode, this.staticInput);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne$Initializer.class */
    public class Initializer {
        private INDArray staticData;

        public Initializer() {
        }

        public Initializer(INDArray iNDArray) {
            this.staticData = iNDArray;
        }

        public INDArray initData() {
            return this.staticData != null ? this.staticData.dup() : Nd4j.randn(BarnesHutTsne.this.x.dataType(), new long[]{BarnesHutTsne.this.x.rows(), BarnesHutTsne.this.numDimensions}).muli(Float.valueOf(0.001f));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/deeplearning4j/plot/BarnesHutTsne$SymResult.class */
    public static class SymResult {
        INDArray rows;
        INDArray cols;
        INDArray vals;

        public INDArray getRows() {
            return this.rows;
        }

        public INDArray getCols() {
            return this.cols;
        }

        public INDArray getVals() {
            return this.vals;
        }

        public void setRows(INDArray iNDArray) {
            this.rows = iNDArray;
        }

        public void setCols(INDArray iNDArray) {
            this.cols = iNDArray;
        }

        public void setVals(INDArray iNDArray) {
            this.vals = iNDArray;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof SymResult)) {
                return false;
            }
            SymResult symResult = (SymResult) obj;
            if (!symResult.canEqual(this)) {
                return false;
            }
            INDArray rows = getRows();
            INDArray rows2 = symResult.getRows();
            if (rows == null) {
                if (rows2 != null) {
                    return false;
                }
            } else if (!rows.equals(rows2)) {
                return false;
            }
            INDArray cols = getCols();
            INDArray cols2 = symResult.getCols();
            if (cols == null) {
                if (cols2 != null) {
                    return false;
                }
            } else if (!cols.equals(cols2)) {
                return false;
            }
            INDArray vals = getVals();
            INDArray vals2 = symResult.getVals();
            return vals == null ? vals2 == null : vals.equals(vals2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof SymResult;
        }

        public int hashCode() {
            INDArray rows = getRows();
            int hashCode = (1 * 59) + (rows == null ? 43 : rows.hashCode());
            INDArray cols = getCols();
            int hashCode2 = (hashCode * 59) + (cols == null ? 43 : cols.hashCode());
            INDArray vals = getVals();
            return (hashCode2 * 59) + (vals == null ? 43 : vals.hashCode());
        }

        public String toString() {
            return "BarnesHutTsne.SymResult(rows=" + getRows() + ", cols=" + getCols() + ", vals=" + getVals() + ")";
        }

        public SymResult(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
            this.rows = iNDArray;
            this.cols = iNDArray2;
            this.vals = iNDArray3;
        }
    }

    public BarnesHutTsne(int i, String str, double d, boolean z, int i2, double d2, double d3, double d4, double d5, int i3, boolean z2, int i4, double d6, double d7, boolean z3, double d8, TrainingListener trainingListener, double d9, int i5) {
        this(i, str, d, z, i2, d2, d3, d4, d5, i3, z2, i4, d6, d7, z3, d8, trainingListener, d9, i5, WorkspaceMode.NONE, null);
    }

    public BarnesHutTsne(int i, String str, double d, boolean z, int i2, double d2, double d3, double d4, double d5, int i3, boolean z2, int i4, double d6, double d7, boolean z3, double d8, TrainingListener trainingListener, double d9, int i5, WorkspaceMode workspaceMode, INDArray iNDArray) {
        this.maxIter = 1000;
        this.realMin = Nd4j.EPS_THRESHOLD;
        this.initialMomentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.minGain = 0.01d;
        this.momentum = this.initialMomentum;
        this.switchMomentumIteration = 250;
        this.normalize = true;
        this.usePca = false;
        this.stopLyingIteration = 250;
        this.tolerance = 1.0E-5d;
        this.learningRate = 500.0d;
        this.useAdaGrad = true;
        this.perplexity = 30.0d;
        this.simiarlityFunction = "cosinesimilarity";
        this.invert = true;
        this.numDimensions = 0;
        this.workspaceConfigurationFeedForward = WorkspaceConfiguration.builder().initialSize(0L).overallocationLimit(0.2d).policyReset(ResetPolicy.BLOCK_LEFT).policyLearning(LearningPolicy.OVER_TIME).policySpill(SpillPolicy.REALLOCATE).policyAllocation(AllocationPolicy.OVERALLOCATE).build();
        this.maxIter = i2;
        this.realMin = d2;
        this.initialMomentum = d3;
        this.finalMomentum = d4;
        this.momentum = d5;
        this.normalize = z2;
        this.useAdaGrad = z3;
        this.stopLyingIteration = i4;
        this.learningRate = d7;
        this.switchMomentumIteration = i3;
        this.tolerance = d6;
        this.perplexity = d8;
        this.minGain = d9;
        this.numDimensions = i;
        this.simiarlityFunction = str;
        this.theta = d;
        this.trainingListener = trainingListener;
        this.invert = z;
        this.vpTreeWorkers = i5;
        this.workspaceMode = workspaceMode;
        if (this.workspaceMode == null) {
            this.workspaceMode = WorkspaceMode.NONE;
        }
        this.initializer = iNDArray != null ? new Initializer(iNDArray) : new Initializer();
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String str) {
        this.simiarlityFunction = str;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean z) {
        this.invert = z;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public int getNumDimensions() {
        return this.numDimensions;
    }

    public void setNumDimensions(int i) {
        this.numDimensions = i;
    }

    public INDArray computeGaussianPerplexity(INDArray iNDArray, double d) {
        double d2;
        this.N = iNDArray.rows();
        int i = (int) (3.0d * d);
        if (this.N - 1 < 3.0d * d) {
            throw new IllegalStateException("Perplexity " + d + "is too large for number of samples " + this.N);
        }
        this.rows = Nd4j.zeros(DataType.INT, new long[]{1, this.N + 1});
        this.cols = Nd4j.zeros(DataType.INT, new long[]{1, this.N * i});
        this.vals = Nd4j.zeros(iNDArray.dataType(), new long[]{this.N * i});
        for (int i2 = 0; i2 < this.N; i2++) {
            this.rows.putScalar(i2 + 1, this.rows.getDouble(i2) + i);
        }
        double log2 = Math.log(d);
        VPTree vPTree = new VPTree(iNDArray, this.simiarlityFunction, this.vpTreeWorkers, this.invert);
        log.info("Calculating probabilities of data similarities...");
        for (int i3 = 0; i3 < this.N; i3++) {
            if (i3 % 500 == 0) {
                log.info("Handled " + i3 + " records");
            }
            double d3 = -1.7976931348623157E308d;
            double d4 = Double.MAX_VALUE;
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            vPTree.search(iNDArray.getRow(i3), i + 1, arrayList, arrayList2, false, true);
            double d5 = 1.0d;
            if (arrayList.size() == 0) {
                throw new IllegalStateException("Search returned no values for vector " + i3 + " - similarity \"" + this.simiarlityFunction + "\" may not be defined (for example, vector is all zeros with cosine similarity)");
            }
            Double[] dArr = new Double[arrayList2.size()];
            arrayList2.toArray(dArr);
            INDArray castTo = Nd4j.createFromArray(dArr).castTo(iNDArray.dataType());
            INDArray iNDArray2 = null;
            int i4 = 0;
            boolean z = false;
            while (!z && i4 < 200) {
                Pair<INDArray, Double> computeGaussianKernel = computeGaussianKernel(castTo, d5, i);
                iNDArray2 = (INDArray) computeGaussianKernel.getFirst();
                double doubleValue = ((Double) computeGaussianKernel.getSecond()).doubleValue() - log2;
                if (doubleValue >= this.tolerance || (-doubleValue) >= this.tolerance) {
                    if (doubleValue > 0.0d) {
                        d3 = d5;
                        d2 = (d4 == Double.MAX_VALUE || d4 == -1.7976931348623157E308d) ? d5 * 2.0d : (d5 + d4) / 2.0d;
                    } else {
                        d4 = d5;
                        d2 = (d3 == -1.7976931348623157E308d || d3 == Double.MAX_VALUE) ? d5 / 2.0d : (d5 + d3) / 2.0d;
                    }
                    d5 = d2;
                    i4++;
                } else {
                    z = true;
                }
            }
            iNDArray2.divi(Double.valueOf(iNDArray2.sumNumber().doubleValue() + Double.MIN_VALUE));
            INDArray create = Nd4j.create(new int[]{1, i + 1});
            for (int i5 = 0; i5 < create.length() && i5 < arrayList.size(); i5++) {
                create.putScalar(i5, ((DataPoint) arrayList.get(i5)).getIndex());
            }
            for (int i6 = 0; i6 < i; i6++) {
                this.cols.putScalar(this.rows.getInt(new int[]{i3}) + i6, create.getDouble(i6 + 1));
                this.vals.putScalar(this.rows.getInt(new int[]{i3}) + i6, iNDArray2.getDouble(i6));
            }
        }
        return this.vals;
    }

    public INDArray input() {
        return this.x;
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public INDArray getParam(String str) {
        return null;
    }

    public void addListeners(TrainingListener... trainingListenerArr) {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public Map<String, INDArray> paramTable(boolean z) {
        return null;
    }

    public void setParamTable(Map<String, INDArray> map) {
    }

    public void setParam(String str, INDArray iNDArray) {
    }

    public void clear() {
    }

    public void applyConstraints(int i, int i2) {
    }

    protected Pair<Double, INDArray> gradient(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public SymResult symmetrized(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray create = Nd4j.create(DataType.INT, new long[]{this.N});
        for (int i = 0; i < this.N; i++) {
            int i2 = iNDArray.getInt(new int[]{i});
            int i3 = iNDArray.getInt(new int[]{i + 1});
            for (int i4 = i2; i4 < i3; i4++) {
                boolean z = false;
                for (int i5 = iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i4})}); i5 < iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i4}) + 1}); i5++) {
                    if (iNDArray2.getInt(new int[]{i5}) == i) {
                        z = true;
                    }
                }
                if (z) {
                    create.putScalar(i, create.getInt(new int[]{i}) + 1);
                } else {
                    create.putScalar(i, create.getInt(new int[]{i}) + 1);
                    create.putScalar(iNDArray2.getInt(new int[]{i4}), create.getInt(new int[]{iNDArray2.getInt(new int[]{i4})}) + 1);
                }
            }
        }
        int intValue = create.sumNumber().intValue();
        INDArray create2 = Nd4j.create(DataType.INT, new long[]{this.N});
        INDArray zeros = Nd4j.zeros(DataType.INT, new long[]{this.N + 1});
        INDArray create3 = Nd4j.create(DataType.INT, new long[]{intValue});
        INDArray create4 = Nd4j.create(iNDArray3.dataType(), new long[]{intValue});
        for (int i6 = 0; i6 < this.N; i6++) {
            zeros.putScalar(i6 + 1, zeros.getInt(new int[]{i6}) + create.getInt(new int[]{i6}));
        }
        for (int i7 = 0; i7 < this.N; i7++) {
            for (int i8 = iNDArray.getInt(new int[]{i7}); i8 < iNDArray.getInt(new int[]{i7 + 1}); i8++) {
                boolean z2 = false;
                for (int i9 = iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}); i9 < iNDArray.getInt(new int[]{iNDArray2.getInt(new int[]{i8}) + 1}); i9++) {
                    if (iNDArray2.getInt(new int[]{i9}) == i7) {
                        z2 = true;
                        if (i7 <= iNDArray2.getInt(new int[]{i8})) {
                            create3.putScalar(zeros.getInt(new int[]{i7}) + create2.getInt(new int[]{i7}), iNDArray2.getInt(new int[]{i8}));
                            create3.putScalar(zeros.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}) + create2.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}), i7);
                            create4.putScalar(zeros.getInt(new int[]{i7}) + create2.getInt(new int[]{i7}), iNDArray3.getDouble(i8) + iNDArray3.getDouble(i9));
                            create4.putScalar(zeros.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}) + create2.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}), iNDArray3.getDouble(i8) + iNDArray3.getDouble(i9));
                        }
                    }
                }
                if (!z2) {
                    create3.putScalar(zeros.getInt(new int[]{i7}) + create2.getInt(new int[]{i7}), iNDArray2.getInt(new int[]{i8}));
                    create3.putScalar(zeros.getInt(new int[]{iNDArray2.getInt(new int[]{i8})}) + create2.getInt(new int[]{r0}), i7);
                    create4.putScalar(zeros.getInt(new int[]{i7}) + create2.getInt(new int[]{i7}), iNDArray3.getDouble(i8));
                    create4.putScalar(zeros.getInt(new int[]{r0}) + create2.getInt(new int[]{r0}), iNDArray3.getDouble(i8));
                }
                if (!z2 || (z2 && i7 <= iNDArray2.getInt(new int[]{i8}))) {
                    create2.putScalar(i7, create2.getInt(new int[]{i7}) + 1);
                    int i10 = iNDArray2.getInt(new int[]{i8});
                    if (i10 != i7) {
                        create2.putScalar(i10, create2.getInt(new int[]{i10}) + 1);
                    }
                }
            }
        }
        create4.divi(Double.valueOf(2.0d));
        return new SymResult(zeros, create3, create4);
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray iNDArray, double d, int i) {
        INDArray create = Nd4j.create(iNDArray.dataType(), new long[]{i});
        for (int i2 = 0; i2 < i; i2++) {
            create.putScalar(i2, Math.exp((-d) * iNDArray.getDouble(i2 + 1)));
        }
        double doubleValue = create.sumNumber().doubleValue() + Double.MIN_VALUE;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < i; i3++) {
            d2 += d * iNDArray.getDouble(i3 + 1) * create.getDouble(i3);
        }
        return new Pair<>(create, Double.valueOf((d2 / doubleValue) + Math.log(doubleValue)));
    }

    public void init() {
    }

    public void setListeners(Collection<TrainingListener> collection) {
    }

    public void setListeners(TrainingListener... trainingListenerArr) {
    }

    private int calculateOutputLength() {
        INDArray create = Nd4j.create(this.N);
        for (int i = 0; i < this.N; i++) {
            int i2 = this.rows.getInt(new int[]{i});
            int i3 = this.rows.getInt(new int[]{i + 1});
            for (int i4 = i2; i4 < i3; i4++) {
                boolean z = false;
                for (int i5 = this.rows.getInt(new int[]{this.cols.getInt(new int[]{i4})}); i5 < this.rows.getInt(new int[]{this.cols.getInt(new int[]{i4}) + 1}); i5++) {
                    if (this.cols.getInt(new int[]{i5}) == i) {
                        z = true;
                    }
                }
                if (z) {
                    create.putScalar(i, create.getDouble(i) + 1.0d);
                } else {
                    create.putScalar(i, create.getDouble(i) + 1.0d);
                    create.putScalar(this.cols.getInt(new int[]{i4}), create.getDouble(this.cols.getInt(new int[]{i4})) + 1.0d);
                }
            }
        }
        return create.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
    }

    public static void zeroMean(INDArray iNDArray) {
        iNDArray.subiRowVector(iNDArray.mean(new int[]{0}));
    }

    public void fit() {
        if (this.theta == 0.0d) {
            log.debug("theta == 0, using decomposed version, might be slow");
            this.Y = new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity).calculate(this.x, this.numDimensions, this.perplexity);
            return;
        }
        if (this.Y == null) {
            this.Y = this.initializer.initData();
        }
        this.x.divi(this.x.maxNumber());
        computeGaussianPerplexity(this.x, this.perplexity);
        SymResult symmetrized = symmetrized(this.rows, this.cols, this.vals);
        this.vals = symmetrized.vals.divi(Double.valueOf(symmetrized.vals.sumNumber().doubleValue()));
        this.rows = symmetrized.rows;
        this.cols = symmetrized.cols;
        this.vals.muli(12);
        for (int i = 0; i < this.maxIter; i++) {
            step(this.vals, i);
            zeroMean(this.Y);
            if (i == this.switchMomentumIteration) {
                this.momentum = this.finalMomentum;
            }
            if (i == this.stopLyingIteration) {
                this.vals.divi(12);
            }
            if (this.trainingListener != null) {
                this.trainingListener.iterationDone(this, i, 0);
            }
        }
    }

    public void update(Gradient gradient) {
    }

    public void step(INDArray iNDArray, int i) {
        update(gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    static double sign_tsne(double d) {
        if (d == 0.0d) {
            return 0.0d;
        }
        return d < 0.0d ? -1.0d : 1.0d;
    }

    public void update(INDArray iNDArray, String str) {
        if (this.gains == null) {
            this.gains = this.Y.ulike().assign(Double.valueOf(1.0d));
        }
        for (int i = 0; i < iNDArray.rows(); i++) {
            for (int i2 = 0; i2 < iNDArray.columns(); i2++) {
                if (sign_tsne(iNDArray.getDouble(i, i2)) == sign_tsne(this.yIncs.getDouble(i, i2))) {
                    this.gains.putScalar(new int[]{i, i2}, this.gains.getDouble(i, i2) * 0.8d);
                } else {
                    this.gains.putScalar(new int[]{i, i2}, this.gains.getDouble(i, i2) + 0.2d);
                }
            }
        }
        BooleanIndexing.replaceWhere(this.gains, Double.valueOf(this.minGain), Conditions.lessThan(Double.valueOf(this.minGain)));
        this.Y.addi(this.yIncs);
        INDArray mul = this.gains.mul(iNDArray);
        if (this.useAdaGrad) {
            if (this.adaGrad == null) {
                this.adaGrad = new AdaGrad(iNDArray.shape(), this.learningRate);
                this.adaGrad.setStateViewArray(Nd4j.zeros(iNDArray.shape()).reshape(1L, mul.length()), mul.shape(), iNDArray.ordering(), true);
            }
            mul = this.adaGrad.getGradient(mul, 0);
        } else {
            mul.muli(Double.valueOf(this.learningRate));
        }
        this.yIncs.muli(Double.valueOf(this.momentum)).subi(mul);
    }

    public void saveAsFile(List<String> list, String str) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str)));
        Throwable th = null;
        for (int i = 0; i < this.Y.rows() && i < list.size(); i++) {
            try {
                try {
                    String str2 = list.get(i);
                    if (str2 != null) {
                        StringBuilder sb = new StringBuilder();
                        INDArray row = this.Y.getRow(i);
                        for (int i2 = 0; i2 < row.length(); i2++) {
                            sb.append(row.getDouble(i2));
                            if (i2 < row.length() - 1) {
                                sb.append(",");
                            }
                        }
                        sb.append(",");
                        sb.append(str2);
                        sb.append("\n");
                        bufferedWriter.write(sb.toString());
                    }
                } catch (Throwable th2) {
                    th = th2;
                    throw th2;
                }
            } catch (Throwable th3) {
                if (bufferedWriter != null) {
                    if (th != null) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                throw th3;
            }
        }
        bufferedWriter.flush();
        if (bufferedWriter != null) {
            if (0 == 0) {
                bufferedWriter.close();
                return;
            }
            try {
                bufferedWriter.close();
            } catch (Throwable th5) {
                th.addSuppressed(th5);
            }
        }
    }

    public void saveAsFile(String str) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str)));
        Throwable th = null;
        for (int i = 0; i < this.Y.rows(); i++) {
            try {
                try {
                    StringBuilder sb = new StringBuilder();
                    INDArray row = this.Y.getRow(i);
                    for (int i2 = 0; i2 < row.length(); i2++) {
                        sb.append(row.getDouble(i2));
                        if (i2 < row.length() - 1) {
                            sb.append(",");
                        }
                    }
                    sb.append("\n");
                    bufferedWriter.write(sb.toString());
                } catch (Throwable th2) {
                    th = th2;
                    throw th2;
                }
            } catch (Throwable th3) {
                if (bufferedWriter != null) {
                    if (th != null) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                throw th3;
            }
        }
        bufferedWriter.flush();
        if (bufferedWriter != null) {
            if (0 == 0) {
                bufferedWriter.close();
                return;
            }
            try {
                bufferedWriter.close();
            } catch (Throwable th5) {
                th.addSuppressed(th5);
            }
        }
    }

    @Deprecated
    public void plot(INDArray iNDArray, int i, List<String> list, String str) throws IOException {
        fit(iNDArray, i);
        saveAsFile(list, str);
    }

    public double score() {
        INDArray create = Nd4j.create(this.numDimensions);
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        for (int i = 0; i < this.N; i++) {
            this.tree.computeNonEdgeForces(i, this.theta, create, atomicDouble);
        }
        double d = 0.0d;
        INDArray iNDArray = this.Y;
        for (int i2 = 0; i2 < this.N; i2++) {
            int i3 = this.rows.getInt(new int[]{i2});
            int i4 = this.rows.getInt(new int[]{i2 + 1});
            int i5 = i2;
            for (int i6 = i3; i6 < i4; i6++) {
                iNDArray.slice(i5).subi(iNDArray.slice(this.cols.getInt(new int[]{i6})), create);
                d += (this.vals.getDouble(i6) * Math.log(this.vals.getDouble(i6) + Nd4j.EPS_THRESHOLD)) / (((1.0d / (1.0d + Transforms.pow(create, 2).sumNumber().doubleValue())) / atomicDouble.doubleValue()) + Nd4j.EPS_THRESHOLD);
            }
        }
        return d;
    }

    public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) {
    }

    public INDArray params() {
        return null;
    }

    public long numParams() {
        return 0L;
    }

    public long numParams(boolean z) {
        return 0L;
    }

    public void setParams(INDArray iNDArray) {
    }

    public void setParamsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray getGradientsViewArray() {
        throw new UnsupportedOperationException();
    }

    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public void fit(INDArray iNDArray) {
        this.x = iNDArray;
        fit();
    }

    public void fit(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        fit(iNDArray);
    }

    @Deprecated
    public void fit(INDArray iNDArray, int i) {
        this.x = iNDArray;
        this.numDimensions = i;
        fit();
    }

    public Gradient gradient() {
        if (this.yIncs == null) {
            this.yIncs = this.Y.like();
        }
        if (this.gains == null) {
            this.gains = this.Y.ulike().assign(Double.valueOf(1.0d));
        }
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        INDArray like = this.Y.like();
        INDArray like2 = this.Y.like();
        this.tree = new SpTree(this.Y);
        this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, like);
        for (int i = 0; i < this.N; i++) {
            this.tree.computeNonEdgeForces(i, this.theta, like2.slice(i), atomicDouble);
        }
        INDArray subi = like.subi(like2.divi(atomicDouble));
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put(Y_GRAD, subi);
        return defaultGradient;
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    public int batchSize() {
        return 0;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
    }

    public INDArray getData() {
        return this.Y;
    }

    public void setData(INDArray iNDArray) {
        this.Y = iNDArray;
    }

    public void setN(int i) {
        this.N = i;
    }

    public void close() {
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public double getRealMin() {
        return this.realMin;
    }

    public double getInitialMomentum() {
        return this.initialMomentum;
    }

    public double getFinalMomentum() {
        return this.finalMomentum;
    }

    public double getMinGain() {
        return this.minGain;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public int getSwitchMomentumIteration() {
        return this.switchMomentumIteration;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    public boolean isUsePca() {
        return this.usePca;
    }

    public int getStopLyingIteration() {
        return this.stopLyingIteration;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public AdaGrad getAdaGrad() {
        return this.adaGrad;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public INDArray getY() {
        return this.Y;
    }

    public int getN() {
        return this.N;
    }

    public INDArray getRows() {
        return this.rows;
    }

    public INDArray getCols() {
        return this.cols;
    }

    public INDArray getVals() {
        return this.vals;
    }

    public INDArray getX() {
        return this.x;
    }

    public SpTree getTree() {
        return this.tree;
    }

    public INDArray getGains() {
        return this.gains;
    }

    public INDArray getYIncs() {
        return this.yIncs;
    }

    public int getVpTreeWorkers() {
        return this.vpTreeWorkers;
    }

    public TrainingListener getTrainingListener() {
        return this.trainingListener;
    }

    public WorkspaceMode getWorkspaceMode() {
        return this.workspaceMode;
    }

    public Initializer getInitializer() {
        return this.initializer;
    }

    public WorkspaceConfiguration getWorkspaceConfigurationFeedForward() {
        return this.workspaceConfigurationFeedForward;
    }

    public void setMaxIter(int i) {
        this.maxIter = i;
    }

    public void setRealMin(double d) {
        this.realMin = d;
    }

    public void setInitialMomentum(double d) {
        this.initialMomentum = d;
    }

    public void setFinalMomentum(double d) {
        this.finalMomentum = d;
    }

    public void setMinGain(double d) {
        this.minGain = d;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public void setSwitchMomentumIteration(int i) {
        this.switchMomentumIteration = i;
    }

    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    public void setUsePca(boolean z) {
        this.usePca = z;
    }

    public void setStopLyingIteration(int i) {
        this.stopLyingIteration = i;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setAdaGrad(AdaGrad adaGrad) {
        this.adaGrad = adaGrad;
    }

    public void setUseAdaGrad(boolean z) {
        this.useAdaGrad = z;
    }

    public void setPerplexity(double d) {
        this.perplexity = d;
    }

    public void setY(INDArray iNDArray) {
        this.Y = iNDArray;
    }

    public void setTheta(double d) {
        this.theta = d;
    }

    public void setRows(INDArray iNDArray) {
        this.rows = iNDArray;
    }

    public void setCols(INDArray iNDArray) {
        this.cols = iNDArray;
    }

    public void setVals(INDArray iNDArray) {
        this.vals = iNDArray;
    }

    public void setX(INDArray iNDArray) {
        this.x = iNDArray;
    }

    public void setTree(SpTree spTree) {
        this.tree = spTree;
    }

    public void setGains(INDArray iNDArray) {
        this.gains = iNDArray;
    }

    public void setVpTreeWorkers(int i) {
        this.vpTreeWorkers = i;
    }

    public void setTrainingListener(TrainingListener trainingListener) {
        this.trainingListener = trainingListener;
    }

    public void setWorkspaceMode(WorkspaceMode workspaceMode) {
        this.workspaceMode = workspaceMode;
    }

    public void setInitializer(Initializer initializer) {
        this.initializer = initializer;
    }

    public void setWorkspaceConfigurationFeedForward(WorkspaceConfiguration workspaceConfiguration) {
        this.workspaceConfigurationFeedForward = workspaceConfiguration;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BarnesHutTsne)) {
            return false;
        }
        BarnesHutTsne barnesHutTsne = (BarnesHutTsne) obj;
        if (!barnesHutTsne.canEqual(this) || getMaxIter() != barnesHutTsne.getMaxIter() || Double.compare(getRealMin(), barnesHutTsne.getRealMin()) != 0 || Double.compare(getInitialMomentum(), barnesHutTsne.getInitialMomentum()) != 0 || Double.compare(getFinalMomentum(), barnesHutTsne.getFinalMomentum()) != 0 || Double.compare(getMinGain(), barnesHutTsne.getMinGain()) != 0 || Double.compare(getMomentum(), barnesHutTsne.getMomentum()) != 0 || getSwitchMomentumIteration() != barnesHutTsne.getSwitchMomentumIteration() || isNormalize() != barnesHutTsne.isNormalize() || isUsePca() != barnesHutTsne.isUsePca() || getStopLyingIteration() != barnesHutTsne.getStopLyingIteration() || Double.compare(getTolerance(), barnesHutTsne.getTolerance()) != 0 || Double.compare(getLearningRate(), barnesHutTsne.getLearningRate()) != 0) {
            return false;
        }
        AdaGrad adaGrad = getAdaGrad();
        AdaGrad adaGrad2 = barnesHutTsne.getAdaGrad();
        if (adaGrad == null) {
            if (adaGrad2 != null) {
                return false;
            }
        } else if (!adaGrad.equals(adaGrad2)) {
            return false;
        }
        if (isUseAdaGrad() != barnesHutTsne.isUseAdaGrad() || Double.compare(getPerplexity(), barnesHutTsne.getPerplexity()) != 0) {
            return false;
        }
        INDArray y = getY();
        INDArray y2 = barnesHutTsne.getY();
        if (y == null) {
            if (y2 != null) {
                return false;
            }
        } else if (!y.equals(y2)) {
            return false;
        }
        if (getN() != barnesHutTsne.getN() || Double.compare(getTheta(), barnesHutTsne.getTheta()) != 0) {
            return false;
        }
        INDArray rows = getRows();
        INDArray rows2 = barnesHutTsne.getRows();
        if (rows == null) {
            if (rows2 != null) {
                return false;
            }
        } else if (!rows.equals(rows2)) {
            return false;
        }
        INDArray cols = getCols();
        INDArray cols2 = barnesHutTsne.getCols();
        if (cols == null) {
            if (cols2 != null) {
                return false;
            }
        } else if (!cols.equals(cols2)) {
            return false;
        }
        INDArray vals = getVals();
        INDArray vals2 = barnesHutTsne.getVals();
        if (vals == null) {
            if (vals2 != null) {
                return false;
            }
        } else if (!vals.equals(vals2)) {
            return false;
        }
        String simiarlityFunction = getSimiarlityFunction();
        String simiarlityFunction2 = barnesHutTsne.getSimiarlityFunction();
        if (simiarlityFunction == null) {
            if (simiarlityFunction2 != null) {
                return false;
            }
        } else if (!simiarlityFunction.equals(simiarlityFunction2)) {
            return false;
        }
        if (isInvert() != barnesHutTsne.isInvert()) {
            return false;
        }
        INDArray x = getX();
        INDArray x2 = barnesHutTsne.getX();
        if (x == null) {
            if (x2 != null) {
                return false;
            }
        } else if (!x.equals(x2)) {
            return false;
        }
        if (getNumDimensions() != barnesHutTsne.getNumDimensions()) {
            return false;
        }
        SpTree tree = getTree();
        SpTree tree2 = barnesHutTsne.getTree();
        if (tree == null) {
            if (tree2 != null) {
                return false;
            }
        } else if (!tree.equals(tree2)) {
            return false;
        }
        INDArray gains = getGains();
        INDArray gains2 = barnesHutTsne.getGains();
        if (gains == null) {
            if (gains2 != null) {
                return false;
            }
        } else if (!gains.equals(gains2)) {
            return false;
        }
        INDArray yIncs = getYIncs();
        INDArray yIncs2 = barnesHutTsne.getYIncs();
        if (yIncs == null) {
            if (yIncs2 != null) {
                return false;
            }
        } else if (!yIncs.equals(yIncs2)) {
            return false;
        }
        if (getVpTreeWorkers() != barnesHutTsne.getVpTreeWorkers()) {
            return false;
        }
        WorkspaceMode workspaceMode = getWorkspaceMode();
        WorkspaceMode workspaceMode2 = barnesHutTsne.getWorkspaceMode();
        if (workspaceMode == null) {
            if (workspaceMode2 != null) {
                return false;
            }
        } else if (!workspaceMode.equals(workspaceMode2)) {
            return false;
        }
        Initializer initializer = getInitializer();
        Initializer initializer2 = barnesHutTsne.getInitializer();
        if (initializer == null) {
            if (initializer2 != null) {
                return false;
            }
        } else if (!initializer.equals(initializer2)) {
            return false;
        }
        WorkspaceConfiguration workspaceConfigurationFeedForward = getWorkspaceConfigurationFeedForward();
        WorkspaceConfiguration workspaceConfigurationFeedForward2 = barnesHutTsne.getWorkspaceConfigurationFeedForward();
        return workspaceConfigurationFeedForward == null ? workspaceConfigurationFeedForward2 == null : workspaceConfigurationFeedForward.equals(workspaceConfigurationFeedForward2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BarnesHutTsne;
    }

    public int hashCode() {
        int maxIter = (1 * 59) + getMaxIter();
        long doubleToLongBits = Double.doubleToLongBits(getRealMin());
        int i = (maxIter * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getInitialMomentum());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(getFinalMomentum());
        int i3 = (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
        long doubleToLongBits4 = Double.doubleToLongBits(getMinGain());
        int i4 = (i3 * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4));
        long doubleToLongBits5 = Double.doubleToLongBits(getMomentum());
        int switchMomentumIteration = (((((((((i4 * 59) + ((int) ((doubleToLongBits5 >>> 32) ^ doubleToLongBits5))) * 59) + getSwitchMomentumIteration()) * 59) + (isNormalize() ? 79 : 97)) * 59) + (isUsePca() ? 79 : 97)) * 59) + getStopLyingIteration();
        long doubleToLongBits6 = Double.doubleToLongBits(getTolerance());
        int i5 = (switchMomentumIteration * 59) + ((int) ((doubleToLongBits6 >>> 32) ^ doubleToLongBits6));
        long doubleToLongBits7 = Double.doubleToLongBits(getLearningRate());
        int i6 = (i5 * 59) + ((int) ((doubleToLongBits7 >>> 32) ^ doubleToLongBits7));
        AdaGrad adaGrad = getAdaGrad();
        int hashCode = (((i6 * 59) + (adaGrad == null ? 43 : adaGrad.hashCode())) * 59) + (isUseAdaGrad() ? 79 : 97);
        long doubleToLongBits8 = Double.doubleToLongBits(getPerplexity());
        int i7 = (hashCode * 59) + ((int) ((doubleToLongBits8 >>> 32) ^ doubleToLongBits8));
        INDArray y = getY();
        int hashCode2 = (((i7 * 59) + (y == null ? 43 : y.hashCode())) * 59) + getN();
        long doubleToLongBits9 = Double.doubleToLongBits(getTheta());
        int i8 = (hashCode2 * 59) + ((int) ((doubleToLongBits9 >>> 32) ^ doubleToLongBits9));
        INDArray rows = getRows();
        int hashCode3 = (i8 * 59) + (rows == null ? 43 : rows.hashCode());
        INDArray cols = getCols();
        int hashCode4 = (hashCode3 * 59) + (cols == null ? 43 : cols.hashCode());
        INDArray vals = getVals();
        int hashCode5 = (hashCode4 * 59) + (vals == null ? 43 : vals.hashCode());
        String simiarlityFunction = getSimiarlityFunction();
        int hashCode6 = (((hashCode5 * 59) + (simiarlityFunction == null ? 43 : simiarlityFunction.hashCode())) * 59) + (isInvert() ? 79 : 97);
        INDArray x = getX();
        int hashCode7 = (((hashCode6 * 59) + (x == null ? 43 : x.hashCode())) * 59) + getNumDimensions();
        SpTree tree = getTree();
        int hashCode8 = (hashCode7 * 59) + (tree == null ? 43 : tree.hashCode());
        INDArray gains = getGains();
        int hashCode9 = (hashCode8 * 59) + (gains == null ? 43 : gains.hashCode());
        INDArray yIncs = getYIncs();
        int hashCode10 = (((hashCode9 * 59) + (yIncs == null ? 43 : yIncs.hashCode())) * 59) + getVpTreeWorkers();
        WorkspaceMode workspaceMode = getWorkspaceMode();
        int hashCode11 = (hashCode10 * 59) + (workspaceMode == null ? 43 : workspaceMode.hashCode());
        Initializer initializer = getInitializer();
        int hashCode12 = (hashCode11 * 59) + (initializer == null ? 43 : initializer.hashCode());
        WorkspaceConfiguration workspaceConfigurationFeedForward = getWorkspaceConfigurationFeedForward();
        return (hashCode12 * 59) + (workspaceConfigurationFeedForward == null ? 43 : workspaceConfigurationFeedForward.hashCode());
    }

    public String toString() {
        return "BarnesHutTsne(maxIter=" + getMaxIter() + ", realMin=" + getRealMin() + ", initialMomentum=" + getInitialMomentum() + ", finalMomentum=" + getFinalMomentum() + ", minGain=" + getMinGain() + ", momentum=" + getMomentum() + ", switchMomentumIteration=" + getSwitchMomentumIteration() + ", normalize=" + isNormalize() + ", usePca=" + isUsePca() + ", stopLyingIteration=" + getStopLyingIteration() + ", tolerance=" + getTolerance() + ", learningRate=" + getLearningRate() + ", adaGrad=" + getAdaGrad() + ", useAdaGrad=" + isUseAdaGrad() + ", perplexity=" + getPerplexity() + ", Y=" + getY() + ", N=" + getN() + ", theta=" + getTheta() + ", rows=" + getRows() + ", cols=" + getCols() + ", vals=" + getVals() + ", simiarlityFunction=" + getSimiarlityFunction() + ", invert=" + isInvert() + ", x=" + getX() + ", numDimensions=" + getNumDimensions() + ", tree=" + getTree() + ", gains=" + getGains() + ", yIncs=" + getYIncs() + ", vpTreeWorkers=" + getVpTreeWorkers() + ", trainingListener=" + getTrainingListener() + ", workspaceMode=" + getWorkspaceMode() + ", initializer=" + getInitializer() + ", workspaceConfigurationFeedForward=" + getWorkspaceConfigurationFeedForward() + ")";
    }

    public void setYIncs(INDArray iNDArray) {
        this.yIncs = iNDArray;
    }
}
