You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2017/10/12 21:47:30 UTC
madlib git commit: KNN: Provide additional output information
Repository: madlib
Updated Branches:
refs/heads/master cfc54b668 -> 0a7efca73
KNN: Provide additional output information
JIRA: MADLIB-1129
Closes #184
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0a7efca7
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0a7efca7
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0a7efca7
Branch: refs/heads/master
Commit: 0a7efca73bc7d38a60d92b2d5c196d7c449d9525
Parents: cfc54b6
Author: hpandeycodeit <hp...@pivotal.io>
Authored: Thu Oct 12 14:44:40 2017 -0700
Committer: Orhan Kislal <ok...@pivotal.io>
Committed: Thu Oct 12 14:46:39 2017 -0700
----------------------------------------------------------------------
src/ports/postgres/modules/knn/knn.py_in | 151 ++++++++++++++------
src/ports/postgres/modules/knn/knn.sql_in | 128 +++++++++++------
src/ports/postgres/modules/knn/test/knn.sql_in | 50 +++++--
3 files changed, 227 insertions(+), 102 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 4d5d627..477c18d 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -36,19 +36,20 @@ from utilities.utilities import unique_string
from utilities.control import MinWarning
-def knn_validate_src(schema_madlib, point_source, point_column_name,
+def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
- id_column_name, output_table, operation, k, **kwargs):
-
- if not operation or operation not in ['c', 'r']:
- plpy.error("kNN Error: operation='{0}' is an invalid value, has to be"
- " 'r' for regression OR 'c' for classification.".
- format(operation))
+ test_id, output_table, k, output_neighbors, **kwargs):
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
- cols_in_tbl_valid(point_source, (label_column_name, point_column_name), 'kNN')
- cols_in_tbl_valid(test_source, (test_column_name, id_column_name), 'kNN')
+ if label_column_name is not None and label_column_name != '':
+ cols_in_tbl_valid(
+ point_source,
+ (label_column_name,
+ point_column_name),
+ 'kNN')
+ cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
+ cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
if not is_col_array(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table is not"
@@ -75,30 +76,32 @@ def knn_validate_src(schema_madlib, point_source, point_column_name,
plpy.error("kNN Error: k={0} is greater than number of rows in"
" training table.".format(k))
- col_type = get_expr_type(label_column_name, point_source).lower()
- if col_type not in ['integer', 'double precision', 'float', 'boolean']:
- plpy.error("kNN error: Data type '{0}' is not a valid type for"
- " column '{1}' in table '{2}'.".
- format(col_type, label_column_name, point_source))
+ if label_column_name is not None and label_column_name != '':
+ col_type = get_expr_type(label_column_name, point_source).lower()
+ if col_type not in ['integer', 'double precision', 'float', 'boolean']:
+ plpy.error("kNN error: Data type '{0}' is not a valid type for"
+ " column '{1}' in table '{2}'.".
+ format(col_type, label_column_name, point_source))
- col_type_test = get_expr_type(id_column_name, test_source).lower()
+ col_type_test = get_expr_type(test_id, test_source).lower()
if col_type_test not in ['integer']:
plpy.error("kNN Error: Data type '{0}' is not a valid type for"
" column '{1}' in table '{2}'.".
- format(col_type_test, id_column_name, test_source))
+ format(col_type_test, test_id, test_source))
return k
# ------------------------------------------------------------------------------
-def knn(schema_madlib, point_source, point_column_name, label_column_name,
- test_source, test_column_name, id_column_name, output_table,
- operation, k):
+def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name,
+ test_source, test_column_name, test_id, output_table, k, output_neighbors):
"""
KNN function to find the K Nearest neighbours
Args:
@param schema_madlib Name of the Madlib Schema
@param point_source Training data table
@param point_column_name Name of the column with training data
+ @param point_id Name of the column having ids of data
+ point in train data table
points.
@param label_column_name Name of the column with labels/values
of training data points.
@@ -106,7 +109,7 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
data points.
@param test_column_name Name of the column with testing data
points.
- @param id_column_name Name of the column having ids of data
+ @param test_id Name of the column having ids of data
points in test data table.
@param output_table Name of the table to store final
results.
@@ -115,33 +118,85 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
'r' for regression
@param k default: 1. Number of nearest
neighbors to consider
+ @output_neighbours Outputs the list of k-nearest neighbors
+ that were used in the voting/averaging.
Returns:
VARCHAR Name of the output table.
"""
with MinWarning('warning'):
k_val = knn_validate_src(schema_madlib, point_source,
- point_column_name, label_column_name,
- test_source, test_column_name, id_column_name,
- output_table, operation, k)
+ point_column_name, point_id, label_column_name,
+ test_source, test_column_name, test_id,
+ output_table, k, output_neighbors)
x_temp_table = unique_string(desp='x_temp_table')
y_temp_table = unique_string(desp='y_temp_table')
label_col_temp = unique_string(desp='label_col_temp')
- test_id = unique_string(desp='test_id')
+ test_id_temp = unique_string(desp='test_id_temp')
+
+ if output_neighbors is None or '':
+ output_neighbors = False
- is_classification = operation == 'c'
interim_table = unique_string(desp='interim_table')
+
+ if label_column_name is None or label_column_name == '':
+ plpy.execute(
+ """
+ CREATE TEMP TABLE {interim_table} AS
+ SELECT * FROM
+ (
+ SELECT row_number() over
+ (partition by {test_id_temp} order by dist) AS r,
+ {x_temp_table}.*
+ FROM
+ (
+ SELECT test.{test_id} AS {test_id_temp} ,
+ train.id as train_id ,
+ {schema_madlib}.squared_dist_norm2(
+ train.{point_column_name},
+ test.{test_column_name})
+ AS dist
+ FROM {point_source} AS train, {test_source} AS test
+ ) {x_temp_table}
+ ) {y_temp_table}
+ WHERE {y_temp_table}.r <= {k_val}
+ """.format(**locals()))
+ plpy.execute(
+ """
+ CREATE TABLE {output_table} AS
+ SELECT {test_id_temp} AS id, {test_column_name} ,
+ CASE WHEN {output_neighbors}
+ THEN array_agg(knn_temp.train_id)
+ ELSE NULL END AS k_nearest_neighbours
+ FROM pg_temp.{interim_table} AS knn_temp
+ join
+ {test_source} AS knn_test ON
+ knn_temp.{test_id_temp} = knn_test.{test_id}
+ GROUP BY {test_id_temp} , {test_column_name}
+ """.format(**locals()))
+ return
+
+ is_classification = False
+ label_column_type = get_expr_type(
+ label_column_name, point_source).lower()
+ if label_column_type in ['boolean', 'integer', 'text']:
+ is_classification = True
+ convert_boolean_to_int = '::INTEGER'
+ else:
+ is_classification = False
+
plpy.execute(
"""
CREATE TEMP TABLE {interim_table} AS
SELECT * FROM
(
SELECT row_number() over
- (partition by {test_id} order by dist) AS r,
+ (partition by {test_id_temp} order by dist) AS r,
{x_temp_table}.*
FROM
(
- SELECT test.{id_column_name} AS {test_id} ,
+ SELECT test.{test_id} AS {test_id_temp} ,
+ train.id as train_id ,
{schema_madlib}.squared_dist_norm2(
train.{point_column_name},
test.{test_column_name})
@@ -155,26 +210,30 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
""".format(cast_to_int='::INTEGER' if is_classification else '',
**locals()))
+ knn_create_table = 'CREATE TABLE ' + output_table + ' AS ' \
+ 'SELECT ' + test_id_temp + ' AS id,' + test_column_name + ','
+ knn_pred_class = schema_madlib + \
+ '.mode(' + label_col_temp + ') AS prediction'
+ knn_pred_reg = 'avg(' + label_col_temp + ') AS prediction'
+ knn_neighbours = ', array_agg(knn_temp.train_id) AS k_nearest_neighbours '
+ knn_group_by = 'FROM pg_temp.' + interim_table + ' AS knn_temp join ' \
+ + test_source + ' AS knn_test ON knn_temp.' + test_id_temp + '= knn_test.' \
+ + test_id + ' GROUP BY ' + test_id_temp + ', ' + test_column_name
+
if is_classification:
- plpy.execute(
- """
- CREATE TABLE {output_table} AS
- SELECT {test_id} AS id, {test_column_name},
- {schema_madlib}.mode({label_col_temp}) AS prediction
- FROM {interim_table} JOIN {test_source}
- ON {test_id} = {id_column_name}
- GROUP BY {test_id}, {test_column_name}
- """.format(**locals()))
+ if output_neighbors:
+ plpy.execute("""{knn_create_table}{knn_pred_class}
+ {knn_neighbours}{knn_group_by}""".format(**locals()))
+ else:
+ plpy.execute(""" {knn_create_table}{knn_pred_class}
+ {knn_group_by}""".format(**locals()))
else:
- plpy.execute(
- """
- CREATE TABLE {output_table} AS
- SELECT {test_id} AS id, {test_column_name},
- AVG({label_col_temp}) AS prediction
- FROM
- {interim_table} JOIN {test_source}
- ON {test_id} = {id_column_name}
- GROUP BY {test_id}, {test_column_name}
- """.format(**locals()))
+ if output_neighbors:
+ plpy.execute(""" {knn_create_table}{knn_pred_reg}
+ {knn_neighbours}{knn_group_by}""".format(**locals()))
+ else:
+ plpy.execute("""{knn_create_table}{knn_pred_reg}
+ {knn_group_by}""".format(**locals()))
+
plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
# ------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index ca5be88..dfd2374 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -71,13 +71,14 @@ neighbors of the given test point.
<pre class="syntax">
knn( point_source,
point_column_name,
+ point_id,
label_column_name,
test_source,
test_column_name,
- id_column_name,
+ test_id,
output_table,
- operation,
- k
+ k,
+ output_neighbors
)
</pre>
@@ -93,8 +94,17 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
<dt>point_column_name</dt>
<dd>TEXT. Name of the column with training data points.</dd>
+<dt>point_id</dt>
+<dd>TEXT. Name of the column in 'point_source’ containing source data ids.
+The ids are of type INTEGER with no duplicates. They do not need to be contiguous.
+This parameter must be used if the list of nearest neighbors are to be output, i.e.,
+if the parameter 'output_neighbors' below is TRUE or if 'label_column_name' is NULL.
+
<dt>label_column_name</dt>
-<dd>TEXT. Name of the column with labels/values of training data points.</dd>
+<dd>TEXT. Name of the column with labels/values of training data points.
+If Boolean, integer or text types will run knn classification, else if
+double precision values will run knn regression.
+If you set this to NULL will return neighbors only without doing classification or regression.</dd>
<dt>test_source</dt>
<dd>TEXT. Name of the table containing the test data points.
@@ -106,7 +116,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
<dt>test_column_name</dt>
<dd>TEXT. Name of the column with testing data points.</dd>
-<dt>id_column_name</dt>
+<dt>test_id</dt>
<dd>TEXT. Name of the column having ids of data points in test data table.</dd>
<dt>output_table</dt>
@@ -117,7 +127,12 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
<dt>k (optional)</dt>
<dd>INTEGER. default: 1. Number of nearest neighbors to consider.
-For classification, should be an odd number to break ties.</dd>
+For classification, should be an odd number to break ties.
+otherwise result may depend on ordering of the input data.</dd>
+
+<dt>output_neighbors (optional) </dt>
+<dd>BOOLEAN default: FALSE. Outputs the list of k-nearest
+neighbors that were used in the voting/averaging.</dd>
</dl>
@@ -145,15 +160,35 @@ The output of the KNN module is a table with the following columns:
@anchor examples
@examp
--# Prepare some training data:
+-# Prepare some training data for classification:
<pre class="example">
DROP TABLE IF EXISTS knn_train_data;
CREATE TABLE knn_train_data (
id integer,
data integer[],
- label float
+ label integer
);
INSERT INTO knn_train_data VALUES
+(1, '{1,1}', 1),
+(2, '{2,2}', 1),
+(3, '{3,3}', 1),
+(4, '{4,4}', 1),
+(5, '{4,5}', 1),
+(6, '{20,50}', 0),
+(7, '{10,31}', 0),
+(8, '{81,13}', 0),
+(9, '{1,111}', 0);
+</pre>
+
+-# Prepare some training data for regression:
+<pre class="example">
+DROP TABLE IF EXISTS knn_train_data_reg;
+CREATE TABLE knn_train_data_reg (
+ id integer,
+ data integer[],
+ label float
+ );
+INSERT INTO knn_train_data_reg VALUES
(1, '{1,1}', 1.0),
(2, '{2,2}', 1.0),
(3, '{3,3}', 1.0),
@@ -187,26 +222,27 @@ DROP TABLE IF EXISTS madlib_knn_result_classification;
SELECT * FROM madlib.knn(
'knn_train_data', -- Table of training data
'data', -- Col name of training data
+ 'id', -- Col Name of id in train data
'label', -- Training labels
'knn_test_data', -- Table of test data
'data', -- Col name of test data
'id', -- Col name of id in test data
'madlib_knn_result_classification', -- Output table
- 'c', -- Classification
3 -- Number of nearest neighbours
+ True -- True if you want to show Nearest-Neighbors, False otherwise
);
SELECT * from madlib_knn_result_classification ORDER BY id;
</pre>
Result:
<pre class="result">
- id | data | prediction
-----+---------+------------
- 1 | {2,1} | 1
- 2 | {2,6} | 1
- 3 | {15,40} | 0
- 4 | {12,1} | 1
- 5 | {2,90} | 0
- 6 | {50,45} | 0
+ id | data | prediction | k_nearest_neighbours
+----+---------+------------+----------------------
+ 1 | {2,1} | 1 | {1,2,3}
+ 2 | {2,6} | 1 | {5,4,3}
+ 3 | {15,40} | 0 | {7,6,5}
+ 4 | {12,1} | 1 | {4,5,3}
+ 5 | {2,90} | 0 | {9,6,7}
+ 6 | {50,45} | 0 | {6,7,8}
(6 rows)
</pre>
@@ -214,28 +250,29 @@ Result:
<pre class="example">
DROP TABLE IF EXISTS madlib_knn_result_regression;
SELECT * FROM madlib.knn(
- 'knn_train_data', -- Table of training data
+ 'knn_train_data_reg', -- Table of training data
'data', -- Col name of training data
+ 'id', -- Col Name of id in train data
'label', -- Training labels
'knn_test_data', -- Table of test data
'data', -- Col name of test data
'id', -- Col name of id in test data
'madlib_knn_result_regression', -- Output table
- 'r', -- Regressions
- 3 -- Number of nearest neighbours
+ 3, -- Number of nearest neighbours
+ True -- True if you want to show Nearest-Neighbors, False otherwise
);
SELECT * from madlib_knn_result_regression ORDER BY id;
</pre>
Result:
<pre class="result">
- id | data | prediction
-----+---------+-------------------
- 1 | {2,1} | 1
- 2 | {2,6} | 1
- 3 | {15,40} | 0.333333333333333
- 4 | {12,1} | 1
- 5 | {2,90} | 0
- 6 | {50,45} | 0
+ id | data | prediction | k_nearest_neighbours
+----+---------+-------------------+----------------------
+ 1 | {2,1} | 1 | {1,2,3}
+ 2 | {2,6} | 1 | {5,4,3}
+ 3 | {15,40} | 0.333333333333333 | {7,6,5}
+ 4 | {12,1} | 1 | {4,5,3}
+ 5 | {2,90} | 0 | {9,6,7}
+ 6 | {50,45} | 0 | {6,7,8}
(6 rows)
</pre>
@@ -281,7 +318,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
label_column_name VARCHAR,
test_source VARCHAR,
test_column_name VARCHAR,
- id_column_name VARCHAR,
+ test_id VARCHAR,
output_table VARCHAR,
operation VARCHAR,
k INTEGER
@@ -294,7 +331,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
label_column_name,
test_source,
test_column_name,
- id_column_name,
+ test_id,
output_table,
operation,
k
@@ -308,7 +345,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
) RETURNS VOID AS $$
BEGIN
IF arg1 = 'help' OR arg1 = 'usage' OR arg1 = '?' THEN
- RAISE NOTICE
+ RAISE NOTICE
'
-----------------------------------------------------------------------
USAGE
@@ -316,13 +353,14 @@ BEGIN
SELECT {schema_madlib}.knn(
point_source, -- Training data table having training features as vector column and labels
point_column_name, -- Name of column having feature vectors in training data table
+ point_id, -- Name of column having feature vector Ids in train data table
label_column_name, -- Name of column having actual label/vlaue for corresponding feature vector in training data table
test_source, -- Test data table having features as vector column. Id of features is mandatory
test_column_name, -- Name of column having feature vectors in test data table
- id_column_name, -- Name of column having feature vector Ids in test data table
+ test_id, -- Name of column having feature vector Ids in test data table
output_table, -- Name of output table
- operation, -- c for classification task, r for regression task
- k -- value of k. Default will go as 1
+ k, -- value of k. Default will go as 1
+ output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
);
-----------------------------------------------------------------------
@@ -333,6 +371,7 @@ The output of the KNN module is a table with the following columns:
id The ids of test data points.
test_column_name The test data points.
prediction The output of KNN- label in case of classification, average value in case of regression.
+k_nearest_neighbours The list of k-nearest neighbors that were used in the voting/averaging.
';
END IF;
END;
@@ -362,26 +401,28 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
point_source VARCHAR,
point_column_name VARCHAR,
+ point_id VARCHAR,
label_column_name VARCHAR,
test_source VARCHAR,
test_column_name VARCHAR,
- id_column_name VARCHAR,
+ test_id VARCHAR,
output_table VARCHAR,
- operation VARCHAR,
- k INTEGER
+ k INTEGER,
+ output_neighbors Boolean
) RETURNS VARCHAR AS $$
PythonFunctionBodyOnly(`knn', `knn')
return knn.knn(
schema_madlib,
point_source,
point_column_name,
+ point_id,
label_column_name,
test_source,
test_column_name,
- id_column_name,
+ test_id,
output_table,
- operation,
- k
+ k,
+ output_neighbors
)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -390,17 +431,18 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
point_source VARCHAR,
point_column_name VARCHAR,
+ point_id VARCHAR,
label_column_name VARCHAR,
test_source VARCHAR,
test_column_name VARCHAR,
- id_column_name VARCHAR,
+ test_id VARCHAR,
output_table VARCHAR,
- operation VARCHAR
+ output_neighbors Boolean
) RETURNS VARCHAR AS $$
DECLARE
returnstring VARCHAR;
BEGIN
- returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1);
+ returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9);
RETURN returnstring;
END;
$$ LANGUAGE plpgsql VOLATILE
http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index c7d6798..fa38751 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -26,12 +26,29 @@ m4_include(`SQLCommon.m4')
* FIXME: Verify results
* -------------------------------------------------------------------------- */
-drop table if exists "KNN_TRAIN_DATA";
-create table "KNN_TRAIN_DATA" (
- id integer,
- "DATA" integer[],
- label float);
-copy "KNN_TRAIN_DATA" (id, "DATA", label) from stdin delimiter '|';
+drop table if exists knn_train_data;
+create table knn_train_data (
+id integer,
+data integer[],
+label integer);
+copy knn_train_data (id, data, label) from stdin delimiter '|';
+1|{1,1}|1
+2|{2,2}|1
+3|{3,3}|1
+4|{4,4}|1
+5|{4,5}|1
+6|{20,50}|0
+7|{10,31}|0
+8|{81,13}|0
+9|{1,111}|0
+\.
+DROP TABLE IF EXISTS knn_train_data_reg;
+CREATE TABLE knn_train_data_reg (
+ id integer,
+ data integer[],
+ label float
+ );
+COPY knn_train_data_reg (id, data, label) from stdin delimiter '|';
1|{1,1}|1.0
2|{2,2}|1.0
3|{3,3}|1.0
@@ -42,11 +59,10 @@ copy "KNN_TRAIN_DATA" (id, "DATA", label) from stdin delimiter '|';
8|{81,13}|0.0
9|{1,111}|0.0
\.
-drop table if exists knn_test_data;
create table knn_test_data (
- id integer,
- "DATA" integer[]);
-copy knn_test_data (id, "DATA") from stdin delimiter '|';
+id integer,
+data integer[]);
+copy knn_test_data (id, data) from stdin delimiter '|';
1|{2,1}
2|{2,6}
3|{15,40}
@@ -55,15 +71,23 @@ copy knn_test_data (id, "DATA") from stdin delimiter '|';
6|{50,45}
\.
drop table if exists madlib_knn_result_classification;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_classification','c',3);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False);
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True);
+select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
+
drop table if exists madlib_knn_result_regression;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_regression','r',4);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False);
select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
+
drop table if exists madlib_knn_result_classification;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_classification','c');
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',False);
select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
select knn();