You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2019/05/13 16:56:40 UTC

[GitHub] [madlib] njayaram2 commented on a change in pull request #388: DL: Add new param metrics_compute_frequency to madlib_keras_fit()

njayaram2 commented on a change in pull request #388:  DL: Add new param metrics_compute_frequency to madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/388#discussion_r283441574
 
 

 ##########
 File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
 ##########
 @@ -313,6 +316,61 @@ def fit(schema_madlib, source_table, model,model_arch_table,
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
+def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
+                             independent_varname, compile_params, model_arch,
+                             model_state, gpus_per_host, segments_per_host,
+                             seg_ids_val, rows_per_seg_val,
+                             gp_segment_id_col, metrics_list, loss_list,
+                             curr_iter, dataset_name):
+    """
+    Compute the loss and metric using a given model (model_state) on the
+    given dataset (table.)
+    """
+    start_val = time.time()
+    evaluate_result = get_loss_acc_from_keras_eval(schema_madlib,
+                                                   table,
+                                                   dependent_varname,
+                                                   independent_varname,
+                                                   compile_params,
+                                                   model_arch, model_state,
+                                                   gpus_per_host,
+                                                   segments_per_host,
+                                                   seg_ids_val,
+                                                   rows_per_seg_val,
+                                                   gp_segment_id_col)
+    end_val = time.time()
+    plpy.info("Time for evaluation in iteration {0}: {1} sec.". format(
+        curr_iter + 1, end_val - start_val))
+    if len(evaluate_result) < 2:
 
 Review comment:
   Yes, this will be addressed as part of https://issues.apache.org/jira/browse/MADLIB-1338 I guess.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services