You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2018/04/13 00:22:54 UTC
madlib git commit: MLP: Check for 1-hot encoding of dependent
variable for minibatch
Repository: madlib
Updated Branches:
refs/heads/master feeb8a53a -> 0f78d5a27
MLP: Check for 1-hot encoding of dependent variable for minibatch
This commit adds a check to make sure that the dependent variable for mlp
minibatch is one hot encoded. This only validates that the dependent
variable array has more than 1 value.
Closed #261
Co-authored-by: Orhan Kislal <ok...@pivotal.io>
Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/0f78d5a2
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/0f78d5a2
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/0f78d5a2
Branch: refs/heads/master
Commit: 0f78d5a2751218c3ef96a4bd8441d0f7f081392f
Parents: feeb8a5
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Tue Apr 10 16:40:49 2018 -0700
Committer: Nandish Jayaram <nj...@apache.org>
Committed: Thu Apr 12 17:20:03 2018 -0700
----------------------------------------------------------------------
src/ports/postgres/modules/convex/mlp_igd.py_in | 3 ++
.../utilities/minibatch_validation.py_in | 29 ++++++++++++++++++++
2 files changed, 32 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/madlib/blob/0f78d5a2/src/ports/postgres/modules/convex/mlp_igd.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in b/src/ports/postgres/modules/convex/mlp_igd.py_in
index 2799355..6ff3d86 100644
--- a/src/ports/postgres/modules/convex/mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/mlp_igd.py_in
@@ -52,6 +52,7 @@ from utilities.validate_args import input_tbl_valid
from utilities.validate_args import is_var_valid
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import table_exists
+from utilities.minibatch_validation import is_var_one_hot_encoded_for_minibatch
def mlp(schema_madlib, source_table, output_table, independent_varname,
dependent_varname, hidden_layer_sizes, optimizer_param_str, activation,
@@ -681,6 +682,8 @@ def _validate_dependent_var(source_table, dependent_varname,
# strip out '[]' from expr_type
_assert(is_psql_numeric_type(expr_type[:-2]),
"Dependent variable column should be of numeric type.")
+ if is_classification:
+ is_var_one_hot_encoded_for_minibatch(source_table,dependent_varname)
else:
if is_classification:
_assert(("[]" in expr_type \
http://git-wip-us.apache.org/repos/asf/madlib/blob/0f78d5a2/src/ports/postgres/modules/utilities/minibatch_validation.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
new file mode 100644
index 0000000..16b11a9
--- /dev/null
+++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
@@ -0,0 +1,29 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+
+def is_var_one_hot_encoded_for_minibatch(table_name, var_name):
+ query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM
+ {table_name} LIMIT 1;""".format(**locals())
+ result = plpy.execute(query)
+ if not result[0]["is_encoded"]:
+ plpy.error("MiniBatch expects the variable {0} to be one hot encoded."
+ " You might need to re run the minibatch_preprocessor function"
+ " and make sure that the variable is encoded".format(var_name))