You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/07/09 22:09:59 UTC

[GitHub] [madlib] orhankislal commented on a change in pull request #503: DL: Add support for custom loss functions

orhankislal commented on a change in pull request #503:
URL: https://github.com/apache/madlib/pull/503#discussion_r452502853



##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -331,3 +337,32 @@ def validate_compile_param_types(compile_dict):
             compile_dict['sample_weight_mode'] is None or
             compile_dict['sample_weight_mode'] == "temporal",
             """compile parameter sample_weight_mode can only be "temporal" or None""")
+
+# Returns an object of custom function name and it corresponding object
+def query_custom_functions_map(object_table, custom_fn_names):
+    if len(custom_fn_names) < 1:

Review comment:
       We should check if `custom_fn_names` is not `None` first, otherwise the `len` operation would fail.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -175,6 +179,17 @@ def get_metrics_from_compile_param(str_of_args):
                         "please refer to the documentation").format(ckey))
     return metrics
 
+def get_loss_from_compile_param(str_of_args):
+    compile_dict = convert_string_of_args_to_dict(str_of_args)
+    loss = None
+    ckey = 'loss'

Review comment:
       We use `'loss'` in a number of different places and this function is literally named for loss. I don't think we need the `ckey` variable. If we want to generalize this function as in `get_key_from_compile_param(str_args,key)` that might work as well.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -331,3 +337,32 @@ def validate_compile_param_types(compile_dict):
             compile_dict['sample_weight_mode'] is None or
             compile_dict['sample_weight_mode'] == "temporal",
             """compile parameter sample_weight_mode can only be "temporal" or None""")
+
+# Returns an object of custom function name and it corresponding object
+def query_custom_functions_map(object_table, custom_fn_names):
+    if len(custom_fn_names) < 1:
+        return None
+    custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
+    # Dictionary map of name:object
+    # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+    custom_fn_map = defaultdict(list)
+    # Query the custom function if not yet loaded from table
+    res = plpy.execute("SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table} " \

Review comment:
       Formatting. We should try to fold around 80 chars using triple quotes. 

##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
##########
@@ -137,3 +137,82 @@ SELECT assert(loss >= 0 AND
         metric >= 0 AND
         metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
+
+-- TEST custom loss function
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+  c = a*b*0
+  return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        source_table = 'iris_data_packed' AND
+        model = 'iris_model' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        dependent_vartype LIKE '%char%' AND
+        normalizing_const = 1 AND
+        pg_typeof(normalizing_const) = 'real'::regtype AND
+        name is NULL AND
+        description is NULL AND
+        object_table = 'test_custom_function_table' AND
+        model_size > 0 AND
+        madlib_version is NOT NULL AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text AND
+        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+        num_iterations = 3 AND
+        metrics_compute_frequency = 1 AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        metrics_type = '{mae}' AND
+        array_upper(training_metrics, 1) = 3 AND
+        training_loss = '{0,0,0}' AND
+        array_upper(metrics_elapsed_time, 1) = 3 ,
+        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_model_summary) summary;
+
+SELECT assert(
+        model_weights IS NOT NULL AND
+        model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k))
+FROM (SELECT * FROM iris_model) k;
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn1', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);$TRAP$) = 1,
+'custom function in compile_params not defined in Object table.');

Review comment:
       New line

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -307,9 +309,13 @@ def get_optimizers():
     return optimizers
 
 # Run the keras.compile with the given parameters
-def compile_model(model, compile_params):
+def compile_model(model, compile_params, custom_function_map=None):
     optimizers = get_optimizers()
     (opt_name,final_args,compile_dict) = parse_and_validate_compile_params(compile_params)
+    if custom_function_map is not None:
+        import dill

Review comment:
       Do we have to import dill a second time?

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -33,6 +33,8 @@ import keras.losses as losses
 
 import madlib_keras_serializer
 import madlib_keras_gpu_info
+from madlib_keras_custom_function import CustomFunctionSchema

Review comment:
       I am guessing this was supposed to be a part of the previous commit. No need to rebase just to fix it, just wanted to let you know.

##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
##########
@@ -331,3 +337,32 @@ def validate_compile_param_types(compile_dict):
             compile_dict['sample_weight_mode'] is None or
             compile_dict['sample_weight_mode'] == "temporal",
             """compile parameter sample_weight_mode can only be "temporal" or None""")
+
+# Returns an object of custom function name and it corresponding object
+def query_custom_functions_map(object_table, custom_fn_names):
+    if len(custom_fn_names) < 1:
+        return None
+    custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
+    # Dictionary map of name:object
+    # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+    custom_fn_map = defaultdict(list)
+    # Query the custom function if not yet loaded from table
+    res = plpy.execute("SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table} " \
+                       "WHERE {custom_fn_col_name} = ANY(ARRAY{custom_fn_names})".format(custom_obj_col_name=custom_obj_col_name,
+                                                                                object_table=object_table,
+                                                                                custom_fn_col_name=CustomFunctionSchema.FN_NAME,
+                                                                                custom_fn_names=custom_fn_names))
+    if res.nrows() < len(custom_fn_names):
+        plpy.error("Custom function {0} not defined in object table '{1}'".format(custom_fn_names, object_table))
+    for r in res:
+        custom_fn_map[r[CustomFunctionSchema.FN_NAME]] = dill.loads(r[custom_obj_col_name])
+    custom_fn_map_obj = dill.dumps(custom_fn_map)
+    return custom_fn_map_obj
+
+def get_custom_functions_list(compile_params):
+    compile_dict = convert_string_of_args_to_dict(compile_params)
+    builtin_losses = dir(losses)
+    custom_fn_list = []
+    if compile_dict['loss'] not in builtin_losses:
+        custom_fn_list.append(compile_dict['loss'])
+    return custom_fn_list

Review comment:
       New line




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org