You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by fm...@apache.org on 2019/11/22 01:20:51 UTC

[madlib-site] branch automl updated: hyperband diagonal E2E still in work...

This is an automated email from the ASF dual-hosted git repository.

fmcquillan pushed a commit to branch automl
in repository https://gitbox.apache.org/repos/asf/madlib-site.git


The following commit(s) were added to refs/heads/automl by this push:
     new c606abc  hyperband diagonal E2E still in work...
c606abc is described below

commit c606abcf87684808eaa68fc47b700ae247a7f20c
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Thu Nov 21 17:20:43 2019 -0800

    hyperband diagonal E2E still in work...
---
 .../hyperband_diag_v2_mnist-checkpoint.ipynb       | 924 ++++++++++-----------
 .../automl/hyperband_diag_v2_mnist.ipynb           | 924 ++++++++++-----------
 2 files changed, 866 insertions(+), 982 deletions(-)

diff --git a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
index 09598ea..091e6fd 100644
--- a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
+++ b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
@@ -23,7 +23,9 @@
     "\n",
     "<a href=\"#hyperband\">5. Hyperband diagonal</a>\n",
     "\n",
-    "<a href=\"#plot\">6. Plot results</a>"
+    "<a href=\"#plot\">6. Plot results</a>\n",
+    "\n",
+    "<a href=\"#print\">7. Print run schedules</a>"
    ]
   },
   {
@@ -792,7 +794,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
@@ -819,7 +821,7 @@
        "[]"
       ]
      },
-     "execution_count": 6,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -894,7 +896,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -917,344 +919,12 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Pretty print reg Hyperband run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "s=1\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
-      " \n",
-      "s=0\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "2     3\n",
-      " \n",
-      "sum of configurations at leaf nodes across all s = 3.0\n",
-      "(if have more workers than this, they may not be 100% busy)\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
-    "\n",
-    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    print (\" \")\n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n_i      r_i\")\n",
-    "    print (\"------------\")\n",
-    "    counter = 0\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "\n",
-    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
-    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
-    "    for i in range(s+1):\n",
-    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
-    "        n_i = n*eta**(-i)\n",
-    "        r_i = r*eta**(i)\n",
-    "        \n",
-    "        print (str(n_i) + \"     \" + str (r_i))\n",
-    "        \n",
-    "        # check if leaf node for this s\n",
-    "        if counter == s:\n",
-    "            sum_leaf_n_i += n_i\n",
-    "        counter += 1\n",
-    "        \n",
-    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
-    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "    #### End Finite Horizon Successive Halving with (n,r)\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
-    "print (\"(if have more workers than this, they may not be 100% busy)\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Pretty print Hyperband diagonal run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "echo input:\n",
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "s_max = 1\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "initial n, r values for each s:\n",
-      "s=1\n",
-      "n=3\n",
-      "r=1.0\n",
-      " \n",
-      "s=0\n",
-      "n=2\n",
-      "r=3\n",
-      " \n",
-      "outer loop on diagonal:\n",
-      " \n",
-      "i=0\n",
-      "inner loop on s desc:\n",
-      "s=1\n",
-      "n_i=3\n",
-      "r_i=1.0\n",
-      " \n",
-      "i=1\n",
-      "inner loop on s desc:\n",
-      "s=1\n",
-      "n_i=1.0\n",
-      "r_i=3.0\n",
-      "s=0\n",
-      "n_i=2\n",
-      "r_i=3\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"echo input:\")\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"s_max = \" + str(s_max))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"initial n, r values for each s:\")\n",
-    "initial_n_vals = {}\n",
-    "initial_r_vals = {}\n",
-    "# get hyper parameter configs for each s\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "    \n",
-    "    initial_n_vals[s] = n \n",
-    "    initial_r_vals[s] = r \n",
-    "    \n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n=\" + str(n))\n",
-    "    print (\"r=\" + str(r))\n",
-    "    print (\" \")\n",
-    "    \n",
-    "print (\"outer loop on diagonal:\")\n",
-    "# outer loop on diagonal\n",
-    "for i in range(s_max+1):\n",
-    "    print (\" \")\n",
-    "    print (\"i=\" + str(i))\n",
-    "    \n",
-    "    print (\"inner loop on s desc:\")\n",
-    "    # inner loop on s desc\n",
-    "    for s in range(s_max, s_max-i-1, -1):\n",
-    "        n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
-    "        r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
-    "        \n",
-    "        print (\"s=\" + str(s))\n",
-    "        print (\"n_i=\" + str(n_i))\n",
-    "        print (\"r_i=\" + str(r_i))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Compute and store run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 45,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "s=1\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
-      " \n",
-      "s=0\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "2     3\n",
-      " \n",
-      "sum of configurations at leaf nodes across all s = 3.0\n",
-      "(if have more workers than this, they may not be 100% busy)\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
-    "\n",
-    "n_vals = np.zeros((s_max+1, s_max+1), dtype=int)\n",
-    "r_vals = np.zeros((s_max+1, s_max+1), dtype=int)\n",
-    "\n",
-    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    print (\" \")\n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n_i      r_i\")\n",
-    "    print (\"------------\")\n",
-    "    counter = 0\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "\n",
-    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
-    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
-    "    for i in range(s+1):\n",
-    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
-    "        n_i = n*eta**(-i)\n",
-    "        r_i = r*eta**(i)\n",
-    "        \n",
-    "        n_vals[s][i] = n_i\n",
-    "        r_vals[s][i] = r_i\n",
-    "        \n",
-    "        print (str(n_i) + \"     \" + str (r_i))\n",
-    "        \n",
-    "        # check if leaf node for this s\n",
-    "        if counter == s:\n",
-    "            sum_leaf_n_i += n_i\n",
-    "        counter += 1\n",
-    "        \n",
-    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
-    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "    #### End Finite Horizon Successive Halving with (n,r)\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
-    "print (\"(if have more workers than this, they may not be 100% busy)\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 46,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[2, 0],\n",
-       "       [3, 1]])"
-      ]
-     },
-     "execution_count": 46,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "n_vals "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 47,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[3, 0],\n",
-       "       [1, 3]])"
-      ]
-     },
-     "execution_count": 47,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "r_vals"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "Hyperband diagonal"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1269,7 +939,7 @@
     "        self.get_params = get_params_function\n",
     "        self.try_params = try_params_function\n",
     "\n",
-    "        self.max_iter = 3  # maximum iterations per configuration\n",
+    "        self.max_iter = 9  # maximum iterations per configuration\n",
     "        self.eta = 3        # defines configuration downsampling rate (default = 3)\n",
     "\n",
     "        self.logeta = lambda x: log( x ) / log( self.eta )\n",
@@ -1288,6 +958,8 @@
     "        for s in reversed(range(self.s_max+1)):\n",
     "\n",
     "            print (\" \")\n",
+    "            print (\"Hyperband brackets\")\n",
+    "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n_i      r_i\")\n",
     "            print (\"------------\")\n",
@@ -1315,9 +987,12 @@
     "\n",
     "                #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
     "                #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "            #### End Finite Horizon Successive Halving with (n,r)\n",
+    "                #### End Finite Horizon Successive Halving with (n,r)\n",
     "\n",
-    "            \n",
+    "            #print (\" \")\n",
+    "            #print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
+    "            #print (\"(if have more workers than this, they may not be 100% busy)\")\n",
+    "        \n",
     "    # generate model selection tuples for all brackets\n",
     "    def create_mst_superset(self):\n",
     "        # get hyper parameter configs for each bracket s\n",
@@ -1325,6 +1000,10 @@
     "            n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
     "            r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
     "\n",
+    "            \n",
+    "            print (\" \")\n",
+    "            print (\"Create superset of MSTs, i.e., i=0 for for each bracket s\")\n",
+    "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n=\" + str(n))\n",
     "            print (\"r=\" + str(r))\n",
@@ -1364,15 +1043,37 @@
     "            print (\"loop on s desc to prune mst table:\")\n",
     "            for s in range(self.s_max, self.s_max-i-1, -1):\n",
     "                \n",
-    "                k = int( self.n_vals[s][i] / self.eta)                \n",
-    "                %sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "            \n",
+    "                k = int( self.n_vals[s][i] / self.eta)\n",
+    "                \n",
+    "                # temporarily re-run table names again due to weird scope issues\n",
+    "                results_table = 'results_mnist'\n",
+    "\n",
+    "                output_table = 'mnist_multi_model'\n",
+    "                output_table_info = '_'.join([output_table, 'info'])\n",
+    "                output_table_summary = '_'.join([output_table, 'summary'])\n",
+    "\n",
+    "                mst_table = 'mst_table_hb_mnist'\n",
+    "                mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+    "\n",
+    "                mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+    "                mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "\n",
+    "                model_arch_table = 'model_arch_library_mnist'\n",
+    "                \n",
+    "                query = \"\"\"\n",
+    "                DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+    "                \"\"\".format(**locals())\n",
+    "                cur.execute(query)\n",
+    "                conn.commit()\n",
+    "                #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "#                 %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "                #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
     "        return"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1428,7 +1129,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1462,159 +1163,130 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
-   "metadata": {},
+   "execution_count": 24,
+   "metadata": {
+    "scrolled": false
+   },
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       " \n",
+      "Hyperband brackets\n",
+      " \n",
+      "s=2\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "9     1.0\n",
+      "3.0     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "Hyperband brackets\n",
+      " \n",
       "s=1\n",
       "n_i      r_i\n",
       "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
+      "3     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "Hyperband brackets\n",
       " \n",
       "s=0\n",
       "n_i      r_i\n",
       "------------\n",
-      "2     3\n",
+      "3     9\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
+      "s=2\n",
+      "n=9\n",
+      "r=1.0\n",
+      " \n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
       "s=1\n",
       "n=3\n",
-      "r=1.0\n",
+      "r=3.0\n",
       " \n",
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
       "s=0\n",
-      "n=2\n",
-      "r=3\n",
+      "n=3\n",
+      "r=9\n",
       " \n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
       "outer loop on diagonal:\n",
       " \n",
       "i=0\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
       "try params for i = 0\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
       "Done.\n",
-      "3 rows affected.\n",
-      "3 rows affected.\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
+      "9 rows affected.\n",
+      "9 rows affected.\n",
       "loop on s desc to prune mst table:\n",
-      "4 rows affected.\n",
       " \n",
       "i=1\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
-      "1 rows affected.\n",
-      "0 rows affected.\n",
+      "3 rows affected.\n",
+      "3 rows affected.\n",
       "try params for i = 1\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "1 rows affected.\n",
+      "6 rows affected.\n",
       "Done.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
+      "6 rows affected.\n",
+      "6 rows affected.\n",
+      "6 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      " \n",
+      "i=2\n",
+      "Done.\n",
+      "loop on s desc to create diagonal table:\n",
       "1 rows affected.\n",
-      "0 rows affected.\n"
-     ]
-    }
-   ],
-   "source": [
-    "hp = Hyperband_diagonal(get_params, try_params )\n",
-    "results = hp.run()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[3, 0],\n",
-       "       [1, 3]])"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "self.r_vals"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "0"
-      ]
-     },
-     "execution_count": 16,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "i"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1"
-      ]
-     },
-     "execution_count": 17,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "self.s_max"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1"
-      ]
-     },
-     "execution_count": 18,
-     "metadata": {},
-     "output_type": "execute_result"
+      "0 rows affected.\n",
+      "3 rows affected.\n",
+      "try params for i = 2\n",
+      "Done.\n",
+      "1 rows affected.\n",
+      "Done.\n",
+      "4 rows affected.\n",
+      "Done.\n",
+      "4 rows affected.\n",
+      "4 rows affected.\n",
+      "4 rows affected.\n",
+      "loop on s desc to prune mst table:\n"
+     ]
     }
    ],
    "source": [
-    "self.r_vals[self.s_max][i]"
+    "hp = Hyperband_diagonal(get_params, try_params )\n",
+    "results = hp.run()"
    ]
   },
   {
@@ -1627,12 +1299,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
     "%matplotlib notebook\n",
     "import matplotlib.pyplot as plt\n",
+    "from matplotlib.ticker import MaxNLocator\n",
     "from collections import defaultdict\n",
     "import pandas as pd\n",
     "import seaborn as sns\n",
@@ -1642,15 +1315,22 @@
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Training dataset"
+   ]
+  },
+  {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 30,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "4 rows affected.\n"
+      "15 rows affected.\n"
      ]
     },
     {
@@ -2436,7 +2116,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\" [...]
+       "<img src=\" [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -2452,16 +2132,44 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 7;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
+    "#set up plots\n",
     "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
+    "fig.legend(ncol=4)\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "ax_metric = axs[0]\n",
+    "ax_loss = axs[1]\n",
+    "\n",
+    "ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_metric.set_xlabel('Iteration')\n",
+    "ax_metric.set_ylabel('Metric')\n",
+    "ax_metric.set_title('Training metric curve')\n",
+    "\n",
+    "ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_loss.set_xlabel('Iteration')\n",
+    "ax_loss.set_ylabel('Loss')\n",
+    "ax_loss.set_title('Training loss curve')\n",
+    "\n",
     "for run_id in df_results['run_id']:\n",
     "    df_output_info = %sql SELECT training_metrics,training_loss FROM $results_table WHERE run_id = $run_id\n",
     "    df_output_info = df_output_info.DataFrame()\n",
@@ -2469,35 +2177,29 @@
     "    training_loss = df_output_info['training_loss'][0]\n",
     "    X = range(len(training_metrics))\n",
     "    \n",
-    "    ax_metric = axs[0]\n",
-    "    ax_loss = axs[1]\n",
-    "    ax_metric.set_xticks(X[::1])\n",
-    "    ax_metric.plot(X, training_metrics, label=run_id)\n",
-    "    ax_metric.set_xlabel('Iteration')\n",
-    "    ax_metric.set_ylabel('Metric')\n",
-    "    ax_metric.set_title('Training metric curve')\n",
-    "\n",
-    "    ax_loss.set_xticks(X[::1])\n",
-    "    ax_loss.plot(X, training_loss, label=run_id)\n",
-    "    ax_loss.set_xlabel('Iteration')\n",
-    "    ax_loss.set_ylabel('Loss')\n",
-    "    ax_loss.set_title('Training loss curve')\n",
-    "    \n",
-    "fig.legend(ncol=4)\n",
-    "fig.tight_layout()\n",
+    "    ax_metric.plot(X, training_metrics, label=run_id, marker='o')\n",
+    "    ax_loss.plot(X, training_loss, label=run_id, marker='o')\n",
+    "\n",
     "# fig.savefig('./lc_keras_fit.png', dpi = 300)"
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Validation dataset"
+   ]
+  },
+  {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 31,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "4 rows affected.\n"
+      "15 rows affected.\n"
      ]
     },
     {
@@ -3283,7 +2985,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\" [...]
+       "<img src=\" [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -3299,16 +3001,44 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 5;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
+    "#set up plots\n",
     "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
+    "fig.legend(ncol=4)\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "ax_metric = axs[0]\n",
+    "ax_loss = axs[1]\n",
+    "\n",
+    "ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_metric.set_xlabel('Iteration')\n",
+    "ax_metric.set_ylabel('Metric')\n",
+    "ax_metric.set_title('Validation metric curve')\n",
+    "\n",
+    "ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_loss.set_xlabel('Iteration')\n",
+    "ax_loss.set_ylabel('Loss')\n",
+    "ax_loss.set_title('Validation loss curve')\n",
+    "\n",
     "for run_id in df_results['run_id']:\n",
     "    df_output_info = %sql SELECT validation_metrics,validation_loss FROM $results_table WHERE run_id = $run_id\n",
     "    df_output_info = df_output_info.DataFrame()\n",
@@ -3316,24 +3046,236 @@
     "    validation_loss = df_output_info['validation_loss'][0]\n",
     "    X = range(len(validation_metrics))\n",
     "    \n",
-    "    ax_metric = axs[0]\n",
-    "    ax_loss = axs[1]\n",
-    "    ax_metric.set_xticks(X[::1])\n",
-    "    ax_metric.plot(X, validation_metrics, label=run_id)\n",
-    "    ax_metric.set_xlabel('Iteration')\n",
-    "    ax_metric.set_ylabel('Metric')\n",
-    "    ax_metric.set_title('Validation metric curve')\n",
-    "\n",
-    "    ax_loss.set_xticks(X[::1])\n",
-    "    ax_loss.plot(X, validation_loss, label=run_id)\n",
-    "    ax_loss.set_xlabel('Iteration')\n",
-    "    ax_loss.set_ylabel('Loss')\n",
-    "    ax_loss.set_title('Validation loss curve')\n",
-    "    \n",
-    "fig.legend(ncol=4)\n",
-    "fig.tight_layout()\n",
+    "    ax_metric.plot(X, validation_metrics, label=run_id, marker='o')\n",
+    "    ax_loss.plot(X, validation_loss, label=run_id, marker='o')\n",
+    "\n",
     "# fig.savefig('./lc_keras_fit.png', dpi = 300)"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<a id=\"print\"></a>\n",
+    "# 7. Print run schedules"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretty print reg Hyperband run schedule"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "max_iter = 9\n",
+      "eta = 3\n",
+      "B = 3*max_iter = 27\n",
+      " \n",
+      "s=2\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "9     1.0\n",
+      "3.0     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "s=1\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "3     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "s=0\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "3     9\n",
+      " \n",
+      "sum of configurations at leaf nodes across all s = 5.0\n",
+      "(if have more workers than this, they may not be 100% busy)\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "from math import log, ceil\n",
+    "\n",
+    "#input\n",
+    "max_iter = 9  # maximum iterations/epochs per configuration\n",
+    "eta = 3  # defines downsampling rate (default=3)\n",
+    "\n",
+    "logeta = lambda x: log(x)/log(eta)\n",
+    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
+    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
+    "\n",
+    "#echo output\n",
+    "print (\"max_iter = \" + str(max_iter))\n",
+    "print (\"eta = \" + str(eta))\n",
+    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
+    "\n",
+    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+    "\n",
+    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
+    "for s in reversed(range(s_max+1)):\n",
+    "    \n",
+    "    print (\" \")\n",
+    "    print (\"s=\" + str(s))\n",
+    "    print (\"n_i      r_i\")\n",
+    "    print (\"------------\")\n",
+    "    counter = 0\n",
+    "    \n",
+    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
+    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
+    "\n",
+    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
+    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
+    "    for i in range(s+1):\n",
+    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
+    "        n_i = n*eta**(-i)\n",
+    "        r_i = r*eta**(i)\n",
+    "        \n",
+    "        print (str(n_i) + \"     \" + str (r_i))\n",
+    "        \n",
+    "        # check if leaf node for this s\n",
+    "        if counter == s:\n",
+    "            sum_leaf_n_i += n_i\n",
+    "        counter += 1\n",
+    "        \n",
+    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
+    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
+    "    #### End Finite Horizon Successive Halving with (n,r)\n",
+    "\n",
+    "print (\" \")\n",
+    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
+    "print (\"(if have more workers than this, they may not be 100% busy)\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretty print Hyperband diagonal run schedule"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "echo input:\n",
+      "max_iter = 9\n",
+      "eta = 3\n",
+      "s_max = 2\n",
+      "B = 3*max_iter = 27\n",
+      " \n",
+      "initial n, r values for each s:\n",
+      "s=2\n",
+      "n=9\n",
+      "r=1.0\n",
+      " \n",
+      "s=1\n",
+      "n=3\n",
+      "r=3.0\n",
+      " \n",
+      "s=0\n",
+      "n=3\n",
+      "r=9\n",
+      " \n",
+      "outer loop on diagonal:\n",
+      " \n",
+      "i=0\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=9\n",
+      "r_i=1.0\n",
+      " \n",
+      "i=1\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=3.0\n",
+      "r_i=3.0\n",
+      "s=1\n",
+      "n_i=3\n",
+      "r_i=3.0\n",
+      " \n",
+      "i=2\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=1.0\n",
+      "r_i=9.0\n",
+      "s=1\n",
+      "n_i=1.0\n",
+      "r_i=9.0\n",
+      "s=0\n",
+      "n_i=3\n",
+      "r_i=9\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "from math import log, ceil\n",
+    "\n",
+    "#input\n",
+    "max_iter = 9  # maximum iterations/epochs per configuration\n",
+    "eta = 3  # defines downsampling rate (default=3)\n",
+    "\n",
+    "logeta = lambda x: log(x)/log(eta)\n",
+    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
+    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
+    "\n",
+    "#echo output\n",
+    "print (\"echo input:\")\n",
+    "print (\"max_iter = \" + str(max_iter))\n",
+    "print (\"eta = \" + str(eta))\n",
+    "print (\"s_max = \" + str(s_max))\n",
+    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
+    "\n",
+    "print (\" \")\n",
+    "print (\"initial n, r values for each s:\")\n",
+    "initial_n_vals = {}\n",
+    "initial_r_vals = {}\n",
+    "# get hyper parameter configs for each s\n",
+    "for s in reversed(range(s_max+1)):\n",
+    "    \n",
+    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
+    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
+    "    \n",
+    "    initial_n_vals[s] = n \n",
+    "    initial_r_vals[s] = r \n",
+    "    \n",
+    "    print (\"s=\" + str(s))\n",
+    "    print (\"n=\" + str(n))\n",
+    "    print (\"r=\" + str(r))\n",
+    "    print (\" \")\n",
+    "    \n",
+    "print (\"outer loop on diagonal:\")\n",
+    "# outer loop on diagonal\n",
+    "for i in range(s_max+1):\n",
+    "    print (\" \")\n",
+    "    print (\"i=\" + str(i))\n",
+    "    \n",
+    "    print (\"inner loop on s desc:\")\n",
+    "    # inner loop on s desc\n",
+    "    for s in range(s_max, s_max-i-1, -1):\n",
+    "        n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
+    "        r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
+    "        \n",
+    "        print (\"s=\" + str(s))\n",
+    "        print (\"n_i=\" + str(n_i))\n",
+    "        print (\"r_i=\" + str(r_i))"
+   ]
   }
  ],
  "metadata": {
diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
index 09598ea..091e6fd 100644
--- a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
+++ b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
@@ -23,7 +23,9 @@
     "\n",
     "<a href=\"#hyperband\">5. Hyperband diagonal</a>\n",
     "\n",
-    "<a href=\"#plot\">6. Plot results</a>"
+    "<a href=\"#plot\">6. Plot results</a>\n",
+    "\n",
+    "<a href=\"#print\">7. Print run schedules</a>"
    ]
   },
   {
@@ -792,7 +794,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
@@ -819,7 +821,7 @@
        "[]"
       ]
      },
-     "execution_count": 6,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -894,7 +896,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 18,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -917,344 +919,12 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "Pretty print reg Hyperband run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "s=1\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
-      " \n",
-      "s=0\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "2     3\n",
-      " \n",
-      "sum of configurations at leaf nodes across all s = 3.0\n",
-      "(if have more workers than this, they may not be 100% busy)\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
-    "\n",
-    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    print (\" \")\n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n_i      r_i\")\n",
-    "    print (\"------------\")\n",
-    "    counter = 0\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "\n",
-    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
-    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
-    "    for i in range(s+1):\n",
-    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
-    "        n_i = n*eta**(-i)\n",
-    "        r_i = r*eta**(i)\n",
-    "        \n",
-    "        print (str(n_i) + \"     \" + str (r_i))\n",
-    "        \n",
-    "        # check if leaf node for this s\n",
-    "        if counter == s:\n",
-    "            sum_leaf_n_i += n_i\n",
-    "        counter += 1\n",
-    "        \n",
-    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
-    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "    #### End Finite Horizon Successive Halving with (n,r)\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
-    "print (\"(if have more workers than this, they may not be 100% busy)\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Pretty print Hyperband diagonal run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "echo input:\n",
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "s_max = 1\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "initial n, r values for each s:\n",
-      "s=1\n",
-      "n=3\n",
-      "r=1.0\n",
-      " \n",
-      "s=0\n",
-      "n=2\n",
-      "r=3\n",
-      " \n",
-      "outer loop on diagonal:\n",
-      " \n",
-      "i=0\n",
-      "inner loop on s desc:\n",
-      "s=1\n",
-      "n_i=3\n",
-      "r_i=1.0\n",
-      " \n",
-      "i=1\n",
-      "inner loop on s desc:\n",
-      "s=1\n",
-      "n_i=1.0\n",
-      "r_i=3.0\n",
-      "s=0\n",
-      "n_i=2\n",
-      "r_i=3\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"echo input:\")\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"s_max = \" + str(s_max))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"initial n, r values for each s:\")\n",
-    "initial_n_vals = {}\n",
-    "initial_r_vals = {}\n",
-    "# get hyper parameter configs for each s\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "    \n",
-    "    initial_n_vals[s] = n \n",
-    "    initial_r_vals[s] = r \n",
-    "    \n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n=\" + str(n))\n",
-    "    print (\"r=\" + str(r))\n",
-    "    print (\" \")\n",
-    "    \n",
-    "print (\"outer loop on diagonal:\")\n",
-    "# outer loop on diagonal\n",
-    "for i in range(s_max+1):\n",
-    "    print (\" \")\n",
-    "    print (\"i=\" + str(i))\n",
-    "    \n",
-    "    print (\"inner loop on s desc:\")\n",
-    "    # inner loop on s desc\n",
-    "    for s in range(s_max, s_max-i-1, -1):\n",
-    "        n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
-    "        r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
-    "        \n",
-    "        print (\"s=\" + str(s))\n",
-    "        print (\"n_i=\" + str(n_i))\n",
-    "        print (\"r_i=\" + str(r_i))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "Compute and store run schedule"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 45,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "max_iter = 3\n",
-      "eta = 3\n",
-      "B = 2*max_iter = 6\n",
-      " \n",
-      "s=1\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
-      " \n",
-      "s=0\n",
-      "n_i      r_i\n",
-      "------------\n",
-      "2     3\n",
-      " \n",
-      "sum of configurations at leaf nodes across all s = 3.0\n",
-      "(if have more workers than this, they may not be 100% busy)\n"
-     ]
-    }
-   ],
-   "source": [
-    "import numpy as np\n",
-    "from math import log, ceil\n",
-    "\n",
-    "#input\n",
-    "max_iter = 3  # maximum iterations/epochs per configuration\n",
-    "eta = 3  # defines downsampling rate (default=3)\n",
-    "\n",
-    "logeta = lambda x: log(x)/log(eta)\n",
-    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
-    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
-    "\n",
-    "#echo output\n",
-    "print (\"max_iter = \" + str(max_iter))\n",
-    "print (\"eta = \" + str(eta))\n",
-    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
-    "\n",
-    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
-    "\n",
-    "n_vals = np.zeros((s_max+1, s_max+1), dtype=int)\n",
-    "r_vals = np.zeros((s_max+1, s_max+1), dtype=int)\n",
-    "\n",
-    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
-    "for s in reversed(range(s_max+1)):\n",
-    "    \n",
-    "    print (\" \")\n",
-    "    print (\"s=\" + str(s))\n",
-    "    print (\"n_i      r_i\")\n",
-    "    print (\"------------\")\n",
-    "    counter = 0\n",
-    "    \n",
-    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
-    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
-    "\n",
-    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
-    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
-    "    for i in range(s+1):\n",
-    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
-    "        n_i = n*eta**(-i)\n",
-    "        r_i = r*eta**(i)\n",
-    "        \n",
-    "        n_vals[s][i] = n_i\n",
-    "        r_vals[s][i] = r_i\n",
-    "        \n",
-    "        print (str(n_i) + \"     \" + str (r_i))\n",
-    "        \n",
-    "        # check if leaf node for this s\n",
-    "        if counter == s:\n",
-    "            sum_leaf_n_i += n_i\n",
-    "        counter += 1\n",
-    "        \n",
-    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
-    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "    #### End Finite Horizon Successive Halving with (n,r)\n",
-    "\n",
-    "print (\" \")\n",
-    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
-    "print (\"(if have more workers than this, they may not be 100% busy)\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 46,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[2, 0],\n",
-       "       [3, 1]])"
-      ]
-     },
-     "execution_count": 46,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "n_vals "
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 47,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[3, 0],\n",
-       "       [1, 3]])"
-      ]
-     },
-     "execution_count": 47,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "r_vals"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
     "Hyperband diagonal"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1269,7 +939,7 @@
     "        self.get_params = get_params_function\n",
     "        self.try_params = try_params_function\n",
     "\n",
-    "        self.max_iter = 3  # maximum iterations per configuration\n",
+    "        self.max_iter = 9  # maximum iterations per configuration\n",
     "        self.eta = 3        # defines configuration downsampling rate (default = 3)\n",
     "\n",
     "        self.logeta = lambda x: log( x ) / log( self.eta )\n",
@@ -1288,6 +958,8 @@
     "        for s in reversed(range(self.s_max+1)):\n",
     "\n",
     "            print (\" \")\n",
+    "            print (\"Hyperband brackets\")\n",
+    "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n_i      r_i\")\n",
     "            print (\"------------\")\n",
@@ -1315,9 +987,12 @@
     "\n",
     "                #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
     "                #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
-    "            #### End Finite Horizon Successive Halving with (n,r)\n",
+    "                #### End Finite Horizon Successive Halving with (n,r)\n",
     "\n",
-    "            \n",
+    "            #print (\" \")\n",
+    "            #print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
+    "            #print (\"(if have more workers than this, they may not be 100% busy)\")\n",
+    "        \n",
     "    # generate model selection tuples for all brackets\n",
     "    def create_mst_superset(self):\n",
     "        # get hyper parameter configs for each bracket s\n",
@@ -1325,6 +1000,10 @@
     "            n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
     "            r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
     "\n",
+    "            \n",
+    "            print (\" \")\n",
+    "            print (\"Create superset of MSTs, i.e., i=0 for for each bracket s\")\n",
+    "            print (\" \")\n",
     "            print (\"s=\" + str(s))\n",
     "            print (\"n=\" + str(n))\n",
     "            print (\"r=\" + str(r))\n",
@@ -1364,15 +1043,37 @@
     "            print (\"loop on s desc to prune mst table:\")\n",
     "            for s in range(self.s_max, self.s_max-i-1, -1):\n",
     "                \n",
-    "                k = int( self.n_vals[s][i] / self.eta)                \n",
-    "                %sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
-    "            \n",
+    "                k = int( self.n_vals[s][i] / self.eta)\n",
+    "                \n",
+    "                # temporarily re-run table names again due to weird scope issues\n",
+    "                results_table = 'results_mnist'\n",
+    "\n",
+    "                output_table = 'mnist_multi_model'\n",
+    "                output_table_info = '_'.join([output_table, 'info'])\n",
+    "                output_table_summary = '_'.join([output_table, 'summary'])\n",
+    "\n",
+    "                mst_table = 'mst_table_hb_mnist'\n",
+    "                mst_table_summary = '_'.join([mst_table, 'summary'])\n",
+    "\n",
+    "                mst_diag_table = 'mst_diag_table_hb_mnist'\n",
+    "                mst_diag_table_summary = '_'.join([mst_diag_table, 'summary'])\n",
+    "\n",
+    "                model_arch_table = 'model_arch_library_mnist'\n",
+    "                \n",
+    "                query = \"\"\"\n",
+    "                DELETE FROM {mst_table} WHERE s={s} AND mst_key NOT IN (SELECT {output_table_info}.mst_key FROM {output_table_info} JOIN {mst_table} ON {output_table_info}.mst_key={mst_table}.mst_key WHERE s={s} ORDER BY validation_loss_final ASC LIMIT {k}::INT);\n",
+    "                \"\"\".format(**locals())\n",
+    "                cur.execute(query)\n",
+    "                conn.commit()\n",
+    "                #%sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "#                 %sql DELETE FROM $mst_table WHERE s={0} AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
+    "                #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
     "        return"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1428,7 +1129,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 23,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1462,159 +1163,130 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
-   "metadata": {},
+   "execution_count": 24,
+   "metadata": {
+    "scrolled": false
+   },
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
       " \n",
+      "Hyperband brackets\n",
+      " \n",
+      "s=2\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "9     1.0\n",
+      "3.0     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "Hyperband brackets\n",
+      " \n",
       "s=1\n",
       "n_i      r_i\n",
       "------------\n",
-      "3     1.0\n",
-      "1.0     3.0\n",
+      "3     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "Hyperband brackets\n",
       " \n",
       "s=0\n",
       "n_i      r_i\n",
       "------------\n",
-      "2     3\n",
+      "3     9\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
+      "s=2\n",
+      "n=9\n",
+      "r=1.0\n",
+      " \n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
       "s=1\n",
       "n=3\n",
-      "r=1.0\n",
+      "r=3.0\n",
       " \n",
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      " \n",
+      "Create superset of MSTs, i.e., i=0 for for each bracket s\n",
+      " \n",
       "s=0\n",
-      "n=2\n",
-      "r=3\n",
+      "n=3\n",
+      "r=9\n",
       " \n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
       "outer loop on diagonal:\n",
       " \n",
       "i=0\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
       "try params for i = 0\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
       "Done.\n",
-      "3 rows affected.\n",
-      "3 rows affected.\n",
-      "3 rows affected.\n",
+      "9 rows affected.\n",
+      "9 rows affected.\n",
+      "9 rows affected.\n",
       "loop on s desc to prune mst table:\n",
-      "4 rows affected.\n",
       " \n",
       "i=1\n",
       "Done.\n",
       "loop on s desc to create diagonal table:\n",
-      "1 rows affected.\n",
-      "0 rows affected.\n",
+      "3 rows affected.\n",
+      "3 rows affected.\n",
       "try params for i = 1\n",
       "Done.\n",
       "1 rows affected.\n",
       "Done.\n",
-      "1 rows affected.\n",
+      "6 rows affected.\n",
       "Done.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
-      "1 rows affected.\n",
+      "6 rows affected.\n",
+      "6 rows affected.\n",
+      "6 rows affected.\n",
       "loop on s desc to prune mst table:\n",
+      " \n",
+      "i=2\n",
+      "Done.\n",
+      "loop on s desc to create diagonal table:\n",
       "1 rows affected.\n",
-      "0 rows affected.\n"
-     ]
-    }
-   ],
-   "source": [
-    "hp = Hyperband_diagonal(get_params, try_params )\n",
-    "results = hp.run()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "array([[3, 0],\n",
-       "       [1, 3]])"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "self.r_vals"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "0"
-      ]
-     },
-     "execution_count": 16,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "i"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1"
-      ]
-     },
-     "execution_count": 17,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "self.s_max"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 18,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1"
-      ]
-     },
-     "execution_count": 18,
-     "metadata": {},
-     "output_type": "execute_result"
+      "0 rows affected.\n",
+      "3 rows affected.\n",
+      "try params for i = 2\n",
+      "Done.\n",
+      "1 rows affected.\n",
+      "Done.\n",
+      "4 rows affected.\n",
+      "Done.\n",
+      "4 rows affected.\n",
+      "4 rows affected.\n",
+      "4 rows affected.\n",
+      "loop on s desc to prune mst table:\n"
+     ]
     }
    ],
    "source": [
-    "self.r_vals[self.s_max][i]"
+    "hp = Hyperband_diagonal(get_params, try_params )\n",
+    "results = hp.run()"
    ]
   },
   {
@@ -1627,12 +1299,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
     "%matplotlib notebook\n",
     "import matplotlib.pyplot as plt\n",
+    "from matplotlib.ticker import MaxNLocator\n",
     "from collections import defaultdict\n",
     "import pandas as pd\n",
     "import seaborn as sns\n",
@@ -1642,15 +1315,22 @@
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Training dataset"
+   ]
+  },
+  {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 30,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "4 rows affected.\n"
+      "15 rows affected.\n"
      ]
     },
     {
@@ -2436,7 +2116,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\" [...]
+       "<img src=\" [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -2452,16 +2132,44 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 7;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 15;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
+    "#set up plots\n",
     "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
+    "fig.legend(ncol=4)\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "ax_metric = axs[0]\n",
+    "ax_loss = axs[1]\n",
+    "\n",
+    "ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_metric.set_xlabel('Iteration')\n",
+    "ax_metric.set_ylabel('Metric')\n",
+    "ax_metric.set_title('Training metric curve')\n",
+    "\n",
+    "ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_loss.set_xlabel('Iteration')\n",
+    "ax_loss.set_ylabel('Loss')\n",
+    "ax_loss.set_title('Training loss curve')\n",
+    "\n",
     "for run_id in df_results['run_id']:\n",
     "    df_output_info = %sql SELECT training_metrics,training_loss FROM $results_table WHERE run_id = $run_id\n",
     "    df_output_info = df_output_info.DataFrame()\n",
@@ -2469,35 +2177,29 @@
     "    training_loss = df_output_info['training_loss'][0]\n",
     "    X = range(len(training_metrics))\n",
     "    \n",
-    "    ax_metric = axs[0]\n",
-    "    ax_loss = axs[1]\n",
-    "    ax_metric.set_xticks(X[::1])\n",
-    "    ax_metric.plot(X, training_metrics, label=run_id)\n",
-    "    ax_metric.set_xlabel('Iteration')\n",
-    "    ax_metric.set_ylabel('Metric')\n",
-    "    ax_metric.set_title('Training metric curve')\n",
-    "\n",
-    "    ax_loss.set_xticks(X[::1])\n",
-    "    ax_loss.plot(X, training_loss, label=run_id)\n",
-    "    ax_loss.set_xlabel('Iteration')\n",
-    "    ax_loss.set_ylabel('Loss')\n",
-    "    ax_loss.set_title('Training loss curve')\n",
-    "    \n",
-    "fig.legend(ncol=4)\n",
-    "fig.tight_layout()\n",
+    "    ax_metric.plot(X, training_metrics, label=run_id, marker='o')\n",
+    "    ax_loss.plot(X, training_loss, label=run_id, marker='o')\n",
+    "\n",
     "# fig.savefig('./lc_keras_fit.png', dpi = 300)"
    ]
   },
   {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Validation dataset"
+   ]
+  },
+  {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 31,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "4 rows affected.\n"
+      "15 rows affected.\n"
      ]
     },
     {
@@ -3283,7 +2985,7 @@
     {
      "data": {
       "text/html": [
-       "<img src=\" [...]
+       "<img src=\" [...]
       ],
       "text/plain": [
        "<IPython.core.display.HTML object>"
@@ -3299,16 +3001,44 @@
       "1 rows affected.\n",
       "1 rows affected.\n",
       "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
+      "1 rows affected.\n",
       "1 rows affected.\n"
      ]
     }
    ],
    "source": [
     "#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
-    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 5;\n",
+    "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 15;\n",
     "df_results = df_results.DataFrame()\n",
     "\n",
+    "#set up plots\n",
     "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
+    "fig.legend(ncol=4)\n",
+    "fig.tight_layout()\n",
+    "\n",
+    "ax_metric = axs[0]\n",
+    "ax_loss = axs[1]\n",
+    "\n",
+    "ax_metric.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_metric.set_xlabel('Iteration')\n",
+    "ax_metric.set_ylabel('Metric')\n",
+    "ax_metric.set_title('Validation metric curve')\n",
+    "\n",
+    "ax_loss.xaxis.set_major_locator(MaxNLocator(integer=True))\n",
+    "ax_loss.set_xlabel('Iteration')\n",
+    "ax_loss.set_ylabel('Loss')\n",
+    "ax_loss.set_title('Validation loss curve')\n",
+    "\n",
     "for run_id in df_results['run_id']:\n",
     "    df_output_info = %sql SELECT validation_metrics,validation_loss FROM $results_table WHERE run_id = $run_id\n",
     "    df_output_info = df_output_info.DataFrame()\n",
@@ -3316,24 +3046,236 @@
     "    validation_loss = df_output_info['validation_loss'][0]\n",
     "    X = range(len(validation_metrics))\n",
     "    \n",
-    "    ax_metric = axs[0]\n",
-    "    ax_loss = axs[1]\n",
-    "    ax_metric.set_xticks(X[::1])\n",
-    "    ax_metric.plot(X, validation_metrics, label=run_id)\n",
-    "    ax_metric.set_xlabel('Iteration')\n",
-    "    ax_metric.set_ylabel('Metric')\n",
-    "    ax_metric.set_title('Validation metric curve')\n",
-    "\n",
-    "    ax_loss.set_xticks(X[::1])\n",
-    "    ax_loss.plot(X, validation_loss, label=run_id)\n",
-    "    ax_loss.set_xlabel('Iteration')\n",
-    "    ax_loss.set_ylabel('Loss')\n",
-    "    ax_loss.set_title('Validation loss curve')\n",
-    "    \n",
-    "fig.legend(ncol=4)\n",
-    "fig.tight_layout()\n",
+    "    ax_metric.plot(X, validation_metrics, label=run_id, marker='o')\n",
+    "    ax_loss.plot(X, validation_loss, label=run_id, marker='o')\n",
+    "\n",
     "# fig.savefig('./lc_keras_fit.png', dpi = 300)"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "<a id=\"print\"></a>\n",
+    "# 7. Print run schedules"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretty print reg Hyperband run schedule"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "max_iter = 9\n",
+      "eta = 3\n",
+      "B = 3*max_iter = 27\n",
+      " \n",
+      "s=2\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "9     1.0\n",
+      "3.0     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "s=1\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "3     3.0\n",
+      "1.0     9.0\n",
+      " \n",
+      "s=0\n",
+      "n_i      r_i\n",
+      "------------\n",
+      "3     9\n",
+      " \n",
+      "sum of configurations at leaf nodes across all s = 5.0\n",
+      "(if have more workers than this, they may not be 100% busy)\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "from math import log, ceil\n",
+    "\n",
+    "#input\n",
+    "max_iter = 9  # maximum iterations/epochs per configuration\n",
+    "eta = 3  # defines downsampling rate (default=3)\n",
+    "\n",
+    "logeta = lambda x: log(x)/log(eta)\n",
+    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
+    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
+    "\n",
+    "#echo output\n",
+    "print (\"max_iter = \" + str(max_iter))\n",
+    "print (\"eta = \" + str(eta))\n",
+    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
+    "\n",
+    "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
+    "\n",
+    "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
+    "for s in reversed(range(s_max+1)):\n",
+    "    \n",
+    "    print (\" \")\n",
+    "    print (\"s=\" + str(s))\n",
+    "    print (\"n_i      r_i\")\n",
+    "    print (\"------------\")\n",
+    "    counter = 0\n",
+    "    \n",
+    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
+    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
+    "\n",
+    "    #### Begin Finite Horizon Successive Halving with (n,r)\n",
+    "    #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
+    "    for i in range(s+1):\n",
+    "        # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
+    "        n_i = n*eta**(-i)\n",
+    "        r_i = r*eta**(i)\n",
+    "        \n",
+    "        print (str(n_i) + \"     \" + str (r_i))\n",
+    "        \n",
+    "        # check if leaf node for this s\n",
+    "        if counter == s:\n",
+    "            sum_leaf_n_i += n_i\n",
+    "        counter += 1\n",
+    "        \n",
+    "        #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
+    "        #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
+    "    #### End Finite Horizon Successive Halving with (n,r)\n",
+    "\n",
+    "print (\" \")\n",
+    "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
+    "print (\"(if have more workers than this, they may not be 100% busy)\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Pretty print Hyperband diagonal run schedule"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "echo input:\n",
+      "max_iter = 9\n",
+      "eta = 3\n",
+      "s_max = 2\n",
+      "B = 3*max_iter = 27\n",
+      " \n",
+      "initial n, r values for each s:\n",
+      "s=2\n",
+      "n=9\n",
+      "r=1.0\n",
+      " \n",
+      "s=1\n",
+      "n=3\n",
+      "r=3.0\n",
+      " \n",
+      "s=0\n",
+      "n=3\n",
+      "r=9\n",
+      " \n",
+      "outer loop on diagonal:\n",
+      " \n",
+      "i=0\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=9\n",
+      "r_i=1.0\n",
+      " \n",
+      "i=1\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=3.0\n",
+      "r_i=3.0\n",
+      "s=1\n",
+      "n_i=3\n",
+      "r_i=3.0\n",
+      " \n",
+      "i=2\n",
+      "inner loop on s desc:\n",
+      "s=2\n",
+      "n_i=1.0\n",
+      "r_i=9.0\n",
+      "s=1\n",
+      "n_i=1.0\n",
+      "r_i=9.0\n",
+      "s=0\n",
+      "n_i=3\n",
+      "r_i=9\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "from math import log, ceil\n",
+    "\n",
+    "#input\n",
+    "max_iter = 9  # maximum iterations/epochs per configuration\n",
+    "eta = 3  # defines downsampling rate (default=3)\n",
+    "\n",
+    "logeta = lambda x: log(x)/log(eta)\n",
+    "s_max = int(logeta(max_iter))  # number of unique executions of Successive Halving (minus one)\n",
+    "B = (s_max+1)*max_iter  # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
+    "\n",
+    "#echo output\n",
+    "print (\"echo input:\")\n",
+    "print (\"max_iter = \" + str(max_iter))\n",
+    "print (\"eta = \" + str(eta))\n",
+    "print (\"s_max = \" + str(s_max))\n",
+    "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
+    "\n",
+    "print (\" \")\n",
+    "print (\"initial n, r values for each s:\")\n",
+    "initial_n_vals = {}\n",
+    "initial_r_vals = {}\n",
+    "# get hyper parameter configs for each s\n",
+    "for s in reversed(range(s_max+1)):\n",
+    "    \n",
+    "    n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
+    "    r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
+    "    \n",
+    "    initial_n_vals[s] = n \n",
+    "    initial_r_vals[s] = r \n",
+    "    \n",
+    "    print (\"s=\" + str(s))\n",
+    "    print (\"n=\" + str(n))\n",
+    "    print (\"r=\" + str(r))\n",
+    "    print (\" \")\n",
+    "    \n",
+    "print (\"outer loop on diagonal:\")\n",
+    "# outer loop on diagonal\n",
+    "for i in range(s_max+1):\n",
+    "    print (\" \")\n",
+    "    print (\"i=\" + str(i))\n",
+    "    \n",
+    "    print (\"inner loop on s desc:\")\n",
+    "    # inner loop on s desc\n",
+    "    for s in range(s_max, s_max-i-1, -1):\n",
+    "        n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
+    "        r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
+    "        \n",
+    "        print (\"s=\" + str(s))\n",
+    "        print (\"n_i=\" + str(n_i))\n",
+    "        print (\"r_i=\" + str(r_i))"
+   ]
   }
  ],
  "metadata": {