You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by fm...@apache.org on 2019/11/23 00:29:58 UTC
[madlib-site] branch automl updated: hyperband diagonal E2E update
This is an automated email from the ASF dual-hosted git repository.
fmcquillan pushed a commit to branch automl
in repository https://gitbox.apache.org/repos/asf/madlib-site.git
The following commit(s) were added to refs/heads/automl by this push:
new 94a7f7e hyperband diagonal E2E update
94a7f7e is described below
commit 94a7f7e81077ccd67710648850b696e2344e39d9
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Fri Nov 22 16:29:51 2019 -0800
hyperband diagonal E2E update
---
.../hyperband_diag_v2_mnist-checkpoint.ipynb | 157 +++++++++------------
.../automl/hyperband_diag_v2_mnist.ipynb | 130 ++++++++---------
2 files changed, 135 insertions(+), 152 deletions(-)
diff --git a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
index 091e6fd..b62f8d5 100644
--- a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
+++ b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
@@ -30,19 +30,17 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [
{
- "name": "stderr",
+ "name": "stdout",
"output_type": "stream",
"text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
+ "The sql extension is already loaded. To reload it, use:\n",
+ " %reload_ext sql\n"
]
}
],
@@ -52,7 +50,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -74,7 +72,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
@@ -100,7 +98,7 @@
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
- "execution_count": 3,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -121,24 +119,9 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 20,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Using TensorFlow backend.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Couldn't import dot_parser, loading of dot files will not be possible.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"from __future__ import print_function\n",
"\n",
@@ -180,7 +163,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@@ -794,7 +777,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -821,7 +804,7 @@
"[]"
]
},
- "execution_count": 17,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -896,7 +879,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@@ -924,7 +907,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@@ -953,12 +936,13 @@
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"Hyperband brackets\")\n",
"\n",
" #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
" for s in reversed(range(self.s_max+1)):\n",
- "\n",
- " print (\" \")\n",
- " print (\"Hyperband brackets\")\n",
+ " \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
@@ -1040,40 +1024,44 @@
" # filter out early stops, if any\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
- " print (\"loop on s desc to prune mst table:\")\n",
- " for s in range(self.s_max, self.s_max-i-1, -1):\n",
- " \n",
- " k = int( self.n_vals[s][i] / self.eta)\n",
+ " # don't need to prune if finished last diagonal\n",
+ " if i < self.s_max:\n",
+ " print (\"loop on s desc to prune mst table:\")\n",
+ " for s in range(self.s_max, self.s_max-i-1, -1):\n",
" \n",
- " # temporarily re-run table names again due to weird scope issues\n",
- " results_table = 'results_mnist'\n",
+ " # compute number of configs to keep\n",
+ " # remember i value is different for each bracket s on the diagonal\n",
+ " k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
+ " print (\"pruning s = {} with k = {}\".format(s, k))\n",
"\n",
- " output_table = 'mnist_multi_model'\n",
- " output_table_info = '_'.join([output_table, 'info'])\n",
- " output_table_summary = '_'.join([output_table, 'summary'])\n",
+ " # temporarily re-define table names due to weird Python scope issues\n",
+ " results_table = 'results_mnist'\n",
"\n",
- " mst_table = 'mst_table_hb_mnist'\n",
- " mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+ " output_table = 'mnist_multi_model'\n",
+ " output_table_info = '_'.join([output_table, 'info'])\n",
+ " output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
- " mst_diag_table = 'mst_diag_table_hb_mnist'\n",
- " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+ " mst_table = 'mst_table_hb_mnist'\n",
+ " mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
- " model_arch_table = 'model_arch_library_mnist'\n",
- " \n",
- " query = \"\"\"\n",
- " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
- " \"\"\".format(**locals())\n",
- " cur.execute(query)\n",
- " conn.commit()\n",
- " #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
- "# %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
- " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+ " mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+ " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+ "\n",
+ " model_arch_table = 'model_arch_library_mnist'\n",
+ " \n",
+ " query = \"\"\"\n",
+ " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+ " \"\"\".format(**locals())\n",
+ " cur.execute(query)\n",
+ " conn.commit()\n",
+ " #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+ " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
" return"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -1129,7 +1117,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -1163,7 +1151,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 27,
"metadata": {
"scrolled": false
},
@@ -1182,16 +1170,12 @@
"3.0 3.0\n",
"1.0 9.0\n",
" \n",
- "Hyperband brackets\n",
- " \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"3 3.0\n",
"1.0 9.0\n",
" \n",
- "Hyperband brackets\n",
- " \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
@@ -1248,6 +1232,7 @@
"9 rows affected.\n",
"9 rows affected.\n",
"loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 3\n",
" \n",
"i=1\n",
"Done.\n",
@@ -1264,29 +1249,35 @@
"6 rows affected.\n",
"6 rows affected.\n",
"loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 1\n",
+ "pruning s = 1 with k = 1\n",
" \n",
"i=2\n",
"Done.\n",
"loop on s desc to create diagonal table:\n",
"1 rows affected.\n",
- "0 rows affected.\n",
+ "1 rows affected.\n",
"3 rows affected.\n",
"try params for i = 2\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
- "4 rows affected.\n",
+ "5 rows affected.\n",
"Done.\n",
- "4 rows affected.\n",
- "4 rows affected.\n",
- "4 rows affected.\n",
- "loop on s desc to prune mst table:\n"
+ "5 rows affected.\n",
+ "5 rows affected.\n",
+ "5 rows affected.\n",
+ "loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 0\n",
+ "pruning s = 1 with k = 0\n",
+ "pruning s = 0 with k = 1\n"
]
}
],
"source": [
"hp = Hyperband_diagonal(get_params, try_params )\n",
- "results = hp.run()"
+ "results = hp.run()\n",
+ "#hp.n_vals[1]"
]
},
{
@@ -1299,7 +1290,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@@ -1330,7 +1321,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "15 rows affected.\n"
+ "10 rows affected.\n"
]
},
{
@@ -2116,7 +2107,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb1f3+X0se8ooTJ3GcxM5wEhJ2gSTMJJBCGYVCCJRSaKGD1VL6b/kBLSNAwygd0AE0ZbS0BcoKq0CBssPMYI8MMp3h7HhLsiX9n/fI15GH7CtZlq7s9zyPHyf2uWd8zpX11Xu/IyMUCoWgJgIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIi0KsEMiTE9SpfDS4CIiACIiACIiACIiACIiACIiACIiACIiAChoCEON0IIpDmBE477TTMnz8fP/7xj3H77bcndDeTJ0/GkiVL8Nvf/hb/93//l9CxNVjiCHz66afYd999zYBbt27FkCFDEje4RhIBERABERABEUgaAdl1HVHX1dWhsLDQ/GLRokWgfaomAiIgAulMQEJcOp+e1t5rBDI [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -2138,18 +2129,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 10;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
@@ -2199,7 +2185,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "15 rows affected.\n"
+ "10 rows affected.\n"
]
},
{
@@ -2985,7 +2971,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCZiT1fX/v8ns+zDMhgyLyOaKsoqyuAGCu2LVYkVFLdat1bb8q4IoLqVVa12x1V/FuiMq1YpAQRRQZFNEBUGQZZBhFph9yUyS//O9wztmMsnkTfImeSdz7vPMM0Duve+9n3tDTr733HMsTqfTCSlCQAgIASEgBISAEBACQkAICAEhIASEgBAQAkJACISUgEWEuJDylc6FgBAQAkJACAgBISAEhIAQEAJCQAgIASEgBISAIiBCnGwEIRBGAh999BEmTpyIhIQE1NfXt3pye6/pGWKw7fU8o7068+bNw80334wBAwZg27ZtwXYn7UNEQNYpRGClWyEgBISAEOh0BMSuM9+SR9oeNh8RGZEQEAJmJCBCnBlXRcZkOIEbb7wRL7zwArKysvD [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -3007,18 +2993,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 10;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
index 091e6fd..171c9cd 100644
--- a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
+++ b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
@@ -52,7 +52,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -74,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -100,7 +100,7 @@
"[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -121,7 +121,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -180,7 +180,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -794,7 +794,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
@@ -821,7 +821,7 @@
"[]"
]
},
- "execution_count": 17,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -896,7 +896,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@@ -924,7 +924,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@@ -953,12 +953,13 @@
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"Hyperband brackets\")\n",
"\n",
" #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
" for s in reversed(range(self.s_max+1)):\n",
- "\n",
- " print (\" \")\n",
- " print (\"Hyperband brackets\")\n",
+ " \n",
" print (\" \")\n",
" print (\"s=\" + str(s))\n",
" print (\"n_i r_i\")\n",
@@ -1040,40 +1041,44 @@
" # filter out early stops, if any\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
- " print (\"loop on s desc to prune mst table:\")\n",
- " for s in range(self.s_max, self.s_max-i-1, -1):\n",
- " \n",
- " k = int( self.n_vals[s][i] / self.eta)\n",
+ " # don't need to prune if finished last diagonal\n",
+ " if i < self.s_max:\n",
+ " print (\"loop on s desc to prune mst table:\")\n",
+ " for s in range(self.s_max, self.s_max-i-1, -1):\n",
" \n",
- " # temporarily re-run table names again due to weird scope issues\n",
- " results_table = 'results_mnist'\n",
+ " # compute number of configs to keep\n",
+ " # remember i value is different for each bracket s on the diagonal\n",
+ " k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
+ " print (\"pruning s = {} with k = {}\".format(s, k))\n",
"\n",
- " output_table = 'mnist_multi_model'\n",
- " output_table_info = '_'.join([output_table, 'info'])\n",
- " output_table_summary = '_'.join([output_table, 'summary'])\n",
+ " # temporarily re-define table names due to weird Python scope issues\n",
+ " results_table = 'results_mnist'\n",
"\n",
- " mst_table = 'mst_table_hb_mnist'\n",
- " mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+ " output_table = 'mnist_multi_model'\n",
+ " output_table_info = '_'.join([output_table, 'info'])\n",
+ " output_table_summary = '_'.join([output_table, 'summary'])\n",
"\n",
- " mst_diag_table = 'mst_diag_table_hb_mnist'\n",
- " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+ " mst_table = 'mst_table_hb_mnist'\n",
+ " mst_table_summary = '_'.join([mst_table, 'summary'])\n",
"\n",
- " model_arch_table = 'model_arch_library_mnist'\n",
- " \n",
- " query = \"\"\"\n",
- " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
- " \"\"\".format(**locals())\n",
- " cur.execute(query)\n",
- " conn.commit()\n",
- " #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
- "# %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
- " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+ " mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+ " mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+ "\n",
+ " model_arch_table = 'model_arch_library_mnist'\n",
+ " \n",
+ " query = \"\"\"\n",
+ " DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+ " \"\"\".format(**locals())\n",
+ " cur.execute(query)\n",
+ " conn.commit()\n",
+ " #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+ " #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
" return"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@@ -1129,7 +1134,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -1163,7 +1168,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 27,
"metadata": {
"scrolled": false
},
@@ -1182,16 +1187,12 @@
"3.0 3.0\n",
"1.0 9.0\n",
" \n",
- "Hyperband brackets\n",
- " \n",
"s=1\n",
"n_i r_i\n",
"------------\n",
"3 3.0\n",
"1.0 9.0\n",
" \n",
- "Hyperband brackets\n",
- " \n",
"s=0\n",
"n_i r_i\n",
"------------\n",
@@ -1248,6 +1249,7 @@
"9 rows affected.\n",
"9 rows affected.\n",
"loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 3\n",
" \n",
"i=1\n",
"Done.\n",
@@ -1264,29 +1266,35 @@
"6 rows affected.\n",
"6 rows affected.\n",
"loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 1\n",
+ "pruning s = 1 with k = 1\n",
" \n",
"i=2\n",
"Done.\n",
"loop on s desc to create diagonal table:\n",
"1 rows affected.\n",
- "0 rows affected.\n",
+ "1 rows affected.\n",
"3 rows affected.\n",
"try params for i = 2\n",
"Done.\n",
"1 rows affected.\n",
"Done.\n",
- "4 rows affected.\n",
+ "5 rows affected.\n",
"Done.\n",
- "4 rows affected.\n",
- "4 rows affected.\n",
- "4 rows affected.\n",
- "loop on s desc to prune mst table:\n"
+ "5 rows affected.\n",
+ "5 rows affected.\n",
+ "5 rows affected.\n",
+ "loop on s desc to prune mst table:\n",
+ "pruning s = 2 with k = 0\n",
+ "pruning s = 1 with k = 0\n",
+ "pruning s = 0 with k = 1\n"
]
}
],
"source": [
"hp = Hyperband_diagonal(get_params, try_params )\n",
- "results = hp.run()"
+ "results = hp.run()\n",
+ "#hp.n_vals[1]"
]
},
{
@@ -1299,7 +1307,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -1323,14 +1331,14 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "15 rows affected.\n"
+ "12 rows affected.\n"
]
},
{
@@ -2116,7 +2124,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb5dn+b0m2vDOc2Fl2diBAWVnsXSBQRinQSQu00EJ3oS39GIE2FMr3tdD2T1tmW7poWWXPllF2BrthZMcZjp3YTjwlW9L/ul/52LItyUfS0bB9v9fly4n1nnf83mPr0X2e4QqFQiGoiYAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIpJWAS0JcWvlqcBEQAREQAREQAREQAREQAREQAREQAREQAREwBCTE6UYQgSFO4KyzzsL999+Pb3zjG7j55psd3c2CBQuwcuVK/N///R++//3vOzq2BnOOwHvvvYd9993XDFhfX4/x48c7N7hGEgEREAEREAERyBgB2XUDUbe0tKCsrMy8sHz5ctA+VRMBERCBoUxAQtxQPj2tPW0EXC5X0mP [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -2140,16 +2148,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
@@ -2192,14 +2197,14 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "15 rows affected.\n"
+ "12 rows affected.\n"
]
},
{
@@ -2985,7 +2990,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhU1fnG31myrwQStgQQVMAN2aKAgKBW3Dds3dFqFWtrq/6rbVWkxaV0cWmtxbVq1boUq3VHZUchgAiiICoCCRAC2deZJJP/857kJpPJLPfO3FmSfOd58iQk5557zu+cYb5577dYWlpaWiBNCAgBISAEhIAQEAJCQAgIASEgBISAEBACQkAICIGwErCIEBdWvjK4EBACQkAICAEhIASEgBAQAkJACAgBISAEhIAQUAREiJODIAQiSOD999/HGWecgYSEBDQ0NHS6s7+/6ZliqNfruYe/PosWLcKNN96IkSNHYvv27aEOJ9eHiYDsU5jAyrBCQAgIASHQ6wiIXRd7Wx5tezj2iMiMhIAQiEUCIsTF4q7InEwn8JOf/ARPPfUUsrKysG/ [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -3009,16 +3014,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",