package de.lmu.ifi.dbs.elki.algorithm.clustering.em;

import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.model.MeanModel;
import de.lmu.ifi.dbs.elki.data.type.SimpleTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil;
import de.lmu.ifi.dbs.elki.database.datastore.WritableDataStore;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.MaterializedRelation;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.utilities.Alias;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
import de.lmu.ifi.dbs.elki.utilities.exceptions.ExceptionMessages;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Parameter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

@Reference(authors = "A. P. Dempster, N. M. Laird, D. B. Rubin", title = "Maximum Likelihood from Incomplete Data via the EM algorithm", booktitle = "Journal of the Royal Statistical Society, Series B, 39(1), 1977, pp. 1-31", url = "http://www.jstor.org/stable/2984875")
@Alias({"de.lmu.ifi.dbs.elki.algorithm.clustering.EM", "EM"})
@Description("Cluster data via Gaussian mixture modeling and the EM algorithm")
@Title("EM-Clustering: Clustering by Expectation Maximization")
/* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/clustering/em/EM.class */
public class EM<V extends NumberVector, M extends MeanModel> extends AbstractAlgorithm<Clustering<M>> implements ClusteringAlgorithm<Clustering<M>> {
    private int k;
    private double delta;
    private EMClusterModelFactory<V, M> mfactory;
    private int maxiter;
    private boolean soft;
    private static final double MIN_LOGLIKELIHOOD = -100000.0d;
    private static final Logging LOG = Logging.getLogger((Class<?>) EM.class);
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation<>(double[].class);

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/clustering/em/EM$Parameterizer.class */
    public static class Parameterizer<V extends NumberVector, M extends MeanModel> extends AbstractParameterizer {
        public static final OptionID K_ID = new OptionID("em.k", "The number of clusters to find.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID INIT_ID = new OptionID("em.model", "Model factory.");
        protected int k;
        protected double delta;
        protected EMClusterModelFactory<V, M> initializer;
        protected int maxiter = -1;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            Parameter<?> intParameter = new IntParameter(K_ID);
            intParameter.addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.k = ((Integer) intParameter.getValue()).intValue();
            }
            ObjectParameter objectParameter = new ObjectParameter(INIT_ID, (Class<?>) EMClusterModelFactory.class, (Class<?>) MultivariateGaussianModelFactory.class);
            if (parameterization.grab(objectParameter)) {
                this.initializer = (EMClusterModelFactory) objectParameter.instantiateClass(parameterization);
            }
            Parameter<?> parameter = (DoubleParameter) new DoubleParameter(DELTA_ID, 1.0E-5d).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE);
            if (parameterization.grab(parameter)) {
                this.delta = ((Double) parameter.getValue()).doubleValue();
            }
            Parameter<?> parameter2 = (IntParameter) ((IntParameter) new IntParameter(KMeans.MAXITER_ID).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ZERO_INT)).setOptional(true);
            if (parameterization.grab(parameter2)) {
                this.maxiter = ((Integer) parameter2.getValue()).intValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public EM<V, M> makeInstance() {
            return new EM<>(this.k, this.delta, this.initializer, this.maxiter, false);
        }
    }

    public EM(int i, double d, EMClusterModelFactory<V, M> eMClusterModelFactory, int i2, boolean z) {
        this.k = i;
        this.delta = d;
        this.mfactory = eMClusterModelFactory;
        this.maxiter = i2;
        setSoft(z);
    }

    public Clustering<M> run(Database database, Relation<V> relation) {
        if (relation.size() == 0) {
            throw new IllegalArgumentException(ExceptionMessages.DATABASE_EMPTY);
        }
        if (LOG.isVerbose()) {
            LOG.verbose("initializing " + this.k + " models");
        }
        List<? extends EMClusterModel<M>> buildInitialModels = this.mfactory.buildInitialModels(database, relation, this.k, SquaredEuclideanDistanceFunction.STATIC);
        WritableDataStore makeStorage = DataStoreUtil.makeStorage(relation.getDBIDs(), 10, double[].class);
        double assignProbabilitiesToInstances = assignProbabilitiesToInstances(relation, buildInitialModels, makeStorage);
        if (LOG.isVerbose()) {
            LOG.verbose("iterating EM");
        }
        if (LOG.isVerbose()) {
            LOG.verbose("iteration 0 - expectation value: " + assignProbabilitiesToInstances);
        }
        int i = 1;
        while (true) {
            if (i > this.maxiter && this.maxiter >= 0) {
                break;
            }
            double d = assignProbabilitiesToInstances;
            recomputeCovarianceMatrices(relation, makeStorage, buildInitialModels);
            assignProbabilitiesToInstances = assignProbabilitiesToInstances(relation, buildInitialModels, makeStorage);
            if (LOG.isVerbose()) {
                LOG.verbose("iteration " + i + " - expectation value: " + assignProbabilitiesToInstances);
            }
            if (Math.abs(d - assignProbabilitiesToInstances) <= this.delta || d > assignProbabilitiesToInstances) {
                break;
            }
            i++;
        }
        if (LOG.isVerbose()) {
            LOG.verbose("assigning clusters");
        }
        ArrayList arrayList = new ArrayList(this.k);
        for (int i2 = 0; i2 < this.k; i2++) {
            arrayList.add(DBIDUtil.newHashSet());
        }
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            double[] dArr = (double[]) makeStorage.get(iterDBIDs);
            int i3 = 0;
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.k; i4++) {
                if (dArr[i4] > d2) {
                    i3 = i4;
                    d2 = dArr[i4];
                }
            }
            ((ModifiableDBIDs) arrayList.get(i3)).add(iterDBIDs);
            iterDBIDs.advance();
        }
        Clustering<M> clustering = new Clustering<>("EM Clustering", "em-clustering");
        for (int i5 = 0; i5 < this.k; i5++) {
            clustering.addToplevelCluster(new Cluster<>((DBIDs) arrayList.get(i5), buildInitialModels.get(i5).finalizeCluster()));
        }
        if (isSoft()) {
            clustering.addChildResult(new MaterializedRelation("cluster assignments", "em-soft-score", SOFT_TYPE, makeStorage, relation.getDBIDs()));
        } else {
            makeStorage.destroy();
        }
        return clustering;
    }

    public static void recomputeCovarianceMatrices(Relation<? extends NumberVector> relation, WritableDataStore<double[]> writableDataStore, List<? extends EMClusterModel<?>> list) {
        Iterator<? extends EMClusterModel<?>> it = list.iterator();
        while (it.hasNext()) {
            it.next().beginEStep();
        }
        double[] dArr = new double[list.size()];
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            double[] dArr2 = writableDataStore.get(iterDBIDs);
            NumberVector numberVector = relation.get(iterDBIDs);
            int i = 0;
            for (EMClusterModel<?> eMClusterModel : list) {
                double d = dArr2[i];
                if (d > 0.0d) {
                    eMClusterModel.updateE(numberVector, d);
                }
                int i2 = i;
                dArr[i2] = dArr[i2] + d;
                i++;
            }
            iterDBIDs.advance();
        }
        int i3 = 0;
        for (EMClusterModel<?> eMClusterModel2 : list) {
            eMClusterModel2.finalizeEStep();
            eMClusterModel2.setWeight(dArr[i3] / relation.size());
            i3++;
        }
    }

    public static double assignProbabilitiesToInstances(Relation<? extends NumberVector> relation, List<? extends EMClusterModel<?>> list, WritableDataStore<double[]> writableDataStore) {
        int size = list.size();
        double d = 0.0d;
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            NumberVector numberVector = relation.get(iterDBIDs);
            double[] dArr = new double[size];
            int i = 0;
            Iterator<? extends EMClusterModel<?>> it = list.iterator();
            while (it.hasNext()) {
                dArr[i] = it.next().estimateDensity(numberVector);
                i++;
            }
            double d2 = 0.0d;
            for (int i2 = 0; i2 < size; i2++) {
                d2 += dArr[i2];
            }
            double max = Math.max(Math.log(d2), MIN_LOGLIKELIHOOD);
            d += max == max ? max : 0.0d;
            double[] dArr2 = new double[size];
            if (d2 > 0.0d) {
                for (int i3 = 0; i3 < size; i3++) {
                    dArr2[i3] = dArr[i3] / d2;
                }
            }
            writableDataStore.put(iterDBIDs, dArr2);
            iterDBIDs.advance();
        }
        return d / relation.size();
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm
    public Logging getLogger() {
        return LOG;
    }

    public boolean isSoft() {
        return this.soft;
    }

    public void setSoft(boolean z) {
        this.soft = z;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public /* bridge */ /* synthetic */ Clustering run(Database database) {
        return (Clustering) super.run(database);
    }
}
