package de.lmu.ifi.dbs.elki.utilities.scaling.outlier;

import de.lmu.ifi.dbs.elki.database.ids.ArrayDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBIDArrayIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.relation.DoubleRelation;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
import de.lmu.ifi.dbs.elki.result.outlier.OutlierResult;
import de.lmu.ifi.dbs.elki.utilities.Alias;
import de.lmu.ifi.dbs.elki.utilities.datastructures.BitsUtil;
import de.lmu.ifi.dbs.elki.utilities.datastructures.arraylike.NumberArrayAdapter;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import net.jafama.FastMath;

@Reference(authors = "J. Gao, P.-N. Tan", title = "Converting Output Scores from Outlier Detection Algorithms into Probability Estimates", booktitle = "Proc. Sixth International Conference on Data Mining, 2006. ICDM'06.", url = "https://doi.org/10.1109/ICDM.2006.43", bibkey = "DBLP:conf/icdm/GaoT06")
@Alias({"de.lmu.ifi.dbs.elki.utilities.scaling.outlier.SigmoidOutlierScalingFunction"})
/* loaded from: input_file:de/lmu/ifi/dbs/elki/utilities/scaling/outlier/SigmoidOutlierScaling.class */
public class SigmoidOutlierScaling implements OutlierScaling {
    private static final Logging LOG = Logging.getLogger((Class<?>) SigmoidOutlierScaling.class);
    double Afinal;
    double Bfinal;

    @Override // de.lmu.ifi.dbs.elki.utilities.scaling.outlier.OutlierScaling
    public void prepare(OutlierResult outlierResult) {
        MeanVariance meanVariance = new MeanVariance();
        DoubleRelation scores = outlierResult.getScores();
        DBIDIter iterDBIDs = scores.iterDBIDs();
        while (iterDBIDs.valid()) {
            double doubleValue = scores.doubleValue(iterDBIDs);
            if (Double.isFinite(doubleValue)) {
                meanVariance.put(doubleValue);
            }
            iterDBIDs.advance();
        }
        double d = 1.0d;
        double d2 = -meanVariance.getMean();
        int i = 0;
        ArrayDBIDs ensureArray = DBIDUtil.ensureArray(outlierResult.getScores().getDBIDs());
        DBIDArrayIter iter = ensureArray.iter();
        long[] zero = BitsUtil.zero(ensureArray.size());
        boolean z = true;
        while (true) {
            if (!z) {
                break;
            }
            z = false;
            iter.seek(0);
            int i2 = 0;
            while (i2 < ensureArray.size()) {
                if ((d * outlierResult.getScores().doubleValue(iter)) + d2 > 0.0d) {
                    if (!BitsUtil.get(zero, i2)) {
                        BitsUtil.setI(zero, i2);
                        z = true;
                    }
                } else if (BitsUtil.get(zero, i2)) {
                    BitsUtil.clearI(zero, i2);
                    z = true;
                }
                i2++;
                iter.advance();
            }
            if (!z) {
                break;
            }
            double[] MStepLevenbergMarquardt = MStepLevenbergMarquardt(d, d2, ensureArray, zero, outlierResult.getScores());
            d = MStepLevenbergMarquardt[0];
            d2 = MStepLevenbergMarquardt[1];
            i++;
            if (i > 100) {
                LOG.warning("Max iterations met in sigmoid fitting.");
                break;
            }
        }
        this.Afinal = d;
        this.Bfinal = d2;
        LOG.debugFine("A = " + this.Afinal + " B = " + this.Bfinal);
    }

    @Override // de.lmu.ifi.dbs.elki.utilities.scaling.outlier.OutlierScaling
    public <A> void prepare(A a, NumberArrayAdapter<?, A> numberArrayAdapter) {
        MeanVariance meanVariance = new MeanVariance();
        int size = numberArrayAdapter.size(a);
        for (int i = 0; i < size; i++) {
            double d = numberArrayAdapter.getDouble(a, i);
            if (Double.isFinite(d)) {
                meanVariance.put(d);
            }
        }
        double d2 = 1.0d;
        double d3 = -meanVariance.getMean();
        int i2 = 0;
        long[] zero = BitsUtil.zero(size);
        boolean z = true;
        while (true) {
            if (!z) {
                break;
            }
            z = false;
            for (int i3 = 0; i3 < size; i3++) {
                if ((d2 * numberArrayAdapter.getDouble(a, i3)) + d3 > 0.0d) {
                    if (!BitsUtil.get(zero, i3)) {
                        BitsUtil.setI(zero, i3);
                        z = true;
                    }
                } else if (BitsUtil.get(zero, i3)) {
                    BitsUtil.clearI(zero, i3);
                    z = true;
                }
            }
            if (!z) {
                break;
            }
            double[] MStepLevenbergMarquardt = MStepLevenbergMarquardt(d2, d3, zero, (long[]) a, (NumberArrayAdapter<?, long[]>) numberArrayAdapter);
            d2 = MStepLevenbergMarquardt[0];
            d3 = MStepLevenbergMarquardt[1];
            i2++;
            if (i2 > 100) {
                LOG.warning("Max iterations met in sigmoid fitting.");
                break;
            }
        }
        this.Afinal = d2;
        this.Bfinal = d3;
        LOG.debugFine("A = " + this.Afinal + " B = " + this.Bfinal);
    }

    private double[] MStepLevenbergMarquardt(double d, double d2, ArrayDBIDs arrayDBIDs, long[] jArr, DoubleRelation doubleRelation) {
        double d3;
        double d4;
        double log;
        double exp;
        double exp2;
        double d5;
        double d6;
        double log2;
        int cardinality = BitsUtil.cardinality(jArr);
        int size = arrayDBIDs.size() - cardinality;
        double d7 = (cardinality + 1.0d) / (cardinality + 2.0d);
        double d8 = 1.0d / (size + 2.0d);
        double d9 = 0.0d;
        DBIDArrayIter iter = arrayDBIDs.iter();
        int i = 0;
        while (i < arrayDBIDs.size()) {
            double doubleValue = (doubleRelation.doubleValue(iter) * d) + d2;
            double d10 = BitsUtil.get(jArr, i) ? d8 : d7;
            if (doubleValue >= 0.0d) {
                d5 = d9;
                d6 = d10 * doubleValue;
                log2 = FastMath.log(1.0d + FastMath.exp(-doubleValue));
            } else {
                d5 = d9;
                d6 = (d10 - 1.0d) * doubleValue;
                log2 = FastMath.log(1.0d + FastMath.exp(doubleValue));
            }
            d9 = d5 + d6 + log2;
            i++;
            iter.advance();
        }
        for (int i2 = 0; i2 < 10; i2++) {
            double d11 = 1.0E-12d;
            double d12 = 1.0E-12d;
            double d13 = 0.0d;
            double d14 = 0.0d;
            double d15 = 0.0d;
            iter.seek(0);
            int i3 = 0;
            while (i3 < arrayDBIDs.size()) {
                double doubleValue2 = doubleRelation.doubleValue(iter);
                double d16 = (doubleValue2 * d) + d2;
                if (d16 >= 0.0d) {
                    exp = FastMath.exp(-d16) / (1.0d + FastMath.exp(-d16));
                    exp2 = 1.0d / (1.0d + FastMath.exp(-d16));
                } else {
                    exp = 1.0d / (1.0d + FastMath.exp(d16));
                    exp2 = FastMath.exp(d16) / (1.0d + FastMath.exp(d16));
                }
                double d17 = exp * exp2;
                d11 += doubleValue2 * doubleValue2 * d17;
                d12 += d17;
                d13 += doubleValue2 * d17;
                double d18 = (BitsUtil.get(jArr, i3) ? d8 : d7) - exp;
                d14 += doubleValue2 * d18;
                d15 += d18;
                i3++;
                iter.advance();
            }
            if (Math.abs(d14) < 1.0E-5d && Math.abs(d15) < 1.0E-5d) {
                break;
            }
            double d19 = (d11 * d12) - (d13 * d13);
            double d20 = (-((d12 * d14) - (d13 * d15))) / d19;
            double d21 = (-(((-d13) * d14) + (d11 * d15))) / d19;
            double d22 = (d14 * d20) + (d15 * d21);
            double d23 = 1.0d;
            while (true) {
                if (d23 >= 1.0E-8d) {
                    double d24 = d + (d23 * d20);
                    double d25 = d2 + (d23 * d21);
                    double d26 = 0.0d;
                    iter.seek(0);
                    int i4 = 0;
                    while (i4 < arrayDBIDs.size()) {
                        double doubleValue3 = (doubleRelation.doubleValue(iter) * d24) + d25;
                        double d27 = BitsUtil.get(jArr, i4) ? d8 : d7;
                        if (doubleValue3 >= 0.0d) {
                            d3 = d26;
                            d4 = d27 * doubleValue3;
                            log = FastMath.log(1.0d + FastMath.exp(-doubleValue3));
                        } else {
                            d3 = d26;
                            d4 = (d27 - 1.0d) * doubleValue3;
                            log = FastMath.log(1.0d + FastMath.exp(doubleValue3));
                        }
                        d26 = d3 + d4 + log;
                        i4++;
                        iter.advance();
                    }
                    if (d26 < d9 + (1.0E-4d * d23 * d22)) {
                        d = d24;
                        d2 = d25;
                        d9 = d26;
                        break;
                    }
                    d23 /= 2.0d;
                    if (d23 < 1.0E-8d) {
                        LOG.debug("Minstep hit.");
                        break;
                    }
                }
            }
        }
        return new double[]{d, d2};
    }

    private <A> double[] MStepLevenbergMarquardt(double d, double d2, long[] jArr, A a, NumberArrayAdapter<?, A> numberArrayAdapter) {
        double d3;
        double d4;
        double log;
        double exp;
        double exp2;
        double d5;
        double d6;
        double log2;
        int size = numberArrayAdapter.size(a);
        int cardinality = BitsUtil.cardinality(jArr);
        int i = size - cardinality;
        double d7 = (cardinality + 1.0d) / (cardinality + 2.0d);
        double d8 = 1.0d / (i + 2.0d);
        double d9 = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            double d10 = (numberArrayAdapter.getDouble(a, i2) * d) + d2;
            double d11 = BitsUtil.get(jArr, i2) ? d8 : d7;
            if (d10 >= 0.0d) {
                d5 = d9;
                d6 = d11 * d10;
                log2 = FastMath.log(1.0d + FastMath.exp(-d10));
            } else {
                d5 = d9;
                d6 = (d11 - 1.0d) * d10;
                log2 = FastMath.log(1.0d + FastMath.exp(d10));
            }
            d9 = d5 + d6 + log2;
        }
        for (int i3 = 0; i3 < 10; i3++) {
            double d12 = 1.0E-12d;
            double d13 = 1.0E-12d;
            double d14 = 0.0d;
            double d15 = 0.0d;
            double d16 = 0.0d;
            for (int i4 = 0; i4 < size; i4++) {
                double d17 = numberArrayAdapter.getDouble(a, i4);
                double d18 = (d17 * d) + d2;
                if (d18 >= 0.0d) {
                    exp = FastMath.exp(-d18) / (1.0d + FastMath.exp(-d18));
                    exp2 = 1.0d / (1.0d + FastMath.exp(-d18));
                } else {
                    exp = 1.0d / (1.0d + FastMath.exp(d18));
                    exp2 = FastMath.exp(d18) / (1.0d + FastMath.exp(d18));
                }
                double d19 = exp * exp2;
                d12 += d17 * d17 * d19;
                d13 += d19;
                d14 += d17 * d19;
                double d20 = (BitsUtil.get(jArr, i4) ? d8 : d7) - exp;
                d15 += d17 * d20;
                d16 += d20;
            }
            if (Math.abs(d15) < 1.0E-5d && Math.abs(d16) < 1.0E-5d) {
                break;
            }
            double d21 = (d12 * d13) - (d14 * d14);
            double d22 = (-((d13 * d15) - (d14 * d16))) / d21;
            double d23 = (-(((-d14) * d15) + (d12 * d16))) / d21;
            double d24 = (d15 * d22) + (d16 * d23);
            double d25 = 1.0d;
            while (true) {
                if (d25 >= 1.0E-8d) {
                    double d26 = d + (d25 * d22);
                    double d27 = d2 + (d25 * d23);
                    double d28 = 0.0d;
                    for (int i5 = 0; i5 < size; i5++) {
                        double d29 = (numberArrayAdapter.getDouble(a, i5) * d26) + d27;
                        double d30 = BitsUtil.get(jArr, i5) ? d8 : d7;
                        if (d29 >= 0.0d) {
                            d3 = d28;
                            d4 = d30 * d29;
                            log = FastMath.log(1.0d + FastMath.exp(-d29));
                        } else {
                            d3 = d28;
                            d4 = (d30 - 1.0d) * d29;
                            log = FastMath.log(1.0d + FastMath.exp(d29));
                        }
                        d28 = d3 + d4 + log;
                    }
                    if (d28 < d9 + (1.0E-4d * d25 * d24)) {
                        d = d26;
                        d2 = d27;
                        d9 = d28;
                        break;
                    }
                    d25 /= 2.0d;
                    if (d25 < 1.0E-8d) {
                        LOG.debug("Minstep hit.");
                        break;
                    }
                }
            }
        }
        return new double[]{d, d2};
    }

    @Override // de.lmu.ifi.dbs.elki.utilities.scaling.ScalingFunction
    public double getMax() {
        return 1.0d;
    }

    @Override // de.lmu.ifi.dbs.elki.utilities.scaling.ScalingFunction
    public double getMin() {
        return 0.0d;
    }

    @Override // de.lmu.ifi.dbs.elki.utilities.scaling.ScalingFunction
    public double getScaled(double d) {
        return 1.0d / (1.0d + FastMath.exp(((-this.Afinal) * d) - this.Bfinal));
    }
}
