You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by fm...@apache.org on 2021/03/06 00:54:55 UTC
[madlib] branch master updated: update example in multi-fit to use
new model config generator
This is an automated email from the ASF dual-hosted git repository.
fmcquillan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 33ad16c update example in multi-fit to use new model config generator
33ad16c is described below
commit 33ad16c29af1e99a02a8a153671a9a16608e74c6
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Fri Mar 5 16:54:34 2021 -0800
update example in multi-fit to use new model config generator
---
.../madlib_keras_fit_multiple_model.sql_in | 69 ++++++++++++----------
1 file changed, 37 insertions(+), 32 deletions(-)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 67ee2c7..e8c4d51 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -999,41 +999,44 @@ $$
'MLP with 2 hidden layers' -- Descr
);
</pre>
--# Define model selection tuples and load. Select the model(s) from the model architecture
-table that you want to run, along with the compile and fit parameters. Combinations will be
-created for the set of model selection parameters will be loaded:
+-# Generate model configurations using grid search. The output table for grid
+search contains the unique combinations of model architectures, compile and
+fit parameters.
<pre class="example">
DROP TABLE IF EXISTS mst_table, mst_table_summary;
-SELECT madlib.load_model_selection_table('model_arch_library', -- model architecture table
- 'mst_table', -- model selection table output
- ARRAY[1,2], -- model ids from model architecture table
- ARRAY[ -- compile params
- $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
- $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
- $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
- ],
- ARRAY[ -- fit params
- $$batch_size=4,epochs=1$$,
- $$batch_size=8,epochs=1$$
- ]
+SELECT madlib.generate_model_configs(
+ 'model_arch_library', -- model architecture table
+ 'mst_table', -- model selection table output
+ ARRAY[1,2], -- model ids from model architecture table
+ $$
+ {'loss': ['categorical_crossentropy'],
+ 'optimizer_params_list': [ {'optimizer': ['Adam'], 'lr': [0.001, 0.01, 0.1]} ],
+ 'metrics': ['accuracy']}
+ $$, -- compile_param_grid
+ $$
+ { 'batch_size': [4, 8],
+ 'epochs': [1]
+ }
+ $$, -- fit_param_grid
+ 'grid' -- search_type
);
SELECT * FROM mst_table ORDER BY mst_key;
</pre>
<pre class="result">
- mst_key | model_id | compile_params | fit_params
+ mst_key | model_id | compile_params | fit_params
---------+----------+---------------------------------------------------------------------------------+-----------------------
- 1 | 1 | loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy'] | batch_size=4,epochs=1
- 2 | 1 | loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy'] | batch_size=8,epochs=1
- 3 | 1 | loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy'] | batch_size=4,epochs=1
- 4 | 1 | loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy'] | batch_size=8,epochs=1
- 5 | 1 | loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy'] | batch_size=4,epochs=1
- 6 | 1 | loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy'] | batch_size=8,epochs=1
- 7 | 2 | loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy'] | batch_size=4,epochs=1
- 8 | 2 | loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy'] | batch_size=8,epochs=1
- 9 | 2 | loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy'] | batch_size=4,epochs=1
- 10 | 2 | loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy'] | batch_size=8,epochs=1
- 11 | 2 | loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy'] | batch_size=4,epochs=1
- 12 | 2 | loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy'] | batch_size=8,epochs=1
+ 1 | 1 | optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 2 | 1 | optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
+ 3 | 1 | optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 4 | 1 | optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
+ 5 | 1 | optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 6 | 1 | optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
+ 7 | 2 | optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 8 | 2 | optimizer='Adam(lr=0.001)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
+ 9 | 2 | optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 10 | 2 | optimizer='Adam(lr=0.01)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
+ 11 | 2 | optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=4
+ 12 | 2 | optimizer='Adam(lr=0.1)',metrics=['accuracy'],loss='categorical_crossentropy' | epochs=1,batch_size=8
(12 rows)
</pre>
This is the name of the model architecture table that corresponds to the model selection table:
@@ -1041,9 +1044,9 @@ This is the name of the model architecture table that corresponds to the model s
SELECT * FROM mst_table_summary;
</pre>
<pre class="result">
- model_arch_table
---------------------+
- model_arch_library
+ model_arch_table | object_table
+--------------------+--------------
+ model_arch_library |
</pre>
-# Train multiple models.
@@ -1176,8 +1179,10 @@ SELECT * FROM iris_predict ORDER BY id;
</pre>
Count missclassifications:
<pre class="example">
-SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id)
+SELECT COUNT(*) FROM iris_predict JOIN iris_test USING (id)
WHERE iris_predict.class_value != iris_test.class_text;
+</pre>
+<pre class="result">
count
-------+
0