/*
 * Decompiled with CFR 0.152.
 */
package org.meteoinfo.math.stats.kde;

import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import org.apache.commons.math4.core.jdkmath.AccurateMath;
import org.meteoinfo.math.stats.kde.Ball;
import org.meteoinfo.math.stats.kde.Event;

public class BallTree {
    private static final String NAME = "BallTree";
    private Ball headBall;
    int minParent;
    public static int numSkipped = 0;

    protected BallTree(List<Event> data, int minParent) {
        this.minParent = minParent;
        this.headBall = this.splitBall(data, 0);
    }

    private Ball splitBall(List<Event> data, int feature) {
        int numPoints = data.size();
        if (numPoints <= this.minParent) {
            return new Ball(data);
        }
        double[] ll = new double[]{Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY};
        double[] ur = new double[]{Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY};
        double minBW = Double.MAX_VALUE;
        double maxBW = -1.0;
        TreeMap<DoubleEqual, Event> sorted = new TreeMap<DoubleEqual, Event>();
        for (Event e : data) {
            minBW = Math.min(minBW, e.getH());
            maxBW = Math.max(maxBW, e.getH());
            double[] point = e.getPoint();
            ll[0] = Math.min(ll[0], point[0]);
            ll[1] = Math.min(ll[1], point[1]);
            ur[0] = Math.max(ur[0], point[0]);
            ur[1] = Math.max(ur[1], point[1]);
            DoubleEqual de = new DoubleEqual(point[feature]);
            sorted.put(de, e);
        }
        int medIndex = (int)Math.ceil((double)numPoints / 2.0) - 1;
        ArrayList<Event> rightBranchData = new ArrayList<Event>(medIndex);
        ArrayList<Event> leftBranchData = new ArrayList<Event>(medIndex);
        Event medianEvent = null;
        int i = -1;
        for (Event e : sorted.values()) {
            if (++i == medIndex) {
                medianEvent = e;
                continue;
            }
            if (i < medIndex) {
                leftBranchData.add(e);
                continue;
            }
            rightBranchData.add(e);
        }
        Ball leftBranch = this.splitBall(leftBranchData, (feature + 1) % 2);
        Ball rightBranch = this.splitBall(rightBranchData, (feature + 1) % 2);
        return new Ball(ll, ur, medianEvent, numPoints, leftBranch, rightBranch, minBW, maxBW);
    }

    protected List<Double> logPdfRecurse(Event e) {
        return BallTree.logPdfRecurse(e, this.headBall);
    }

    private static List<Double> logPdfRecurse(Event e, Ball ball) {
        ArrayList<Double> logValues = new ArrayList<Double>();
        if (ball.events != null) {
            logValues.addAll(BallTree.computeLogKernel(e.getPoint(), ball.events));
            return logValues;
        }
        double minPdf = ball.minPdf(e);
        double maxPdf = ball.maxPdf(e);
        if (AccurateMath.exp((double)maxPdf) - AccurateMath.exp((double)minPdf) < 0.001) {
            logValues.add(AccurateMath.log((double)ball.numPoints) + (maxPdf + minPdf) / 2.0);
            numSkipped += ball.numPoints;
            return logValues;
        }
        logValues.add(BallTree.computeLogKernel(e.getPoint(), ball.event.getPoint(), ball.event.getH()));
        logValues.addAll(BallTree.logPdfRecurse(e, ball.leftBall));
        logValues.addAll(BallTree.logPdfRecurse(e, ball.rightBall));
        return logValues;
    }

    protected static List<Double> computeLogKernel(double[] y, List<Event> samples) {
        ArrayList<Double> logValues = new ArrayList<Double>(samples.size());
        for (Event s : samples) {
            logValues.add(BallTree.computeLogKernel(y, s.getPoint(), s.getH()));
        }
        return logValues;
    }

    protected static double computeLogKernel(double[] y, double[] s, double h) {
        return BallTree.computeLogKernel(y, s, h, h);
    }

    protected static double computeLogKernel(double[] y, double[] s, double hMax, double hMin) {
        double invH = 1.0 / hMin;
        double ones = y[0] - s[0];
        double twos = y[1] - s[1];
        double expVal = ones * (invH * ones) + twos * (invH * twos);
        double firstVal = -AccurateMath.log((double)(Math.PI * 2)) - AccurateMath.log((double)hMax);
        double value = firstVal + -0.5 * expVal;
        return value;
    }

    private String className() {
        return NAME;
    }

    private static class DoubleEqual
    implements Comparable<DoubleEqual> {
        Double val;

        public DoubleEqual(double v) {
            this.val = v;
        }

        @Override
        public int compareTo(DoubleEqual o) {
            if (this.val >= o.val) {
                return 1;
            }
            return -1;
        }
    }
}

