You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2019/09/18 00:39:58 UTC

[madlib] branch master updated: Kmeans: Add automatic optimal cluster estimation

This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new bc93e25  Kmeans: Add automatic optimal cluster estimation
bc93e25 is described below

commit bc93e25c6790b814bbc6079601e42c13a4f2ad08
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Thu Aug 15 20:42:33 2019 -0400

    Kmeans: Add automatic optimal cluster estimation
    
    JIRA: MADLIB-1380
    
    This commit adds the option to run k-means clustering algorithm for a range of
    `k` values and get the optimal `k` with its associated cluster centers. It is
    only supported for random and pp initial seeding options.
    
    Closes #433
    
    Co-authored-by: Nikhil Kak <nk...@pivotal.io>
    Co-authored-by: Ekta Khanna <ek...@pivotal.io>
---
 src/ports/postgres/modules/kmeans/kmeans.sql_in    | 177 +++++++++++++++-
 .../postgres/modules/kmeans/kmeans_auto.py_in      | 223 +++++++++++++++++++++
 .../postgres/modules/kmeans/test/kmeans.sql_in     | 115 +++++++++++
 .../modules/kmeans/test/unit_tests/plpy_mock.py_in |  43 ++++
 .../kmeans/test/unit_tests/test_kmeans_auto.py_in  |  87 ++++++++
 src/ports/postgres/modules/knn/knn.py_in           |  23 +--
 .../test/unit_tests/test_validate_args.py_in       |  21 +-
 .../postgres/modules/utilities/validate_args.py_in |  21 +-
 8 files changed, 685 insertions(+), 25 deletions(-)

diff --git a/src/ports/postgres/modules/kmeans/kmeans.sql_in b/src/ports/postgres/modules/kmeans/kmeans.sql_in
index 81dea80..1eae525 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.sql_in
@@ -999,7 +999,7 @@ BEGIN
     sampled_rel_source = MADLIB_SCHEMA.__unique_string();
     sampled_col_name = MADLIB_SCHEMA.__unique_string();
     IF (seeding_sample_ratio < 1.0) THEN
-        EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source;
+        EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source||' CASCADE';
         EXECUTE 'CREATE TEMP TABLE '||sampled_rel_source||' AS
             SELECT *
             FROM
@@ -1059,7 +1059,7 @@ BEGIN
     EXECUTE 'SET client_min_messages TO ' || oldClientMinMessages;
 
     IF (seeding_sample_ratio < 1.0) THEN
-        EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source;
+        EXECUTE 'DROP TABLE IF EXISTS '||sampled_rel_source||' CASCADE';
     END IF;
 
     RETURN theResult;
@@ -1120,6 +1120,7 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
     min_frac_reassigned
 )</pre>
  */
+
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp(
     rel_source VARCHAR,
     expr_point VARCHAR,
@@ -1733,3 +1734,175 @@ AS $$
         'MADLIB_SCHEMA.dist_norm2')
 $$
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
+
+/**
+ * @brief Run auto k-Means.
+ *
+ *
+ */
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */,
+    min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+    seeding_sample_ratio DOUBLE PRECISION  /*+ DEFAULT 1.0 */,
+    k_selection_algorithm VARCHAR /*+ DEFAULT 'silhouette' */
+) RETURNS VOID AS $$
+    PythonFunction(`kmeans', `kmeans_auto', `kmeanspp_auto')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */,
+    min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+    seeding_sample_ratio DOUBLE PRECISION  /*+ DEFAULT 1.0 */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */,
+    min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, $7, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, $5, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeanspp_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[]
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeanspp_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */,
+    min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */,
+    k_selection_algorithm VARCHAR /*+ DEFAULT 'silhouette' */
+) RETURNS VOID AS $$
+    PythonFunction(`kmeans', `kmeans_auto', `kmeans_random_auto')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */,
+    min_frac_reassigned DOUBLE PRECISION /*+ DEFAULT 0.001 */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, $7, $8, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */,
+    max_num_iterations INTEGER /*+ DEFAULT 20 */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, $7, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */,
+    agg_centroid VARCHAR /*+ DEFAULT 'avg' */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, $6, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[],
+    fn_dist VARCHAR /*+ DEFAULT 'squared_dist_norm2' */
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, $5, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    expr_point VARCHAR,
+    k INTEGER[]
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
diff --git a/src/ports/postgres/modules/kmeans/kmeans_auto.py_in b/src/ports/postgres/modules/kmeans/kmeans_auto.py_in
new file mode 100644
index 0000000..5eb7d3c
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/kmeans_auto.py_in
@@ -0,0 +1,223 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+@file kmeans_auto.py_in
+
+@brief
+
+"""
+
+import numpy as np
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import output_tbl_valid
+from utilities.validate_args import get_algorithm_name
+
+ELBOW = 'elbow'
+SILHOUETTE = 'silhouette'
+BOTH = 'both'
+
+RANDOM = 'random'
+PP = 'pp'
+
+def _validate(output_table, k):
+
+    output_tbl_valid(output_table, "kmeans_auto")
+    output_tbl_valid('{0}_summary'.format(output_table), "kmeans_auto")
+
+    _assert(k, "kmeans_auto: k cannot be NULL.")
+    _assert(len(k)>1, "kmeans_auto: Length of k array should be more than 1.")
+    _assert(min(k)>1, "kmeans_auto: the minimum k value has to be > 1.")
+    _assert(len(set(k)) == len(k), "kmeans_auto: Duplicate values are not allowed in k.")
+
+
+def set_defaults(schema_madlib, fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned, k_selection_algorithm, seeding, seeding_sample_ratio):
+
+    fn_dist = (fn_dist if fn_dist else '{0}.squared_dist_norm2'.format(schema_madlib))
+    agg_centroid = agg_centroid if agg_centroid \
+                    else '{0}.avg'.format(schema_madlib)
+    max_num_iterations = max_num_iterations if max_num_iterations \
+                          else 20
+    min_frac_reassigned = min_frac_reassigned if min_frac_reassigned \
+                           else 0.001
+
+    k_selection_algorithm = get_algorithm_name(k_selection_algorithm, SILHOUETTE,
+        [ELBOW, SILHOUETTE, BOTH], 'kmeans_auto')
+
+    if seeding is PP:
+        seeding_sample_ratio = (seeding_sample_ratio
+                                if seeding_sample_ratio is not None else 1.0)
+    return (fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+            k_selection_algorithm, seeding_sample_ratio)
+
+def kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+    fn_dist=None, agg_centroid=None, max_num_iterations=None,
+    min_frac_reassigned=None, k_selection_algorithm=None, seeding=None,
+    seeding_sample_ratio=None, **kwargs):
+
+    with MinWarning("error"):
+        _validate(output_table, k)
+
+        (fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+         k_selection_algorithm, seeding_sample_ratio) = set_defaults(
+            schema_madlib, fn_dist, agg_centroid, max_num_iterations,
+            min_frac_reassigned, k_selection_algorithm, seeding,
+            seeding_sample_ratio)
+
+        silhouette_col = ""
+        elbow_col = ""
+
+        # If the selection is elbow or both, calculate elbow
+        use_silhouette = k_selection_algorithm in [SILHOUETTE, BOTH]
+        # If the selection is silhouette or both, calculate silhouette
+        use_elbow = k_selection_algorithm in [ELBOW, BOTH]
+
+        if use_silhouette:
+            silhouette_col = ", {0} DOUBLE PRECISION".format(SILHOUETTE)
+        if use_elbow:
+            elbow_col = ", {0} DOUBLE PRECISION".format(ELBOW)
+
+        plpy.execute("""
+            CREATE TABLE {output_table} (
+                k INTEGER,
+                centroids   DOUBLE PRECISION[][],
+                cluster_variance    DOUBLE PRECISION[],
+                objective_fn    DOUBLE PRECISION,
+                frac_reassigned DOUBLE PRECISION,
+                num_iterations  INTEGER
+                {silhouette_col}
+                {elbow_col})
+            """.format(**locals()))
+
+        silhouette_vals = []
+
+        for current_k in k:
+            if seeding is 'random':
+                plpy.execute("""
+                    INSERT INTO {output_table}
+                    (k, centroids, cluster_variance, objective_fn, frac_reassigned,
+                    num_iterations)
+                    SELECT {current_k} as k, *
+                    FROM {schema_madlib}.kmeans_random('{rel_source}',
+                                         '{expr_point}',
+                                         {current_k},
+                                         '{fn_dist}',
+                                         '{agg_centroid}',
+                                         {max_num_iterations},
+                                         {min_frac_reassigned});
+                    """.format(**locals()))
+            else:
+                plpy.execute("""
+                    INSERT INTO {output_table}
+                    (k, centroids, cluster_variance, objective_fn, frac_reassigned,
+                    num_iterations)
+                    SELECT {current_k} as k, *
+                    FROM {schema_madlib}.kmeanspp('{rel_source}',
+                                         '{expr_point}',
+                                         {current_k},
+                                         '{fn_dist}',
+                                         '{agg_centroid}',
+                                         {max_num_iterations},
+                                         {min_frac_reassigned},
+                                         {seeding_sample_ratio});
+                    """.format(**locals()))
+
+            if use_silhouette:
+                silhouette_query= """
+                    SELECT * FROM {schema_madlib}.simple_silhouette(
+                        '{rel_source}',
+                        '{expr_point}',
+                        (SELECT centroids
+                         FROM {output_table}
+                         WHERE k = {current_k}),
+                        '{fn_dist}')
+                    """.format(**locals())
+                silhouette_vals.append(
+                    plpy.execute(silhouette_query)[0]['simple_silhouette'])
+
+        update_query = """
+            UPDATE {output_table} SET {{column}} = __value__ FROM
+            (SELECT unnest(ARRAY[{k_arr}]) AS __k__,
+                    unnest(ARRAY[{{calc_arr}}]) AS __value__
+            )sub_q
+            WHERE __k__ = k
+            """.format(output_table = output_table,
+                       k_arr = str(k)[1:-1])
+        if use_silhouette:
+            optimal_sil =  k[np.argmax(np.array(silhouette_vals))]
+            plpy.execute(update_query.format(column = SILHOUETTE,
+                calc_arr = str(silhouette_vals)[1:-1]))
+
+        if use_elbow:
+            optimal_elbow, second_order = _calculate_elbow(output_table)
+            plpy.execute(update_query.format(column = ELBOW,
+                calc_arr = str(second_order)[1:-1]))
+
+        optimal_k = optimal_sil if use_silhouette else optimal_elbow
+
+        plpy.execute("""
+            CREATE TABLE {output_table}_summary AS
+            SELECT {output_table}.*,
+                   '{algorithm}'::VARCHAR AS selection_algorithm
+            FROM {output_table}
+            WHERE k = {optimal_k}
+            """.format(algorithm = SILHOUETTE if use_silhouette else ELBOW,
+                       **locals()))
+
+    return
+
+def _calculate_elbow(output_table):
+
+    # We have to get the values in ordered fashion because the elbow is only defined for ordered values.
+    inertia_result = plpy.execute("""
+                 SELECT k, objective_fn FROM {output_table} ORDER BY k ASC
+                 """.format(**locals()))
+    k = [ i['k'] for i in inertia_result ]
+    inertia_list = [ i['objective_fn'] for i in inertia_result ]
+    inertia_list = np.array(inertia_list)
+
+    first_order=np.gradient(inertia_list, k)
+    second_order=np.gradient(first_order, k)
+    index_with_elbow=k[np.argmax(second_order)]
+
+    return index_with_elbow, second_order.tolist()
+
+def kmeans_random_auto(schema_madlib, rel_source, output_table, expr_point, k,
+    fn_dist=None, agg_centroid=None, max_num_iterations=None,
+    min_frac_reassigned=None, k_selection_algorithm=None, **kwargs):
+
+    kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+    fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+    k_selection_algorithm, RANDOM)
+
+    return
+
+def kmeanspp_auto(schema_madlib, rel_source, output_table, expr_point, k,
+    fn_dist=None, agg_centroid=None, max_num_iterations=None,
+    min_frac_reassigned=None, seeding_sample_ratio=None,
+    k_selection_algorithm=None, **kwargs):
+
+    kmeans_auto(schema_madlib, rel_source, output_table, expr_point, k,
+    fn_dist, agg_centroid, max_num_iterations, min_frac_reassigned,
+    k_selection_algorithm, PP, seeding_sample_ratio)
+
+    return
diff --git a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
index 8d790fa..4553b6c 100644
--- a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
@@ -91,6 +91,121 @@ SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]');
 SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
 SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
 
+-- Test kmeanspp_auto
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'elbow');
+
+SELECT assert(
+        elbow > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_pp failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'silhouette');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_pp failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 0.1, 'both');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_pp failed for both.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8]);
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_pp failed for default.')
+FROM autokm_out_summary;
+
+-- Test kmeans_random_auto
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'elbow');
+
+SELECT assert(
+        elbow > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'El');
+
+SELECT assert(
+        elbow > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for elbow.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[5,6,7,8]);
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for default.')
+FROM autokm_out_summary;
+
+
+SELECT assert(count(*) = 4, 'Kmeans: Auto Kmeans_random output has incorrect number of rows')
+FROM (SELECT * FROM autokm_out WHERE k = any(ARRAY[5,6,7,8]))q;
+
+-- Unordered k list test
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[12,3,5,6,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouette');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for silhouette on unordered k vals')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouetTe');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'siL');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0,
+        'Kmeans: Auto Kmeans_random failed for silhouette.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'both');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0 AND
+        selection_algorithm = 'silhouette',
+        'Kmeans: Auto Kmeans_random failed for both.')
+FROM autokm_out_summary;
+
+DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
+SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001, 'b');
+
+SELECT assert(
+        silhouette > 0 AND objective_fn > 0 AND
+        selection_algorithm = 'silhouette',
+        'Kmeans: Auto Kmeans_random failed for both.')
+FROM autokm_out_summary;
+
 DROP TABLE IF EXISTS km_sample CASCADE;
 DROP TABLE IF EXISTS centroids CASCADE;
 DROP TABLE IF EXISTS kmeans_2d CASCADE;
diff --git a/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in b/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in
new file mode 100644
index 0000000..dd18649
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/test/unit_tests/plpy_mock.py_in
@@ -0,0 +1,43 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+def __init__(self):
+    pass
+
+def error(message):
+    raise PLPYException(message)
+
+def execute(query):
+    pass
+
+def warning(query):
+    pass
+
+def info(query):
+    print query
+
+
+class PLPYException(Exception):
+    def __init__(self, message):
+        super(PLPYException, self).__init__()
+        self.message = message
+
+    def __str__(self):
+        return repr(self.message)
diff --git a/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in b/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in
new file mode 100644
index 0000000..b56f4ad
--- /dev/null
+++ b/src/ports/postgres/modules/kmeans/test/unit_tests/test_kmeans_auto.py_in
@@ -0,0 +1,87 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import numpy as np
+from os import path
+
+# Add modules to the pythonpath.
+sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
+sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+
+import unittest
+from mock import *
+import plpy_mock as plpy
+
+m4_changequote(`<!', `!>')
+
+class KmeansAutoTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy,
+            'utilities.mean_std_dev_calculator': Mock(),
+        }
+        # we need to use MagicMock() instead of Mock() for the plpy.execute mock
+        # to be able to iterate on the return value
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+
+        self.default_schema_madlib = "madlib"
+        self.default_source_table = "source"
+        self.default_output_table = "output"
+
+        import kmeans_auto
+        self.module = kmeans_auto
+
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_calculate_elbow_evenly_spaced(self):
+
+        self.plpy_mock_execute.return_value = [
+            {'k':2, 'objective_fn':100 },
+            {'k':3, 'objective_fn':50 },
+            {'k':4, 'objective_fn':25 },
+            {'k':5, 'objective_fn':20 },
+            {'k':6, 'objective_fn':10 }
+        ]
+        elbow,_ = self.module._calculate_elbow('foo')
+        self.assertEqual(3, elbow)
+
+    def test_calculate_elbow_unevenly_spaced(self):
+
+        self.plpy_mock_execute.return_value = [
+            {'k':2, 'objective_fn':100 },
+            {'k':4, 'objective_fn':80 },
+            {'k':6, 'objective_fn':25 },
+            {'k':7, 'objective_fn':20 },
+            {'k':8, 'objective_fn':10 }
+        ]
+        elbow,_ = self.module._calculate_elbow('foo')
+        self.assertEqual(6, elbow)
+
+if __name__ == '__main__':
+    unittest.main()
+
+# ---------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 6d681e2..eb2150e 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -50,6 +50,7 @@ from utilities.validate_args import input_tbl_valid, output_tbl_valid
 from utilities.validate_args import is_col_array
 from utilities.validate_args import is_var_valid
 from utilities.validate_args import quote_ident
+from utilities.validate_args import get_algorithm_name
 
 WEIGHT_FOR_ZERO_DIST = 1e107
 BRUTE_FORCE = 'brute_force'
@@ -421,25 +422,6 @@ def _create_interim_tbl(schema_madlib, point_source, point_column_name, point_id
 
 # ------------------------------------------------------------------------------
 
-def _get_algorithm_name(algorithm):
-    if not algorithm:
-        algorithm = BRUTE_FORCE
-    else:
-        supported_algorithms = [BRUTE_FORCE, KD_TREE]
-        try:
-            # allow user to specify a prefix substring of
-            # supported algorithms. This works because the supported
-            # algorithms have unique prefixes.
-            algorithm = next(x for x in supported_algorithms
-                               if x.startswith(algorithm))
-        except StopIteration:
-            # next() returns a StopIteration if no element found
-            plpy.error("kNN Error: Invalid algorithm: "
-                       "{0}. Supported algorithms are ({1})"
-                       .format(algorithm, ','.join(sorted(supported_algorithms))))
-    return algorithm
-# ------------------------------------------------------------------------------
-
 def knn(schema_madlib, point_source, point_column_name, point_id,
         label_column_name, test_source, test_column_name, test_id, output_table,
         k, output_neighbors, fn_dist, weighted_avg, algorithm, algorithm_params,
@@ -489,7 +471,8 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
         if k is None:
             k = 1
 
-        algorithm = _get_algorithm_name(algorithm)
+        algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
+            [BRUTE_FORCE, KD_TREE], 'kNN')
 
         # Default values for depth and leaf nodes
         depth = 3
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
index a3f2539..063d762 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_validate_args.py_in
@@ -49,11 +49,11 @@ class ValidateArgsTestCase(unittest.TestCase):
     def test_input_tbl_valid_null_tbl_raises_exception(self):
         with self.assertRaises(plpy.PLPYException):
           self.subject.input_tbl_valid(None, "unittest_module")
-        
+
     def test_input_tbl_valid_whitespaces_tbl_raises(self):
         with self.assertRaises(plpy.PLPYException):
           self.subject.input_tbl_valid("  ", "unittest_module")
-        
+
     def test_input_tbl_valid_table_not_exists_raises(self):
         self.subject.table_exists = Mock(return_value=False)
         with self.assertRaises(plpy.PLPYException):
@@ -113,5 +113,22 @@ class ValidateArgsTestCase(unittest.TestCase):
             self.subject.input_tbl_valid("foo", "unittest_module")
         self.assertNotIn('custom exception', str(error.exception))
 
+    def test_get_algorithm_name(self):
+        self.assertEqual('abc', self.subject.get_algorithm_name(
+            'abc', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+        self.assertEqual('aaa', self.subject.get_algorithm_name(
+            'aaa', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+        self.assertEqual('aaa', self.subject.get_algorithm_name(
+            'aa', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+        self.assertEqual('bcd', self.subject.get_algorithm_name(
+            'bc', 'aaa', ['aaa','abc','bcd'], 'qwerty'))
+
+        # If two options satisfy the given selection,
+        # pick the first one from the list
+        self.assertEqual('aaa', self.subject.get_algorithm_name(
+            'a', 'abc', ['aaa','abc','bcd'], 'qwerty'))
+        self.assertEqual('aqq', self.subject.get_algorithm_name(
+            'a', 'abc', ['aqq','abc','bcd'], 'qwerty'))
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index e0758d3..ea4d133 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -734,6 +734,26 @@ def does_exclude_reserved(targets, reserved):
                        ', '.join(intersect)))
 # -------------------------------------------------------------------------
 
+
+def get_algorithm_name(algorithm, default, supported_algorithms, module):
+    if not algorithm:
+        algorithm = default
+    else:
+        algorithm = algorithm.lower()
+        try:
+            # allow user to specify a prefix substring of
+            # supported algorithms. This works because the supported
+            # algorithms have unique prefixes.
+            algorithm = next(x for x in supported_algorithms
+                               if x.startswith(algorithm))
+        except StopIteration:
+            # next() returns a StopIteration if no element found
+            plpy.error("{0} Error: Invalid algorithm: "
+                       "{1}. Supported algorithms are ({2})"
+                       .format(module, algorithm,
+                        ','.join(sorted(supported_algorithms))))
+    return algorithm
+
 import unittest
 
 
@@ -749,6 +769,5 @@ class TestValidateFunctions(unittest.TestCase):
         self.assertEqual('Test123', unquote_ident('"Test123"'))
         self.assertEqual('test', unquote_ident('"test"'))
 
-
 if __name__ == '__main__':
     unittest.main()