You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2019/04/04 22:25:03 UTC

[GitHub] [madlib] reductionista commented on a change in pull request #362: DL: Remove num_classes param from madlib_keras_fit()

reductionista commented on a change in pull request #362: DL: Remove num_classes param from madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/362#discussion_r272390533
 
 

 ##########
 File path: src/ports/postgres/modules/utilities/model_arch_info.py_in
 ##########
 @@ -21,69 +21,45 @@ m4_changequote(`<!', `!>')
 
 import sys
 import json
+import plpy
 
-def get_layers(arch):
-    d = json.loads(arch)
+def _get_layers(model_arch):
+    d = json.loads(model_arch)
     config = d['config']
     if type(config) == list:
-        return config  # In keras 1.x, all models are sequential
+        return config  # In keras 2.1.x, all models are sequential
     elif type(config) == dict and 'layers' in config:
         layers = config['layers']
         if type(layers) == list:
             return config['layers']  # In keras 2.x, only sequential models are supported
-    plpy.error('Unable to read input_shape from keras model arch.  Note: only sequential keras models are supported.')
-    return None
+    plpy.error("Unable to read model architecture JSON.")
 
-def get_input_shape(arch):
-    layers = get_layers(arch)
-    return layers[0]['config']['batch_input_shape'][1:]
+def get_input_shape(model_arch):
+    arch_layers = _get_layers(model_arch)
+    if 'batch_input_shape' in arch_layers[0]['config']:
+        return arch_layers[0]['config']['batch_input_shape'][1:]
+    plpy.error('Unable to get input shape from model architecture.')
 
-def print_model_arch_layers(arch):
-    layers = get_layers(arch)
+def get_num_classes(model_arch):
+    arch_layers = _get_layers(model_arch)
+    if 'units' in arch_layers[-1]['config']:
+        return arch_layers[-1]['config']['units']
+    plpy.error('Unable to get number of classes from model architecture.')
 
-    print("\nModel arch layers:")
+def get_model_arch_layers_str(model_arch):
+    arch_layers = _get_layers(model_arch)
+    layers = "Model arch layers:\n"
     first = True
-    for layer in layers:
+    for layer in arch_layers:
         if first:
             first = False
         else:
-            print("   |")
-            print("   V")
+            layers = "{0}   |\n".format(layers)
+            layers = "{0}   V\n".format(layers)
         class_name = layer['class_name']
         config = layer['config']
         if class_name == 'Dense':
-            print("{0}[{1}]".class_name)
+            layers = "{0}{1}[{2}]\n".format(layers, config, class_name)
         else:
-            print(class_name)
-
-def print_input_shape(arch):
-    layers = get_layers(arch)
-    print("\nInput shape:")
-    print(layers[0]['config']['batch_input_shape'][1:])
-
-def print_required_imports(arch):
-    layers = get_layers(arch)
-    class_names = set(layer['class_name'] for layer in layers )
-    print("\nRequired imports:")
-    for module in class_names:
-        print("import {}".module)
 
 Review comment:
   Some version of this function might be good to leave in.  It finds all the layer names and builds import statements out of them.  This would be useful if we ever decided to stop doing `import *` on everything in keras.  We already know which optimizer the user is using, and Sequential is the only model we support.  So I think the only left we'd need to do something more for is the regularizers.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services