You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2017/02/07 19:12:55 UTC
incubator-madlib git commit: knn: Fix input validation issues
Repository: incubator-madlib
Updated Branches:
refs/heads/master 735dc35c0 -> 2d5a5edb9
knn: Fix input validation issues
- Some missing input validation cases are included now.
- Remove reduntant test cases from sql, since they are all now
handled in python code.
- There was still a bug wrt ambiguous references to column names
in the query that computes the squared_dist_norm2 in knn.sql_in.
We now use unique strings for variables in that query that fixes it.
- Handle boolean values for classification. MADlib's mode()
function does not handle boolean values, so they have to be converted
to integer before using it with mode().
Closes #98
Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/2d5a5edb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/2d5a5edb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/2d5a5edb
Branch: refs/heads/master
Commit: 2d5a5edb995758d0f8667c2c6c9cf58d9e802d50
Parents: 735dc35
Author: Nandish Jayaram <nj...@apache.org>
Authored: Tue Feb 7 11:06:09 2017 -0800
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Tue Feb 7 11:06:09 2017 -0800
----------------------------------------------------------------------
src/ports/postgres/modules/knn/knn.py_in | 128 ++++++++++----------
src/ports/postgres/modules/knn/knn.sql_in | 97 ++++++++-------
src/ports/postgres/modules/knn/test/knn.sql_in | 16 +--
3 files changed, 128 insertions(+), 113 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2d5a5edb/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index da7f9d6..c0d9cd7 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -44,86 +44,88 @@ UDF_ON_SEGMENT_NOT_ALLOWED = m4_ifdef(<!__UDF_ON_SEGMENT_NOT_ALLOWED__!>, <!True
# ----------------------------------------------------------------------
-def knn_validate_src(schema_madlib, **kwargs):
- trainingSource = kwargs['trainingSource']
- if not trainingSource:
- plpy.error("knn error: Invalid training table name!")
- if not table_exists(trainingSource):
- plpy.error("knn error: Training table {0} does not exist!".format(trainingSource))
- if table_is_empty(trainingSource):
- plpy.error("knn error: Training table {0} is empty!".format(trainingSource))
-
- testSource = kwargs['testSource']
- if not testSource:
- plpy.error("knn error: Invalid test table name!")
- if not table_exists(testSource):
- plpy.error("knn error: Test table {0} does not exist!".format(testSource))
- if table_is_empty(testSource):
- plpy.error("knn error: Test table {0} is empty!".format(testSource))
-
- trainingClassColumn = kwargs['trainingClassColumn']
- trainingFeatureColumn = kwargs['trainingFeatureColumn']
- for c in (trainingClassColumn, trainingFeatureColumn):
+def knn_validate_src(schema_madlib, point_source, point_column_name, label_column_name,
+ test_source, test_column_name, id_column_name, output_table, operation, k, **kwargs):
+ if not operation or operation not in ['c', 'r']:
+ plpy.error("kNN Error: operation='{0}' is an invalid value, has to be 'r' for regression OR 'c' for classification.".format(operation))
+ if not point_source:
+ plpy.error("kNN Error: Invalid training table name.")
+ if not table_exists(point_source):
+ plpy.error("kNN Error: Training table '{0}' does not exist.".format(point_source))
+ if table_is_empty(point_source):
+ plpy.error("kNN Error: Training table '{0}' is empty.".format(point_source))
+
+ if not test_source:
+ plpy.error("kNN Error: Invalid test table name.")
+ if not table_exists(test_source):
+ plpy.error("kNN Error: Test table '{0}' does not exist.".format(test_source))
+ if table_is_empty(test_source):
+ plpy.error("kNN Error: Test table '{0}' is empty.".format(test_source))
+
+ for c in (label_column_name, point_column_name):
if not c:
- plpy.error("knn error: Invalid column name in training table!")
- if not columns_exist_in_table(trainingSource, [c]):
- plpy.error("knn error: " + \
- "Column '{0}' does not exist in {1}!".format(c, trainingSource))
-
- testingFeatureColumn = kwargs['testingFeatureColumn']
- testingIdColumn = kwargs['testingIdColumn']
- for c in (testingFeatureColumn, testingIdColumn):
+ plpy.error("kNN Error: Invalid column name in training table.")
+ if not columns_exist_in_table(point_source, [c]):
+ plpy.error("kNN Error: " + \
+ "Column '{0}' does not exist in {1}.".format(c, point_source))
+
+ for c in (test_column_name, id_column_name):
if not c:
- plpy.error("knn error: Invalid column name in test table!")
- if not columns_exist_in_table(testSource, [c]):
- plpy.error("knn error: " + \
- "Column '{0}' does not exist in {1}!".format(c, testSource))
-
- if not is_col_array(trainingSource, trainingFeatureColumn):
- plpy.error("knn error:" + \
- "'Feature column {0} in train table is not an array!".format(str(trainingFeatureColumn)))
- if not is_col_array(testSource, testingFeatureColumn):
- plpy.error("knn error:" + \
- "'Feature column {0} in test table is not an array!".format(str(testingFeatureColumn)))
-
- if not array_col_has_no_null(trainingSource, trainingFeatureColumn):
- plpy.error("knn error:" + \
- "'Feature column {0} in train table has some NULL values!".format(str(trainingFeatureColumn)))
- if not array_col_has_no_null(testSource, testingFeatureColumn):
- plpy.error("knn error:" + \
- "'Feature column {0} in test table has some NULL values!".format(str(testingFeatureColumn)))
-
- k = int(kwargs['K'])
+ plpy.error("kNN Error: Invalid column name in test table.")
+ if not columns_exist_in_table(test_source, [c]):
+ plpy.error("kNN Error: " + \
+ "Column '{0}' does not exist in {1}.".format(c, test_source))
+
+ if not is_col_array(point_source, point_column_name):
+ plpy.error("kNN Error: " + \
+ "Feature column '{0}' in train table is not an array.".format(point_column_name))
+ if not is_col_array(test_source, test_column_name):
+ plpy.error("kNN Error: " + \
+ "Feature column '{0}' in test table is not an array.".format(test_column_name))
+
+ if not array_col_has_no_null(point_source, point_column_name):
+ plpy.error("kNN Error: " + \
+ "Feature column '{0}' in train table has some NULL values.".format(point_column_name))
+ if not array_col_has_no_null(test_source, test_column_name):
+ plpy.error("kNN Error: " + \
+ "Feature column '{0}' in test table has some NULL values.".format(test_column_name))
+
+ if not output_table:
+ plpy.error("kNN Error: Invalid output table name")
+ if table_exists(output_table):
+ plpy.error("kNN Error: Table '{0}' already exists, cannot use it as output table.".format(output_table))
+
+ if k is None:
+ k = 1
if k<=0:
- plpy.error("knn error:" + \
- "'k' {0} is not valid for knn!".format(str(k)))
+ plpy.error("kNN Error: k='{0}' is an invalid value, must be greater than 0.".format(k))
bound = plpy.execute("""SELECT {k} <= count(*)
- AS bound FROM {tbl}""".format(k=str(k),
- trainingFeatureColumn=trainingFeatureColumn, tbl=trainingSource))[0]['bound']
+ AS bound FROM {tbl}""".format(k=k,
+ point_column_name=point_column_name, tbl=point_source))[0]['bound']
if not bound:
- plpy.error("knn error:" + \
- "'k' {0} is greater than number of rows in training table!".format(str(k)))
+ plpy.error("kNN Error: " + \
+ "k='{0}' is greater than number of rows in training table.".format(k))
- colTypesList = get_cols_and_types(trainingSource)
+ colTypesList = get_cols_and_types(point_source)
colType = ''
for type in colTypesList:
- if type[0] == trainingClassColumn:
+ if type[0] == label_column_name:
colType = type[1]
break
if colType not in ['INTEGER','integer','double precision','DOUBLE PRECISION','float','FLOAT','boolean','BOOLEAN'] :
- plpy.error("knn error:" + \
- "Data type {0} is not valid as label for scope of knn!".format(str(colType)))
+ plpy.error("kNN Error: " + \
+ "Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, label_column_name, point_source))
- colTypesTestList = get_cols_and_types(testSource)
+ colTypesTestList = get_cols_and_types(test_source)
colType = ''
for type in colTypesTestList:
- if type[0] == testingIdColumn:
+ if type[0] == id_column_name:
colType = type[1]
break
if colType not in ['INTEGER','integer'] :
- plpy.error("knn error:" + \
- "Data type {0} is not valid as Id in test table!".format(str(colType)))
-
+ plpy.error("kNN Error: " + \
+ "Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, id_column_name, test_source))
+ return k
# ----------------------------------------------------------------------
m4_changequote(<!`!>, <!'!>)
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2d5a5edb/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index 7ee736b..526c8dd 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -271,19 +271,32 @@ File knn.sql_in documenting the knn SQL functions
@endinternal
*/
-
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
-"trainingSource" VARCHAR,
-"trainingClassColumn" VARCHAR,
-"trainingFeatureColumn" VARCHAR,
-"testSource" VARCHAR,
-"testingIdColumn" VARCHAR,
-"testingFeatureColumn" VARCHAR,
-"K" INTEGER
-) RETURNS VOID AS $$
- PythonFunction(knn, knn, knn_validate_src)
-$$ LANGUAGE plpythonu
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
+ point_source VARCHAR,
+ point_column_name VARCHAR,
+ label_column_name VARCHAR,
+ test_source VARCHAR,
+ test_column_name VARCHAR,
+ id_column_name VARCHAR,
+ output_table VARCHAR,
+ operation VARCHAR,
+ k INTEGER
+) RETURNS INTEGER AS $$
+ PythonFunctionBodyOnly(`knn', `knn')
+ return knn.knn_validate_src(
+ schema_madlib,
+ point_source,
+ point_column_name,
+ label_column_name,
+ test_source,
+ test_column_name,
+ id_column_name,
+ output_table,
+ operation,
+ k
+ )
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
@@ -353,10 +366,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
k INTEGER
) RETURNS VARCHAR AS $$
DECLARE
- class_test_source REGCLASS;
- class_point_source REGCLASS;
l FLOAT;
- outputTableFlag INTEGER;
id INTEGER;
vector DOUBLE PRECISION[];
cur_pid integer;
@@ -364,29 +374,24 @@ DECLARE
returnstring VARCHAR;
x_temp_table VARCHAR;
y_temp_table VARCHAR;
+ k_val INTEGER;
+ label_column_name_unique VARCHAR;
+ test_id VARCHAR;
+ convert_boolean_to_int VARCHAR;
BEGIN
oldClientMinMessages :=
(SELECT setting FROM pg_settings WHERE name = 'client_min_messages');
EXECUTE 'SET client_min_messages TO warning';
- PERFORM MADLIB_SCHEMA.__knn_validate_src(point_source, label_column_name, point_column_name, test_source, id_column_name, test_column_name,k);
- class_test_source := test_source;
- class_point_source := point_source;
- --checks
- IF (k <= 0) THEN
- RAISE EXCEPTION 'KNN error: Number of neighbors k must be a positive integer.';
- END IF;
- IF (operation != 'c' AND operation != 'r') THEN
- RAISE EXCEPTION 'KNN error: The operation has to be r for regression OR c for classification.';
- END IF;
+ SELECT * FROM MADLIB_SCHEMA.__knn_validate_src(point_source, point_column_name, label_column_name, test_source, test_column_name, id_column_name, output_table, operation, k) INTO k_val;
PERFORM MADLIB_SCHEMA.create_schema_pg_temp();
x_temp_table := 'knn_'||md5('knn_'||now()::text||random()::text)||'_temp';
y_temp_table := 'knn_'||md5('knn_'||now()::text||random()::text)||'_temp';
+ label_column_name_unique := 'label'||md5('knn_'||now()::text||random()::text)||'_name';
+ test_id := 'id'||md5('knn_'||now()::text||random()::text)||'_name';
- EXECUTE
- $sql$
- SELECT count(*) FROM information_schema.tables WHERE table_name = '$sql$ || output_table || $sql$'$sql$ into outputTableFlag;
- IF (outputTableFlag != 0) THEN
- RAISE Exception 'KNN error: Output table % already exists.', output_table;
+ convert_boolean_to_int := '';
+ IF (operation = 'c') THEN
+ convert_boolean_to_int := '::INTEGER';
END IF;
EXECUTE
@@ -396,30 +401,38 @@ BEGIN
SELECT *
FROM
(
- SELECT row_number() over (partition by test_id order by dist) as r, $sql$ || x_temp_table || $sql$.*
+ SELECT row_number() over (partition by $sql$ || test_id || $sql$ order by dist) AS r, $sql$ || x_temp_table || $sql$.*
FROM
(
- SELECT test. $sql$ || id_column_name || $sql$ as test_id, MADLIB_SCHEMA.squared_dist_norm2(train.$sql$ || point_column_name || $sql$,test.$sql$ || test_column_name || $sql$) as dist, $sql$ || label_column_name || $sql$ from $sql$ || textin(regclassout(point_source)) || $sql$ AS train, $sql$ || textin(regclassout(test_source)) || $sql$ AS test
+ SELECT test.$sql$ || id_column_name || $sql$ AS $sql$ || test_id || $sql$, MADLIB_SCHEMA.squared_dist_norm2(train.$sql$ || point_column_name || $sql$,test.$sql$ || test_column_name || $sql$) AS dist, train.$sql$ || label_column_name || $sql$ $sql$ || convert_boolean_to_int || $sql$ AS $sql$ || label_column_name_unique || $sql$
+ FROM $sql$ || textin(regclassout(point_source)) || $sql$ AS train, $sql$ || textin(regclassout(test_source)) || $sql$ AS test
)$sql$ || x_temp_table || $sql$
)$sql$ || y_temp_table || $sql$
- WHERE $sql$ || y_temp_table || $sql$.r <= $sql$ || k;
- IF (operation = 'c') THEN
+ WHERE $sql$ || y_temp_table || $sql$.r <= $sql$ || k_val;
+
+ IF (operation = 'c') THEN
EXECUTE
$sql$
- CREATE TABLE $sql$ || output_table || $sql$ AS
- SELECT test_id as id, $sql$ || test_column_name || $sql$, MADLIB_SCHEMA.mode($sql$ || label_column_name || $sql$) as prediction from pg_temp.madlib_knn_interm join $sql$ || textin(regclassout(test_source)) || $sql$ on test_id=$sql$ || id_column_name || $sql$ group by test_id, $sql$ || test_column_name;
- ELSE
+ CREATE TABLE $sql$ || output_table || $sql$ AS
+ SELECT $sql$ || test_id || $sql$ AS id, $sql$ || test_column_name || $sql$, MADLIB_SCHEMA.mode($sql$ || label_column_name_unique || $sql$) AS prediction
+ FROM pg_temp.madlib_knn_interm join $sql$ || textin(regclassout(test_source)) || $sql$ ON $sql$ || test_id || $sql$=$sql$ || id_column_name || $sql$
+ GROUP BY $sql$ || test_id || $sql$, $sql$ || test_column_name;
+ ELSE
EXECUTE
$sql$
- CREATE TABLE $sql$ || output_table || $sql$ AS
- SELECT test_id as id, $sql$ || test_column_name || $sql$ ,avg($sql$ || label_column_name || $sql$) as prediction from pg_temp.madlib_knn_interm join $sql$ || textin(regclassout(test_source)) || $sql$ on test_id=$sql$ || id_column_name || $sql$ group by test_id, $sql$ || test_column_name || $sql$ order by test_id $sql$;
- END IF;
+ CREATE TABLE $sql$ || output_table || $sql$ AS
+ SELECT $sql$ || test_id || $sql$ AS id, $sql$ || test_column_name || $sql$, avg($sql$ || label_column_name_unique || $sql$) AS prediction
+ FROM
+ pg_temp.madlib_knn_interm join $sql$ || textin(regclassout(test_source)) || $sql$ on $sql$ || test_id || $sql$=$sql$ || id_column_name || $sql$
+ GROUP BY $sql$ || test_id || $sql$, $sql$ || test_column_name || $sql$
+ ORDER BY $sql$ || test_id || $sql$ $sql$;
+ END IF;
EXECUTE 'SET client_min_messages TO ' || oldClientMinMessages;
IF (operation = 'c') THEN
- returnstring := 'The classification results have been written to table';
+ returnstring := 'The classification results have been written to output table '||output_table;
ELSE
- returnstring := 'The regression results have been written to table';
+ returnstring := 'The regression results have been written to output table '||output_table;
END IF;
DROP TABLE pg_temp.madlib_knn_interm;
RETURN returnstring;
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/2d5a5edb/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index 3c730ee..1bf6b57 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -55,16 +55,16 @@ copy knn_test_data (id, data) from stdin delimiter '|';
6|{50,45}
\.
drop table if exists madlib_knn_result_classification;
-select madlib.knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_classification','c',3);
-select madlib.assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+select knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_classification','c',3);
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
drop table if exists madlib_knn_result_regression;
-select madlib.knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_regression','r',4);
-select madlib.assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+select knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_regression','r',4);
+select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
drop table if exists madlib_knn_result_classification;
-select madlib.knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_classification','c');
-select madlib.assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
+select knn('knn_train_data','data','label','knn_test_data','data','id','madlib_knn_result_classification','c');
+select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
-select madlib.knn();
-select madlib.knn('help');
+select knn();
+select knn('help');