You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2018/02/09 04:07:56 UTC

[1/2] madlib git commit: KNN: Add weighted averaging/voting by distance

Repository: madlib
Updated Branches:
  refs/heads/master c51da40a1 -> 7c6fea20b


KNN: Add weighted averaging/voting by distance

JIRA: MADLIB-1181

This commit adds option for weighting the voting (for classification) or
averaging (for regression) by the inverse of the distance between a test
point and the nearest neighbors.


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/f6aba420
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/f6aba420
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/f6aba420

Branch: refs/heads/master
Commit: f6aba42083c1017bb2092397cb34c20a8cd091e2
Parents: c51da40
Author: Himanshu Pandey <hp...@pivotal.io>
Authored: Tue Jan 16 09:32:41 2018 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Feb 8 20:00:58 2018 -0800

----------------------------------------------------------------------
 src/ports/postgres/modules/knn/knn.py_in       | 92 +++++++++++++++++----
 src/ports/postgres/modules/knn/knn.sql_in      | 75 +++++++++++++++--
 src/ports/postgres/modules/knn/test/knn.sql_in | 18 +++-
 3 files changed, 161 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index da67952..b9e7916 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -111,7 +111,7 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
 
 def knn(schema_madlib, point_source, point_column_name, point_id,
         label_column_name, test_source, test_column_name, test_id, output_table,
-        k, output_neighbors, fn_dist):
+        k, output_neighbors, fn_dist, weighted_avg):
     """
         KNN function to find the K Nearest neighbours
         Args:
@@ -142,6 +142,8 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                                         dist_angle , dist_tanimoto
                                         Or user defined function with signature
                                         DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION
+            @param weighted_avg         Calculates the Regression or classication of k-NN using
+                                        the weighted average method.
     """
     with MinWarning('warning'):
         k_val = knn_validate_src(schema_madlib, point_source,
@@ -167,6 +169,9 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
         knn_neighbors = ""
         label_out = ""
         cast_to_int = ""
+        view_def = ""
+        view_join = ""
+        view_grp_by = ""
 
         if output_neighbors:
             knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
@@ -178,11 +183,42 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
             if label_column_type in ['boolean', 'integer', 'text']:
                 is_classification = True
                 cast_to_int = '::INTEGER'
+            if weighted_avg:
+                pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(
+                    label_col_temp=label_col_temp)
+            else:
+                pred_out = ", avg({label_col_temp})".format(
+                    label_col_temp=label_col_temp)
 
-            pred_out = ", avg({label_col_temp})".format(**locals())
             if is_classification:
-                pred_out = (", {schema_madlib}.mode({label_col_temp})"
-                            ).format(**locals())
+                if weighted_avg:
+                    # This view is to calculate the max value of sum of the 1/distance grouped by label and Id.
+                    # And this max value will be the prediction for the
+                    # classification model.
+                    view_def = (""" WITH vw
+                                    AS (SELECT distinct on ({test_id_temp}) {test_id_temp},
+                                    max(data_sum) data_dist,
+                                    {label_col_temp}
+                                    FROM   (SELECT {test_id_temp},
+                                    {label_col_temp},
+                                    sum(1 / dist) data_sum
+                                    FROM   pg_temp.{interim_table}
+                                    GROUP  BY {test_id_temp},
+                                    {label_col_temp}) a
+                                    GROUP  BY {test_id_temp} , {label_col_temp})""").format(**locals())
+                    # This join is needed to get the max value of predicion
+                    # calculated above
+                    view_join = (" JOIN vw AS knn_vw "
+                                 "ON knn_temp.{test_id_temp} = knn_vw.{test_id_temp}").format(
+                        test_id_temp=test_id_temp)
+                    view_grp_by = ", knn_vw.{label_col_temp}".format(
+                        label_col_temp=label_col_temp)
+                    pred_out = ", knn_vw.{label_col_temp}".format(
+                        label_col_temp=label_col_temp)
+                else:
+                    pred_out = (", {schema_madlib}.mode({label_col_temp})"
+                                ).format(**locals())
+
             pred_out += " AS prediction"
             label_out = (", train.{label_column_name}{cast_to_int}"
                          " AS {label_col_temp}").format(**locals())
@@ -212,22 +248,27 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
             WHERE {y_temp_table}.r <= {k_val}
             """.format(**locals()))
 
-        plpy.execute(
-            """
+        plpy.execute("""
             CREATE TABLE {output_table} AS
-                SELECT {test_id_temp} AS id, {test_column_name}
+                {view_def}
+                SELECT knn_temp.{test_id_temp} AS id ,
+                    knn_test.data
                     {pred_out}
                     {knn_neighbors}
-                FROM pg_temp.{interim_table} AS knn_temp
-                    JOIN
-                    {test_source} AS knn_test ON
-                    knn_temp.{test_id_temp} = knn_test.{test_id}
-                GROUP BY {test_id_temp} , {test_column_name}
+                FROM   {interim_table}  AS knn_temp
+                JOIN {test_source} AS knn_test
+                ON knn_temp.{test_id_temp} = knn_test.id
+                {view_join}
+                GROUP  BY knn_temp.{test_id_temp},
+                        knn_test.{test_column_name}
+                        {view_grp_by}
+                    ORDER  BY knn_temp.{test_id_temp}
             """.format(**locals()))
 
         plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
         return
 
+
 def knn_help(schema_madlib, message, **kwargs):
     """
     Help function for knn
@@ -258,6 +299,7 @@ SELECT {schema_madlib}.knn(
     k,                  -- value of k. Default will go as 1
     output_neighbors    -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
     fn_dist             -- The name of the function to use to calculate the distance from a data point to a centroid.
+    weighted_avg         Calculates the Regression or classication of k-NN using the weighted average method.
     );
 
 -----------------------------------------------------------------------
@@ -340,7 +382,8 @@ SELECT * FROM {schema_madlib}.knn(
                 'knn_result_classification',  -- Output table
                  3,                    -- Number of nearest neighbors
                  True,                 -- True to list nearest-neighbors by id
-                 'madlib.squared_dist_norm2' -- Distance function
+                 'madlib.squared_dist_norm2', -- Distance function
+                 False                        -- False for not using weighted average
                 );
 SELECT * from knn_result_classification ORDER BY id;
 
@@ -360,7 +403,8 @@ SELECT * FROM {schema_madlib}.knn(
                 'knn_result_regression',  -- Output table
                  3,                    -- Number of nearest neighbors
                 True,                  -- True to list nearest-neighbors by id
-                'madlib.dist_norm2'    -- Distance function
+                'madlib.dist_norm2',   -- Distance function
+                False                  -- False for not using weighted average
                 );
 SELECT * FROM knn_result_regression ORDER BY id;
 
@@ -379,6 +423,25 @@ SELECT * FROM {schema_madlib}.knn(
                 3                      -- Number of nearest neighbors
                 );
 SELECT * FROM knn_result_list_neighbors ORDER BY id;
+
+--  Run KNN for classification using weighted average:
+DROP TABLE IF EXISTS knn_result_classification;
+SELECT * FROM {schema_madlib}.knn(
+                'knn_train_data',      -- Table of training data
+                'data',                -- Col name of training data
+                'id',                  -- Col name of id in train data
+                'label',               -- Training labels
+                'knn_test_data',       -- Table of test data
+                'data',                -- Col name of test data
+                'id',                  -- Col name of id in test data
+                'knn_result_classification',  -- Output table
+                 3,                    -- Number of nearest neighbors
+                 True,                 -- True to list nearest-neighbors by id
+                 'madlib.squared_dist_norm2', -- Distance function
+                 True                         -- Calculation using weighted average
+                );
+SELECT * from knn_result_classification ORDER BY id;
+
 """
         else:
             help_string = """
@@ -405,4 +468,3 @@ SELECT {schema_madlib}.knn('example')
 
     return help_string.format(schema_madlib=schema_madlib)
 # ------------------------------------------------------------------------------
-

http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index 9db702d..3139c15 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -79,7 +79,8 @@ knn( point_source,
      output_table,
      k,
      output_neighbors,
-     fn_dist
+     fn_dist,
+     weighted_avg
    )
 </pre>
 
@@ -145,6 +146,10 @@ The following distance functions can be used:
 <li><b>\ref dist_tanimoto</b>: tanimoto</li>
 <li><b>user defined function</b> with signature <tt>DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION</tt></li></ul></dd>
 
+<dt>weighted_avg (optional)</dt>
+<dd>BOOLEAN, default: FALSE. Calculates the Regression or classication 
+of k-NN using the weighted average method.
+
 </dl>
 
 
@@ -326,6 +331,39 @@ Result, with neighbors sorted from closest to furthest:
 (6 rows)
 </pre>
 
+
+-#   Run KNN for classification using the 
+weighted average:
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification;
+SELECT * FROM madlib.knn(
+                'knn_train_data',      -- Table of training data
+                'data',                -- Col name of training data
+                'id',                  -- Col name of id in train data
+                'label',               -- Training labels
+                'knn_test_data',       -- Table of test data
+                'data',                -- Col name of test data
+                'id',                  -- Col name of id in test data
+                'knn_result_classification',  -- Output table
+                 3,                    -- Number of nearest neighbors
+                 True,                 -- True to list nearest-neighbors by id
+                 'madlib.squared_dist_norm2', -- Distance function
+                 True                 -- For weighted average
+                );
+SELECT * FROM knn_result_classification ORDER BY id;
+</pre>
+<pre class="result">
+ id |  data   |     prediction      | k_nearest_neighbours 
+----+---------+---------------------+----------------------
+  1 | {2,1}   |                 1   | {2,1,3}
+  2 | {2,6}   |                 1   | {5,4,3}
+  3 | {15,40} |                 0   | {7,6,5}
+  4 | {12,1}  |                 1   | {4,5,3}
+  5 | {2,90}  |                 0   | {9,6,7}
+  6 | {50,45} |                 0   | {6,7,8}
+(6 rows)
+</pre>
+
 @anchor background
 @par Technical Background
 
@@ -397,7 +435,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     output_table VARCHAR,
     k INTEGER,
     output_neighbors BOOLEAN,
-    fn_dist TEXT
+    fn_dist TEXT, 
+    weighted_avg BOOLEAN
 ) RETURNS VARCHAR AS $$
     PythonFunctionBodyOnly(`knn', `knn')
     return knn.knn(
@@ -412,7 +451,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
         output_table,
         k,
         output_neighbors,
-        fn_dist
+        fn_dist,
+        weighted_avg
     )
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -428,13 +468,36 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     test_id VARCHAR,
     output_table VARCHAR,
     k INTEGER,
+    output_neighbors BOOLEAN,
+    fn_dist TEXT
+) RETURNS VARCHAR AS $$
+    DECLARE
+    returnstring VARCHAR;
+BEGIN
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11, FALSE);
+    RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+    point_source VARCHAR,
+    point_column_name VARCHAR,
+    point_id VARCHAR,
+    label_column_name VARCHAR,
+    test_source VARCHAR,
+    test_column_name VARCHAR,
+    test_id VARCHAR,
+    output_table VARCHAR,
+    k INTEGER,
     output_neighbors BOOLEAN
 ) RETURNS VARCHAR AS $$
 DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,
-                                     'MADLIB_SCHEMA.squared_dist_norm2');
+                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -455,7 +518,7 @@ DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,
-                                     'MADLIB_SCHEMA.squared_dist_norm2');
+                                     'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -475,7 +538,7 @@ DECLARE
     returnstring VARCHAR;
 BEGIN
     returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,
-                                     'MADLIB_SCHEMA.squared_dist_norm2');
+                                     'MADLIB_SCHEMA.squared_dist_norm2',FALSE);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE

http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 8c62dad..6a39c89 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -72,14 +72,15 @@ copy knn_test_data (id, data) from stdin delimiter '|';
 \.
 
 drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2');
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
 
 drop table if exists madlib_knn_result_classification;
 select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
 select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
+
 drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2');
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
 select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
 
 drop table if exists madlib_knn_result_regression;
@@ -87,7 +88,7 @@ select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id',
 select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
 
 drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
 
 drop table if exists madlib_knn_result_classification;
@@ -110,5 +111,16 @@ drop table if exists madlib_knn_result_regression;
 select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
 select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression;
 
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
 select knn();
 select knn('help');


[2/2] madlib git commit: KNN: Add window for distinct on + other changes

Posted by ri...@apache.org.
KNN: Add window for distinct on + other changes

This commit includes a few fixes:
1. 'DISTINCT ON' requires a window with partition to give the rows
corresponding to a distinct value. This was added to compute the label
with highest weighted votes.
2. Zero distances led to divide-by-zero issues during the weigthing
process. This was fixed by replacing the inverse of the distance with a
pre-defined large number.
3. Other changes include changing error messages and code cleanup.

Co-authored-by: Nandish Jayaram <nj...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/7c6fea20
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/7c6fea20
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/7c6fea20

Branch: refs/heads/master
Commit: 7c6fea20bac49ee75d1c68cef5f1248e6bcc78b1
Parents: f6aba42
Author: Rahul Iyer <ri...@apache.org>
Authored: Thu Feb 8 20:02:09 2018 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Feb 8 20:02:09 2018 -0800

----------------------------------------------------------------------
 src/ports/postgres/modules/knn/knn.py_in | 216 +++++++++++++++-----------
 1 file changed, 121 insertions(+), 95 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/7c6fea20/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index b9e7916..cfd93d9 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -20,7 +20,7 @@
 """
 @file knn.py_in
 
-@brief knn: Driver functions
+@brief knn: K-Nearest Neighbors for regression and classification
 
 @namespace knn
 
@@ -32,9 +32,12 @@ from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import is_col_array
 from utilities.validate_args import array_col_has_no_null
 from utilities.validate_args import get_expr_type
+from utilities.utilities import _assert
 from utilities.utilities import unique_string
 from utilities.control import MinWarning
 
+MAX_WEIGHT_ZERO_DIST = 1e6
+
 
 def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                      label_column_name, test_source, test_column_name,
@@ -43,21 +46,22 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
     input_tbl_valid(point_source, 'kNN')
     input_tbl_valid(test_source, 'kNN')
     output_tbl_valid(output_table, 'kNN')
-    if label_column_name is not None and label_column_name != '':
-        cols_in_tbl_valid(
-            point_source,
-            (label_column_name,
-             point_column_name),
-            'kNN')
+
+    _assert(label_column_name or output_neighbors,
+            "kNN error: Either label_column_name or "
+            "output_neighbors has to be inputed.")
+
+    if label_column_name and label_column_name.strip():
+        cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
     cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
     cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
 
     if not is_col_array(point_source, point_column_name):
         plpy.error("kNN Error: Feature column '{0}' in train table is not"
-                   " an array.").format(point_column_name)
+                   " an array.".format(point_column_name))
     if not is_col_array(test_source, test_column_name):
         plpy.error("kNN Error: Feature column '{0}' in test table is not"
-                   " an array.").format(test_column_name)
+                   " an array.".format(test_column_name))
 
     if not array_col_has_no_null(point_source, point_column_name):
         plpy.error("kNN Error: Feature column '{0}' in train table has some"
@@ -66,44 +70,46 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
         plpy.error("kNN Error: Feature column '{0}' in test table has some"
                    " NULL values.".format(test_column_name))
 
-    if k is None:
-        k = 1
     if k <= 0:
-        plpy.error("kNN Error: k={0} is an invalid value, must be greater"
+        plpy.error("kNN Error: k={0} is an invalid value, must be greater "
                    "than 0.".format(k))
+
     bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}".
                          format(k=k, tbl=point_source))[0]['bound']
     if not bound:
         plpy.error("kNN Error: k={0} is greater than number of rows in"
                    " training table.".format(k))
 
-    if label_column_name is not None and label_column_name != '':
+    if label_column_name:
         col_type = get_expr_type(label_column_name, point_source).lower()
         if col_type not in ['integer', 'double precision', 'float', 'boolean']:
-            plpy.error("kNN error: Data type '{0}' is not a valid type for"
-                       " column '{1}' in table '{2}'.".
-                       format(col_type, label_column_name, point_source))
+            plpy.error("kNN error: Invalid data type '{0}' for"
+                       " label_column_name in table '{1}'.".
+                       format(col_type, point_source))
 
     col_type_test = get_expr_type(test_id, test_source).lower()
     if col_type_test not in ['integer']:
-        plpy.error("kNN Error: Data type '{0}' is not a valid type for"
-                   " column '{1}' in table '{2}'.".
-                   format(col_type_test, test_id, test_source))
+        plpy.error("kNN Error: Invalid data type '{0}' for"
+                   " test_id column in table '{1}'.".
+                   format(col_type_test, test_source))
 
     if fn_dist:
         fn_dist = fn_dist.lower().strip()
-        dist_functions = set([schema_madlib + dist for dist in
-                              ('.dist_norm1', '.dist_norm2', '.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')])
-
-        is_invalid_func = plpy.execute(
-            """select prorettype != 'DOUBLE PRECISION'::regtype
-                OR proisagg = TRUE AS OUTPUT from pg_proc where
-                oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
-                """.format(**locals()))[0]['output']
-
-        if is_invalid_func or fn_dist not in dist_functions:
-            plpy.error(
-                "KNN error: Distance function has wrong signature or is not a simple function.")
+        dist_functions = set(["{0}.{1}".format(schema_madlib, dist) for dist in
+                              ('dist_norm1', 'dist_norm2',
+                               'squared_dist_norm2', 'dist_angle',
+                               'dist_tanimoto')])
+
+        is_invalid_func = plpy.execute("""
+            SELECT prorettype != 'DOUBLE PRECISION'::regtype OR
+                   proisagg = TRUE AS OUTPUT
+            FROM pg_proc
+            WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
+            """.format(fn_dist=fn_dist))[0]['output']
+
+        if is_invalid_func or (fn_dist not in dist_functions):
+            plpy.error("KNN error: Distance function has invalid signature "
+                       "or is not a simple function.")
 
     return k
 # ------------------------------------------------------------------------------
@@ -146,21 +152,21 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
                                         the weighted average method.
     """
     with MinWarning('warning'):
-        k_val = knn_validate_src(schema_madlib, point_source,
-                                 point_column_name, point_id, label_column_name,
-                                 test_source, test_column_name, test_id,
-                                 output_table, k, output_neighbors, fn_dist)
+        output_neighbors = True if output_neighbors is None else output_neighbors
+        if k is None:
+            k = 1
+        knn_validate_src(schema_madlib, point_source,
+                         point_column_name, point_id, label_column_name,
+                         test_source, test_column_name, test_id,
+                         output_table, k, output_neighbors, fn_dist)
 
         x_temp_table = unique_string(desp='x_temp_table')
         y_temp_table = unique_string(desp='y_temp_table')
         label_col_temp = unique_string(desp='label_col_temp')
         test_id_temp = unique_string(desp='test_id_temp')
 
-        if output_neighbors is None:
-            output_neighbors = True
-
         if not fn_dist:
-            fn_dist = schema_madlib + '.squared_dist_norm2'
+            fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
 
         fn_dist = fn_dist.lower().strip()
         interim_table = unique_string(desp='interim_table')
@@ -173,98 +179,121 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
         view_join = ""
         view_grp_by = ""
 
-        if output_neighbors:
-            knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
-                             "knn_temp.dist ASC) AS k_nearest_neighbours ")
         if label_column_name:
-            is_classification = False
             label_column_type = get_expr_type(
                 label_column_name, point_source).lower()
             if label_column_type in ['boolean', 'integer', 'text']:
                 is_classification = True
                 cast_to_int = '::INTEGER'
-            if weighted_avg:
-                pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(
-                    label_col_temp=label_col_temp)
             else:
-                pred_out = ", avg({label_col_temp})".format(
-                    label_col_temp=label_col_temp)
+                is_classification = False
 
             if is_classification:
                 if weighted_avg:
                     # This view is to calculate the max value of sum of the 1/distance grouped by label and Id.
                     # And this max value will be the prediction for the
                     # classification model.
-                    view_def = (""" WITH vw
-                                    AS (SELECT distinct on ({test_id_temp}) {test_id_temp},
-                                    max(data_sum) data_dist,
-                                    {label_col_temp}
-                                    FROM   (SELECT {test_id_temp},
+                    view_def = """
+                        WITH vw AS (
+                            SELECT DISTINCT ON({test_id_temp})
+                                {test_id_temp},
+                                last_value(data_sum) OVER (
+                                    PARTITION BY {test_id_temp}
+                                    ORDER BY data_sum, {label_col_temp}
+                                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+                                    ) AS data_dist ,
+                                last_value({label_col_temp}) OVER (
+                                    PARTITION BY {test_id_temp}
+                                    ORDER BY data_sum, {label_col_temp}
+                                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+                                    ) AS {label_col_temp}
+                            FROM   (
+                                SELECT
+                                    {test_id_temp},
                                     {label_col_temp},
-                                    sum(1 / dist) data_sum
-                                    FROM   pg_temp.{interim_table}
-                                    GROUP  BY {test_id_temp},
-                                    {label_col_temp}) a
-                                    GROUP  BY {test_id_temp} , {label_col_temp})""").format(**locals())
+                                    sum(dist_inverse) data_sum
+                                FROM pg_temp.{interim_table}
+                                GROUP BY {test_id_temp},
+                                         {label_col_temp}
+                            ) a
+                            -- GROUP BY {test_id_temp} , {label_col_temp}
+                        )
+                        """.format(**locals())
                     # This join is needed to get the max value of predicion
                     # calculated above
-                    view_join = (" JOIN vw AS knn_vw "
-                                 "ON knn_temp.{test_id_temp} = knn_vw.{test_id_temp}").format(
-                        test_id_temp=test_id_temp)
-                    view_grp_by = ", knn_vw.{label_col_temp}".format(
-                        label_col_temp=label_col_temp)
-                    pred_out = ", knn_vw.{label_col_temp}".format(
-                        label_col_temp=label_col_temp)
+                    view_join = (" JOIN vw ON knn_temp.{0} = vw.{0}".
+                                 format(test_id_temp))
+                    view_grp_by = ", vw.{0}".format(label_col_temp)
+                    pred_out = ", vw.{0}".format(label_col_temp)
+                else:
+                    pred_out = ", {0}.mode({1})".format(schema_madlib, label_col_temp)
+            else:
+                if weighted_avg:
+                    pred_out = (", sum({0} * dist_inverse) / sum(dist_inverse)".
+                                format(label_col_temp))
                 else:
-                    pred_out = (", {schema_madlib}.mode({label_col_temp})"
-                                ).format(**locals())
+                    pred_out = ", avg({0})".format(label_col_temp)
 
             pred_out += " AS prediction"
             label_out = (", train.{label_column_name}{cast_to_int}"
                          " AS {label_col_temp}").format(**locals())
+            comma_label_out_alias = ', ' + label_col_temp
+        else:
+            pred_out = ""
+            label_out = ""
+            comma_label_out_alias = ""
 
-        if not label_column_name and not output_neighbors:
-
-            plpy.error("kNN error: Either label_column_name or "
-                       "output_neighbors has to be non-NULL.")
-
+        # interim_table picks the 'k' nearest neighbors for each test point
+        if output_neighbors:
+            knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
+                             "knn_temp.dist_inverse DESC) AS k_nearest_neighbours ")
+        else:
+            knn_neighbors = ''
         plpy.execute("""
             CREATE TEMP TABLE {interim_table} AS
                 SELECT * FROM (
                     SELECT row_number() over
                             (partition by {test_id_temp} order by dist) AS r,
-                            {x_temp_table}.*
+                            {test_id_temp},
+                            train_id,
+                            CASE WHEN dist = 0.0 THEN {max_weight_zero_dist}
+                                 ELSE 1.0 / dist
+                            END AS dist_inverse
+                            {comma_label_out_alias}
                     FROM (
-                        SELECT test.{test_id} AS {test_id_temp} ,
-                            train.{point_id} as train_id ,
+                        SELECT test.{test_id} AS {test_id_temp},
+                            train.{point_id} as train_id,
                             {fn_dist}(
                                 train.{point_column_name},
                                 test.{test_column_name})
-                            AS dist {label_out}
+                            AS dist
+                            {label_out}
                             FROM {point_source} AS train,
-                                {test_source} AS test
+                                 {test_source} AS test
                         ) {x_temp_table}
                     ) {y_temp_table}
-            WHERE {y_temp_table}.r <= {k_val}
-            """.format(**locals()))
+            WHERE {y_temp_table}.r <= {k}
+            """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals()))
 
-        plpy.execute("""
+        sql = """
             CREATE TABLE {output_table} AS
                 {view_def}
-                SELECT knn_temp.{test_id_temp} AS id ,
-                    knn_test.data
+                SELECT
+                    knn_temp.{test_id_temp} AS id,
+                    knn_test.{test_column_name}
                     {pred_out}
                     {knn_neighbors}
-                FROM   {interim_table}  AS knn_temp
-                JOIN {test_source} AS knn_test
-                ON knn_temp.{test_id_temp} = knn_test.id
-                {view_join}
-                GROUP  BY knn_temp.{test_id_temp},
-                        knn_test.{test_column_name}
-                        {view_grp_by}
-                    ORDER  BY knn_temp.{test_id_temp}
-            """.format(**locals()))
-
+                FROM
+                    pg_temp.{interim_table}  AS knn_temp
+                    JOIN
+                    {test_source} AS knn_test
+                ON knn_temp.{test_id_temp} = knn_test.{test_id}
+                    {view_join}
+                GROUP BY knn_temp.{test_id_temp},
+                         knn_test.{test_column_name}
+                         {view_grp_by}
+            """
+        plpy.execute(sql.format(**locals()))
         plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
         return
 
@@ -394,9 +423,6 @@ to furthest from the corresponding test point.
 DROP TABLE IF EXISTS knn_result_regression;
 SELECT * FROM {schema_madlib}.knn(
                 'knn_train_data_reg',  -- Table of training data
-                'data',                -- Col name of training data
-                'id',                  -- Col Name of id in train data
-                'label',               -- Training labels
                 'knn_test_data',       -- Table of test data
                 'data',                -- Col name of test data
                 'id',                  -- Col name of id in test data