You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/04/29 19:33:09 UTC
[madlib] branch master updated (6cae627 -> 27ddd27)
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git.
from 6cae627 DL: Trap unsupported options for compile and fit params
new 3e2869d DL: Do not compile params in predict
new 920175f DL: Add GPU support for predict
new 27ddd27 DL: Handle NULL value for optional pred_type param in predict
The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
Summary of changes:
.../modules/deep_learning/madlib_keras.sql_in | 38 +++++++----
.../deep_learning/madlib_keras_predict.py_in | 26 ++++----
.../deep_learning/madlib_keras_validator.py_in | 1 -
.../deep_learning/madlib_keras_wrapper.py_in | 12 +++-
.../deep_learning/predict_input_params.py_in | 4 --
.../modules/deep_learning/test/madlib_keras.sql_in | 74 ++++++++++++++++++----
6 files changed, 115 insertions(+), 40 deletions(-)
[madlib] 02/03: DL: Add GPU support for predict
Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 920175f4825694628db27a66bcf2a384c2d53322
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Apr 25 12:13:59 2019 -0700
DL: Add GPU support for predict
JIRA: MADLIB-1330
Add an optional parameter named use_gpu in predict interface which is
set to TRUE by default. If use_gpu is true, then GPU support on predict
is enabled.
Closes #377
---
.../modules/deep_learning/madlib_keras.sql_in | 36 ++++++++---
.../deep_learning/madlib_keras_predict.py_in | 16 +++--
.../deep_learning/madlib_keras_wrapper.py_in | 3 +-
.../modules/deep_learning/test/madlib_keras.sql_in | 74 ++++++++++++++++++----
4 files changed, 102 insertions(+), 27 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 37b1068..5f53488 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -173,7 +173,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
id_col VARCHAR,
independent_varname VARCHAR,
output_table VARCHAR,
- pred_type VARCHAR
+ pred_type VARCHAR,
+ use_gpu BOOLEAN
) RETURNS VOID AS $$
PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
with AOControl(False):
@@ -183,7 +184,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
id_col,
independent_varname,
output_table,
- pred_type)
+ pred_type,
+ use_gpu)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
@@ -192,19 +194,33 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
test_table VARCHAR,
id_col VARCHAR,
independent_varname VARCHAR,
+ output_table VARCHAR,
+ pred_type VARCHAR
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, $6, TRUE);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
+ model_table VARCHAR,
+ test_table VARCHAR,
+ id_col VARCHAR,
+ independent_varname VARCHAR,
output_table VARCHAR
) RETURNS VOID AS $$
- SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, 'response');
+ SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, 'response', TRUE);
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
- independent_var double precision [],
+ independent_var DOUBLE PRECISION [],
model_architecture TEXT,
- model_data bytea,
- input_shape integer[],
- is_response BOOLEAN,
- normalizing_const DOUBLE PRECISION
+ model_data BYTEA,
+ input_shape INTEGER[],
+ is_response BOOLEAN,
+ normalizing_const DOUBLE PRECISION,
+ use_gpu BOOLEAN,
+ seg INTEGER
) RETURNS DOUBLE PRECISION[] AS $$
PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
with AOControl(False):
@@ -214,7 +230,9 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
model_data,
input_shape,
is_response,
- normalizing_const)
+ normalizing_const,
+ use_gpu,
+ seg)
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index e726f57..d47f53a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -30,11 +30,13 @@ import numpy as np
from madlib_keras_helper import expand_input_dims
from madlib_keras_helper import MODEL_DATA_COLNAME
from madlib_keras_validator import PredictInputValidator
+from madlib_keras_wrapper import get_device_name_and_set_cuda_env
from madlib_keras_wrapper import set_model_weights
from predict_input_params import PredictParamsProcessor
from utilities.model_arch_info import get_input_shape
from utilities.utilities import add_postfix
from utilities.utilities import create_cols_from_array_sql_string
+from utilities.utilities import is_platform_pg
from utilities.utilities import unique_string
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
@@ -79,7 +81,7 @@ def _strip_trailing_nulls_from_class_values(class_values):
return class_values
def predict(schema_madlib, model_table, test_table, id_col,
- independent_varname, output_table, pred_type, **kwargs):
+ independent_varname, output_table, pred_type, use_gpu, **kwargs):
input_validator = PredictInputValidator(
test_table, model_table, id_col, independent_varname,
output_table, pred_type, MODULE_NAME)
@@ -110,6 +112,9 @@ def predict(schema_madlib, model_table, test_table, id_col,
class_values, intermediate_col, pred_col_name,
pred_col_type, is_response, MODULE_NAME)
+ segment_id = -1 if is_platform_pg() else '{0}.gp_segment_id'.format(
+ test_table)
+
plpy.execute("""
CREATE TABLE {output_table} AS
SELECT {id_col}, {prediction_select_clause}
@@ -121,17 +126,18 @@ def predict(schema_madlib, model_table, test_table, id_col,
{0},
ARRAY{input_shape},
{is_response},
- {normalizing_const})
+ {normalizing_const},
+ {use_gpu},
+ {segment_id})
) AS {intermediate_col}
FROM {test_table}, {model_table}
) q
""".format(MODEL_DATA_COLNAME, **locals()))
def internal_keras_predict(x_test, model_arch, model_data, input_shape,
- is_response, normalizing_const):
+ is_response, normalizing_const, use_gpu, seg):
model = model_from_json(model_arch)
- device_name = '/cpu:0'
- os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+ device_name = get_device_name_and_set_cuda_env(use_gpu, seg)
model_shapes = madlib_keras_serializer.get_model_shapes(model)
set_model_weights(model, device_name, model_data, model_shapes)
# Since the test data isn't mini-batched,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index e0fd8f7..6149411 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -33,6 +33,7 @@ import keras.losses as losses
import madlib_keras_serializer
from utilities.utilities import _assert
+from utilities.utilities import is_platform_pg
#######################################################################
########### Keras specific functions #####
@@ -41,7 +42,7 @@ def get_device_name_and_set_cuda_env(use_gpu, seg):
gpus_per_host = 4
if use_gpu:
device_name = '/gpu:0'
- if seg == -1:
+ if is_platform_pg():
cuda_visible_dev = ','.join([i for i in range(gpus_per_host)])
else:
cuda_visible_dev = str(seg % gpus_per_host)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 3d1f8d7..81a088e 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -144,6 +144,42 @@ FROM (SELECT * FROM keras_saved_out_summary) summary;
SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') FROM (SELECT * FROM keras_saved_out) k;
+-- Fit with use_gpu set to TRUE must error out on machines
+-- that don't have GPUs. Since Jenkins builds are run on docker containers
+-- that don't have GPUs, these queries must error out.
+DROP TABLE IF EXISTS keras_saved_out_gpu, keras_saved_out_gpu_summary;
+SELECT assert(trap_error($TRAP$madlib_keras_fit(
+ 'cifar_10_sample_batched',
+ 'keras_saved_out_gpu',
+ 'dependent_var',
+ 'independent_var',
+ 'model_arch',
+ 1,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+ $$ batch_size=2, epochs=1, verbose=0 $$::text,
+ 3,
+ TRUE,
+ 'cifar_10_sample_val');$TRAP$) = 1,
+ 'Fit with use_gpu=True must error out.');
+
+-- Prediction with use_gpu set to TRUE must error out on machines
+-- that don't have GPUs. Since Jenkins builds are run on docker containers
+-- that don't have GPUs, these queries must error out.
+
+-- IMPRORTANT: The following test must be run when we have a valid
+-- keras_saved_out model table. Otherwise, it will fail because of a
+-- non-existent model table, while we want to trap failure due to
+-- use_gpu=TRUE
+DROP TABLE IF EXISTS cifar10_predict_gpu;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
+ 'keras_saved_out',
+ 'cifar_10_sample',
+ 'id',
+ 'x',
+ 'cifar10_predict_gpu',
+ NULL,
+ TRUE);$TRAP$) = 1,
+ 'Prediction with use_gpu=TRUE must error out.');
-- Test for
-- Non null name and description columns
@@ -203,7 +239,9 @@ SELECT madlib_keras_predict(
'cifar_10_sample',
'id',
'x',
- 'cifar10_predict');
+ 'cifar10_predict',
+ NULL,
+ FALSE);
-- Validate that prediction output table exists and has correct schema
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be INTEGER type')
@@ -223,12 +261,15 @@ SELECT assert(estimated_dependent_var IN (0,1),
'Predicted value not in set of defined class values for model')
FROM cifar10_predict;
-select assert(trap_error($TRAP$madlib_keras_predict(
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT assert(trap_error($TRAP$madlib_keras_predict(
'keras_saved_out',
'cifar_10_sample_batched',
'id',
'x',
- 'cifar10_predict');$TRAP$) = 1,
+ 'cifar10_predict',
+ NULL,
+ FALSE);$TRAP$) = 1,
'Passing batched image table to predict should error out.');
-- Compile and fit parameter tests
@@ -321,7 +362,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');
+ 'prob',
+ FALSE);
SELECT assert(UPPER(atttypid::regtype::TEXT) =
'DOUBLE PRECISION', 'column prob_0 should be double precision type')
@@ -405,7 +447,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');
+ 'prob',
+ FALSE);
-- Validate the output datatype of newly created prediction columns
-- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
@@ -438,7 +481,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'response');
+ 'response',
+ FALSE);
-- Validate the output datatype of newly created prediction columns
-- for prediction type = 'response' and class_values 'TEXT' with NULL
@@ -461,7 +505,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');
+ 'prob',
+ FALSE);
-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_value = NULL
@@ -479,7 +524,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'response');
+ 'response',
+ FALSE);
-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_value = NULL
@@ -535,7 +581,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');
+ 'prob',
+ FALSE);
-- Validate the output datatype of newly created prediction column
-- for prediction type = 'prob' and class_values 'INT' with NULL
@@ -558,7 +605,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'response');
+ 'response',
+ FALSE);
-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_values 'TEXT' with NULL
@@ -625,7 +673,8 @@ SELECT madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');
+ 'prob',
+ FALSE);
-- Prediction with incorrectly shaped data must error out.
DROP TABLE IF EXISTS cifar10_predict;
@@ -635,5 +684,6 @@ SELECT assert(trap_error($TRAP$madlib_keras_predict(
'id',
'x',
'cifar10_predict',
- 'prob');$TRAP$) = 1,
+ 'prob',
+ FALSE);$TRAP$) = 1,
'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should have failed.');
[madlib] 03/03: DL: Handle NULL value for optional pred_type param
in predict
Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 27ddd279dc6f3ac4e59a3dc205716f177b5479a5
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Apr 25 14:23:08 2019 -0700
DL: Handle NULL value for optional pred_type param in predict
The pred_type param in predict is an optional param, so one can use NULL
for the same. This commit sets NULL to 'response' as default value,
instead of erroring out.
---
src/ports/postgres/modules/deep_learning/madlib_keras.sql_in | 2 +-
src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 5f53488..543bbed 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -208,7 +208,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
independent_varname VARCHAR,
output_table VARCHAR
) RETURNS VOID AS $$
- SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, 'response', TRUE);
+ SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, NULL, TRUE);
$$ LANGUAGE sql VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index d47f53a..4e2a206 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -82,6 +82,8 @@ def _strip_trailing_nulls_from_class_values(class_values):
def predict(schema_madlib, model_table, test_table, id_col,
independent_varname, output_table, pred_type, use_gpu, **kwargs):
+ if not pred_type:
+ pred_type = 'response'
input_validator = PredictInputValidator(
test_table, model_table, id_col, independent_varname,
output_table, pred_type, MODULE_NAME)
[madlib] 01/03: DL: Do not compile params in predict
Posted by nj...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 3e2869db20ac24d4d1353e1a4a9d9d5d756e4682
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Thu Apr 25 11:24:36 2019 -0700
DL: Do not compile params in predict
JIRA: MADLIB-1330
Do not compile params in predict, but instead directly get the model
weights and architecture and use it for prediction. Compiling params
during predict is not necessary from Keras 1.0.3.
Closes #377
---
src/ports/postgres/modules/deep_learning/madlib_keras.sql_in | 2 --
.../postgres/modules/deep_learning/madlib_keras_predict.py_in | 10 +++-------
.../modules/deep_learning/madlib_keras_validator.py_in | 1 -
.../postgres/modules/deep_learning/madlib_keras_wrapper.py_in | 9 +++++++++
.../postgres/modules/deep_learning/predict_input_params.py_in | 4 ----
5 files changed, 12 insertions(+), 14 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 8e13933..37b1068 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -203,7 +203,6 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
model_architecture TEXT,
model_data bytea,
input_shape integer[],
- compile_params TEXT,
is_response BOOLEAN,
normalizing_const DOUBLE PRECISION
) RETURNS DOUBLE PRECISION[] AS $$
@@ -214,7 +213,6 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
model_architecture,
model_data,
input_shape,
- compile_params,
is_response,
normalizing_const)
$$ LANGUAGE plpythonu VOLATILE
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 739f042..e726f57 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -30,7 +30,7 @@ import numpy as np
from madlib_keras_helper import expand_input_dims
from madlib_keras_helper import MODEL_DATA_COLNAME
from madlib_keras_validator import PredictInputValidator
-from madlib_keras_wrapper import compile_and_set_weights
+from madlib_keras_wrapper import set_model_weights
from predict_input_params import PredictParamsProcessor
from utilities.model_arch_info import get_input_shape
from utilities.utilities import add_postfix
@@ -87,7 +87,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
class_values = param_proc.get_class_values()
input_validator.validate_pred_type(class_values)
- compile_params = param_proc.get_compile_params()
dependent_varname = param_proc.get_dependent_varname()
dependent_vartype = param_proc.get_dependent_vartype()
model_data = param_proc.get_model_data()
@@ -95,7 +94,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
normalizing_const = param_proc.get_normalizing_const()
input_shape = get_input_shape(model_arch)
input_validator.validate_input_shape(input_shape)
- compile_params = "$madlib$" + compile_params + "$madlib$"
is_response = True if pred_type == 'response' else False
intermediate_col = unique_string()
@@ -122,7 +120,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
$MAD${model_arch}$MAD$,
{0},
ARRAY{input_shape},
- {compile_params},
{is_response},
{normalizing_const})
) AS {intermediate_col}
@@ -131,13 +128,12 @@ def predict(schema_madlib, model_table, test_table, id_col,
""".format(MODEL_DATA_COLNAME, **locals()))
def internal_keras_predict(x_test, model_arch, model_data, input_shape,
- compile_params, is_response, normalizing_const):
+ is_response, normalizing_const):
model = model_from_json(model_arch)
device_name = '/cpu:0'
os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
model_shapes = madlib_keras_serializer.get_model_shapes(model)
- compile_and_set_weights(model, compile_params, device_name,
- model_data, model_shapes)
+ set_model_weights(model, device_name, model_data, model_shapes)
# Since the test data isn't mini-batched,
# we have to make sure that the test data np array has the same
# number of dimensions as input_shape. So we add a dimension to x.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index cbe8f3c..ee667d0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -130,7 +130,6 @@ class PredictInputValidator:
def _validate_summary_tbl_cols(self):
cols_to_check_for = [CLASS_VALUES_COLNAME,
- COMPILE_PARAMS_COLNAME,
DEPENDENT_VARNAME_COLNAME,
DEPENDENT_VARTYPE_COLNAME,
MODEL_ARCH_ID_COLNAME,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 71c257f..e0fd8f7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -73,6 +73,15 @@ def compile_and_set_weights(segment_model, compile_params, device_name,
previous_state, model_shapes)
segment_model.set_weights(model_weights)
+# TODO: This can be refactored to be part of compile_and_set_weights(),
+# by making compile_params an optional param in that function. Doing that
+# now might create more merge conflicts with other JIRAs, so get to this later.
+def set_model_weights(segment_model, device_name, state, model_shapes):
+ with K.tf.device(device_name):
+ _, _, _, model_weights = madlib_keras_serializer.deserialize_weights(
+ state, model_shapes)
+ segment_model.set_weights(model_weights)
+
"""
Used to convert compile_params and fit_params to actual argument dictionaries
"""
diff --git a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
index 69ee961..aba6dce 100644
--- a/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
+++ b/src/ports/postgres/modules/deep_learning/predict_input_params.py_in
@@ -23,7 +23,6 @@ from utilities.utilities import add_postfix
from utilities.validate_args import input_tbl_valid
from madlib_keras_helper import CLASS_VALUES_COLNAME
-from madlib_keras_helper import COMPILE_PARAMS_COLNAME
from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
@@ -63,9 +62,6 @@ class PredictParamsProcessor:
def get_class_values(self):
return self.model_summary_dict[CLASS_VALUES_COLNAME]
- def get_compile_params(self):
- return self.model_summary_dict[COMPILE_PARAMS_COLNAME]
-
def get_dependent_varname(self):
return self.model_summary_dict[DEPENDENT_VARNAME_COLNAME]