You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by fm...@apache.org on 2019/11/23 00:29:58 UTC

[madlib-site] branch automl updated: hyperband diagonal E2E update

This is an automated email from the ASF dual-hosted git repository.

fmcquillan pushed a commit to branch automl
in repository https://gitbox.apache.org/repos/asf/madlib-site.git


The following commit(s) were added to refs/heads/automl by this push:
     new 94a7f7e  hyperband diagonal E2E update
94a7f7e is described below

commit 94a7f7e81077ccd67710648850b696e2344e39d9
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Fri Nov 22 16:29:51 2019 -0800

    hyperband diagonal E2E update
---
 .../hyperband_diag_v2_mnist-checkpoint.ipynb       | 157 +++++++++------------
 .../automl/hyperband_diag_v2_mnist.ipynb           | 130 ++++++++---------
 2 files changed, 135 insertions(+), 152 deletions(-)

diff --git a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
index 091e6fd..b62f8d5 100644
--- a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
+++ b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
@@ -30,19 +30,17 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": 16,
    "metadata": {
     "scrolled": true
    },
    "outputs": [
     {
-     "name": "stderr",
+     "name": "stdout",
      "output_type": "stream",
      "text": [
-      "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
-      "  \"You should import from traitlets.config instead.\", ShimWarning)\n",
-      "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
-      "  warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
+      "The sql extension is already loaded. To reload it, use:\n",
+      "  %reload_ext sql\n"
      ]
     }
    ],
@@ -52,7 +50,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -74,7 +72,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
@@ -100,7 +98,7 @@
        "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
       ]
      },
-     "execution_count": 3,
+     "execution_count": 19,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -121,24 +119,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 20,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Using TensorFlow backend.\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Couldn't import dot_parser, loading of dot files will not be possible.\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "from __future__ import print_function\n",
     "\n",
@@ -180,7 +163,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -794,7 +777,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
@@ -821,7 +804,7 @@
        "[]"
       ]
      },
-     "execution_count": 17,
+     "execution_count": 22,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -896,7 +879,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -924,7 +907,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 24,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -953,12 +936,13 @@
     "        self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
     "        self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
     "        sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+    "        \n",
+    "        print (\" \")\n",
+    "        print (\"Hyperband brackets\")\n",
     "\n",
     "        #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
     "        for s in reversed(range(self.s_max+1)):\n",
-    "\n",
-    "            print (\" \")\n",
-    "            print (\"Hyperband brackets\")\n",
+    "            \n",
     "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n_i      r_i\")\n",
@@ -1040,40 +1024,44 @@
     "            # filter out early stops, if any\n",
     "            \n",
     "            # loop on brackets s desc to prune model selection table\n",
-    "            print (\"loop on s desc to prune mst table:\")\n",
-    "            for s in range(self.s_max, self.s_max-i-1, -1):\n",
-    "                \n",
-    "                k = int( self.n_vals[s][i] / self.eta)\n",
+    "            # don't need to prune if finished last diagonal\n",
+    "            if i < self.s_max:\n",
+    "                print (\"loop on s desc to prune mst table:\")\n",
+    "                for s in range(self.s_max, self.s_max-i-1, -1):\n",
     "                \n",
-    "                # temporarily re-run table names again due to weird scope issues\n",
-    "                results_table = 'results_mnist'\n",
+    "                    # compute number of configs to keep\n",
+    "                    # remember i value is different for each bracket s on the diagonal\n",
+    "                    k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
+    "                    print (\"pruning s = {} with k = {}\".format(s, k))\n",
     "\n",
-    "                output_table = 'mnist_multi_model'\n",
-    "                output_table_info = '_'.join([output_table, 'info'])\n",
-    "                output_table_summary = '_'.join([output_table, 'summary'])\n",
+    "                    # temporarily re-define table names due to weird Python scope issues\n",
+    "                    results_table = 'results_mnist'\n",
     "\n",
-    "                mst_table = 'mst_table_hb_mnist'\n",
-    "                mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+    "                    output_table = 'mnist_multi_model'\n",
+    "                    output_table_info = '_'.join([output_table, 'info'])\n",
+    "                    output_table_summary = '_'.join([output_table, 'summary'])\n",
     "\n",
-    "                mst_diag_table = 'mst_diag_table_hb_mnist'\n",
-    "                mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "                    mst_table = 'mst_table_hb_mnist'\n",
+    "                    mst_table_summary = '_'.join([mst_table, 'summary'])\n",
     "\n",
-    "                model_arch_table = 'model_arch_library_mnist'\n",
-    "                \n",
-    "                query = \"\"\"\n",
-    "                DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
-    "                \"\"\".format(**locals())\n",
-    "                cur.execute(query)\n",
-    "                conn.commit()\n",
-    "                #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "#                 %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "                #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+    "                    mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+    "                    mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "\n",
+    "                    model_arch_table = 'model_arch_library_mnist'\n",
+    "            \n",
+    "                    query = \"\"\"\n",
+    "                    DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+    "                    \"\"\".format(**locals())\n",
+    "                    cur.execute(query)\n",
+    "                    conn.commit()\n",
+    "                    #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "                    #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
     "        return"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1129,7 +1117,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1163,7 +1151,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 27,
    "metadata": {
     "scrolled": false
    },
@@ -1182,16 +1170,12 @@
       "3.0     3.0\n",
       "1.0     9.0\n",
       " \n",
-      "Hyperband brackets\n",
-      " \n",
       "s=1\n",
       "n_i      r_i\n",
       "------------\n",
       "3     3.0\n",
       "1.0     9.0\n",
       " \n",
-      "Hyperband brackets\n",
-      " \n",
       "s=0\n",
       "n_i      r_i\n",
       "------------\n",
@@ -1248,6 +1232,7 @@
       "9 rows affected.\n",
       "9 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 3\n",
       " \n",
       "i=1\n",
       "Done.\n",
@@ -1264,29 +1249,35 @@
       "6 rows affected.\n",
       "6 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 1\n",
+      "pruning s = 1 with k = 1\n",
       " \n",
       "i=2\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
       "1 rows affected.\n",
-      "0 rows affected.\n",
+      "1 rows affected.\n",
       "3 rows affected.\n",
       "try params for i = 2\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "4 rows affected.\n",
+      "5 rows affected.\n",
       "Done.\n",
-      "4 rows affected.\n",
-      "4 rows affected.\n",
-      "4 rows affected.\n",
-      "loop on s desc to prune mst table:\n"
+      "5 rows affected.\n",
+      "5 rows affected.\n",
+      "5 rows affected.\n",
+      "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 0\n",
+      "pruning s = 1 with k = 0\n",
+      "pruning s = 0 with k = 1\n"
      ]
     }
    ],
    "source": [
     "hp = Hyperband_diagonal(get_params, try_params )\n",
-    "results = hp.run()"
+    "results = hp.run()\n",
+    "#hp.n_vals[1]"
    ]
   },
   {
@@ -1299,7 +1290,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 28,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1330,7 +1321,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "15 rows affected.\n"
+      "10 rows affected.\n"
      ]
     },
     {
@@ -2116,7 +2107,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...]
+       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb1f3+X0se8ooTJ3GcxM5wEhJ2gSTMJJBCGYVCCJRSaKGD1VL6b/kBLSNAwygd0AE0ZbS0BcoKq0CBssPMYI8MMp3h7HhLsiX9n/fI15GH7CtZlq7s9zyPHyf2uWd8zpX11Xu/IyMUCoWgJgIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIi0KsEMiTE9SpfDS4CIiACIiACIiACIiACIiACIiACIiACIiAChoCEON0IIpDmBE477TTMnz8fP/7xj3H77bcndDeTJ0/GkiVL8Nvf/hb/93//l9CxNVjiCHz66afYd999zYBbt27FkCFDEje4RhIBERABERABEUgaAdl1HVHX1dWhsLDQ/GLRokWgfaomAiIgAulMQEJcOp+e1t5rBDI [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -2138,18 +2129,13 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 10;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
     "#set up plots\n",
@@ -2199,7 +2185,7 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "15 rows affected.\n"
+      "10 rows affected.\n"
      ]
     },
     {
@@ -2985,7 +2971,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...]
+       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCZiT1fX/v8ns+zDMhgyLyOaKsoqyuAGCu2LVYkVFLdat1bb8q4IoLqVVa12x1V/FuiMq1YpAQRRQZFNEBUGQZZBhFph9yUyS//O9wztmMsnkTfImeSdz7vPMM0Duve+9n3tDTr733HMsTqfTCSlCQAgIASEgBISAEBACQkAICAEhIASEgBAQAkJACISUgEWEuJDylc6FgBAQAkJACAgBISAEhIAQEAJCQAgIASEgBISAIiBCnGwEIRBGAh999BEmTpyIhIQE1NfXt3pye6/pGWKw7fU8o7068+bNw80334wBAwZg27ZtwXYn7UNEQNYpRGClWyEgBISAEOh0BMSuM9+SR9oeNh8RGZEQEAJmJCBCnBlXRcZkOIEbb7wRL7zwArKysvD [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -3007,18 +2993,13 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 10;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
     "#set up plots\n",
diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
index 091e6fd..171c9cd 100644
--- a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
+++ b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
@@ -52,7 +52,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -74,7 +74,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
@@ -100,7 +100,7 @@
        "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
       ]
      },
-     "execution_count": 3,
+     "execution_count": 5,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -121,7 +121,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [
     {
@@ -180,7 +180,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -794,7 +794,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 17,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [
     {
@@ -821,7 +821,7 @@
        "[]"
       ]
      },
-     "execution_count": 17,
+     "execution_count": 22,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -896,7 +896,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -924,7 +924,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 21,
+   "execution_count": 32,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -953,12 +953,13 @@
     "        self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
     "        self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
     "        sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+    "        \n",
+    "        print (\" \")\n",
+    "        print (\"Hyperband brackets\")\n",
     "\n",
     "        #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
     "        for s in reversed(range(self.s_max+1)):\n",
-    "\n",
-    "            print (\" \")\n",
-    "            print (\"Hyperband brackets\")\n",
+    "            \n",
     "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n_i      r_i\")\n",
@@ -1040,40 +1041,44 @@
     "            # filter out early stops, if any\n",
     "            \n",
     "            # loop on brackets s desc to prune model selection table\n",
-    "            print (\"loop on s desc to prune mst table:\")\n",
-    "            for s in range(self.s_max, self.s_max-i-1, -1):\n",
-    "                \n",
-    "                k = int( self.n_vals[s][i] / self.eta)\n",
+    "            # don't need to prune if finished last diagonal\n",
+    "            if i < self.s_max:\n",
+    "                print (\"loop on s desc to prune mst table:\")\n",
+    "                for s in range(self.s_max, self.s_max-i-1, -1):\n",
     "                \n",
-    "                # temporarily re-run table names again due to weird scope issues\n",
-    "                results_table = 'results_mnist'\n",
+    "                    # compute number of configs to keep\n",
+    "                    # remember i value is different for each bracket s on the diagonal\n",
+    "                    k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
+    "                    print (\"pruning s = {} with k = {}\".format(s, k))\n",
     "\n",
-    "                output_table = 'mnist_multi_model'\n",
-    "                output_table_info = '_'.join([output_table, 'info'])\n",
-    "                output_table_summary = '_'.join([output_table, 'summary'])\n",
+    "                    # temporarily re-define table names due to weird Python scope issues\n",
+    "                    results_table = 'results_mnist'\n",
     "\n",
-    "                mst_table = 'mst_table_hb_mnist'\n",
-    "                mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+    "                    output_table = 'mnist_multi_model'\n",
+    "                    output_table_info = '_'.join([output_table, 'info'])\n",
+    "                    output_table_summary = '_'.join([output_table, 'summary'])\n",
     "\n",
-    "                mst_diag_table = 'mst_diag_table_hb_mnist'\n",
-    "                mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "                    mst_table = 'mst_table_hb_mnist'\n",
+    "                    mst_table_summary = '_'.join([mst_table, 'summary'])\n",
     "\n",
-    "                model_arch_table = 'model_arch_library_mnist'\n",
-    "                \n",
-    "                query = \"\"\"\n",
-    "                DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
-    "                \"\"\".format(**locals())\n",
-    "                cur.execute(query)\n",
-    "                conn.commit()\n",
-    "                #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "#                 %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "                #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+    "                    mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+    "                    mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "\n",
+    "                    model_arch_table = 'model_arch_library_mnist'\n",
+    "            \n",
+    "                    query = \"\"\"\n",
+    "                    DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+    "                    \"\"\".format(**locals())\n",
+    "                    cur.execute(query)\n",
+    "                    conn.commit()\n",
+    "                    #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "                    #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
     "        return"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1129,7 +1134,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 23,
+   "execution_count": 26,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1163,7 +1168,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 24,
+   "execution_count": 27,
    "metadata": {
     "scrolled": false
    },
@@ -1182,16 +1187,12 @@
       "3.0     3.0\n",
       "1.0     9.0\n",
       " \n",
-      "Hyperband brackets\n",
-      " \n",
       "s=1\n",
       "n_i      r_i\n",
       "------------\n",
       "3     3.0\n",
       "1.0     9.0\n",
       " \n",
-      "Hyperband brackets\n",
-      " \n",
       "s=0\n",
       "n_i      r_i\n",
       "------------\n",
@@ -1248,6 +1249,7 @@
       "9 rows affected.\n",
       "9 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 3\n",
       " \n",
       "i=1\n",
       "Done.\n",
@@ -1264,29 +1266,35 @@
       "6 rows affected.\n",
       "6 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 1\n",
+      "pruning s = 1 with k = 1\n",
       " \n",
       "i=2\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
       "1 rows affected.\n",
-      "0 rows affected.\n",
+      "1 rows affected.\n",
       "3 rows affected.\n",
       "try params for i = 2\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "4 rows affected.\n",
+      "5 rows affected.\n",
       "Done.\n",
-      "4 rows affected.\n",
-      "4 rows affected.\n",
-      "4 rows affected.\n",
-      "loop on s desc to prune mst table:\n"
+      "5 rows affected.\n",
+      "5 rows affected.\n",
+      "5 rows affected.\n",
+      "loop on s desc to prune mst table:\n",
+      "pruning s = 2 with k = 0\n",
+      "pruning s = 1 with k = 0\n",
+      "pruning s = 0 with k = 1\n"
      ]
     }
    ],
    "source": [
     "hp = Hyperband_diagonal(get_params, try_params )\n",
-    "results = hp.run()"
+    "results = hp.run()\n",
+    "#hp.n_vals[1]"
    ]
   },
   {
@@ -1299,7 +1307,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1323,14 +1331,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "15 rows affected.\n"
+      "12 rows affected.\n"
      ]
     },
     {
@@ -2116,7 +2124,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3xb5b3+H8myJK/sOE7iDDJIwt5lUyhlFFpKGR1QVktv6O2ggz+9rBtuWrjpbbmFrlAohQuUUaCUtrQFWvZoGGUnhGw7y3GWtyRL+n+eVz62LEv2kXRkH9nPy0fIsd7znvf9vkfWT8/5DU88Ho9DTQREQAREQAREQAREQAREQAREQAREQAREQAREoKAEPBLiCspXg4uACIiACIiACIiACIiACIiACIiACIiACIiAISAhTheCCBQ5gbPPPhsPP/ww/v3f/x0/+9nPHF3NIYccgtdffx3/8z//g+9+97uOjq3BnCPw7rvvYt999zUDbtu2DRMmTHBucI0kAiIgAiIgAiIwaARk1/VF3dLSgqqqKvPCq6++CtqnaiIgAiJQzAQkxBXz7mn [...]
+       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb5dn+b0m2vDOc2Fl2diBAWVnsXSBQRinQSQu00EJ3oS39GIE2FMr3tdD2T1tmW7poWWXPllF2BrthZMcZjp3YTjwlW9L/ul/52LItyUfS0bB9v9fly4n1nnf83mPr0X2e4QqFQiGoiYAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIpJWAS0JcWvlqcBEQAREQAREQAREQAREQAREQAREQAREQAREwBCTE6UYQgSFO4KyzzsL999+Pb3zjG7j55psd3c2CBQuwcuVK/N///R++//3vOzq2BnOOwHvvvYd9993XDFhfX4/x48c7N7hGEgEREAEREAERyBgB2XUDUbe0tKCsrMy8sHz5ctA+VRMBERCBoUxAQtxQPj2tPW0EXC5X0mP [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -2140,16 +2148,13 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 12;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
     "#set up plots\n",
@@ -2192,14 +2197,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "15 rows affected.\n"
+      "12 rows affected.\n"
      ]
     },
     {
@@ -2985,7 +2990,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhbV5n+X0m2Nttx7CRuFmdplibpknRnui8U2lKgLaQMAx06MCxh5j8z0MJ0oG0IEyiTKVNgYKDDNnRo2UppYaAb0L2lpE2XJG3SZmkWJ3GcxbusxZL+z3vtq8iyZF1ZV/aV9Z5Wj2zr3HPP+Z1zcz+99zvf50omk0moiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIlJSAS0JcSfmqcREQAREQAREQAREQAREQAREQAREQAREQAREwCEiI00IQgTEk8NBDD+Hyyy+Hz+dDOBwecuaRPrPSxWKPt3KOkerccccd+OQnP4nFixdjy5YtxTan40tEQPNUIrBqVgREQAREoOIIyK5z3pSPtz3sPCLqkQiIgBMJSIhz4qyoT7YT+NjHPob [...]
+       "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhU1fnG31myrwQStgQQVMAN2aKAgKBW3Dds3dFqFWtrq/6rbVWkxaV0cWmtxbVq1boUq3VHZUchgAiiICoCCRAC2deZJJP/857kJpPJLPfO3FmSfOd58iQk5557zu+cYb5577dYWlpaWiBNCAgBISAEhIAQEAJCQAgIASEgBISAEBACQkAICIGwErCIEBdWvjK4EBACQkAICAEhIASEgBAQAkJACAgBISAEhIAQUAREiJODIAQiSOD999/HGWecgYSEBDQ0NHS6s7+/6ZliqNfruYe/PosWLcKNN96IkSNHYvv27aEOJ9eHiYDsU5jAyrBCQAgIASHQ6wiIXRd7Wx5tezj2iMiMhIAQiEUCIsTF4q7InEwn8JOf/ARPPfUUsrKysG/ [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -3009,16 +3014,13 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 12;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
     "#set up plots\n",