You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@commons.apache.org by "Artem Barger (JIRA)" <ji...@apache.org> on 2016/06/23 11:58:16 UTC

[jira] [Created] (MATH-1378) KMeansPlusPlusClusterer optimize seeding procedure, by computing sum of squared distances outside the loop.

Artem Barger created MATH-1378:
----------------------------------

             Summary: KMeansPlusPlusClusterer optimize seeding procedure, by computing sum of squared distances outside the loop.
                 Key: MATH-1378
                 URL: https://issues.apache.org/jira/browse/MATH-1378
             Project: Commons Math
          Issue Type: Improvement
            Reporter: Artem Barger
            Assignee: Artem Barger


Currently in KMeansPlusPlusClusterer class,  function which implements initial clusters seeding *chooseInitialCenters*, has following computation executed inside the while loop cycle:

{code}
        while (resultSet.size() < k) {

            // Sum up the squared distances for the points in pointList not
            // already taken.
            double distSqSum = 0.0;

            for (int i = 0; i < numPoints; i++) {
                if (!taken[i]) {
                    distSqSum += minDistSquared[i];
                }
            }

// Rest skipped for simplicity
{code}

While computation of this sum could be produced once outside the loop and latter adjusted according to the values of minimum distances to the centers set. E.g.:

{code}
        // Sum up the squared distances for the points in pointList not
        // already taken.
        double distSqSum = 0.0;

        // There is no need to compute sum of squared distances within the "while" loop
        // we can compute initial value ones and maintain deltas in the loop.
        for (int i = 0; i < numPoints; i++) {
            if (!taken[i]) {
                distSqSum += minDistSquared[i];
            }
        }

        while (resultSet.size() < k) {
            // Add one new data point as a center. Each point x is chosen with
            // probability proportional to D(x)2
            final double r = random.nextDouble() * distSqSum;

            // The index of the next point to be added to the resultSet.
            int nextPointIndex = -1;

            // Sum through the squared min distances again, stopping when
            // sum >= r.
            double sum = 0.0;
            for (int i = 0; i < numPoints; i++) {
                if (!taken[i]) {
                    sum += minDistSquared[i];
                    if (sum >= r) {
                        nextPointIndex = i;
                        break;
                    }
                }
            }

            // If it's not set to >= 0, the point wasn't found in the previous
            // for loop, probably because distances are extremely small.  Just pick
            // the last available point.
            if (nextPointIndex == -1) {
                for (int i = numPoints - 1; i >= 0; i--) {
                    if (!taken[i]) {
                        nextPointIndex = i;
                        break;
                    }
                }
            }

            // We found one.
            if (nextPointIndex >= 0) {

                final T p = pointList.get(nextPointIndex);

                resultSet.add(new CentroidCluster<T> (p));

                // Mark it as taken.
                taken[nextPointIndex] = true;

                if (resultSet.size() < k) {
                    // Now update elements of minDistSquared.  We only have to compute
                    // the distance to the new center to do this.
                    for (int j = 0; j < numPoints; j++) {
                        // Only have to worry about the points still not taken.
                        if (!taken[j]) {
                            double d = distance(p, pointList.get(j));
                            // Subtracting the old value.
                            distSqSum -= minDistSquared[j];
                            // Update minimum distance.
                            minDistSquared[j] = FastMath.min(d*d, minDistSquared[j]);
                            // Adjust the overall sum of squared distances.
                            distSqSum += minDistSquared[j];
                        }
                    }
                }

            } else {
                // None found --
                // Break from the while loop to prevent
                // an infinite loop.
                break;
            }
        }
{code}



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)