You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2015/12/16 01:13:43 UTC

incubator-madlib git commit: SVM: Fix how grouping cols are validated

Repository: incubator-madlib
Updated Branches:
  refs/heads/master afc9c2483 -> 30e92868e


SVM: Fix how grouping cols are validated

SVM prediction expects grouping_col to be None if no grouping is
performed. This assumption was not applied in the training function.


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/30e92868
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/30e92868
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/30e92868

Branch: refs/heads/master
Commit: 30e92868e5ad0e8fff2de0c7b89b8d6307107108
Parents: afc9c24
Author: Xiaocheng Tang <xt...@pivotal.io>
Authored: Tue Dec 15 16:03:48 2015 -0800
Committer: Rahul Iyer <ri...@pivotal.io>
Committed: Tue Dec 15 16:13:12 2015 -0800

----------------------------------------------------------------------
 src/ports/postgres/modules/svm/svm.py_in        | 19 ++---
 .../postgres/modules/svm/test/linear_svm.sql_in | 78 ++++++++++----------
 2 files changed, 49 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/30e92868/src/ports/postgres/modules/svm/svm.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/svm/svm.py_in b/src/ports/postgres/modules/svm/svm.py_in
index 3cdfb5f..25a1d92 100644
--- a/src/ports/postgres/modules/svm/svm.py_in
+++ b/src/ports/postgres/modules/svm/svm.py_in
@@ -92,7 +92,7 @@ def _verify_table(source_table, model_table, dependent_varname,
 
 
 def _verify_grouping(schema_madlib, source_table, grouping_col):
-    if grouping_col:
+    if grouping_col and grouping_col.lower() != 'null':
         cols_in_tbl_valid(source_table,
                           _string_to_array_with_quotes(grouping_col),
                           'SVM')
@@ -330,8 +330,8 @@ def _svm_parsed_params(schema_madlib, source_table, model_table,
     """
     Executes the linear support vector algorithm.
     """
-    grouping_str, grouping_col = _verify_grouping(schema_madlib, 
-                                                  source_table, 
+    grouping_str, grouping_col = _verify_grouping(schema_madlib,
+                                                  source_table,
                                                   grouping_col)
 
     kernel_func = _verify_kernel(kernel_func)
@@ -413,9 +413,9 @@ def svm_predict(schema_madlib, model_table, new_data_table, id_col_name,
         grouping_col = summary['grouping_col']
 
         input_tbl_valid(new_data_table, 'SVM')
-        _assert(is_var_valid(new_data_table, dependent_varname),
-                "SVM Error: dependent_varname ('" + dependent_varname +
-                "') is invalid for new_data_table (" + new_data_table + ")!")
+        grouping_str, grouping_col = _verify_grouping(schema_madlib,
+                                                      new_data_table,
+                                                      grouping_col)
         _assert(is_var_valid(new_data_table, independent_varname),
                 "SVM Error: independent_varname ('" + independent_varname +
                 "') is invalid for new_data_table (" + new_data_table + ")!")
@@ -446,16 +446,17 @@ def svm_predict(schema_madlib, model_table, new_data_table, id_col_name,
             plpy.error("SVM Error: Invalid 'method' value in summary table. "
                        "'method' can only be SVC or SVR!")
 
-        if grouping_col != "NULL":
+        if grouping_col:
             sql = """
             CREATE TABLE {output_table} AS
             SELECT
                 {id_col_name} AS {id_col_name},
                 {pred_query} AS prediction,
-                {model_table}.{grouping_col} as grouping_col
+                ARRAY[{grouping_str}] as grouping_col,
+                {grouping_col}
             FROM {model_table}
             JOIN {new_data_table}
-            ON {model_table}.{grouping_col} = {new_data_table}.{grouping_col}
+            USING ({grouping_col})
             WHERE not {schema_madlib}.array_contains_null({independent_varname})
             ORDER BY grouping_col, {id_col_name}
             """.format(**locals())

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/30e92868/src/ports/postgres/modules/svm/test/linear_svm.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/svm/test/linear_svm.sql_in b/src/ports/postgres/modules/svm/test/linear_svm.sql_in
index bba6567..e94a90a 100644
--- a/src/ports/postgres/modules/svm/test/linear_svm.sql_in
+++ b/src/ports/postgres/modules/svm/test/linear_svm.sql_in
@@ -110,19 +110,19 @@ SELECT svm_regression(
      NULL,
      'init_stepsize=0.01, max_iter=50, lambda=2, norm=l2, epsilon=0.01',
      false);
-DROP TABLE IF EXISTS svr_test_result; 
+DROP TABLE IF EXISTS svr_test_result;
 SELECT svm_predict('svr_model', 'svr_train_data', 'id', 'svr_test_result');
 \x on
 SELECT * FROM svr_model;
 \x off
-SELECT 
+SELECT
     assert(
            avg(subq.err) < 0.1,
            'prediction error is too large!')
-FROM 
+FROM
     (
-        SELECT 
-            train.id, 
+        SELECT
+            train.id,
             abs(train.label - test.prediction) AS err
         FROM svr_train_data AS train, svr_test_result AS test
         WHERE train.id = test.id
@@ -134,13 +134,13 @@ SELECT svm_regression(
      'svr_train_data',
      'svr_model',
      'label',
-     'ind', 
-     NULL, 
-     NULL, 
-     NULL, 
+     'ind',
+     NULL,
+     NULL,
+     NULL,
      'init_stepsize=1, max_iter=10, lambda=2');
-SELECT 
-    assert(epsilon > 0,'default epsilon is positive!') 
+SELECT
+    assert(epsilon > 0,'default epsilon is positive!')
 FROM svr_model_summary;
 
 -- Example usage for LINEAR classification, replace the above by
@@ -389,7 +389,7 @@ SELECT assert(count(*)=4, '4 group exist') FROM svm_model_fancy_label;
 
 DROP TABLE IF EXISTS svm_test_predict CASCADE;
 SELECT svm_predict('svm_model_fancy_label', 'svm_test_normalized_fancy_label', 'id', 'svm_test_predict');
-SELECT o.id, label, prediction, gid FROM svm_test_predict p, svm_test_normalized_fancy_label o where o.id = p.id;
+SELECT o.id, label, prediction, o.gid FROM svm_test_predict p, svm_test_normalized_fancy_label o where o.id = p.id;
 
 -- calculating accuracy
 -- the accuracy is not guaranteed to be high because the stepsize & decay_factor
@@ -554,78 +554,78 @@ INSERT INTO abalone_eps(sex, epsilon) VALUES
 
 -- solve it with grouping and table of epsilon as inputs
 
-DROP TABLE IF EXISTS svr_mdl, svr_mdl_summary; 
+DROP TABLE IF EXISTS svr_mdl, svr_mdl_summary;
 SELECT madlib.svm_regression(
         'abalone_train_small',
-        'svr_mdl', 
+        'svr_mdl',
         'rings',
         'ARRAY[1,diameter,shell,shucked,length]',
-        NULL,NULL,'sex', 
-        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 50, eps_table=abalone_eps', 
-        false); 
+        NULL,NULL,'sex',
+        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 50, eps_table=abalone_eps',
+        false);
 SELECT * FROM svr_mdl;
 
 
-DROP TABLE IF EXISTS svr_mdl_i, svr_mdl_i_summary; 
+DROP TABLE IF EXISTS svr_mdl_i, svr_mdl_i_summary;
 SELECT madlib.svm_regression(
         'abalone_train_small',
-        'svr_mdl_i', 
+        'svr_mdl_i',
         'rings',
         'ARRAY[1,diameter,shell,shucked,length]',
-        NULL,NULL,'sex', 
-        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.2', 
-        false); 
+        NULL,NULL,'sex',
+        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.2',
+        false);
 SELECT * FROM svr_mdl_i where sex = 'I';
 
-DROP TABLE IF EXISTS svr_mdl_m, svr_mdl_m_summary; 
+DROP TABLE IF EXISTS svr_mdl_m, svr_mdl_m_summary;
 SELECT madlib.svm_regression(
         'abalone_train_small',
-        'svr_mdl_m', 
+        'svr_mdl_m',
         'rings',
         'ARRAY[1,diameter,shell,shucked,length]',
-        NULL,NULL,'sex', 
-        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05', 
-        false); 
+        NULL,NULL,'sex',
+        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 0.05',
+        false);
 SELECT * FROM svr_mdl_m where sex = 'M';
 
-DROP TABLE IF EXISTS svr_mdl_f, svr_mdl_f_summary; 
+DROP TABLE IF EXISTS svr_mdl_f, svr_mdl_f_summary;
 SELECT madlib.svm_regression(
         'abalone_train_small',
-        'svr_mdl_f', 
+        'svr_mdl_f',
         'rings',
         'ARRAY[1,diameter,shell,shucked,length]',
-        NULL,NULL,'sex', 
-        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 50', 
-        false); 
+        NULL,NULL,'sex',
+        'max_iter=50, init_stepsize=1, decay_factor=0.9, tolerance=1e-16, epsilon = 50',
+        false);
 SELECT * FROM svr_mdl_f where sex = 'F';
 
--- verify that the results are the same 
+-- verify that the results are the same
 
 SELECT assert(
-    abs_err < 1e-5, 
+    abs_err < 1e-5,
     'SVR with epsilon table input: Wrong results!')
 FROM (
-    SELECT 
+    SELECT
         abs(t1.norm_of_gradient - t2.norm_of_gradient) AS abs_err
     FROM svr_mdl_f AS t1 JOIN svr_mdl AS t2 USING (sex)
     where sex = 'F'
 ) AS q1;
 
 SELECT assert(
-    rel_err < 1e-1, 
+    rel_err < 1e-1,
     'SVR with epsilon table input: Wrong results!')
 FROM (
-    SELECT 
+    SELECT
         relative_error(t1.norm_of_gradient,  t2.norm_of_gradient) AS rel_err
     FROM svr_mdl_i AS t1 JOIN svr_mdl AS t2 USING (sex)
     where sex = 'I'
 ) AS q1;
 
 SELECT assert(
-    rel_err < 1e-1, 
+    rel_err < 1e-1,
     'SVR with epsilon table input: Wrong results!')
 FROM (
-    SELECT 
+    SELECT
         relative_error(t1.norm_of_gradient,  t2.norm_of_gradient) AS rel_err
     FROM svr_mdl_m AS t1 JOIN svr_mdl AS t2 USING (sex)
     where sex = 'M'