You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nj...@apache.org on 2019/04/24 22:10:58 UTC
[madlib] branch master updated: DL: Validate predict input
parameters
This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 4e6f337 DL: Validate predict input parameters
4e6f337 is described below
commit 4e6f337cfec64875f0759c2d220a84215276673d
Author: Nandish Jayaram <nj...@apache.org>
AuthorDate: Mon Apr 22 15:20:07 2019 -0700
DL: Validate predict input parameters
JIRA: MADLIB-1321
Add validation code for predict input parameters. This commit also
refactors the code a bit and updates relevant unit tests.
Closes #374
---
.../modules/deep_learning/madlib_keras.py_in | 4 +-
.../deep_learning/madlib_keras_predict.py_in | 23 +--
.../deep_learning/madlib_keras_validator.py_in | 173 ++++++++++++++++-----
.../test/unit_tests/test_madlib_keras.py_in | 78 ++++++----
4 files changed, 189 insertions(+), 89 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 8d4e384..542f489 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -79,9 +79,7 @@ def fit(schema_madlib, source_table, model, dependent_varname,
model_arch = query_result[Format.MODEL_ARCH]
input_shape = get_input_shape(model_arch)
num_classes = get_num_classes(model_arch)
- fit_validator.validate_input_shapes(source_table, input_shape, 2)
- if validation_table:
- fit_validator.validate_input_shapes(validation_table, input_shape, 1)
+ fit_validator.validate_input_shapes(input_shape)
model_weights_serialized = query_result[Format.MODEL_WEIGHTS]
# Convert model from json and initialize weights
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 95ae2cf..739f042 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -29,6 +29,7 @@ import numpy as np
from madlib_keras_helper import expand_input_dims
from madlib_keras_helper import MODEL_DATA_COLNAME
+from madlib_keras_validator import PredictInputValidator
from madlib_keras_wrapper import compile_and_set_weights
from predict_input_params import PredictParamsProcessor
from utilities.model_arch_info import get_input_shape
@@ -42,15 +43,6 @@ import madlib_keras_serializer
MODULE_NAME = 'madlib_keras_predict'
-def validate_pred_type(pred_type, class_values):
- if not pred_type in ['prob', 'response']:
- plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
- "either response or prob.".format(MODULE_NAME, pred_type))
- if pred_type == 'prob' and class_values and len(class_values)+1 >= 1600:
- plpy.error({"{0}: The output will have {1} columns, exceeding the "\
- " max number of columns that can be created (1600)".format(
- MODULE_NAME, len(class_values)+1)})
-
def _strip_trailing_nulls_from_class_values(class_values):
"""
class_values is a list of unique class levels in training data. This
@@ -88,24 +80,23 @@ def _strip_trailing_nulls_from_class_values(class_values):
def predict(schema_madlib, model_table, test_table, id_col,
independent_varname, output_table, pred_type, **kwargs):
- # Refactor and add more validation as part of MADLIB-1312.
- input_tbl_valid(model_table, MODULE_NAME)
- input_tbl_valid(test_table, MODULE_NAME)
- output_tbl_valid(output_table, MODULE_NAME)
- param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
+ input_validator = PredictInputValidator(
+ test_table, model_table, id_col, independent_varname,
+ output_table, pred_type, MODULE_NAME)
+ param_proc = PredictParamsProcessor(model_table, MODULE_NAME)
class_values = param_proc.get_class_values()
+ input_validator.validate_pred_type(class_values)
compile_params = param_proc.get_compile_params()
dependent_varname = param_proc.get_dependent_varname()
dependent_vartype = param_proc.get_dependent_vartype()
model_data = param_proc.get_model_data()
model_arch = param_proc.get_model_arch()
normalizing_const = param_proc.get_normalizing_const()
- # TODO: Validate input shape as part of MADLIB-1312
input_shape = get_input_shape(model_arch)
+ input_validator.validate_input_shape(input_shape)
compile_params = "$madlib$" + compile_params + "$madlib$"
- validate_pred_type(pred_type, class_values)
is_response = True if pred_type == 'response' else False
intermediate_col = unique_string()
if is_response:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 606461d..cbe8f3c 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -17,7 +17,16 @@
# specific language governing permissions and limitations
# under the License.
+import plpy
from madlib_keras_helper import CLASS_VALUES_COLNAME
+from madlib_keras_helper import COMPILE_PARAMS_COLNAME
+from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
+from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
+from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
+from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
+from madlib_keras_helper import MODEL_DATA_COLNAME
+from madlib_keras_helper import NORMALIZING_CONST_COLNAME
+
from utilities.minibatch_validation import validate_dependent_var_for_minibatch
from utilities.utilities import _assert
from utilities.utilities import add_postfix
@@ -25,10 +34,126 @@ from utilities.utilities import is_var_valid
from utilities.utilities import is_valid_psql_type
from utilities.utilities import NUMERIC
from utilities.utilities import ONLY_ARRAY
+from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid
from utilities.validate_args import output_tbl_valid
-import plpy
+
+
+def _validate_input_shapes(table, independent_varname, input_shape, offset):
+ """
+ Validate if the input shape specified in model architecture is the same
+ as the shape of the image specified in the indepedent var of the input
+ table.
+ offset: This offset is the index of the start of the image array. We also
+ need to consider that sql array indexes start from 1
+ For ex if the image is of shape [32,32,3] and is minibatched, the image will
+ look like [10, 32, 32, 3]. The offset in this case is 1 (start the index at 1) +
+ 1 (ignore the buffer size 10) = 2.
+ If the image is not batched then it will look like [32, 32 ,3] and the offset in
+ this case is 1 (start the index at 1).
+ """
+ array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+ independent_varname, i+offset, i) for i in range(len(input_shape)))
+ query = """
+ SELECT {0}
+ FROM {1}
+ LIMIT 1
+ """.format(array_upper_query, table)
+ # This query will fail if an image in independent var does not have the
+ # same number of dimensions as the input_shape.
+ result = plpy.execute(query)[0]
+ _assert(len(result) == len(input_shape),
+ "model_keras error: The number of dimensions ({0}) of each image"
+ " in model architecture and {1} in {2} ({3}) do not match.".format(
+ len(input_shape), independent_varname, table, len(result)))
+ for i in range(len(input_shape)):
+ key_name = "n_{0}".format(i)
+ if result[key_name] != input_shape[i]:
+ # Construct the shape in independent varname to display
+ # meaningful error msg.
+ input_shape_from_table = [result["n_{0}".format(i)]
+ for i in range(len(input_shape))]
+ plpy.error("model_keras error: Input shape {0} in the model"
+ " architecture does not match the input shape {1} of column"
+ " {2} in table {3}.".format(
+ input_shape, input_shape_from_table,
+ independent_varname, table))
+
+class PredictInputValidator:
+ def __init__(self, test_table, model_table, id_col, independent_varname,
+ output_table, pred_type, module_name):
+ self.test_table = test_table
+ self.model_table = model_table
+ self.id_col = id_col
+ self.independent_varname = independent_varname
+ self.output_table = output_table
+ self.pred_type = pred_type
+ if self.model_table:
+ self.model_summary_table = add_postfix(
+ self.model_table, "_summary")
+ self.module_name = module_name
+ self._validate_input_args()
+
+ def _validate_input_args(self):
+ input_tbl_valid(self.model_table, self.module_name)
+ self._validate_model_data_col()
+ input_tbl_valid(self.model_summary_table, self.module_name)
+ self._validate_summary_tbl_cols()
+ input_tbl_valid(self.test_table, self.module_name)
+ self._validate_test_tbl_cols()
+ output_tbl_valid(self.output_table, self.module_name)
+
+ def _validate_model_data_col(self):
+ _assert(is_var_valid(self.model_table, MODEL_DATA_COLNAME),
+ "{module_name} error: invalid model_data "
+ "('{model_data}') in model table ({table}).".format(
+ module_name=self.module_name,
+ model_data=MODEL_DATA_COLNAME,
+ table=self.model_table))
+
+ def _validate_test_tbl_cols(self):
+ _assert(is_var_valid(self.test_table, self.independent_varname),
+ "{module_name} error: invalid independent_varname "
+ "('{independent_varname}') for test table "
+ "({table}).".format(
+ module_name=self.module_name,
+ independent_varname=self.independent_varname,
+ table=self.test_table))
+
+ _assert(is_var_valid(self.test_table, self.id_col),
+ "{module_name} error: invalid id column "
+ "('{id_col}') for test table ({table}).".format(
+ module_name=self.module_name,
+ id_col=self.id_col,
+ table=self.test_table))
+
+ def _validate_summary_tbl_cols(self):
+ cols_to_check_for = [CLASS_VALUES_COLNAME,
+ COMPILE_PARAMS_COLNAME,
+ DEPENDENT_VARNAME_COLNAME,
+ DEPENDENT_VARTYPE_COLNAME,
+ MODEL_ARCH_ID_COLNAME,
+ MODEL_ARCH_TABLE_COLNAME,
+ NORMALIZING_CONST_COLNAME]
+ _assert(columns_exist_in_table(
+ self.model_summary_table, cols_to_check_for, cols_to_check_for),
+ "{0} error: One or more expected columns missing in model "
+ "summary table ('{1}'). The expected columns are {2}.".format(
+ self.module_name, self.model_summary_table, cols_to_check_for))
+
+ def validate_pred_type(self, class_values):
+ if not self.pred_type in ['prob', 'response']:
+ plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
+ "either response or prob.".format(self.module_name, self.pred_type))
+ if self.pred_type == 'prob' and class_values and len(class_values)+1 >= 1600:
+ plpy.error({"{0}: The output will have {1} columns, exceeding the "\
+ " max number of columns that can be created (1600)".format(
+ self.module_name, len(class_values)+1)})
+
+ def validate_input_shape(self, input_shape_from_arch):
+ _validate_input_shapes(self.test_table, self.independent_varname,
+ input_shape_from_arch, 1)
class FitInputValidator:
def __init__(self, source_table, validation_table, output_model_table,
@@ -105,42 +230,10 @@ class FitInputValidator:
self.dependent_varname, self.validation_table))
- def validate_input_shapes(self, table, input_shape, offset):
- """
- Validate if the input shape specified in model architecture is the same
- as the shape of the image specified in the indepedent var of the input
- table.
- offset: This offset is the index of the start of the image array. We also
- need to consider that sql array indexes start from 1
- For ex if the image is of shape [32,32,3] and is minibatched, the image will
- look like [10, 32, 32, 3]. The offset in this case is 1 (start the index at 1) +
- 1 (ignore the buffer size 10) = 2.
- If the image is not batched then it will look like [32, 32 ,3] and the offset in
- this case is 1 (start the index at 1).
- """
- array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
- self.independent_varname, i+offset, i) for i in range(len(input_shape)))
- query = """
- SELECT {0}
- FROM {1}
- LIMIT 1
- """.format(array_upper_query, table)
- # This query will fail if an image in independent var does not have the
- # same number of dimensions as the input_shape.
- result = plpy.execute(query)[0]
- _assert(len(result) == len(input_shape),
- "model_keras error: The number of dimensions ({0}) of each image" \
- " in model architecture and {1} in {2} ({3}) do not match.".format(
- len(input_shape), self.independent_varname, table, len(result)))
- for i in range(len(input_shape)):
- key_name = "n_{0}".format(i)
- if result[key_name] != input_shape[i]:
- # Construct the shape in independent varname to display
- # meaningful error msg.
- input_shape_from_table = [result["n_{0}".format(i)]
- for i in range(len(input_shape))]
- plpy.error("model_keras error: Input shape {0} in the model" \
- " architecture does not match the input shape {1} of column" \
- " {2} in table {3}.".format(
- input_shape, input_shape_from_table,
- self.independent_varname, table))
+ def validate_input_shapes(self, input_shape):
+ _validate_input_shapes(self.source_table, self.independent_varname,
+ input_shape, 2)
+ if self.validation_table:
+ _validate_input_shapes(
+ self.validation_table, self.independent_varname,
+ input_shape, 1)
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index e6b09e4..84e4ce7 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -211,8 +211,8 @@ class MadlibKerasValidatorTestCase(unittest.TestCase):
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
- from madlib_keras_validator import FitInputValidator
- self.subject = FitInputValidator
+ import madlib_keras_validator
+ self.subject = madlib_keras_validator
self.model = Sequential()
self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
@@ -233,35 +233,25 @@ class MadlibKerasValidatorTestCase(unittest.TestCase):
def test_validate_input_shapes_shapes_do_not_match(self):
self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
self.subject._validate_input_args = Mock()
- input_validator_obj = self.subject('foo',
- 'foo_valid',
- 'model',
- 'model_arch_table',
- 'dependent_varname',
- 'independent_varname',
- 1)
with self.assertRaises(plpy.PLPYException):
- input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3], 2)
+ self.subject._validate_input_shapes(
+ 'dummy_tbl', 'dummy_col', [32,32,3], 2)
self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 32}]
with self.assertRaises(plpy.PLPYException):
- input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3], 2)
+ self.subject._validate_input_shapes(
+ 'dummy_tbl', 'dummy_col', [32,32,3], 2)
self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': None}]
with self.assertRaises(plpy.PLPYException):
- input_validator_obj.validate_input_shapes('dummy_tbl', [3,32], 2)
+ self.subject._validate_input_shapes(
+ 'dummy_tbl', 'dummy_col', [3,32], 2)
def test_validate_input_shapes_shapes_match(self):
self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 3}]
self.subject._validate_input_args = Mock()
- input_validator_obj = self.subject('foo',
- 'foo_valid',
- 'model',
- 'model_arch_table',
- 'dependent_varname',
- 'independent_varname',
- 1)
- input_validator_obj.validate_input_shapes('dummy_tbl', [32,32,3], 1)
+ self.subject._validate_input_shapes(
+ 'dummy_tbl', 'dummy_col', [32,32,3], 1)
class MadlibSerializerTestCase(unittest.TestCase):
def setUp(self):
@@ -369,7 +359,7 @@ class MadlibSerializerTestCase(unittest.TestCase):
self.assertEqual(np.array([0,1,2,1,3,4,5], dtype=np.float32).tostring(),
res)
-class MadlibKerasPredictTestCase(unittest.TestCase):
+class PredictInputPredTypeValidationTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
@@ -381,26 +371,54 @@ class MadlibKerasPredictTestCase(unittest.TestCase):
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
- import madlib_keras_predict
- self.subject = madlib_keras_predict
+ import madlib_keras_validator
+ self.module = madlib_keras_validator
+ self.module.PredictInputValidator._validate_input_args = Mock()
+ self.subject = self.module.PredictInputValidator(
+ 'test_table', 'model_table', 'id_col', 'independent_varname',
+ 'output_table', 'pred_type', 'module_name')
self.classes = ['train', 'boat', 'car', 'airplane']
def tearDown(self):
self.module_patcher.stop()
def test_validate_pred_type_invalid_pred_type(self):
+ self.subject.pred_type = 'invalid'
with self.assertRaises(plpy.PLPYException):
- self.subject.validate_pred_type('invalid', ['cat', 'dog'])
+ self.subject.validate_pred_type(['cat', 'dog'])
def test_validate_pred_type_valid_pred_type_invalid_num_class_values(self):
+ self.subject.pred_type = 'prob'
with self.assertRaises(plpy.PLPYException):
- self.subject.validate_pred_type('prob', range(1599))
+ self.subject.validate_pred_type(range(1599))
+
+ def test_validate_pred_type_valid_pred_type_valid_class_values_prob(self):
+ self.subject.pred_type = 'prob'
+ self.subject.validate_pred_type(range(1598))
+ self.subject.validate_pred_type(None)
+
+ def test_validate_pred_type_valid_pred_type_valid_class_values_response(self):
+ self.subject.pred_type = 'response'
+ self.subject.validate_pred_type(range(1598))
+ self.subject.validate_pred_type(None)
+
+class MadlibKerasPredictTestCase(unittest.TestCase):
+ def setUp(self):
+ self.plpy_mock = Mock(spec='error')
+ patches = {
+ 'plpy': plpy
+ }
+
+ self.plpy_mock_execute = MagicMock()
+ plpy.execute = self.plpy_mock_execute
- def test_validate_pred_type_valid_pred_type_valid_class_values(self):
- self.subject.validate_pred_type('prob', range(1598))
- self.subject.validate_pred_type('prob', None)
- self.subject.validate_pred_type('response', range(1598))
- self.subject.validate_pred_type('response', None)
+ self.module_patcher = patch.dict('sys.modules', patches)
+ self.module_patcher.start()
+ import madlib_keras_predict
+ self.subject = madlib_keras_predict
+
+ def tearDown(self):
+ self.module_patcher.stop()
def test_strip_trailing_nulls_from_class_values(self):
self.assertEqual(['cat', 'dog'],