package de.lmu.ifi.dbs.elki.algorithm.classification;

import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.DistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.data.ClassLabel;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDListIter;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.result.Result;
import de.lmu.ifi.dbs.elki.utilities.Priority;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.Collections;

@Priority(100)
@Description("Lazy classifier classifies a given instance to the majority class of the k-nearest neighbors.")
@Title("kNN-classifier")
/* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/classification/KNNClassifier.class */
public class KNNClassifier<O> extends AbstractAlgorithm<Result> implements DistanceBasedAlgorithm<O>, Classifier<O> {
    private static final Logging LOG = Logging.getLogger((Class<?>) KNNClassifier.class);
    protected int k;
    protected KNNQuery<O> knnq;
    protected Relation<? extends ClassLabel> labelrep;
    protected DistanceFunction<? super O> distanceFunction;

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/classification/KNNClassifier$Parameterizer.class */
    public static class Parameterizer<O> extends AbstractParameterizer {
        public static final OptionID K_ID = new OptionID("knnclassifier.k", "The number of neighbors to take into account for classification.");
        protected DistanceFunction<? super O> distanceFunction;
        protected int k;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            ObjectParameter objectParameter = new ObjectParameter(DistanceBasedAlgorithm.DISTANCE_FUNCTION_ID, (Class<?>) DistanceFunction.class, (Class<?>) EuclideanDistanceFunction.class);
            if (parameterization.grab(objectParameter)) {
                this.distanceFunction = (DistanceFunction) objectParameter.instantiateClass(parameterization);
            }
            IntParameter intParameter = (IntParameter) new IntParameter(K_ID, 1).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.k = intParameter.intValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public KNNClassifier<O> makeInstance() {
            return new KNNClassifier<>(this.distanceFunction, this.k);
        }
    }

    public KNNClassifier(DistanceFunction<? super O> distanceFunction, int i) {
        this.distanceFunction = distanceFunction;
        this.k = i;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.classification.Classifier
    public void buildClassifier(Database database, Relation<? extends ClassLabel> relation) {
        this.knnq = database.getKNNQuery(database.getDistanceQuery(database.getRelation(getDistanceFunction().getInputTypeRestriction(), new Object[0]), getDistanceFunction(), new Object[0]), Integer.valueOf(this.k));
        this.labelrep = relation;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // de.lmu.ifi.dbs.elki.algorithm.classification.Classifier
    public ClassLabel classify(O o) {
        Object2IntOpenHashMap object2IntOpenHashMap = new Object2IntOpenHashMap();
        DoubleDBIDListIter iter = this.knnq.getKNNForObject(o, this.k).iter();
        while (iter.valid()) {
            object2IntOpenHashMap.addTo(this.labelrep.get(iter), 1);
            iter.advance();
        }
        int i = Integer.MIN_VALUE;
        ClassLabel classLabel = null;
        ObjectIterator fastIterator = object2IntOpenHashMap.object2IntEntrySet().fastIterator();
        while (fastIterator.hasNext()) {
            Object2IntMap.Entry entry = (Object2IntMap.Entry) fastIterator.next();
            if (entry.getIntValue() > i) {
                i = entry.getIntValue();
                classLabel = (ClassLabel) entry.getKey();
            }
        }
        return classLabel;
    }

    public double[] classProbabilities(O o, ArrayList<ClassLabel> arrayList) {
        int[] iArr = new int[arrayList.size()];
        DoubleDBIDListIter iter = this.knnq.getKNNForObject(o, this.k).iter();
        while (iter.valid()) {
            int binarySearch = Collections.binarySearch(arrayList, this.labelrep.get(iter));
            if (binarySearch >= 0) {
                iArr[binarySearch] = iArr[binarySearch] + 1;
            }
            iter.advance();
        }
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = iArr[i] / r0.size();
        }
        return dArr;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.classification.Classifier
    public String model() {
        return "lazy learner - provides no model";
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    @Deprecated
    public Result run(Database database) throws IllegalStateException {
        throw new AbortException("Classifiers cannot auto-run on a database, but need to be trained and can then predict.");
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.DistanceBasedAlgorithm
    public DistanceFunction<? super O> getDistanceFunction() {
        return this.distanceFunction;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(TypeUtil.NUMBER_VECTOR_FIELD);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm
    public Logging getLogger() {
        return LOG;
    }
}
