You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2017/11/30 21:36:38 UTC

madlib git commit: KNN: Add additional distance metrics

Repository: madlib
Updated Branches:
  refs/heads/master daf67f81b -> 5a291aa81


KNN: Add additional distance metrics

JIRA: MADLIB-1059


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/5a291aa8
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/5a291aa8
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/5a291aa8

Branch: refs/heads/master
Commit: 5a291aa81fa7b735b43fc0198b0eb2bf2955364b
Parents: daf67f8
Author: Himanshu Pandey <hp...@pivotal.io>
Authored: Thu Nov 30 13:35:54 2017 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Nov 30 13:35:54 2017 -0800

----------------------------------------------------------------------
 src/ports/postgres/modules/knn/knn.py_in       | 42 ++++++++++---
 src/ports/postgres/modules/knn/knn.sql_in      | 37 +++++++++---
 src/ports/postgres/modules/knn/test/knn.sql_in | 67 +++++++++++++++++++--
 3 files changed, 126 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 0e21cdd..caa89e0 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -38,7 +38,7 @@ from utilities.control import MinWarning
 
 def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                      label_column_name, test_source, test_column_name,
-                     test_id, output_table, k, output_neighbors, **kwargs):
+                     test_id, output_table, k, output_neighbors, fn_dist, **kwargs):
     input_tbl_valid(point_source, 'kNN')
     input_tbl_valid(test_source, 'kNN')
     output_tbl_valid(output_table, 'kNN')
@@ -88,12 +88,28 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
         plpy.error("kNN Error: Data type '{0}' is not a valid type for"
                    " column '{1}' in table '{2}'.".
                    format(col_type_test, test_id, test_source))
+
+    if fn_dist:
+        fn_dist = fn_dist.lower().strip()
+        dist_functions = set([schema_madlib + dist for dist in
+                              ('.dist_norm1', '.dist_norm2', '.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')])
+
+        is_invalid_func = plpy.execute(
+            """select prorettype != 'DOUBLE PRECISION'::regtype
+                OR proisagg = TRUE AS OUTPUT from pg_proc where
+                oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
+                """.format(**locals()))[0]['output']
+
+        if is_invalid_func or fn_dist not in dist_functions:
+            plpy.error(
+                "KNN error: Distance function has wrong signature or is not a simple function.")
+
     return k
 # ------------------------------------------------------------------------------
 
 
 def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name,
-        test_source, test_column_name, test_id, output_table, k, output_neighbors):
+        test_source, test_column_name, test_id, output_table, k, output_neighbors, fn_dist):
     """
         KNN function to find the K Nearest neighbours
         Args:
@@ -117,12 +133,19 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
                                         neighbors to consider
             @output_neighbours          Outputs the list of k-nearest neighbors
                                         that were used in the voting/averaging.
+            @param fn_dist              Distance metrics function. Default is
+                                        squared_dist_norm2. Following functions
+                                        are supported :
+                                        dist_norm1 , dist_norm2,squared_dist_norm2,
+                                        dist_angle , dist_tanimoto
+                                        Or user defined function with signature
+                                        DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION
     """
     with MinWarning('warning'):
         k_val = knn_validate_src(schema_madlib, point_source,
                                  point_column_name, point_id, label_column_name,
                                  test_source, test_column_name, test_id,
-                                 output_table, k, output_neighbors)
+                                 output_table, k, output_neighbors, fn_dist)
 
         x_temp_table = unique_string(desp='x_temp_table')
         y_temp_table = unique_string(desp='y_temp_table')
@@ -132,6 +155,10 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
         if output_neighbors is None:
             output_neighbors = True
 
+        if not fn_dist:
+            fn_dist = schema_madlib + '.squared_dist_norm2'
+
+        fn_dist = fn_dist.lower().strip()
         interim_table = unique_string(desp='interim_table')
 
         pred_out = ""
@@ -141,7 +168,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
 
         if output_neighbors:
             knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
-                                "knn_temp.dist ASC) AS k_nearest_neighbours ")
+                             "knn_temp.dist ASC) AS k_nearest_neighbours ")
         if label_column_name:
             is_classification = False
             label_column_type = get_expr_type(
@@ -156,12 +183,12 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
                             ).format(**locals())
             pred_out += " AS prediction"
             label_out = (", train.{label_column_name}{cast_to_int}"
-                            " AS {label_col_temp}").format(**locals())
+                         " AS {label_col_temp}").format(**locals())
 
         if not label_column_name and not output_neighbors:
 
             plpy.error("kNN error: Either label_column_name or "
-                   "output_neighbors has to be non-NULL.")
+                       "output_neighbors has to be non-NULL.")
 
         plpy.execute("""
             CREATE TEMP TABLE {interim_table} AS
@@ -172,7 +199,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
                     FROM (
                         SELECT test.{test_id} AS {test_id_temp} ,
                             train.{point_id} as train_id ,
-                            {schema_madlib}.squared_dist_norm2(
+                            {fn_dist}(
                                 train.{point_column_name},
                                 test.{test_column_name})
                             AS dist {label_out}
@@ -196,7 +223,6 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
                 GROUP BY {test_id_temp} , {test_column_name}
             """.format(**locals()))
 
-
         plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
         return
 # ------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index befb768..17d81ad 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -78,7 +78,8 @@ knn( point_source,
      test_id,
      output_table,
      k,
-     output_neighbors
+     output_neighbors,
+     fn_dist
    )
 </pre>
 
@@ -131,6 +132,22 @@ otherwise the result may depend on ordering of the input data.</dd>
 neighbors that were used in the voting/averaging, sorted
 from closest to furthest.</dd>
 
+<dt>fn_dist (optional)</dt>
+<dd>TEXT, default: squared_dist_norm2'. The name of the function to use to calculate the distance from a data point to a centroid.
+
+The following distance functions can be used (computation of barycenter/mean in parentheses):
+<ul>
+<li><b>\ref dist_norm1</b>:  1-norm/Manhattan (element-wise median
+[Note that MADlib does not provide a median aggregate function for support and
+performance reasons.])</li>
+<li><b>\ref dist_norm2</b>: 2-norm/Euclidean (element-wise mean)</li>
+<li><b>\ref squared_dist_norm2</b>: squared Euclidean distance (element-wise mean)</li>
+<li><b>\ref dist_angle</b>: angle (element-wise mean of normalized points)</li>
+<li><b>\ref dist_tanimoto</b>: tanimoto (element-wise mean of normalized points <a href="#kmeans-lit-5">[5]</a>)</li>
+<li><b>user defined function</b> with signature <tt>DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION</tt></li></ul></dd>
+
+
+
 </dl>
 
 
@@ -227,6 +244,7 @@ SELECT * FROM madlib.knn(
                 'knn_result_classification',  -- Output table
                  3,                    -- Number of nearest neighbors
                  True                  -- True if you want to show Nearest-Neighbors by id, False otherwise
+                 'madlib.squared_dist_norm2' -- Distance function
                 );
 SELECT * from knn_result_classification ORDER BY id;
 </pre>
@@ -258,6 +276,7 @@ SELECT * FROM madlib.knn(
                 'knn_result_regression',  -- Output table
                  3,                    -- Number of nearest neighbors
                 True                   -- True if you want to show Nearest-Neighbors, False otherwise
+                'madlib.squared_dist_norm2' -- Distance function
                 );
 SELECT * FROM knn_result_regression ORDER BY id;
 </pre>
@@ -388,6 +407,7 @@ SELECT {schema_madlib}.knn(
     output_table,       -- Name of output table
     k,                  -- value of k. Default will go as 1
     output_neighbors    -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
+    fn_dist             -- The name of the function to use to calculate the distance from a data point to a centroid.
     );
 
 -----------------------------------------------------------------------
@@ -435,7 +455,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     test_id VARCHAR,
     output_table VARCHAR,
     k INTEGER,
-    output_neighbors Boolean
+    output_neighbors Boolean,
+    fn_dist TEXT
 ) RETURNS VARCHAR AS $$
     PythonFunctionBodyOnly(`knn', `knn')
     return knn.knn(
@@ -449,7 +470,9 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
         test_id,
         output_table,
         k,
-        output_neighbors
+        output_neighbors,
+        fn_dist
+
     )
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -469,7 +492,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
 DECLARE
     returnstring VARCHAR;
 BEGIN
-    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9);
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9, 'MADLIB_SCHEMA.squared_dist_norm2');
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -489,7 +512,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
 DECLARE
     returnstring VARCHAR;
 BEGIN
-    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE);
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,'MADLIB_SCHEMA.squared_dist_norm2');
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
@@ -508,8 +531,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
 DECLARE
     returnstring VARCHAR;
 BEGIN
-    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE);
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,'MADLIB_SCHEMA.squared_dist_norm2');
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); 

http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 1e71a0e..8bb8f20 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -73,24 +73,81 @@ copy knn_test_data (id, data) from stdin delimiter '|';
 \.
 
 drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2');
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
 
 drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True,'madlib.squared_dist_norm2');
 select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
 
 drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.squared_dist_norm2');
 select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
 
 drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True,'madlib.squared_dist_norm2');
 select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
 
 drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',False);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2');
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
 
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL);
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True, NULL );
+select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,NULL);
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm1');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm2');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_angle');
+select assert(array_agg(prediction order by id)='{1,0,0,1,0,1}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_tanimoto');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm1');
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm2');
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_angle');
+select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_tanimoto');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+
 select knn();
 select knn('help');