You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2018/02/09 04:07:56 UTC
[1/2] madlib git commit: KNN: Add weighted averaging/voting by
distance
Repository: madlib
Updated Branches:
refs/heads/master c51da40a1 -> 7c6fea20b
KNN: Add weighted averaging/voting by distance
JIRA: MADLIB-1181
This commit adds option for weighting the voting (for classification) or
averaging (for regression) by the inverse of the distance between a test
point and the nearest neighbors.
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/f6aba420
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/f6aba420
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/f6aba420
Branch: refs/heads/master
Commit: f6aba42083c1017bb2092397cb34c20a8cd091e2
Parents: c51da40
Author: Himanshu Pandey <hp...@pivotal.io>
Authored: Tue Jan 16 09:32:41 2018 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Feb 8 20:00:58 2018 -0800
----------------------------------------------------------------------
src/ports/postgres/modules/knn/knn.py_in | 92 +++++++++++++++++----
src/ports/postgres/modules/knn/knn.sql_in | 75 +++++++++++++++--
src/ports/postgres/modules/knn/test/knn.sql_in | 18 +++-
3 files changed, 161 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index da67952..b9e7916 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -111,7 +111,7 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
def knn(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name, test_id, output_table,
- k, output_neighbors, fn_dist):
+ k, output_neighbors, fn_dist, weighted_avg):
"""
KNN function to find the K Nearest neighbours
Args:
@@ -142,6 +142,8 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
dist_angle , dist_tanimoto
Or user defined function with signature
DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION
+ @param weighted_avg Calculates the Regression or classication of k-NN using
+ the weighted average method.
"""
with MinWarning('warning'):
k_val = knn_validate_src(schema_madlib, point_source,
@@ -167,6 +169,9 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
knn_neighbors = ""
label_out = ""
cast_to_int = ""
+ view_def = ""
+ view_join = ""
+ view_grp_by = ""
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
@@ -178,11 +183,42 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
+ if weighted_avg:
+ pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(
+ label_col_temp=label_col_temp)
+ else:
+ pred_out = ", avg({label_col_temp})".format(
+ label_col_temp=label_col_temp)
- pred_out = ", avg({label_col_temp})".format(**locals())
if is_classification:
- pred_out = (", {schema_madlib}.mode({label_col_temp})"
- ).format(**locals())
+ if weighted_avg:
+ # This view is to calculate the max value of sum of the 1/distance grouped by label and Id.
+ # And this max value will be the prediction for the
+ # classification model.
+ view_def = (""" WITH vw
+ AS (SELECT distinct on ({test_id_temp}) {test_id_temp},
+ max(data_sum) data_dist,
+ {label_col_temp}
+ FROM (SELECT {test_id_temp},
+ {label_col_temp},
+ sum(1 / dist) data_sum
+ FROM pg_temp.{interim_table}
+ GROUP BY {test_id_temp},
+ {label_col_temp}) a
+ GROUP BY {test_id_temp} , {label_col_temp})""").format(**locals())
+ # This join is needed to get the max value of predicion
+ # calculated above
+ view_join = (" JOIN vw AS knn_vw "
+ "ON knn_temp.{test_id_temp} = knn_vw.{test_id_temp}").format(
+ test_id_temp=test_id_temp)
+ view_grp_by = ", knn_vw.{label_col_temp}".format(
+ label_col_temp=label_col_temp)
+ pred_out = ", knn_vw.{label_col_temp}".format(
+ label_col_temp=label_col_temp)
+ else:
+ pred_out = (", {schema_madlib}.mode({label_col_temp})"
+ ).format(**locals())
+
pred_out += " AS prediction"
label_out = (", train.{label_column_name}{cast_to_int}"
" AS {label_col_temp}").format(**locals())
@@ -212,22 +248,27 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
WHERE {y_temp_table}.r <= {k_val}
""".format(**locals()))
- plpy.execute(
- """
+ plpy.execute("""
CREATE TABLE {output_table} AS
- SELECT {test_id_temp} AS id, {test_column_name}
+ {view_def}
+ SELECT knn_temp.{test_id_temp} AS id ,
+ knn_test.data
{pred_out}
{knn_neighbors}
- FROM pg_temp.{interim_table} AS knn_temp
- JOIN
- {test_source} AS knn_test ON
- knn_temp.{test_id_temp} = knn_test.{test_id}
- GROUP BY {test_id_temp} , {test_column_name}
+ FROM {interim_table} AS knn_temp
+ JOIN {test_source} AS knn_test
+ ON knn_temp.{test_id_temp} = knn_test.id
+ {view_join}
+ GROUP BY knn_temp.{test_id_temp},
+ knn_test.{test_column_name}
+ {view_grp_by}
+ ORDER BY knn_temp.{test_id_temp}
""".format(**locals()))
plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
return
+
def knn_help(schema_madlib, message, **kwargs):
"""
Help function for knn
@@ -258,6 +299,7 @@ SELECT {schema_madlib}.knn(
k, -- value of k. Default will go as 1
output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
fn_dist -- The name of the function to use to calculate the distance from a data point to a centroid.
+ weighted_avg Calculates the Regression or classication of k-NN using the weighted average method.
);
-----------------------------------------------------------------------
@@ -340,7 +382,8 @@ SELECT * FROM {schema_madlib}.knn(
'knn_result_classification', -- Output table
3, -- Number of nearest neighbors
True, -- True to list nearest-neighbors by id
- 'madlib.squared_dist_norm2' -- Distance function
+ 'madlib.squared_dist_norm2', -- Distance function
+ False -- False for not using weighted average
);
SELECT * from knn_result_classification ORDER BY id;
@@ -360,7 +403,8 @@ SELECT * FROM {schema_madlib}.knn(
'knn_result_regression', -- Output table
3, -- Number of nearest neighbors
True, -- True to list nearest-neighbors by id
- 'madlib.dist_norm2' -- Distance function
+ 'madlib.dist_norm2', -- Distance function
+ False -- False for not using weighted average
);
SELECT * FROM knn_result_regression ORDER BY id;
@@ -379,6 +423,25 @@ SELECT * FROM {schema_madlib}.knn(
3 -- Number of nearest neighbors
);
SELECT * FROM knn_result_list_neighbors ORDER BY id;
+
+-- Run KNN for classification using weighted average:
+DROP TABLE IF EXISTS knn_result_classification;
+SELECT * FROM {schema_madlib}.knn(
+ 'knn_train_data', -- Table of training data
+ 'data', -- Col name of training data
+ 'id', -- Col name of id in train data
+ 'label', -- Training labels
+ 'knn_test_data', -- Table of test data
+ 'data', -- Col name of test data
+ 'id', -- Col name of id in test data
+ 'knn_result_classification', -- Output table
+ 3, -- Number of nearest neighbors
+ True, -- True to list nearest-neighbors by id
+ 'madlib.squared_dist_norm2', -- Distance function
+ True -- Calculation using weighted average
+ );
+SELECT * from knn_result_classification ORDER BY id;
+
"""
else:
help_string = """
@@ -405,4 +468,3 @@ SELECT {schema_madlib}.knn('example')
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------------
-
http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index 9db702d..3139c15 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -79,7 +79,8 @@ knn( point_source,
output_table,
k,
output_neighbors,
- fn_dist
+ fn_dist,
+ weighted_avg
)
</pre>
@@ -145,6 +146,10 @@ The following distance functions can be used:
<li><b>\ref dist_tanimoto</b>: tanimoto</li>
<li><b>user defined function</b> with signature <tt>DOUBLE PRECISION[] x, DOUBLE PRECISION[] y -> DOUBLE PRECISION</tt></li></ul></dd>
+<dt>weighted_avg (optional)</dt>
+<dd>BOOLEAN, default: FALSE. Calculates the Regression or classication
+of k-NN using the weighted average method.
+
</dl>
@@ -326,6 +331,39 @@ Result, with neighbors sorted from closest to furthest:
(6 rows)
</pre>
+
+-# Run KNN for classification using the
+weighted average:
+<pre class="example">
+DROP TABLE IF EXISTS knn_result_classification;
+SELECT * FROM madlib.knn(
+ 'knn_train_data', -- Table of training data
+ 'data', -- Col name of training data
+ 'id', -- Col name of id in train data
+ 'label', -- Training labels
+ 'knn_test_data', -- Table of test data
+ 'data', -- Col name of test data
+ 'id', -- Col name of id in test data
+ 'knn_result_classification', -- Output table
+ 3, -- Number of nearest neighbors
+ True, -- True to list nearest-neighbors by id
+ 'madlib.squared_dist_norm2', -- Distance function
+ True -- For weighted average
+ );
+SELECT * FROM knn_result_classification ORDER BY id;
+</pre>
+<pre class="result">
+ id | data | prediction | k_nearest_neighbours
+----+---------+---------------------+----------------------
+ 1 | {2,1} | 1 | {2,1,3}
+ 2 | {2,6} | 1 | {5,4,3}
+ 3 | {15,40} | 0 | {7,6,5}
+ 4 | {12,1} | 1 | {4,5,3}
+ 5 | {2,90} | 0 | {9,6,7}
+ 6 | {50,45} | 0 | {6,7,8}
+(6 rows)
+</pre>
+
@anchor background
@par Technical Background
@@ -397,7 +435,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
output_table VARCHAR,
k INTEGER,
output_neighbors BOOLEAN,
- fn_dist TEXT
+ fn_dist TEXT,
+ weighted_avg BOOLEAN
) RETURNS VARCHAR AS $$
PythonFunctionBodyOnly(`knn', `knn')
return knn.knn(
@@ -412,7 +451,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
output_table,
k,
output_neighbors,
- fn_dist
+ fn_dist,
+ weighted_avg
)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -428,13 +468,36 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
test_id VARCHAR,
output_table VARCHAR,
k INTEGER,
+ output_neighbors BOOLEAN,
+ fn_dist TEXT
+) RETURNS VARCHAR AS $$
+ DECLARE
+ returnstring VARCHAR;
+BEGIN
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11, FALSE);
+ RETURN returnstring;
+END;
+$$ LANGUAGE plpgsql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
+ point_source VARCHAR,
+ point_column_name VARCHAR,
+ point_id VARCHAR,
+ label_column_name VARCHAR,
+ test_source VARCHAR,
+ test_column_name VARCHAR,
+ test_id VARCHAR,
+ output_table VARCHAR,
+ k INTEGER,
output_neighbors BOOLEAN
) RETURNS VARCHAR AS $$
DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,
- 'MADLIB_SCHEMA.squared_dist_norm2');
+ 'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -455,7 +518,7 @@ DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,$9,TRUE,
- 'MADLIB_SCHEMA.squared_dist_norm2');
+ 'MADLIB_SCHEMA.squared_dist_norm2', FALSE);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
@@ -475,7 +538,7 @@ DECLARE
returnstring VARCHAR;
BEGIN
returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,TRUE,
- 'MADLIB_SCHEMA.squared_dist_norm2');
+ 'MADLIB_SCHEMA.squared_dist_norm2',FALSE);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
http://git-wip-us.apache.org/repos/asf/madlib/blob/f6aba420/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 8c62dad..6a39c89 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -72,14 +72,15 @@ copy knn_test_data (id, data) from stdin delimiter '|';
\.
drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2');
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
drop table if exists madlib_knn_result_classification;
select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3);
select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
+
drop table if exists madlib_knn_result_regression;
-select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2');
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.squared_dist_norm2',False);
select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
drop table if exists madlib_knn_result_regression;
@@ -87,7 +88,7 @@ select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id',
select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
drop table if exists madlib_knn_result_classification;
-select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,NULL,False);
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
drop table if exists madlib_knn_result_classification;
@@ -110,5 +111,16 @@ drop table if exists madlib_knn_result_regression;
select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False,'MADLIB_SCHEMA.dist_angle');
select assert(array_agg(prediction order by id)='{0.75,0.25,0.25,0.75,0.25,1}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+
+
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,False,'MADLIB_SCHEMA.squared_dist_norm2', True);
+select assert(array_agg(prediction::numeric order by id)='{1,1,0.0408728591876018,1,0,0}', 'Wrong output in regression') from madlib_knn_result_regression;
+
+
select knn();
select knn('help');
[2/2] madlib git commit: KNN: Add window for distinct on + other
changes
Posted by ri...@apache.org.
KNN: Add window for distinct on + other changes
This commit includes a few fixes:
1. 'DISTINCT ON' requires a window with partition to give the rows
corresponding to a distinct value. This was added to compute the label
with highest weighted votes.
2. Zero distances led to divide-by-zero issues during the weigthing
process. This was fixed by replacing the inverse of the distance with a
pre-defined large number.
3. Other changes include changing error messages and code cleanup.
Co-authored-by: Nandish Jayaram <nj...@apache.org>
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/7c6fea20
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/7c6fea20
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/7c6fea20
Branch: refs/heads/master
Commit: 7c6fea20bac49ee75d1c68cef5f1248e6bcc78b1
Parents: f6aba42
Author: Rahul Iyer <ri...@apache.org>
Authored: Thu Feb 8 20:02:09 2018 -0800
Committer: Rahul Iyer <ri...@apache.org>
Committed: Thu Feb 8 20:02:09 2018 -0800
----------------------------------------------------------------------
src/ports/postgres/modules/knn/knn.py_in | 216 +++++++++++++++-----------
1 file changed, 121 insertions(+), 95 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/7c6fea20/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index b9e7916..cfd93d9 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -20,7 +20,7 @@
"""
@file knn.py_in
-@brief knn: Driver functions
+@brief knn: K-Nearest Neighbors for regression and classification
@namespace knn
@@ -32,9 +32,12 @@ from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import is_col_array
from utilities.validate_args import array_col_has_no_null
from utilities.validate_args import get_expr_type
+from utilities.utilities import _assert
from utilities.utilities import unique_string
from utilities.control import MinWarning
+MAX_WEIGHT_ZERO_DIST = 1e6
+
def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
@@ -43,21 +46,22 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
- if label_column_name is not None and label_column_name != '':
- cols_in_tbl_valid(
- point_source,
- (label_column_name,
- point_column_name),
- 'kNN')
+
+ _assert(label_column_name or output_neighbors,
+ "kNN error: Either label_column_name or "
+ "output_neighbors has to be inputed.")
+
+ if label_column_name and label_column_name.strip():
+ cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
if not is_col_array(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table is not"
- " an array.").format(point_column_name)
+ " an array.".format(point_column_name))
if not is_col_array(test_source, test_column_name):
plpy.error("kNN Error: Feature column '{0}' in test table is not"
- " an array.").format(test_column_name)
+ " an array.".format(test_column_name))
if not array_col_has_no_null(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table has some"
@@ -66,44 +70,46 @@ def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
plpy.error("kNN Error: Feature column '{0}' in test table has some"
" NULL values.".format(test_column_name))
- if k is None:
- k = 1
if k <= 0:
- plpy.error("kNN Error: k={0} is an invalid value, must be greater"
+ plpy.error("kNN Error: k={0} is an invalid value, must be greater "
"than 0.".format(k))
+
bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}".
format(k=k, tbl=point_source))[0]['bound']
if not bound:
plpy.error("kNN Error: k={0} is greater than number of rows in"
" training table.".format(k))
- if label_column_name is not None and label_column_name != '':
+ if label_column_name:
col_type = get_expr_type(label_column_name, point_source).lower()
if col_type not in ['integer', 'double precision', 'float', 'boolean']:
- plpy.error("kNN error: Data type '{0}' is not a valid type for"
- " column '{1}' in table '{2}'.".
- format(col_type, label_column_name, point_source))
+ plpy.error("kNN error: Invalid data type '{0}' for"
+ " label_column_name in table '{1}'.".
+ format(col_type, point_source))
col_type_test = get_expr_type(test_id, test_source).lower()
if col_type_test not in ['integer']:
- plpy.error("kNN Error: Data type '{0}' is not a valid type for"
- " column '{1}' in table '{2}'.".
- format(col_type_test, test_id, test_source))
+ plpy.error("kNN Error: Invalid data type '{0}' for"
+ " test_id column in table '{1}'.".
+ format(col_type_test, test_source))
if fn_dist:
fn_dist = fn_dist.lower().strip()
- dist_functions = set([schema_madlib + dist for dist in
- ('.dist_norm1', '.dist_norm2', '.squared_dist_norm2', '.dist_angle', '.dist_tanimoto')])
-
- is_invalid_func = plpy.execute(
- """select prorettype != 'DOUBLE PRECISION'::regtype
- OR proisagg = TRUE AS OUTPUT from pg_proc where
- oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
- """.format(**locals()))[0]['output']
-
- if is_invalid_func or fn_dist not in dist_functions:
- plpy.error(
- "KNN error: Distance function has wrong signature or is not a simple function.")
+ dist_functions = set(["{0}.{1}".format(schema_madlib, dist) for dist in
+ ('dist_norm1', 'dist_norm2',
+ 'squared_dist_norm2', 'dist_angle',
+ 'dist_tanimoto')])
+
+ is_invalid_func = plpy.execute("""
+ SELECT prorettype != 'DOUBLE PRECISION'::regtype OR
+ proisagg = TRUE AS OUTPUT
+ FROM pg_proc
+ WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
+ """.format(fn_dist=fn_dist))[0]['output']
+
+ if is_invalid_func or (fn_dist not in dist_functions):
+ plpy.error("KNN error: Distance function has invalid signature "
+ "or is not a simple function.")
return k
# ------------------------------------------------------------------------------
@@ -146,21 +152,21 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
the weighted average method.
"""
with MinWarning('warning'):
- k_val = knn_validate_src(schema_madlib, point_source,
- point_column_name, point_id, label_column_name,
- test_source, test_column_name, test_id,
- output_table, k, output_neighbors, fn_dist)
+ output_neighbors = True if output_neighbors is None else output_neighbors
+ if k is None:
+ k = 1
+ knn_validate_src(schema_madlib, point_source,
+ point_column_name, point_id, label_column_name,
+ test_source, test_column_name, test_id,
+ output_table, k, output_neighbors, fn_dist)
x_temp_table = unique_string(desp='x_temp_table')
y_temp_table = unique_string(desp='y_temp_table')
label_col_temp = unique_string(desp='label_col_temp')
test_id_temp = unique_string(desp='test_id_temp')
- if output_neighbors is None:
- output_neighbors = True
-
if not fn_dist:
- fn_dist = schema_madlib + '.squared_dist_norm2'
+ fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
fn_dist = fn_dist.lower().strip()
interim_table = unique_string(desp='interim_table')
@@ -173,98 +179,121 @@ def knn(schema_madlib, point_source, point_column_name, point_id,
view_join = ""
view_grp_by = ""
- if output_neighbors:
- knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
- "knn_temp.dist ASC) AS k_nearest_neighbours ")
if label_column_name:
- is_classification = False
label_column_type = get_expr_type(
label_column_name, point_source).lower()
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
- if weighted_avg:
- pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(
- label_col_temp=label_col_temp)
else:
- pred_out = ", avg({label_col_temp})".format(
- label_col_temp=label_col_temp)
+ is_classification = False
if is_classification:
if weighted_avg:
# This view is to calculate the max value of sum of the 1/distance grouped by label and Id.
# And this max value will be the prediction for the
# classification model.
- view_def = (""" WITH vw
- AS (SELECT distinct on ({test_id_temp}) {test_id_temp},
- max(data_sum) data_dist,
- {label_col_temp}
- FROM (SELECT {test_id_temp},
+ view_def = """
+ WITH vw AS (
+ SELECT DISTINCT ON({test_id_temp})
+ {test_id_temp},
+ last_value(data_sum) OVER (
+ PARTITION BY {test_id_temp}
+ ORDER BY data_sum, {label_col_temp}
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS data_dist ,
+ last_value({label_col_temp}) OVER (
+ PARTITION BY {test_id_temp}
+ ORDER BY data_sum, {label_col_temp}
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS {label_col_temp}
+ FROM (
+ SELECT
+ {test_id_temp},
{label_col_temp},
- sum(1 / dist) data_sum
- FROM pg_temp.{interim_table}
- GROUP BY {test_id_temp},
- {label_col_temp}) a
- GROUP BY {test_id_temp} , {label_col_temp})""").format(**locals())
+ sum(dist_inverse) data_sum
+ FROM pg_temp.{interim_table}
+ GROUP BY {test_id_temp},
+ {label_col_temp}
+ ) a
+ -- GROUP BY {test_id_temp} , {label_col_temp}
+ )
+ """.format(**locals())
# This join is needed to get the max value of predicion
# calculated above
- view_join = (" JOIN vw AS knn_vw "
- "ON knn_temp.{test_id_temp} = knn_vw.{test_id_temp}").format(
- test_id_temp=test_id_temp)
- view_grp_by = ", knn_vw.{label_col_temp}".format(
- label_col_temp=label_col_temp)
- pred_out = ", knn_vw.{label_col_temp}".format(
- label_col_temp=label_col_temp)
+ view_join = (" JOIN vw ON knn_temp.{0} = vw.{0}".
+ format(test_id_temp))
+ view_grp_by = ", vw.{0}".format(label_col_temp)
+ pred_out = ", vw.{0}".format(label_col_temp)
+ else:
+ pred_out = ", {0}.mode({1})".format(schema_madlib, label_col_temp)
+ else:
+ if weighted_avg:
+ pred_out = (", sum({0} * dist_inverse) / sum(dist_inverse)".
+ format(label_col_temp))
else:
- pred_out = (", {schema_madlib}.mode({label_col_temp})"
- ).format(**locals())
+ pred_out = ", avg({0})".format(label_col_temp)
pred_out += " AS prediction"
label_out = (", train.{label_column_name}{cast_to_int}"
" AS {label_col_temp}").format(**locals())
+ comma_label_out_alias = ', ' + label_col_temp
+ else:
+ pred_out = ""
+ label_out = ""
+ comma_label_out_alias = ""
- if not label_column_name and not output_neighbors:
-
- plpy.error("kNN error: Either label_column_name or "
- "output_neighbors has to be non-NULL.")
-
+ # interim_table picks the 'k' nearest neighbors for each test point
+ if output_neighbors:
+ knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
+ "knn_temp.dist_inverse DESC) AS k_nearest_neighbours ")
+ else:
+ knn_neighbors = ''
plpy.execute("""
CREATE TEMP TABLE {interim_table} AS
SELECT * FROM (
SELECT row_number() over
(partition by {test_id_temp} order by dist) AS r,
- {x_temp_table}.*
+ {test_id_temp},
+ train_id,
+ CASE WHEN dist = 0.0 THEN {max_weight_zero_dist}
+ ELSE 1.0 / dist
+ END AS dist_inverse
+ {comma_label_out_alias}
FROM (
- SELECT test.{test_id} AS {test_id_temp} ,
- train.{point_id} as train_id ,
+ SELECT test.{test_id} AS {test_id_temp},
+ train.{point_id} as train_id,
{fn_dist}(
train.{point_column_name},
test.{test_column_name})
- AS dist {label_out}
+ AS dist
+ {label_out}
FROM {point_source} AS train,
- {test_source} AS test
+ {test_source} AS test
) {x_temp_table}
) {y_temp_table}
- WHERE {y_temp_table}.r <= {k_val}
- """.format(**locals()))
+ WHERE {y_temp_table}.r <= {k}
+ """.format(max_weight_zero_dist=MAX_WEIGHT_ZERO_DIST, **locals()))
- plpy.execute("""
+ sql = """
CREATE TABLE {output_table} AS
{view_def}
- SELECT knn_temp.{test_id_temp} AS id ,
- knn_test.data
+ SELECT
+ knn_temp.{test_id_temp} AS id,
+ knn_test.{test_column_name}
{pred_out}
{knn_neighbors}
- FROM {interim_table} AS knn_temp
- JOIN {test_source} AS knn_test
- ON knn_temp.{test_id_temp} = knn_test.id
- {view_join}
- GROUP BY knn_temp.{test_id_temp},
- knn_test.{test_column_name}
- {view_grp_by}
- ORDER BY knn_temp.{test_id_temp}
- """.format(**locals()))
-
+ FROM
+ pg_temp.{interim_table} AS knn_temp
+ JOIN
+ {test_source} AS knn_test
+ ON knn_temp.{test_id_temp} = knn_test.{test_id}
+ {view_join}
+ GROUP BY knn_temp.{test_id_temp},
+ knn_test.{test_column_name}
+ {view_grp_by}
+ """
+ plpy.execute(sql.format(**locals()))
plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
return
@@ -394,9 +423,6 @@ to furthest from the corresponding test point.
DROP TABLE IF EXISTS knn_result_regression;
SELECT * FROM {schema_madlib}.knn(
'knn_train_data_reg', -- Table of training data
- 'data', -- Col name of training data
- 'id', -- Col Name of id in train data
- 'label', -- Training labels
'knn_test_data', -- Table of test data
'data', -- Col name of test data
'id', -- Col name of id in test data