You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2017/10/12 21:47:30 UTC

madlib git commit: KNN: Provide additional output information

Repository: madlib
Updated Branches:
  refs/heads/master cfc54b668 -> 0a7efca73


KNN: Provide additional output information

JIRA: MADLIB-1129

Closes #184


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0a7efca7
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0a7efca7
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0a7efca7

Branch: refs/heads/master
Commit: 0a7efca73bc7d38a60d92b2d5c196d7c449d9525
Parents: cfc54b6
Author: hpandeycodeit <hp...@pivotal.io>
Authored: Thu Oct 12 14:44:40 2017 -0700
Committer: Orhan Kislal <ok...@pivotal.io>
Committed: Thu Oct 12 14:46:39 2017 -0700

----------------------------------------------------------------------
 src/ports/postgres/modules/knn/knn.py_in       | 151 ++++++++++++++------
 src/ports/postgres/modules/knn/knn.sql_in      | 128 +++++++++++------
 src/ports/postgres/modules/knn/test/knn.sql_in |  50 +++++--
 3 files changed, 227 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/knn.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.py_in b/src/ports/postgres/modules/knn/knn.py_in
index 4d5d627..477c18d 100644
--- a/src/ports/postgres/modules/knn/knn.py_in
+++ b/src/ports/postgres/modules/knn/knn.py_in
@@ -36,19 +36,20 @@ from utilities.utilities import unique_string
 from utilities.control import MinWarning
 
 
-def knn_validate_src(schema_madlib, point_source, point_column_name,
+def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
                      label_column_name, test_source, test_column_name,
-                     id_column_name, output_table, operation, k, **kwargs):
-
-    if not operation or operation not in ['c', 'r']:
-        plpy.error("kNN Error: operation='{0}' is an invalid value, has to be"
-                   " 'r' for regression OR 'c' for classification.".
-                   format(operation))
+                     test_id, output_table, k, output_neighbors, **kwargs):
     input_tbl_valid(point_source, 'kNN')
     input_tbl_valid(test_source, 'kNN')
     output_tbl_valid(output_table, 'kNN')
-    cols_in_tbl_valid(point_source, (label_column_name, point_column_name), 'kNN')
-    cols_in_tbl_valid(test_source, (test_column_name, id_column_name), 'kNN')
+    if label_column_name is not None and label_column_name != '':
+        cols_in_tbl_valid(
+            point_source,
+            (label_column_name,
+             point_column_name),
+            'kNN')
+    cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
+    cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
 
     if not is_col_array(point_source, point_column_name):
         plpy.error("kNN Error: Feature column '{0}' in train table is not"
@@ -75,30 +76,32 @@ def knn_validate_src(schema_madlib, point_source, point_column_name,
         plpy.error("kNN Error: k={0} is greater than number of rows in"
                    " training table.".format(k))
 
-    col_type = get_expr_type(label_column_name, point_source).lower()
-    if col_type not in ['integer', 'double precision', 'float', 'boolean']:
-        plpy.error("kNN error: Data type '{0}' is not a valid type for"
-                   " column '{1}' in table '{2}'.".
-                   format(col_type, label_column_name, point_source))
+    if label_column_name is not None and label_column_name != '':
+        col_type = get_expr_type(label_column_name, point_source).lower()
+        if col_type not in ['integer', 'double precision', 'float', 'boolean']:
+            plpy.error("kNN error: Data type '{0}' is not a valid type for"
+                       " column '{1}' in table '{2}'.".
+                       format(col_type, label_column_name, point_source))
 
-    col_type_test = get_expr_type(id_column_name, test_source).lower()
+    col_type_test = get_expr_type(test_id, test_source).lower()
     if col_type_test not in ['integer']:
         plpy.error("kNN Error: Data type '{0}' is not a valid type for"
                    " column '{1}' in table '{2}'.".
-                   format(col_type_test, id_column_name, test_source))
+                   format(col_type_test, test_id, test_source))
     return k
 # ------------------------------------------------------------------------------
 
 
-def knn(schema_madlib, point_source, point_column_name, label_column_name,
-        test_source, test_column_name, id_column_name, output_table,
-        operation, k):
+def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name,
+        test_source, test_column_name, test_id, output_table, k, output_neighbors):
     """
         KNN function to find the K Nearest neighbours
         Args:
             @param schema_madlib        Name of the Madlib Schema
             @param point_source         Training data table
             @param point_column_name    Name of the column with training data
+            @param point_id             Name of the column having ids of data
+                                        point in train data table
                                         points.
             @param label_column_name    Name of the column with labels/values
                                         of training data points.
@@ -106,7 +109,7 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
                                         data points.
             @param test_column_name     Name of the column with testing data
                                         points.
-            @param id_column_name       Name of the column having ids of data
+            @param test_id              Name of the column having ids of data
                                         points in test data table.
             @param output_table         Name of the table to store final
                                         results.
@@ -115,33 +118,85 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
                                         'r' for regression
             @param k                    default: 1. Number of nearest
                                         neighbors to consider
+            @output_neighbours          Outputs the list of k-nearest neighbors
+                                        that were used in the voting/averaging.
         Returns:
             VARCHAR                     Name of the output table.
     """
     with MinWarning('warning'):
         k_val = knn_validate_src(schema_madlib, point_source,
-                                 point_column_name, label_column_name,
-                                 test_source, test_column_name, id_column_name,
-                                 output_table, operation, k)
+                                 point_column_name, point_id, label_column_name,
+                                 test_source, test_column_name, test_id,
+                                 output_table, k, output_neighbors)
 
         x_temp_table = unique_string(desp='x_temp_table')
         y_temp_table = unique_string(desp='y_temp_table')
         label_col_temp = unique_string(desp='label_col_temp')
-        test_id = unique_string(desp='test_id')
+        test_id_temp = unique_string(desp='test_id_temp')
+
+        if output_neighbors is None or '':
+            output_neighbors = False
 
-        is_classification = operation == 'c'
         interim_table = unique_string(desp='interim_table')
+
+        if label_column_name is None or label_column_name == '':
+            plpy.execute(
+                """
+                CREATE TEMP TABLE {interim_table} AS
+                SELECT * FROM
+                    (
+                    SELECT row_number() over
+                            (partition by {test_id_temp} order by dist) AS r,
+                            {x_temp_table}.*
+                    FROM
+                        (
+                        SELECT test.{test_id} AS {test_id_temp} ,
+                            train.id as train_id ,
+                            {schema_madlib}.squared_dist_norm2(
+                                train.{point_column_name},
+                                test.{test_column_name})
+                            AS dist
+                            FROM {point_source} AS train, {test_source} AS test
+                        ) {x_temp_table}
+                    ) {y_temp_table}
+                WHERE {y_temp_table}.r <= {k_val}
+                """.format(**locals()))
+            plpy.execute(
+                """
+                CREATE TABLE {output_table} AS
+                    SELECT {test_id_temp} AS id, {test_column_name} ,
+                        CASE WHEN {output_neighbors}
+                        THEN array_agg(knn_temp.train_id)
+                        ELSE NULL END  AS k_nearest_neighbours
+                    FROM pg_temp.{interim_table} AS knn_temp
+                    join
+                        {test_source} AS knn_test ON
+                        knn_temp.{test_id_temp} = knn_test.{test_id}
+                    GROUP BY {test_id_temp}  ,  {test_column_name}
+                """.format(**locals()))
+            return
+
+        is_classification = False
+        label_column_type = get_expr_type(
+            label_column_name, point_source).lower()
+        if label_column_type in ['boolean', 'integer', 'text']:
+            is_classification = True
+            convert_boolean_to_int = '::INTEGER'
+        else:
+            is_classification = False
+
         plpy.execute(
             """
             CREATE TEMP TABLE {interim_table} AS
             SELECT * FROM
                 (
                 SELECT row_number() over
-                        (partition by {test_id} order by dist) AS r,
+                        (partition by {test_id_temp} order by dist) AS r,
                         {x_temp_table}.*
                 FROM
                     (
-                    SELECT test.{id_column_name} AS {test_id} ,
+                    SELECT test.{test_id} AS {test_id_temp} ,
+                        train.id as train_id ,
                         {schema_madlib}.squared_dist_norm2(
                             train.{point_column_name},
                             test.{test_column_name})
@@ -155,26 +210,30 @@ def knn(schema_madlib, point_source, point_column_name, label_column_name,
             """.format(cast_to_int='::INTEGER' if is_classification else '',
                        **locals()))
 
+        knn_create_table = 'CREATE TABLE ' + output_table + ' AS '  \
+            'SELECT ' + test_id_temp + ' AS id,' + test_column_name + ','
+        knn_pred_class = schema_madlib + \
+            '.mode(' + label_col_temp + ') AS prediction'
+        knn_pred_reg = 'avg(' + label_col_temp + ') AS prediction'
+        knn_neighbours = ', array_agg(knn_temp.train_id) AS k_nearest_neighbours '
+        knn_group_by = 'FROM pg_temp.' + interim_table + ' AS knn_temp join ' \
+            + test_source + ' AS knn_test ON  knn_temp.' + test_id_temp + '= knn_test.' \
+            + test_id + ' GROUP BY ' + test_id_temp + ', ' + test_column_name
+
         if is_classification:
-            plpy.execute(
-                """
-                CREATE TABLE {output_table} AS
-                SELECT {test_id} AS id, {test_column_name},
-                       {schema_madlib}.mode({label_col_temp}) AS prediction
-                FROM {interim_table} JOIN {test_source}
-                     ON {test_id} = {id_column_name}
-                GROUP BY {test_id}, {test_column_name}
-                """.format(**locals()))
+            if output_neighbors:
+                plpy.execute("""{knn_create_table}{knn_pred_class}
+                    {knn_neighbours}{knn_group_by}""".format(**locals()))
+            else:
+                plpy.execute(""" {knn_create_table}{knn_pred_class}
+                    {knn_group_by}""".format(**locals()))
         else:
-            plpy.execute(
-                """
-                CREATE TABLE {output_table} AS
-                SELECT {test_id} AS id, {test_column_name},
-                       AVG({label_col_temp}) AS prediction
-                FROM
-                    {interim_table} JOIN {test_source}
-                    ON {test_id} = {id_column_name}
-                GROUP BY {test_id}, {test_column_name}
-                """.format(**locals()))
+            if output_neighbors:
+                plpy.execute(""" {knn_create_table}{knn_pred_reg}
+                    {knn_neighbours}{knn_group_by}""".format(**locals()))
+            else:
+                plpy.execute("""{knn_create_table}{knn_pred_reg}
+                    {knn_group_by}""".format(**locals()))
+
         plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
 # ------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/knn.sql_in b/src/ports/postgres/modules/knn/knn.sql_in
index ca5be88..dfd2374 100644
--- a/src/ports/postgres/modules/knn/knn.sql_in
+++ b/src/ports/postgres/modules/knn/knn.sql_in
@@ -71,13 +71,14 @@ neighbors of the given test point.
 <pre class="syntax">
 knn( point_source,
      point_column_name,
+     point_id,
      label_column_name,
      test_source,
      test_column_name,
-     id_column_name,
+     test_id,
      output_table,
-     operation,
-     k
+     k,
+     output_neighbors
    )
 </pre>
 
@@ -93,8 +94,17 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 <dt>point_column_name</dt>
 <dd>TEXT. Name of the column with training data points.</dd>
 
+<dt>point_id</dt>
+<dd>TEXT. Name of the column in 'point_source’ containing source data ids.
+The ids are of type INTEGER with no duplicates. They do not need to be contiguous. 
+This parameter must be used if the list of nearest neighbors are to be output, i.e., 
+if the parameter 'output_neighbors' below is TRUE or if 'label_column_name' is NULL.
+
 <dt>label_column_name</dt>
-<dd>TEXT. Name of the column with labels/values of training data points.</dd>
+<dd>TEXT. Name of the column with labels/values of training data points.
+If Boolean, integer or text types will run knn classification, else if 
+double precision values will run knn regression.  
+If you set this to NULL will return neighbors only without doing classification or regression.</dd>
 
 <dt>test_source</dt>
 <dd>TEXT. Name of the table containing the test data points.
@@ -106,7 +116,7 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 <dt>test_column_name</dt>
 <dd>TEXT. Name of the column with testing data points.</dd>
 
-<dt>id_column_name</dt>
+<dt>test_id</dt>
 <dd>TEXT. Name of the column having ids of data points in test data table.</dd>
 
 <dt>output_table</dt>
@@ -117,7 +127,12 @@ in a column of type <tt>DOUBLE PRECISION[]</tt>.
 
 <dt>k (optional)</dt>
 <dd>INTEGER. default: 1. Number of nearest neighbors to consider.
-For classification, should be an odd number to break ties.</dd>
+For classification, should be an odd number to break ties.
+otherwise result may depend on ordering of the input data.</dd>
+
+<dt>output_neighbors (optional) </dt>
+<dd>BOOLEAN default: FALSE. Outputs the list of k-nearest 
+neighbors that were used in the voting/averaging.</dd>
 
 </dl>
 
@@ -145,15 +160,35 @@ The output of the KNN module is a table with the following columns:
 @anchor examples
 @examp
 
--#  Prepare some training data:
+-#  Prepare some training data for classification:
 <pre class="example">
 DROP TABLE IF EXISTS knn_train_data;
 CREATE TABLE knn_train_data (
                     id integer, 
                     data integer[], 
-                    label float
+                    label integer
                     );
 INSERT INTO knn_train_data VALUES
+(1, '{1,1}', 1),
+(2, '{2,2}', 1),
+(3, '{3,3}', 1),
+(4, '{4,4}', 1),
+(5, '{4,5}', 1),
+(6, '{20,50}', 0),
+(7, '{10,31}', 0),
+(8, '{81,13}', 0),
+(9, '{1,111}', 0);
+</pre>
+
+-#  Prepare some training data for regression:
+<pre class="example">
+DROP TABLE IF EXISTS knn_train_data_reg;
+CREATE TABLE knn_train_data_reg (
+                    id integer, 
+                    data integer[], 
+                    label float
+                    );
+INSERT INTO knn_train_data_reg VALUES
 (1, '{1,1}', 1.0),
 (2, '{2,2}', 1.0),
 (3, '{3,3}', 1.0),
@@ -187,26 +222,27 @@ DROP TABLE IF EXISTS madlib_knn_result_classification;
 SELECT * FROM madlib.knn( 
                 'knn_train_data',      -- Table of training data
                 'data',                -- Col name of training data
+                'id',                  -- Col Name of id in train data 
                 'label',               -- Training labels
                 'knn_test_data',       -- Table of test data
                 'data',                -- Col name of test data
                 'id',                  -- Col name of id in test data 
                 'madlib_knn_result_classification',  -- Output table
-                'c',                   -- Classification
                  3                     -- Number of nearest neighbours
+                 True                  -- True if you want to show Nearest-Neighbors, False otherwise
                 );
 SELECT * from madlib_knn_result_classification ORDER BY id;
 </pre>
 Result:
 <pre class="result">
- id |  data   | prediction 
-----+---------+------------
-  1 | {2,1}   |          1
-  2 | {2,6}   |          1
-  3 | {15,40} |          0
-  4 | {12,1}  |          1
-  5 | {2,90}  |          0
-  6 | {50,45} |          0
+ id |  data   | prediction | k_nearest_neighbours 
+----+---------+------------+----------------------
+  1 | {2,1}   |          1 | {1,2,3}
+  2 | {2,6}   |          1 | {5,4,3}
+  3 | {15,40} |          0 | {7,6,5}
+  4 | {12,1}  |          1 | {4,5,3}
+  5 | {2,90}  |          0 | {9,6,7}
+  6 | {50,45} |          0 | {6,7,8}
 (6 rows)
 </pre>
 
@@ -214,28 +250,29 @@ Result:
 <pre class="example">
 DROP TABLE IF EXISTS madlib_knn_result_regression;
 SELECT * FROM madlib.knn( 
-                'knn_train_data',      -- Table of training data
+                'knn_train_data_reg',  -- Table of training data
                 'data',                -- Col name of training data
+                'id',                  -- Col Name of id in train data 
                 'label',               -- Training labels
                 'knn_test_data',       -- Table of test data
                 'data',                -- Col name of test data
                 'id',                  -- Col name of id in test data 
                 'madlib_knn_result_regression',  -- Output table
-                'r',                   -- Regressions
-                 3                     -- Number of nearest neighbours
+                 3,                    -- Number of nearest neighbours
+                True                   -- True if you want to show Nearest-Neighbors, False otherwise
                 );
 SELECT * from madlib_knn_result_regression ORDER BY id;
 </pre>
 Result:
 <pre class="result">
- id |  data   |    prediction     
-----+---------+-------------------
-  1 | {2,1}   |                 1
-  2 | {2,6}   |                 1
-  3 | {15,40} | 0.333333333333333
-  4 | {12,1}  |                 1
-  5 | {2,90}  |                 0
-  6 | {50,45} |                 0
+ id |  data   |    prediction     | k_nearest_neighbours 
+----+---------+-------------------+----------------------
+  1 | {2,1}   |                 1 | {1,2,3}
+  2 | {2,6}   |                 1 | {5,4,3}
+  3 | {15,40} | 0.333333333333333 | {7,6,5}
+  4 | {12,1}  |                 1 | {4,5,3}
+  5 | {2,90}  |                 0 | {9,6,7}
+  6 | {50,45} |                 0 | {6,7,8}
 (6 rows)
 </pre>
 
@@ -281,7 +318,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
     label_column_name VARCHAR,
     test_source VARCHAR,
     test_column_name VARCHAR,
-    id_column_name VARCHAR,
+    test_id VARCHAR,
     output_table VARCHAR,
     operation VARCHAR,
     k INTEGER
@@ -294,7 +331,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.__knn_validate_src(
         label_column_name,
         test_source,
         test_column_name,
-        id_column_name,
+        test_id,
         output_table,
         operation,
         k
@@ -308,7 +345,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
 ) RETURNS VOID AS $$
 BEGIN
     IF arg1 = 'help' OR arg1 = 'usage' OR arg1 = '?' THEN
-	RAISE NOTICE
+    RAISE NOTICE
 '
 -----------------------------------------------------------------------
                             USAGE
@@ -316,13 +353,14 @@ BEGIN
 SELECT {schema_madlib}.knn(
     point_source,       -- Training data table having training features as vector column and labels
     point_column_name,  -- Name of column having feature vectors in training data table
+    point_id,           -- Name of column having feature vector Ids in train data table
     label_column_name,  -- Name of column having actual label/vlaue for corresponding feature vector in training data table
     test_source,        -- Test data table having features as vector column. Id of features is mandatory
     test_column_name,   -- Name of column having feature vectors in test data table
-    id_column_name,     -- Name of column having feature vector Ids in test data table
+    test_id,     -- Name of column having feature vector Ids in test data table
     output_table,       -- Name of output table
-    operation,          -- c for classification task, r for regression task
-    k                   -- value of k. Default will go as 1
+    k,                  -- value of k. Default will go as 1
+    output_neighbors    -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
     );
 
 -----------------------------------------------------------------------
@@ -333,6 +371,7 @@ The output of the KNN module is a table with the following columns:
 id                  The ids of test data points.
 test_column_name    The test data points.
 prediction          The output of KNN- label in case of classification, average value in case of regression.
+k_nearest_neighbours The list of k-nearest neighbors that were used in the voting/averaging.
 ';
     END IF;
 END;
@@ -362,26 +401,28 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `READS SQL DATA', `');
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     point_source VARCHAR,
     point_column_name VARCHAR,
+    point_id VARCHAR,
     label_column_name VARCHAR,
     test_source VARCHAR,
     test_column_name VARCHAR,
-    id_column_name VARCHAR,
+    test_id VARCHAR,
     output_table VARCHAR,
-    operation VARCHAR,
-    k INTEGER
+    k INTEGER,
+    output_neighbors Boolean
 ) RETURNS VARCHAR AS $$
     PythonFunctionBodyOnly(`knn', `knn')
     return knn.knn(
         schema_madlib,
         point_source,
         point_column_name,
+        point_id,
         label_column_name,
         test_source,
         test_column_name,
-        id_column_name,
+        test_id,
         output_table,
-        operation,
-        k
+        k,
+        output_neighbors
     )
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -390,17 +431,18 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.knn(
     point_source VARCHAR,
     point_column_name VARCHAR,
+    point_id VARCHAR,
     label_column_name VARCHAR,
     test_source VARCHAR,
     test_column_name VARCHAR,
-    id_column_name VARCHAR,
+    test_id VARCHAR,
     output_table VARCHAR,
-    operation VARCHAR
+    output_neighbors Boolean
 ) RETURNS VARCHAR AS $$
 DECLARE
     returnstring VARCHAR;
 BEGIN
-    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1);
+    returnstring = MADLIB_SCHEMA.knn($1,$2,$3,$4,$5,$6,$7,$8,1,$9);
     RETURN returnstring;
 END;
 $$ LANGUAGE plpgsql VOLATILE

http://git-wip-us.apache.org/repos/asf/madlib/blob/0a7efca7/src/ports/postgres/modules/knn/test/knn.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/knn/test/knn.sql_in b/src/ports/postgres/modules/knn/test/knn.sql_in
index c7d6798..fa38751 100644
--- a/src/ports/postgres/modules/knn/test/knn.sql_in
+++ b/src/ports/postgres/modules/knn/test/knn.sql_in
@@ -26,12 +26,29 @@ m4_include(`SQLCommon.m4')
  * FIXME: Verify results
  * -------------------------------------------------------------------------- */
 
-drop table if exists "KNN_TRAIN_DATA";
-create table "KNN_TRAIN_DATA" (
-                id  integer,
-                "DATA"    integer[],
-                label   float);
-copy "KNN_TRAIN_DATA" (id, "DATA", label) from stdin delimiter '|';
+drop table if exists knn_train_data;
+create table knn_train_data (
+id  integer,
+data    integer[],
+label   integer);
+copy knn_train_data (id, data, label) from stdin delimiter '|';
+1|{1,1}|1
+2|{2,2}|1
+3|{3,3}|1
+4|{4,4}|1
+5|{4,5}|1
+6|{20,50}|0
+7|{10,31}|0
+8|{81,13}|0
+9|{1,111}|0
+\.
+DROP TABLE IF EXISTS knn_train_data_reg;
+CREATE TABLE knn_train_data_reg (
+                    id integer, 
+                    data integer[], 
+                    label float
+                    );
+COPY knn_train_data_reg (id, data, label) from stdin delimiter '|';
 1|{1,1}|1.0
 2|{2,2}|1.0
 3|{3,3}|1.0
@@ -42,11 +59,10 @@ copy "KNN_TRAIN_DATA" (id, "DATA", label) from stdin delimiter '|';
 8|{81,13}|0.0
 9|{1,111}|0.0
 \.
-drop table if exists knn_test_data;
 create table knn_test_data (
-                id  integer,
-                "DATA" integer[]);
-copy knn_test_data (id, "DATA") from stdin delimiter '|';
+id  integer,
+data integer[]);
+copy knn_test_data (id, data) from stdin delimiter '|';
 1|{2,1}
 2|{2,6}
 3|{15,40}
@@ -55,15 +71,23 @@ copy knn_test_data (id, "DATA") from stdin delimiter '|';
 6|{50,45}
 \.
 drop table if exists madlib_knn_result_classification;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_classification','c',3);
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,False);
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=3') from madlib_knn_result_classification;
 
+drop table if exists madlib_knn_result_classification;
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',3,True);
+select assert(array_agg(x)= '{1,2,3}','Wrong output in classification with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_classification where id = 1 order by x asc) y;
+
 drop table if exists madlib_knn_result_regression;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_regression','r',4);
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',4,False);
 select assert(array_agg(prediction order by id)='{1,1,0.5,1,0.25,0.25}', 'Wrong output in regression') from madlib_knn_result_regression;
 
+drop table if exists madlib_knn_result_regression;
+select knn('knn_train_data_reg','data','id','label','knn_test_data','data','id','madlib_knn_result_regression',3,True);
+select assert(array_agg(x)= '{1,2,3}' , 'Wrong output in regression with k=3') from (select unnest(k_nearest_neighbours) as x from madlib_knn_result_regression where id = 1 order by x asc) y;
+
 drop table if exists madlib_knn_result_classification;
-select knn('"KNN_TRAIN_DATA"','"DATA"','label','knn_test_data','"DATA"','id','madlib_knn_result_classification','c');
+select knn('knn_train_data','data','id','label','knn_test_data','data','id','madlib_knn_result_classification',False);
 select assert(array_agg(prediction order by id)='{1,1,0,1,0,0}', 'Wrong output in classification with k=1') from madlib_knn_result_classification;
 
 select knn();