You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2017/11/30 21:36:38 UTC
madlib git commit: KNN: Add additional distance metrics
Repository: madlib
Updated Branches:
refs/heads/master daf67f81b -> 5a291aa81
KNN: Add additional distance metrics
JIRA: MADLIB-1059
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/5a291aa8
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/5a291aa8
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/5a291aa8
Branch: refs/heads/master
Commit: 5a291aa81fa7b735b43fc0198b0eb2bf2955364b
Parents: daf67f8
Author: Himanshu Pandey <hp...@pivotal.io>
Authored: Thu Nov 30 13:35:54 2017 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Nov 30 13:35:54 2017 -0800
----------------------------------------------------------------------
src/ports/postgres/modules/knn/knn.py_in | 42 ++++++++++---
src/ports/postgres/modules/knn/knn.sql_in | 37 +++++++++---
src/ports/postgres/modules/knn/test/knn.sql_in | 67 +++++++++++++++++++--
3 files changed, 126 insertions(+), 20 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 0e21cdd..caa89e0 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -38,7 +38,7 @@ from utilities.control import MinWarning
def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
- test_id, output_table, k, output_neighbors, **kwargs):
+ test_id, output_table, k, output_neighbors, fn_dist, **kwargs):
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
@@ -88,12 +88,28 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
plpy.error("kNN Error: Data type '{0}' is not a valid type for"
" column '{1}' in table '{2}'.".
format(col_type_test, test_id, test_source))
+
+ if fn_dist:
+ fn_dist = fn_dist.lower().strip()
+ dist_functions = set([schema_madlib + dist for dist in
+ ('.dist_norm1', '.dist_norm2', '.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')])
+
+ is_invalid_func = plpy.execute(
+ """select prorettype != 'DOUBLE PRECISION'::regtype
+ OR proisagg = TRUE AS OUTPUT from pg_proc where
+ oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
+ """.format(**locals()))[0]['output']
+
+ if is_invalid_func or fn_dist not in dist_functions:
+ plpy.error(
+ "KNN error: Distance function has wrong signature or is not a simple function.")
+
return k
# ------------------------------------------------------------------------------
def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name,
- test_source, test_column_name, test_id, output_table, k, output_neighbors):
+ test_source, test_column_name, test_id, output_table, k, output_neighbors, fn_dist):
"""
KNN function to find the K Nearest neighbours
Args:
@@ -117,12 +133,19 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
neighbors to consider
@output_neighbours Outputs the list of k-nearest neighbors
that were used in the voting/averaging.
+ @param fn_dist Distance metrics function. Default is
+ squared_dist_norm2. Following functions
+ are supported :
+ dist_norm1 , dist_norm2,squared_dist_norm2,
+ dist_angle , dist_tanimoto
+ Or user defined function with signature
+ DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION
"""
with MinWarning('warning'):
k_val = knn_validate_src(schema_madlib, point_source,
point_column_name, point_id, label_column_name,
test_source, test_column_name, test_id,
- output_table, k, output_neighbors)
+ output_table, k, output_neighbors, fn_dist)
x_temp_table = unique_string(desp='x_temp_table')
y_temp_table = unique_string(desp='y_temp_table')
@@ -132,6 +155,10 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
if output_neighbors is None:
output_neighbors = True
+ if not fn_dist:
+ fn_dist = schema_madlib + '.squared_dist_norm2'
+
+ fn_dist = fn_dist.lower().strip()
interim_table = unique_string(desp='interim_table')
pred_out = ""
@@ -141,7 +168,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
- "knn_temp.dist ASC) AS k_nearest_neighbours ")
+ "knn_temp.dist ASC) AS k_nearest_neighbours ")
if label_column_name:
is_classification = False
label_column_type = get_expr_type(
@@ -156,12 +183,12 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
).format(**locals())
pred_out += " AS prediction"
label_out = (", train.{label_column_name}{cast_to_int}"
- " AS {label_col_temp}").format(**locals())
+ " AS {label_col_temp}").format(**locals())
if not label_column_name and not output_neighbors:
plpy.error("kNN error: Either label_column_name or "
- "output_neighbors has to be non-NULL.")
+ "output_neighbors has to be non-NULL.")
plpy.execute("""
CREATE TEMP TABLE {interim_table} AS
@@ -172,7 +199,7 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
FROM (
SELECT test.{test_id} AS {test_id_temp} ,
train.{point_id} as train_id ,
- {schema_madlib}.squared_dist_norm2(
+ {fn_dist}(
train.{point_column_name},
test.{test_column_name})
AS dist {label_out}
@@ -196,7 +223,6 @@ def knn(schema_madlib, point_source, point_column_name, point_id, label_column_n
GROUP BY {test_id_temp} , {test_column_name}
""".format(**locals()))
-
plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
return
# ------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index befb768..17d81ad 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -78,7 +78,8 @@ knn( point_source,
test_id,
output_table,
k,
- output_neighbors
+ output_neighbors,
+ fn_dist
)
</pre>
@@ -131,6 +132,22 @@ otherwise the result may depend on ordering of the input data.</dd>
neighbors that were used in the voting/averaging, sorted
from closest to furthest.</dd>
+<dt>fn_dist (optional)</dt>
+<dd>TEXT, default: squared_dist_norm2'. The name of the function to use to calculate the distance from a data point to a centroid.
+
+The following distance functions can be used (computation of barycenter/mean in parentheses):
+<ul>
+<li><b>\ref dist_norm1</b>: 1-norm/Manhattan (element-wise median
+[Note that MADlib does not provide a median aggregate function for support and
+performance reasons.])</li>
+<li><b>\ref dist_norm2</b>: 2-norm/Euclidean (element-wise mean)</li>
+<li><b>\ref squared_dist_norm2</b>: squared Euclidean distance (element-wise mean)</li>
+<li><b>\ref dist_angle</b>: angle (element-wise mean of normalized points)</li>
+<li><b>\ref dist_tanimoto</b>: tanimoto (element-wise mean of normalized points <a href="#kmeans-lit-5">[5]</a>)</li>
+<li><b>user defined function</b> with signature <tt>DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION</tt></li></ul></dd>
+
+
+
</dl>
@@ -227,6 +244,7 @@ SELECT * FROM madlib.knn(
'knn_result_classification', -- Output table
3, -- Number of nearest neighbors
True -- True if you want to show Nearest-Neighbors by id, False otherwise
+ 'madlib.squared_dist_norm2' -- Distance function
);
SELECT * from knn_result_classification ORDER BY id;
</pre>
@@ -258,6 +276,7 @@ SELECT * FROM madlib.knn(
'knn_result_regression', -- Output table
3, -- Number of nearest neighbors
True -- True if you want to show Nearest-Neighbors, False otherwise
+ 'madlib.squared_dist_norm2' -- Distance function
);
SELECT * FROM knn_result_regression ORDER BY id;
</pre>
@@ -388,6 +407,7 @@ SELECT {schema_madlib}.knn(
output_table, -- Name of output table
k, -- value of k. Default will go as 1
output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
+ fn_dist -- The name of the function to use to calculate the distance from a data point to a centroid.
);
-----------------------------------------------------------------------
@@ -435,7 +455,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
test_id VARCHAR,
output_table VARCHAR,
k INTEGER,
- output_neighbors Boolean
+ output_neighbors Boolean,
+ fn_dist TEXT
) RETURNS VARCHAR AS $$
PythonFunctionBodyOnly(`knn', `knn')
return knn.knn(
@@ -449,7 +470,9 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
test_id,
output_table,
k,
- output_neighbors
+ output_neighbors,
+ fn_dist
+
)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -469,7 +492,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
DECLARE
returnstring VARCHAR;
BEGIN
- returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9);
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9, 'MADLIB_SCHEMA.squared_dist_norm2');
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -489,7 +512,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
DECLARE
returnstring VARCHAR;
BEGIN
- returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE);
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,'MADLIB_SCHEMA.squared_dist_norm2');
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -508,8 +531,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
DECLARE
returnstring VARCHAR;
BEGIN
- returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE);
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,'MADLIB_SCHEMA.squared_dist_norm2');
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
http://git-wip-us.apache.org/repos/asf/madlib/blob/5a291aa8/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 1e71a0e..8bb8f20 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -73,24 +73,81 @@ copy knn_test_data (id, data) from stdin delimiter '|';
\.
drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2');
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True,'madlib.squared_dist_norm2');
select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.squared_dist_norm2');
select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True,'madlib.squared_dist_norm2');
select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',False);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.squared_dist_norm2');
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL);
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True, NULL );
+select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,NULL);
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm1');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_norm2');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_angle');
+select assert(array_agg(prediction order by id)='{1,0,0,1,0,1}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'madlib.dist_tanimoto');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm1');
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_norm2');
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_angle');
+select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'madlib.dist_tanimoto');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
+
select knn();
select knn('help');