You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/04/09 23:32:29 UTC
[madlib] branch master updated: DL: Simplify madlib_keras_predict
interface
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 57b6e50 DL: Simplify madlib_keras_predict interface
57b6e50 is described below
commit 57b6e50a76ddb3e43ee81006956ef6e7660bec5b
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Tue Apr 9 14:00:40 2019 -0700
DL: Simplify madlib_keras_predict interface
JIRA: MADLIB-1316
This commit removes some unnecessary parameters from the predict
function for deep learning. These params are now inferred from the model
summary table instead.
Note that this PR does very basic input validation for predict function.
There is a JIRA (https://issues.apache.org/jira/browse/MADLIB-1321) to
do that and other related refactoring.
Closes #366
---
.../modules/deep_learning/madlib_keras.sql_in | 6 ---
.../deep_learning/madlib_keras_predict.py_in | 44 +++++++++++++---------
.../modules/deep_learning/test/madlib_keras.sql_in | 6 ---
3 files changed, 26 insertions(+), 30 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index e4a8534..9bc462f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -171,10 +171,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
model_table VARCHAR,
test_table VARCHAR,
id_col VARCHAR,
- model_arch_table VARCHAR,
- model_arch_id INTEGER,
independent_varname VARCHAR,
- compile_params VARCHAR,
output_table VARCHAR
) RETURNS VOID AS $$
PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
@@ -183,10 +180,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
model_table,
test_table,
id_col,
- model_arch_table,
- model_arch_id,
independent_varname,
- compile_params,
output_table)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index bf14d1e..85cd6c3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -37,24 +37,33 @@ from madlib_keras_wrapper import convert_string_of_args_to_dict
from madlib_keras_helper import CLASS_VALUES_COLNAME
from madlib_keras_helper import KerasWeightsSerializer
-def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
- model_arch_id, independent_varname, compile_params, output_table,
- **kwargs):
- module_name = 'madlib_keras_predict'
- input_tbl_valid(test_table, module_name)
- input_tbl_valid(model_arch_table, module_name)
- output_tbl_valid(output_table, module_name)
-
- # _validate_input_args(test_table, model_arch_table, output_table)
+MODULE_NAME = 'madlib_keras_predict'
+def predict(schema_madlib, model_table, test_table, id_col,
+ independent_varname, output_table, **kwargs):
+ input_tbl_valid(model_table, MODULE_NAME)
+ model_summary_table = add_postfix(model_table, '_summary')
+ input_tbl_valid(model_summary_table, MODULE_NAME)
+ input_tbl_valid(test_table, MODULE_NAME)
+ output_tbl_valid(output_table, MODULE_NAME)
+ model_summary_dict = plpy.execute("SELECT * FROM {0}".format(
+ model_summary_table))[0]
+ model_arch_table = model_summary_dict['model_arch_table']
+ model_arch_id = model_summary_dict['model_arch_id']
+ compile_params = model_summary_dict['compile_params']
+ input_tbl_valid(model_arch_table, MODULE_NAME)
model_data_query = "SELECT model_data from {0}".format(model_table)
model_data = plpy.execute(model_data_query)[0]['model_data']
- model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
- "WHERE id = {1}".format(model_arch_table, model_arch_id)
+ model_arch_query = """
+ SELECT model_arch, model_weights
+ FROM {0}
+ WHERE id = {1}
+ """.format(model_arch_table, model_arch_id)
query_result = plpy.execute(model_arch_query)
if not query_result or len(query_result) == 0:
- plpy.error("no model arch found in table {0} with id {1}".format(model_arch_table, model_arch_id))
+ plpy.error("{0}: No model arch found in table {1} with id {2}".format(
+ MODULE_NAME, model_arch_table, model_arch_id))
query_result = query_result[0]
model_arch = query_result['model_arch']
input_shape = get_input_shape(model_arch)
@@ -112,12 +121,11 @@ def _get_class_label(class_values, class_index):
if not class_values:
return class_index
elif class_index != int(class_index):
- plpy.error("Invalid class index {0} returned from Keras predict. "\
- "Index value must be an integer".format(
- class_index))
+ plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
+ " Index value must be an integer".format(MODULE_NAME, class_index))
elif class_index < 0 or class_index >= len(class_values):
- plpy.error("Invalid class index {0} returned from Keras predict. "\
- "Index value must be less than {1}".format(
- class_index, len(class_values)))
+ plpy.error("{0}: Invalid class index {1} returned from Keras predict."\
+ " Index value must be less than {2}".format(
+ MODULE_NAME, class_index, len(class_values)))
else:
return class_values[class_index]
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index a1743c9..ca6f9b6 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -172,10 +172,7 @@ SELECT madlib_keras_predict(
'keras_saved_out',
'cifar_10_sample',
'id',
- 'model_arch',
- 1,
'x',
- $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');
-- Validate that prediction output table exists and has correct schema
@@ -202,9 +199,6 @@ select assert(trap_error($TRAP$madlib_keras_predict(
'keras_saved_out',
'cifar_10_sample_batched',
'id',
- 'model_arch',
- 1,
'x',
- $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');$TRAP$) = 1,
'Passing batched image table to predict should error out.');