You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2018/04/13 00:22:54 UTC

madlib git commit: MLP: Check for 1-hot encoding of dependent variable for minibatch

Repository: madlib
Updated Branches:
  refs/heads/master feeb8a53a -> 0f78d5a27


MLP: Check for 1-hot encoding of dependent variable for minibatch

This commit adds a check to make sure that the dependent variable for mlp
minibatch is one hot encoded. This only validates that the dependent
variable array has more than 1 value.

Closed #261

Co-authored-by: Orhan Kislal <ok...@pivotal.io>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0f78d5a2
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0f78d5a2
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0f78d5a2

Branch: refs/heads/master
Commit: 0f78d5a2751218c3ef96a4bd8441d0f7f081392f
Parents: feeb8a5
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Tue Apr 10 16:40:49 2018 -0700
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Thu Apr 12 17:20:03 2018 -0700

----------------------------------------------------------------------
 src/ports/postgres/modules/convex/mlp_igd.py_in |  3 ++
 .../utilities/minibatch_validation.py_in        | 29 ++++++++++++++++++++
 2 files changed, 32 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/0f78d5a2/src/ports/postgres/modules/convex/mlp_igd.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in b/src/ports/postgres/modules/convex/mlp_igd.py_in
index 2799355..6ff3d86 100644
--- a/src/ports/postgres/modules/convex/mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/mlp_igd.py_in
@@ -52,6 +52,7 @@ from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import is_var_valid
 from utilities.validate_args import output_tbl_valid
 from utilities.validate_args import table_exists
+from utilities.minibatch_validation import is_var_one_hot_encoded_for_minibatch
 
 def mlp(schema_madlib, source_table, output_table, independent_varname,
         dependent_varname, hidden_layer_sizes, optimizer_param_str, activation,
@@ -681,6 +682,8 @@ def _validate_dependent_var(source_table, dependent_varname,
         # strip out '[]' from expr_type
         _assert(is_psql_numeric_type(expr_type[:-2]),
                 "Dependent variable column should be of numeric type.")
+        if is_classification:
+            is_var_one_hot_encoded_for_minibatch(source_table,dependent_varname)
     else:
         if is_classification:
             _assert(("[]" in expr_type \

http://git-wip-us.apache.org/repos/asf/madlib/blob/0f78d5a2/src/ports/postgres/modules/utilities/minibatch_validation.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
new file mode 100644
index 0000000..16b11a9
--- /dev/null
+++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
@@ -0,0 +1,29 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+
+def is_var_one_hot_encoded_for_minibatch(table_name, var_name):
+    query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM
+              {table_name} LIMIT 1;""".format(**locals())
+    result = plpy.execute(query)
+    if not result[0]["is_encoded"]:
+        plpy.error("MiniBatch expects the variable {0} to be one hot encoded."
+                   " You might need to re run the minibatch_preprocessor function"
+                   " and make sure that the variable is encoded".format(var_name))