You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by fm...@apache.org on 2019/12/21 01:00:30 UTC
[madlib-site] branch automl updated: add warm start to hyperband
diag
This is an automated email from the ASF dual-hosted git repository.
fmcquillan pushed a commit to branch automl
in repository https://gitbox.apache.org/repos/asf/madlib-site.git
The following commit(s) were added to refs/heads/automl by this push:
new b51d9c9 add warm start to hyperband diag
b51d9c9 is described below
commit b51d9c9d050494c67a0736677566ae947c00842c
Author: Frank McQuillan <fm...@pivotal.io>
AuthorDate: Fri Dec 20 17:00:05 2019 -0800
add warm start to hyperband diag
---
.../hyperband_diag_v2_mnist-checkpoint.ipynb | 275 +-
.../Deep-learning/automl/hyperband_diag_v1.ipynb | 382 ---
.../automl/hyperband_diag_v2_mnist.ipynb | 252 +-
.../Deep-learning/automl/hyperband_v0.ipynb | 259 --
.../Deep-learning/automl/hyperband_v1.ipynb | 3424 --------------------
.../Deep-learning/automl/hyperband_v1.py | 99 -
.../Deep-learning/automl/hyperband_v2.ipynb | 3043 -----------------
.../Deep-learning/automl/hyperband_v3_mnist.ipynb | 2928 -----------------
8 files changed, 272 insertions(+), 10390 deletions(-)
diff --git a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
index b62f8d5..fa92b05 100644
--- a/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
+++ b/community-artifacts/Deep-learning/automl/.ipynb_checkpoints/hyperband_diag_v2_mnist-checkpoint.ipynb
@@ -6,7 +6,7 @@
"source": [
"# Hyperband diagonal using MNIST\n",
"\n",
- "Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP - uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time. \n",
+ "Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP with a synchronous barrier. Uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time, to be more efficient with cluster resources. \n",
"\n",
"Model architecture based on https://keras.io/examples/mnist_transfer_cnn/ \n",
"\n",
@@ -25,22 +25,24 @@
"\n",
"<a href=\"#plot\">6. Plot results</a>\n",
"\n",
- "<a href=\"#print\">7. Print run schedules</a>"
+ "<a href=\"#print\">7. Print run schedules (display only)</a>"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "The sql extension is already loaded. To reload it, use:\n",
- " %reload_ext sql\n"
+ "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
+ " \"You should import from traitlets.config instead.\", ShimWarning)\n",
+ "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
+ " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
@@ -50,7 +52,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -72,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -90,15 +92,15 @@
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
- " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
+ " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
- "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
+ "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
- "execution_count": 19,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -119,9 +121,24 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 5,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using TensorFlow backend.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Couldn't import dot_parser, loading of dot files will not be possible.\n"
+ ]
+ }
+ ],
"source": [
"from __future__ import print_function\n",
"\n",
@@ -163,7 +180,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -183,7 +200,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -215,7 +232,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -638,7 +655,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -648,29 +665,29 @@
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
- "conv2d_3 (Conv2D) (None, 26, 26, 32) 320 \n",
+ "conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
- "activation_5 (Activation) (None, 26, 26, 32) 0 \n",
+ "activation_1 (Activation) (None, 26, 26, 32) 0 \n",
"_________________________________________________________________\n",
- "conv2d_4 (Conv2D) (None, 24, 24, 32) 9248 \n",
+ "conv2d_2 (Conv2D) (None, 24, 24, 32) 9248 \n",
"_________________________________________________________________\n",
- "activation_6 (Activation) (None, 24, 24, 32) 0 \n",
+ "activation_2 (Activation) (None, 24, 24, 32) 0 \n",
"_________________________________________________________________\n",
- "max_pooling2d_2 (MaxPooling2 (None, 12, 12, 32) 0 \n",
+ "max_pooling2d_1 (MaxPooling2 (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
- "dropout_3 (Dropout) (None, 12, 12, 32) 0 \n",
+ "dropout_1 (Dropout) (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
- "flatten_2 (Flatten) (None, 4608) 0 \n",
+ "flatten_1 (Flatten) (None, 4608) 0 \n",
"_________________________________________________________________\n",
- "dense_3 (Dense) (None, 128) 589952 \n",
+ "dense_1 (Dense) (None, 128) 589952 \n",
"_________________________________________________________________\n",
- "activation_7 (Activation) (None, 128) 0 \n",
+ "activation_3 (Activation) (None, 128) 0 \n",
"_________________________________________________________________\n",
- "dropout_4 (Dropout) (None, 128) 0 \n",
+ "dropout_2 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
- "dense_4 (Dense) (None, 10) 1290 \n",
+ "dense_2 (Dense) (None, 10) 1290 \n",
"_________________________________________________________________\n",
- "activation_8 (Activation) (None, 10) 0 \n",
+ "activation_4 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 600,810\n",
"Trainable params: 600,810\n",
@@ -716,7 +733,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -745,7 +762,7 @@
"[(1, u'feature + classification layers trainable')]"
]
},
- "execution_count": 41,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -765,7 +782,7 @@
"metadata": {},
"source": [
"<a id=\"hyperband\"></a>\n",
- "# 5. Hyperband"
+ "# 5. Hyperband diagonal"
]
},
{
@@ -777,7 +794,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 34,
"metadata": {},
"outputs": [
{
@@ -804,7 +821,7 @@
"[]"
]
},
- "execution_count": 22,
+ "execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
@@ -874,12 +891,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Table names"
+ "Generalize table names"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@@ -902,12 +919,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Hyperband diagonal"
+ "Hyperband diagonal logic"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@@ -923,15 +940,19 @@
" self.try_params = try_params_function\n",
"\n",
" self.max_iter = 9 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
+ " self.eta = 3 # defines downsampling rate (default = 3)\n",
"\n",
" self.logeta = lambda x: log( x ) / log( self.eta )\n",
" self.s_max = int( self.logeta( self.max_iter ))\n",
" self.B = ( self.s_max + 1 ) * self.max_iter\n",
" self.setup_full_schedule()\n",
" self.create_mst_superset()\n",
+ " \n",
+ " self.best_loss = np.inf\n",
+ " self.best_accuracy = 0.0\n",
+ "\n",
" \n",
- " # create full Hyperband schedule for all brackets\n",
+ " # create full Hyperband schedule for all brackets ahead of time\n",
" def setup_full_schedule(self):\n",
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
@@ -940,7 +961,7 @@
" print (\" \")\n",
" print (\"Hyperband brackets\")\n",
"\n",
- " #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
+ " # loop through each bracket in reverse order\n",
" for s in reversed(range(self.s_max+1)):\n",
" \n",
" print (\" \")\n",
@@ -953,9 +974,8 @@
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" #### Begin Finite Horizon Successive Halving with (n,r)\n",
- " #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
" for i in range(s+1):\n",
- " # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
+ " # n_i configs for r_i iterations\n",
" n_i = n*self.eta**(-i)\n",
" r_i = r*self.eta**(i)\n",
"\n",
@@ -968,15 +988,14 @@
" if counter == s:\n",
" sum_leaf_n_i += n_i\n",
" counter += 1\n",
- "\n",
- " #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
- " #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
+ " \n",
" #### End Finite Horizon Successive Halving with (n,r)\n",
"\n",
" #print (\" \")\n",
" #print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
" #print (\"(if have more workers than this, they may not be 100% busy)\")\n",
" \n",
+ " \n",
" # generate model selection tuples for all brackets\n",
" def create_mst_superset(self):\n",
" # get hyper parameter configs for each bracket s\n",
@@ -984,7 +1003,6 @@
" n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
- " \n",
" print (\" \")\n",
" print (\"Create superset of MSTs, i.e., i=0 for for each bracket s\")\n",
" print (\" \")\n",
@@ -997,10 +1015,13 @@
" self.get_params(n, s)\n",
" \n",
" \n",
- " # run Hyperband diagonal logic\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ): \n",
+ " # Hyperband diagonal logic\n",
+ " def run( self, skip_last = 0, dry_run = False ): \n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"Hyperband diagonal\")\n",
" print (\"outer loop on diagonal:\")\n",
+ " \n",
" # outer loop on diagonal\n",
" for i in range(self.s_max+1):\n",
" print (\" \")\n",
@@ -1015,20 +1036,24 @@
"\n",
" # build up mst table for diagonal\n",
" %sql INSERT INTO $mst_diag_table (SELECT * FROM $mst_table WHERE s=$s);\n",
+ " \n",
+ " # first pass\n",
+ " if i == 0:\n",
+ " first_pass = True\n",
+ " else:\n",
+ " first_pass = False\n",
" \n",
" # multi-model training\n",
+ " print (\" \")\n",
" print (\"try params for i = \" + str(i))\n",
- " U = self.try_params(i, self.r_vals[self.s_max][i]) # r_i is the same for all diagonal elements\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
+ " U = self.try_params(i, self.r_vals[self.s_max][i], first_pass) # r_i is the same for all diagonal elements\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
" # don't need to prune if finished last diagonal\n",
" if i < self.s_max:\n",
" print (\"loop on s desc to prune mst table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
- " \n",
+ " \n",
" # compute number of configs to keep\n",
" # remember i value is different for each bracket s on the diagonal\n",
" k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
@@ -1054,14 +1079,36 @@
" \"\"\".format(**locals())\n",
" cur.execute(query)\n",
" conn.commit()\n",
+ " \n",
+ " # these were not working so used cursor instead\n",
" #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
" #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+ " \n",
+ " # keep track of best loss so far (for display purposes only)\n",
+ " loss = %sql SELECT validation_loss_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
+ " accuracy = %sql SELECT validation_metrics_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
+ " \n",
+ " if loss < self.best_loss:\n",
+ " self.best_loss = loss\n",
+ " self.best_accuracy = accuracy\n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"best validation loss so far = \" + str(loss))\n",
+ " print (\"best validation accuracy so far = \" + str(accuracy))\n",
+ " \n",
" return"
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate params and insert into MST table"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@@ -1087,7 +1134,7 @@
"\n",
" # fit params\n",
" # batch size\n",
- " batch_size = [64, 128]\n",
+ " batch_size = [32, 64, 128]\n",
" # epochs\n",
" epochs = [1]\n",
"\n",
@@ -1116,21 +1163,32 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run model hopper for candidates in MST table"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
- "def try_params(i, r):\n",
+ "def try_params(i, r, first_pass):\n",
" \n",
" # multi-model fit\n",
- " # TO DO: use warm start to continue from where left off after if not 1st time thru for this s value\n",
- " %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
- " \n",
- " # passing vars as madlib args does not seem to work\n",
- " #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
- " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, 0, 'test_mnist_packed');\n",
- " \n",
+ " if first_pass:\n",
+ " # cold start\n",
+ " %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
+ " # passing vars as madlib args does not seem to work\n",
+ " #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
+ " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed');\n",
+ "\n",
+ " else:\n",
+ " # warm start to continue from previous run\n",
+ " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed', NULL, True);\n",
+ "\n",
" # save results via temp table\n",
" # add everything from info table\n",
" %sql DROP TABLE IF EXISTS temp_results;\n",
@@ -1150,8 +1208,15 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Call Hyperband diagonal"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {
"scrolled": false
},
@@ -1216,61 +1281,17 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
+ " \n",
+ "Hyperband diagonal\n",
"outer loop on diagonal:\n",
" \n",
"i=0\n",
"Done.\n",
"loop on s desc to create diagonal table:\n",
"9 rows affected.\n",
- "try params for i = 0\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "9 rows affected.\n",
- "9 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 3\n",
- " \n",
- "i=1\n",
- "Done.\n",
- "loop on s desc to create diagonal table:\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "try params for i = 1\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "6 rows affected.\n",
- "Done.\n",
- "6 rows affected.\n",
- "6 rows affected.\n",
- "6 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 1\n",
- "pruning s = 1 with k = 1\n",
" \n",
- "i=2\n",
- "Done.\n",
- "loop on s desc to create diagonal table:\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "3 rows affected.\n",
- "try params for i = 2\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "5 rows affected.\n",
- "Done.\n",
- "5 rows affected.\n",
- "5 rows affected.\n",
- "5 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 0\n",
- "pruning s = 1 with k = 0\n",
- "pruning s = 0 with k = 1\n"
+ "try params for i = 0\n",
+ "Done.\n"
]
}
],
@@ -1290,7 +1311,7 @@
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@@ -1321,7 +1342,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "10 rows affected.\n"
+ "8 rows affected.\n"
]
},
{
@@ -2107,7 +2128,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb1f3+X0se8ooTJ3GcxM5wEhJ2gSTMJJBCGYVCCJRSaKGD1VL6b/kBLSNAwygd0AE0ZbS0BcoKq0CBssPMYI8MMp3h7HhLsiX9n/fI15GH7CtZlq7s9zyPHyf2uWd8zpX11Xu/IyMUCoWgJgIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIi0KsEMiTE9SpfDS4CIiACIiACIiACIiACIiACIiACIiACIiAChoCEON0IIpDmBE477TTMnz8fP/7xj3H77bcndDeTJ0/GkiVL8Nvf/hb/93//l9CxNVjiCHz66afYd999zYBbt27FkCFDEje4RhIBERABERABEUgaAdl1HVHX1dWhsLDQ/GLRokWgfaomAiIgAulMQEJcOp+e1t5rBDI [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABEwAAAImCAYAAABJvh+8AAAgAElEQVR4XuzdC5xVVfn/8WcGHEEDVEQFQchLKgKapqWCifzwbirmL/9mCt7wlkl44ScJeAEKvBGpQCKYJlpA5j3xggUqppkiWCZeQMUyQC4iyGX+r++yM80MM3PWOWcO59nnfPbrxQuZWfvstd5rjeuZZ6+9dlllZWWlcSCAAAIIIIAAAggggAACCCCAAAIIVAmUkTBhNCCAAAIIIIAAAggggAACCCCAAAI1BUiYMCISKVBWVpZxvc866yybPHlyxudlcsJ3v/tdmzZtmv32t781/Xeuxy9+8Qv74Q9/aBdffLHpvzk2j8Dll19uN910k40ePdr03xwIIIAAAghkIkCckolWfsu+8cYb1rVrV9tnn31M/82BAAIIZCJAwiQTLcq6Eejbt+8mdfn444/ [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -2127,15 +2148,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 10;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY training_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
@@ -2185,7 +2204,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "10 rows affected.\n"
+ "8 rows affected.\n"
]
},
{
@@ -2971,7 +2990,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCZiT1fX/v8ns+zDMhgyLyOaKsoqyuAGCu2LVYkVFLdat1bb8q4IoLqVVa12x1V/FuiMq1YpAQRRQZFNEBUGQZZBhFph9yUyS//O9wztmMsnkTfImeSdz7vPMM0Duve+9n3tDTr733HMsTqfTCSlCQAgIASEgBISAEBACQkAICAEhIASEgBAQAkJACISUgEWEuJDylc6FgBAQAkJACAgBISAEhIAQEAJCQAgIASEgBISAIiBCnGwEIRBGAh999BEmTpyIhIQE1NfXt3pye6/pGWKw7fU8o7068+bNw80334wBAwZg27ZtwXYn7UNEQNYpRGClWyEgBISAEOh0BMSuM9+SR9oeNh8RGZEQEAJmJCBCnBlXRcZkOIEbb7wRL7zwArKysvD [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABEwAAAImCAYAAABJvh+8AAAgAElEQVR4XuzdCZhUxdX/8cPiCCqgIioIQhQTRRajYuKCEXlxI0TF+Ma/cQE3FI2R4MKrCYsKJqBRQoxCRDAalwTRuEdcIIKK0Sgiaoy4QBQTRWQRQZb5P78yTXqGmenq7mn63O7vfR4fBOr2rfpUDXX63Lp1G1RWVlYaBwIIIIAAAggggAACCCCAAAIIIIDARoEGJEwYDQgggAACCCCAAAIIIIAAAggggEBVARImjAhXAqeeeqrdcccddvrpp9uUKVMy1u2CCy6wG2+80Y477ji77777MpavrcAtt9xiZ599tp155pmm/08db7/9tu2xxx62++67m/4/9li3bp1tscUW1qhRI9P/b47jkEMOsdmzZ9szzzxj+n+OwgvUNm4Kf2WugAACCCBQDAHilNz [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -2991,15 +3010,13 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
],
"source": [
"#df_results = %sql SELECT * FROM $results_table ORDER BY run_id;\n",
- "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 10;\n",
+ "df_results = %sql SELECT * FROM $results_table ORDER BY validation_loss ASC LIMIT 12;\n",
"df_results = df_results.DataFrame()\n",
"\n",
"#set up plots\n",
@@ -3038,7 +3055,7 @@
"metadata": {},
"source": [
"<a id=\"print\"></a>\n",
- "# 7. Print run schedules"
+ "# 7. Print run schedules (display only)"
]
},
{
@@ -3051,7 +3068,9 @@
{
"cell_type": "code",
"execution_count": 32,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
"name": "stdout",
diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v1.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v1.ipynb
deleted file mode 100644
index c485a81..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_diag_v1.ipynb
+++ /dev/null
@@ -1,382 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
- ]
- }
- ],
- "source": [
- "%load_ext sql"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
- "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
- "\n",
- "# Greenplum Database 5.x on GCP - via tunnel\n",
- "%sql postgresql://gpadmin@localhost:8000/madlib\n",
- " \n",
- "# PostgreSQL local\n",
- "#%sql postgresql://fmcquillan@localhost:5432/madlib\n",
- "\n",
- "# psycopg2 connection\n",
- "import psycopg2 as p2\n",
- "#conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- "conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
- "cur = conn.cursor()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>version</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-46-g77ee745, cmake configuration time: Thu Nov 14 17:59:26 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-46-g77ee745, cmake configuration time: Thu Nov 14 17:59:26 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%sql select madlib.version();\n",
- "#%sql select version();"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Pretty print run schedule"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 71,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "max_iter = 81\n",
- "eta = 3\n",
- "B = 5*max_iter = 405\n",
- " \n",
- "s=4\n",
- "n_i r_i\n",
- "------------\n",
- "81 1.0\n",
- "27.0 3.0\n",
- "9.0 9.0\n",
- "3.0 27.0\n",
- "1.0 81.0\n",
- " \n",
- "s=3\n",
- "n_i r_i\n",
- "------------\n",
- "27 3.0\n",
- "9.0 9.0\n",
- "3.0 27.0\n",
- "1.0 81.0\n",
- " \n",
- "s=2\n",
- "n_i r_i\n",
- "------------\n",
- "9 9.0\n",
- "3.0 27.0\n",
- "1.0 81.0\n",
- " \n",
- "s=1\n",
- "n_i r_i\n",
- "------------\n",
- "6 27.0\n",
- "2.0 81.0\n",
- " \n",
- "s=0\n",
- "n_i r_i\n",
- "------------\n",
- "5 81\n",
- " \n",
- "sum of configurations at leaf nodes across all s = 10.0\n",
- "(if have more workers than this, they may not be 100% busy)\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "from math import log, ceil\n",
- "\n",
- "#input\n",
- "max_iter = 81 # maximum iterations/epochs per configuration\n",
- "eta = 3 # defines downsampling rate (default=3)\n",
- "\n",
- "logeta = lambda x: log(x)/log(eta)\n",
- "s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
- "B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
- "\n",
- "#echo output\n",
- "print (\"max_iter = \" + str(max_iter))\n",
- "print (\"eta = \" + str(eta))\n",
- "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
- "\n",
- "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
- "\n",
- "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
- "for s in reversed(range(s_max+1)):\n",
- " \n",
- " print (\" \")\n",
- " print (\"s=\" + str(s))\n",
- " print (\"n_i r_i\")\n",
- " print (\"------------\")\n",
- " counter = 0\n",
- " \n",
- " n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
- " r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
- "\n",
- " #### Begin Finite Horizon Successive Halving with (n,r)\n",
- " #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
- " for i in range(s+1):\n",
- " # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
- " n_i = n*eta**(-i)\n",
- " r_i = r*eta**(i)\n",
- " \n",
- " print (str(n_i) + \" \" + str (r_i))\n",
- " \n",
- " # check if leaf node for this s\n",
- " if counter == s:\n",
- " sum_leaf_n_i += n_i\n",
- " counter += 1\n",
- " \n",
- " #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
- " #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
- " #### End Finite Horizon Successive Halving with (n,r)\n",
- "\n",
- "print (\" \")\n",
- "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
- "print (\"(if have more workers than this, they may not be 100% busy)\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Pretty print diagonal"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 72,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "echo input:\n",
- "max_iter = 81\n",
- "eta = 3\n",
- "s_max = 4\n",
- "B = 5*max_iter = 405\n",
- " \n",
- "initial n, r values for each s:\n",
- "s=4\n",
- "n=81\n",
- "r=1.0\n",
- " \n",
- "s=3\n",
- "n=27\n",
- "r=3.0\n",
- " \n",
- "s=2\n",
- "n=9\n",
- "r=9.0\n",
- " \n",
- "s=1\n",
- "n=6\n",
- "r=27.0\n",
- " \n",
- "s=0\n",
- "n=5\n",
- "r=81\n",
- " \n",
- "outer loop on diagonal:\n",
- " \n",
- "i=0\n",
- "inner loop on s desc:\n",
- "s=4\n",
- "n_i=81\n",
- "r_i=1.0\n",
- " \n",
- "i=1\n",
- "inner loop on s desc:\n",
- "s=4\n",
- "n_i=27.0\n",
- "r_i=3.0\n",
- "s=3\n",
- "n_i=27\n",
- "r_i=3.0\n",
- " \n",
- "i=2\n",
- "inner loop on s desc:\n",
- "s=4\n",
- "n_i=9.0\n",
- "r_i=9.0\n",
- "s=3\n",
- "n_i=9.0\n",
- "r_i=9.0\n",
- "s=2\n",
- "n_i=9\n",
- "r_i=9.0\n",
- " \n",
- "i=3\n",
- "inner loop on s desc:\n",
- "s=4\n",
- "n_i=3.0\n",
- "r_i=27.0\n",
- "s=3\n",
- "n_i=3.0\n",
- "r_i=27.0\n",
- "s=2\n",
- "n_i=3.0\n",
- "r_i=27.0\n",
- "s=1\n",
- "n_i=6\n",
- "r_i=27.0\n",
- " \n",
- "i=4\n",
- "inner loop on s desc:\n",
- "s=4\n",
- "n_i=1.0\n",
- "r_i=81.0\n",
- "s=3\n",
- "n_i=1.0\n",
- "r_i=81.0\n",
- "s=2\n",
- "n_i=1.0\n",
- "r_i=81.0\n",
- "s=1\n",
- "n_i=2.0\n",
- "r_i=81.0\n",
- "s=0\n",
- "n_i=5\n",
- "r_i=81\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "from math import log, ceil\n",
- "\n",
- "#input\n",
- "max_iter = 81 # maximum iterations/epochs per configuration\n",
- "eta = 3 # defines downsampling rate (default=3)\n",
- "\n",
- "logeta = lambda x: log(x)/log(eta)\n",
- "s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
- "B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
- "\n",
- "#echo output\n",
- "print (\"echo input:\")\n",
- "print (\"max_iter = \" + str(max_iter))\n",
- "print (\"eta = \" + str(eta))\n",
- "print (\"s_max = \" + str(s_max))\n",
- "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
- "\n",
- "print (\" \")\n",
- "print (\"initial n, r values for each s:\")\n",
- "initial_n_vals = {}\n",
- "initial_r_vals = {}\n",
- "# get hyper parameter configs for each s\n",
- "for s in reversed(range(s_max+1)):\n",
- " \n",
- " n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
- " r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
- " \n",
- " initial_n_vals[s] = n \n",
- " initial_r_vals[s] = r \n",
- " \n",
- " print (\"s=\" + str(s))\n",
- " print (\"n=\" + str(n))\n",
- " print (\"r=\" + str(r))\n",
- " print (\" \")\n",
- " \n",
- "print (\"outer loop on diagonal:\")\n",
- "# outer loop on diagonal\n",
- "for i in range(s_max+1):\n",
- " print (\" \")\n",
- " print (\"i=\" + str(i))\n",
- " \n",
- " print (\"inner loop on s desc:\")\n",
- " # inner loop on s desc\n",
- " for s in range(s_max, s_max-i-1, -1):\n",
- " n_i = initial_n_vals[s]*eta**(-i+s_max-s)\n",
- " r_i = initial_r_vals[s]*eta**(i-s_max+s)\n",
- " \n",
- " print (\"s=\" + str(s))\n",
- " print (\"n_i=\" + str(n_i))\n",
- " print (\"r_i=\" + str(r_i))"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 1
-}
diff --git a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
index 171c9cd..fa92b05 100644
--- a/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
+++ b/community-artifacts/Deep-learning/automl/hyperband_diag_v2_mnist.ipynb
@@ -6,7 +6,7 @@
"source": [
"# Hyperband diagonal using MNIST\n",
"\n",
- "Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP - uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time. \n",
+ "Implemention of Hyperband https://arxiv.org/pdf/1603.06560.pdf for MPP with a synchronous barrier. Uses the Hyperband schedule but runs it on a diagonal across brackets, instead of one bracket at a time, to be more efficient with cluster resources. \n",
"\n",
"Model architecture based on https://keras.io/examples/mnist_transfer_cnn/ \n",
"\n",
@@ -25,12 +25,12 @@
"\n",
"<a href=\"#plot\">6. Plot results</a>\n",
"\n",
- "<a href=\"#print\">7. Print run schedules</a>"
+ "<a href=\"#print\">7. Print run schedules (display only)</a>"
]
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"metadata": {
"scrolled": true
},
@@ -74,7 +74,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -92,15 +92,15 @@
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
- " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
+ " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
- "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
+ "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-50-g5abfb79, cmake configuration time: Tue Nov 26 01:00:01 UTC 2019, build type: release, build system: Linux-3.10.0-1062.4.3.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
]
},
- "execution_count": 5,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -121,7 +121,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -180,7 +180,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -200,7 +200,7 @@
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -232,7 +232,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@@ -655,7 +655,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -665,29 +665,29 @@
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
- "conv2d_3 (Conv2D) (None, 26, 26, 32) 320 \n",
+ "conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
- "activation_5 (Activation) (None, 26, 26, 32) 0 \n",
+ "activation_1 (Activation) (None, 26, 26, 32) 0 \n",
"_________________________________________________________________\n",
- "conv2d_4 (Conv2D) (None, 24, 24, 32) 9248 \n",
+ "conv2d_2 (Conv2D) (None, 24, 24, 32) 9248 \n",
"_________________________________________________________________\n",
- "activation_6 (Activation) (None, 24, 24, 32) 0 \n",
+ "activation_2 (Activation) (None, 24, 24, 32) 0 \n",
"_________________________________________________________________\n",
- "max_pooling2d_2 (MaxPooling2 (None, 12, 12, 32) 0 \n",
+ "max_pooling2d_1 (MaxPooling2 (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
- "dropout_3 (Dropout) (None, 12, 12, 32) 0 \n",
+ "dropout_1 (Dropout) (None, 12, 12, 32) 0 \n",
"_________________________________________________________________\n",
- "flatten_2 (Flatten) (None, 4608) 0 \n",
+ "flatten_1 (Flatten) (None, 4608) 0 \n",
"_________________________________________________________________\n",
- "dense_3 (Dense) (None, 128) 589952 \n",
+ "dense_1 (Dense) (None, 128) 589952 \n",
"_________________________________________________________________\n",
- "activation_7 (Activation) (None, 128) 0 \n",
+ "activation_3 (Activation) (None, 128) 0 \n",
"_________________________________________________________________\n",
- "dropout_4 (Dropout) (None, 128) 0 \n",
+ "dropout_2 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
- "dense_4 (Dense) (None, 10) 1290 \n",
+ "dense_2 (Dense) (None, 10) 1290 \n",
"_________________________________________________________________\n",
- "activation_8 (Activation) (None, 10) 0 \n",
+ "activation_4 (Activation) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 600,810\n",
"Trainable params: 600,810\n",
@@ -733,7 +733,7 @@
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
@@ -762,7 +762,7 @@
"[(1, u'feature + classification layers trainable')]"
]
},
- "execution_count": 41,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -782,7 +782,7 @@
"metadata": {},
"source": [
"<a id=\"hyperband\"></a>\n",
- "# 5. Hyperband"
+ "# 5. Hyperband diagonal"
]
},
{
@@ -794,7 +794,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 34,
"metadata": {},
"outputs": [
{
@@ -821,7 +821,7 @@
"[]"
]
},
- "execution_count": 22,
+ "execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
@@ -891,12 +891,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Table names"
+ "Generalize table names"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@@ -919,12 +919,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Hyperband diagonal"
+ "Hyperband diagonal logic"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@@ -940,15 +940,19 @@
" self.try_params = try_params_function\n",
"\n",
" self.max_iter = 9 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
+ " self.eta = 3 # defines downsampling rate (default = 3)\n",
"\n",
" self.logeta = lambda x: log( x ) / log( self.eta )\n",
" self.s_max = int( self.logeta( self.max_iter ))\n",
" self.B = ( self.s_max + 1 ) * self.max_iter\n",
" self.setup_full_schedule()\n",
" self.create_mst_superset()\n",
+ " \n",
+ " self.best_loss = np.inf\n",
+ " self.best_accuracy = 0.0\n",
+ "\n",
" \n",
- " # create full Hyperband schedule for all brackets\n",
+ " # create full Hyperband schedule for all brackets ahead of time\n",
" def setup_full_schedule(self):\n",
" self.n_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
" self.r_vals = np.zeros((self.s_max+1, self.s_max+1), dtype=int)\n",
@@ -957,7 +961,7 @@
" print (\" \")\n",
" print (\"Hyperband brackets\")\n",
"\n",
- " #### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
+ " # loop through each bracket in reverse order\n",
" for s in reversed(range(self.s_max+1)):\n",
" \n",
" print (\" \")\n",
@@ -970,9 +974,8 @@
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
" #### Begin Finite Horizon Successive Halving with (n,r)\n",
- " #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
" for i in range(s+1):\n",
- " # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
+ " # n_i configs for r_i iterations\n",
" n_i = n*self.eta**(-i)\n",
" r_i = r*self.eta**(i)\n",
"\n",
@@ -985,15 +988,14 @@
" if counter == s:\n",
" sum_leaf_n_i += n_i\n",
" counter += 1\n",
- "\n",
- " #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
- " #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
+ " \n",
" #### End Finite Horizon Successive Halving with (n,r)\n",
"\n",
" #print (\" \")\n",
" #print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
" #print (\"(if have more workers than this, they may not be 100% busy)\")\n",
" \n",
+ " \n",
" # generate model selection tuples for all brackets\n",
" def create_mst_superset(self):\n",
" # get hyper parameter configs for each bracket s\n",
@@ -1001,7 +1003,6 @@
" n = int(ceil(int(self.B/self.max_iter/(s+1))*self.eta**s)) # initial number of configurations\n",
" r = self.max_iter*self.eta**(-s) # initial number of iterations to run configurations for\n",
"\n",
- " \n",
" print (\" \")\n",
" print (\"Create superset of MSTs, i.e., i=0 for for each bracket s\")\n",
" print (\" \")\n",
@@ -1014,10 +1015,13 @@
" self.get_params(n, s)\n",
" \n",
" \n",
- " # run Hyperband diagonal logic\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ): \n",
+ " # Hyperband diagonal logic\n",
+ " def run( self, skip_last = 0, dry_run = False ): \n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"Hyperband diagonal\")\n",
" print (\"outer loop on diagonal:\")\n",
+ " \n",
" # outer loop on diagonal\n",
" for i in range(self.s_max+1):\n",
" print (\" \")\n",
@@ -1032,20 +1036,24 @@
"\n",
" # build up mst table for diagonal\n",
" %sql INSERT INTO $mst_diag_table (SELECT * FROM $mst_table WHERE s=$s);\n",
+ " \n",
+ " # first pass\n",
+ " if i == 0:\n",
+ " first_pass = True\n",
+ " else:\n",
+ " first_pass = False\n",
" \n",
" # multi-model training\n",
+ " print (\" \")\n",
" print (\"try params for i = \" + str(i))\n",
- " U = self.try_params(i, self.r_vals[self.s_max][i]) # r_i is the same for all diagonal elements\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
+ " U = self.try_params(i, self.r_vals[self.s_max][i], first_pass) # r_i is the same for all diagonal elements\n",
" \n",
" # loop on brackets s desc to prune model selection table\n",
" # don't need to prune if finished last diagonal\n",
" if i < self.s_max:\n",
" print (\"loop on s desc to prune mst table:\")\n",
" for s in range(self.s_max, self.s_max-i-1, -1):\n",
- " \n",
+ " \n",
" # compute number of configs to keep\n",
" # remember i value is different for each bracket s on the diagonal\n",
" k = int( self.n_vals[s][s-self.s_max+i] / self.eta)\n",
@@ -1071,14 +1079,36 @@
" \"\"\".format(**locals())\n",
" cur.execute(query)\n",
" conn.commit()\n",
+ " \n",
+ " # these were not working so used cursor instead\n",
" #%sql DELETE FROM $mst_table WHERE s=$s AND mst_key NOT IN (SELECT $output_table_info.mst_key FROM $output_table_info JOIN $mst_table ON $output_table_info.mst_key=$mst_table.mst_key WHERE s=$s ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
" #%sql DELETE FROM mst_table_hb_mnist WHERE s=1 AND mst_key NOT IN (SELECT mnist_multi_model_info.mst_key FROM mnist_multi_model_info JOIN mst_table_hb_mnist ON mnist_multi_model_info.mst_key=mst_table_hb_mnist.mst_key WHERE s=1 ORDER BY validation_loss_final ASC LIMIT 1);\n",
+ " \n",
+ " # keep track of best loss so far (for display purposes only)\n",
+ " loss = %sql SELECT validation_loss_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
+ " accuracy = %sql SELECT validation_metrics_final FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT 1;\n",
+ " \n",
+ " if loss < self.best_loss:\n",
+ " self.best_loss = loss\n",
+ " self.best_accuracy = accuracy\n",
+ " \n",
+ " print (\" \")\n",
+ " print (\"best validation loss so far = \" + str(loss))\n",
+ " print (\"best validation accuracy so far = \" + str(accuracy))\n",
+ " \n",
" return"
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate params and insert into MST table"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@@ -1104,7 +1134,7 @@
"\n",
" # fit params\n",
" # batch size\n",
- " batch_size = [64, 128]\n",
+ " batch_size = [32, 64, 128]\n",
" # epochs\n",
" epochs = [1]\n",
"\n",
@@ -1133,21 +1163,32 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run model hopper for candidates in MST table"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
- "def try_params(i, r):\n",
+ "def try_params(i, r, first_pass):\n",
" \n",
" # multi-model fit\n",
- " # TO DO: use warm start to continue from where left off after if not 1st time thru for this s value\n",
- " %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
- " \n",
- " # passing vars as madlib args does not seem to work\n",
- " #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
- " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, 0, 'test_mnist_packed');\n",
- " \n",
+ " if first_pass:\n",
+ " # cold start\n",
+ " %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
+ " # passing vars as madlib args does not seem to work\n",
+ " #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_diag_table, $r_i::INT, 0);\n",
+ " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed');\n",
+ "\n",
+ " else:\n",
+ " # warm start to continue from previous run\n",
+ " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_diag_table_hb_mnist', $r::INT, FALSE, 'test_mnist_packed', NULL, True);\n",
+ "\n",
" # save results via temp table\n",
" # add everything from info table\n",
" %sql DROP TABLE IF EXISTS temp_results;\n",
@@ -1167,8 +1208,15 @@
]
},
{
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Call Hyperband diagonal"
+ ]
+ },
+ {
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {
"scrolled": false
},
@@ -1233,61 +1281,17 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
+ " \n",
+ "Hyperband diagonal\n",
"outer loop on diagonal:\n",
" \n",
"i=0\n",
"Done.\n",
"loop on s desc to create diagonal table:\n",
"9 rows affected.\n",
- "try params for i = 0\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "9 rows affected.\n",
- "9 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 3\n",
- " \n",
- "i=1\n",
- "Done.\n",
- "loop on s desc to create diagonal table:\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "try params for i = 1\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "6 rows affected.\n",
- "Done.\n",
- "6 rows affected.\n",
- "6 rows affected.\n",
- "6 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 1\n",
- "pruning s = 1 with k = 1\n",
" \n",
- "i=2\n",
- "Done.\n",
- "loop on s desc to create diagonal table:\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "3 rows affected.\n",
- "try params for i = 2\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "5 rows affected.\n",
- "Done.\n",
- "5 rows affected.\n",
- "5 rows affected.\n",
- "5 rows affected.\n",
- "loop on s desc to prune mst table:\n",
- "pruning s = 2 with k = 0\n",
- "pruning s = 1 with k = 0\n",
- "pruning s = 0 with k = 1\n"
+ "try params for i = 0\n",
+ "Done.\n"
]
}
],
@@ -1307,7 +1311,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@@ -1331,14 +1335,14 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "12 rows affected.\n"
+ "8 rows affected.\n"
]
},
{
@@ -2124,7 +2128,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydB3hb5dn+b0m2vDOc2Fl2diBAWVnsXSBQRinQSQu00EJ3oS39GIE2FMr3tdD2T1tmW7poWWXPllF2BrthZMcZjp3YTjwlW9L/ul/52LItyUfS0bB9v9fly4n1nnf83mPr0X2e4QqFQiGoiYAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIiIAIpJWAS0JcWvlqcBEQAREQAREQAREQAREQAREQAREQAREQAREwBCTE6UYQgSFO4KyzzsL999+Pb3zjG7j55psd3c2CBQuwcuVK/N///R++//3vOzq2BnOOwHvvvYd9993XDFhfX4/x48c7N7hGEgEREAEREAERyBgB2XUDUbe0tKCsrMy8sHz5ctA+VRMBERCBoUxAQtxQPj2tPW0EXC5X0mP [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABEwAAAImCAYAAABJvh+8AAAgAElEQVR4XuzdC5xVVfn/8WcGHEEDVEQFQchLKgKapqWCifzwbirmL/9mCt7wlkl44ScJeAEKvBGpQCKYJlpA5j3xggUqppkiWCZeQMUyQC4iyGX+r++yM80MM3PWOWcO59nnfPbrxQuZWfvstd5rjeuZZ6+9dlllZWWlcSCAAAIIIIAAAggggAACCCCAAAIIVAmUkTBhNCCAAAIIIIAAAggggAACCCCAAAI1BUiYMCISKVBWVpZxvc866yybPHlyxudlcsJ3v/tdmzZtmv32t781/Xeuxy9+8Qv74Q9/aBdffLHpvzk2j8Dll19uN910k40ePdr03xwIIIAAAghkIkCckolWfsu+8cYb1rVrV9tnn31M/82BAAIIZCJAwiQTLcq6Eejbt+8mdfn444/ [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -2144,10 +2148,6 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
@@ -2197,14 +2197,14 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "12 rows affected.\n"
+ "8 rows affected.\n"
]
},
{
@@ -2990,7 +2990,7 @@
{
"data": {
"text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABOIAAAJxCAYAAADvpB2RAAAgAElEQVR4XuydCXhU1fnG31myrwQStgQQVMAN2aKAgKBW3Dds3dFqFWtrq/6rbVWkxaV0cWmtxbVq1boUq3VHZUchgAiiICoCCRAC2deZJJP/857kJpPJLPfO3FmSfOd58iQk5557zu+cYb5577dYWlpaWiBNCAgBISAEhIAQEAJCQAgIASEgBISAEBACQkAICIGwErCIEBdWvjK4EBACQkAICAEhIASEgBAQAkJACAgBISAEhIAQUAREiJODIAQiSOD999/HGWecgYSEBDQ0NHS6s7+/6ZliqNfruYe/PosWLcKNN96IkSNHYvv27aEOJ9eHiYDsU5jAyrBCQAgIASHQ6wiIXRd7Wx5tezj2iMiMhIAQiEUCIsTF4q7InEwn8JOf/ARPPfUUsrKysG/ [...]
+ "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABEwAAAImCAYAAABJvh+8AAAgAElEQVR4XuzdCZhUxdX/8cPiCCqgIioIQhQTRRajYuKCEXlxI0TF+Ma/cQE3FI2R4MKrCYsKJqBRQoxCRDAalwTRuEdcIIKK0Sgiaoy4QBQTRWQRQZb5P78yTXqGmenq7mn63O7vfR4fBOr2rfpUDXX63Lp1G1RWVlYaBwIIIIAAAggggAACCCCAAAIIIIDARoEGJEwYDQgggAACCCCAAAIIIIAAAggggEBVARImjAhXAqeeeqrdcccddvrpp9uUKVMy1u2CCy6wG2+80Y477ji77777MpavrcAtt9xiZ599tp155pmm/08db7/9tu2xxx62++67m/4/9li3bp1tscUW1qhRI9P/b47jkEMOsdmzZ9szzzxj+n+OwgvUNm4Kf2WugAACCCBQDAHilNz [...]
],
"text/plain": [
"<IPython.core.display.HTML object>"
@@ -3010,10 +3010,6 @@
"1 rows affected.\n",
"1 rows affected.\n",
"1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
"1 rows affected.\n"
]
}
@@ -3059,7 +3055,7 @@
"metadata": {},
"source": [
"<a id=\"print\"></a>\n",
- "# 7. Print run schedules"
+ "# 7. Print run schedules (display only)"
]
},
{
@@ -3072,7 +3068,9 @@
{
"cell_type": "code",
"execution_count": 32,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [
{
"name": "stdout",
diff --git a/community-artifacts/Deep-learning/automl/hyperband_v0.ipynb b/community-artifacts/Deep-learning/automl/hyperband_v0.ipynb
deleted file mode 100644
index 4cf7293..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_v0.ipynb
+++ /dev/null
@@ -1,259 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
- ]
- }
- ],
- "source": [
- "%load_ext sql"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "u'Connected: fmcquillan@madlib'"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
- "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
- "\n",
- "# Greenplum Database 5.x on GCP - via tunnel\n",
- "#%sql postgresql://gpadmin@localhost:8000/madlib\n",
- " \n",
- "# PostgreSQL local\n",
- "%sql postgresql://fmcquillan@localhost:5432/madlib"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 46,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- "\n",
- "\tdef __init__( self, get_params_function, try_params_function ):\n",
- "\t\tself.get_params = get_params_function\n",
- "\t\tself.try_params = try_params_function\n",
- "\n",
- "\t\tself.max_iter = 27 \t# maximum iterations per configuration\n",
- "\t\tself.eta = 3\t\t\t# defines configuration downsampling rate (default = 3)\n",
- "\n",
- "\t\tself.logeta = lambda x: log( x ) / log( self.eta )\n",
- "\t\tself.s_max = int( self.logeta( self.max_iter ))\n",
- "\t\tself.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- "\t\tself.results = []\t# list of dicts\n",
- "\t\tself.counter = 0\n",
- "\t\tself.best_loss = np.inf\n",
- "\t\tself.best_counter = -1\n",
- "\n",
- "\n",
- "\t# can be called multiple times\n",
- "\tdef run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- "\t\tfor s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- "\t\t\tprint (\" \") \n",
- "\t\t\tprint (\"s = \", s)\n",
- "\n",
- "\t\t\t# initial number of configurations\n",
- "\t\t\tn = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- "\t\t\t# initial number of iterations per config\n",
- "\t\t\tr = self.max_iter * self.eta ** ( -s )\n",
- "\n",
- "\t\t\t# n random configurations\n",
- "\t\t\tT = [ self.get_params() for i in range( n )]\n",
- "\n",
- "\t\t\tfor i in range(( s + 1 ) - int( skip_last )):\t# changed from s + 1\n",
- "\n",
- "\t\t\t\t# Run each of the n configs for <iterations>\n",
- "\t\t\t\t# and keep best (n_configs / eta) configurations\n",
- "\n",
- "\t\t\t\tn_configs = n * self.eta ** ( -i )\n",
- "\t\t\t\tn_iterations = r * self.eta ** ( i )\n",
- "\n",
- "\t\t\t\tprint \"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- "\t\t\t\t\tn_configs, n_iterations )\n",
- "\n",
- "\t\t\t\tval_losses = []\n",
- "\t\t\t\tearly_stops = []\n",
- "\n",
- "\t\t\t\tfor t in T:\n",
- "\n",
- "\t\t\t\t\tself.counter += 1\n",
- "\t\t\t\t\t#print \"\\n{} | {} | lowest loss so far: {:.4f} (run {})\\n\".format(\n",
- "\t\t\t\t\t#\tself.counter, ctime(), self.best_loss, self.best_counter )\n",
- "\n",
- "\t\t\t\t\tstart_time = time()\n",
- "\n",
- "\t\t\t\t\tif dry_run:\n",
- "\t\t\t\t\t\tresult = { 'loss': random(), 'log_loss': random(), 'auc': random()}\n",
- "\t\t\t\t\telse:\n",
- "\t\t\t\t\t\tresult = self.try_params( n_iterations, t )\t\t# <---\n",
- "\n",
- "\t\t\t\t\tassert( type( result ) == dict )\n",
- "\t\t\t\t\tassert( 'loss' in result )\n",
- "\n",
- "\t\t\t\t\tseconds = int( round( time() - start_time ))\n",
- "\t\t\t\t\t#print \"\\n{} seconds.\".format( seconds )\n",
- "\n",
- "\t\t\t\t\tloss = result['loss']\n",
- "\t\t\t\t\tval_losses.append( loss )\n",
- "\n",
- "\t\t\t\t\tearly_stop = result.get( 'early_stop', False )\n",
- "\t\t\t\t\tearly_stops.append( early_stop )\n",
- "\n",
- "\t\t\t\t\t# keeping track of the best result so far (for display only)\n",
- "\t\t\t\t\t# could do it be checking results each time, but hey\n",
- "\t\t\t\t\tif loss < self.best_loss:\n",
- "\t\t\t\t\t\tself.best_loss = loss\n",
- "\t\t\t\t\t\tself.best_counter = self.counter\n",
- "\n",
- "\t\t\t\t\tresult['counter'] = self.counter\n",
- "\t\t\t\t\tresult['seconds'] = seconds\n",
- "\t\t\t\t\tresult['params'] = t\n",
- "\t\t\t\t\tresult['iterations'] = n_iterations\n",
- "\n",
- "\t\t\t\t\tself.results.append( result )\n",
- "\n",
- "\t\t\t\t# select a number of best configurations for the next loop\n",
- "\t\t\t\t# filter out early stops, if any\n",
- "\t\t\t\tindices = np.argsort( val_losses )\n",
- "\t\t\t\tT = [ T[i] for i in indices if not early_stops[i]]\n",
- "\t\t\t\tT = T[ 0:int( n_configs / self.eta )]\n",
- "\n",
- "\t\treturn self.results\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_params():\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "def try_params():\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 47,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " \n",
- "('s = ', 3)\n",
- "\n",
- "*** 27 configurations x 1.0 iterations each\n",
- "\n",
- "*** 9.0 configurations x 3.0 iterations each\n",
- "\n",
- "*** 3.0 configurations x 9.0 iterations each\n",
- "\n",
- "*** 1.0 configurations x 27.0 iterations each\n",
- " \n",
- "('s = ', 2)\n",
- "\n",
- "*** 9 configurations x 3.0 iterations each\n",
- "\n",
- "*** 3.0 configurations x 9.0 iterations each\n",
- "\n",
- "*** 1.0 configurations x 27.0 iterations each\n",
- " \n",
- "('s = ', 1)\n",
- "\n",
- "*** 6 configurations x 9.0 iterations each\n",
- "\n",
- "*** 2.0 configurations x 27.0 iterations each\n",
- " \n",
- "('s = ', 0)\n",
- "\n",
- "*** 4 configurations x 27.0 iterations each\n"
- ]
- }
- ],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "\"bare-bones demonstration of using hyperband to tune sklearn GBT\"\n",
- "\n",
- "#from hyperband import Hyperband\n",
- "#from defs.gb import get_params, try_params\n",
- "\n",
- "hb = Hyperband( get_params, try_params )\n",
- "\n",
- "# no actual tuning, doesn't call try_params()\n",
- "results = hb.run( dry_run = True )\n",
- "\n",
- "#results = hb.run( skip_last = 1 ) # shorter run\n",
- "#results = hb.run()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/community-artifacts/Deep-learning/automl/hyperband_v1.ipynb b/community-artifacts/Deep-learning/automl/hyperband_v1.ipynb
deleted file mode 100644
index 106fd45..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_v1.ipynb
+++ /dev/null
@@ -1,3424 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Hyperband\n",
- "\n",
- "Impelementation of Hyperband https://arxiv.org/pdf/1603.06560.pdf"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
- ]
- }
- ],
- "source": [
- "%load_ext sql"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "u'Connected: gpadmin@madlib'"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
- "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
- "\n",
- "# Greenplum Database 5.x on GCP - via tunnel\n",
- "%sql postgresql://gpadmin@localhost:8000/madlib\n",
- " \n",
- "# PostgreSQL local\n",
- "#%sql postgresql://fmcquillan@localhost:5432/madlib"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "DROP TABLE IF EXISTS results;\n",
- "\n",
- "CREATE TABLE results ( \n",
- " model_id INTEGER, \n",
- " compile_params TEXT,\n",
- " fit_params TEXT, \n",
- " model_type TEXT, \n",
- " model_size DOUBLE PRECISION, \n",
- " metrics_elapsed_time DOUBLE PRECISION[], \n",
- " metrics_type TEXT[], \n",
- " training_metrics_final DOUBLE PRECISION, \n",
- " training_loss_final DOUBLE PRECISION, \n",
- " training_metrics DOUBLE PRECISION[], \n",
- " training_loss DOUBLE PRECISION[], \n",
- " validation_metrics_final DOUBLE PRECISION, \n",
- " validation_loss_final DOUBLE PRECISION, \n",
- " validation_metrics DOUBLE PRECISION[], \n",
- " validation_loss DOUBLE PRECISION[], \n",
- " model_arch_table TEXT, \n",
- " num_iterations INTEGER, \n",
- " start_training_time TIMESTAMP, \n",
- " end_training_time TIMESTAMP,\n",
- " s INTEGER, \n",
- " n INTEGER, \n",
- " r INTEGER\n",
- " );"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- " \n",
- " def __init__( self, get_params_function, try_params_function ):\n",
- " self.get_params = get_params_function\n",
- " self.try_params = try_params_function\n",
- "\n",
- " #self.max_iter = 81 # maximum iterations per configuration\n",
- " self.max_iter = 9 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
- "\n",
- " self.logeta = lambda x: log( x ) / log( self.eta )\n",
- " self.s_max = int( self.logeta( self.max_iter ))\n",
- " self.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- " self.results = [] # list of dicts\n",
- " self.counter = 0\n",
- " self.best_loss = np.inf\n",
- " self.best_counter = -1\n",
- "\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- " for s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- " # initial number of configurations\n",
- " n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- " # initial number of iterations per config\n",
- " r = self.max_iter * self.eta ** ( -s )\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n = \", n)\n",
- " print (\"r = \", r)\n",
- "\n",
- " # n random configurations\n",
- " T = self.get_params(n) # what to return from function if anything?\n",
- " \n",
- " for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1\n",
- "\n",
- " # Run each of the n configs for <iterations>\n",
- " # and keep best (n_configs / eta) configurations\n",
- "\n",
- " n_configs = n * self.eta ** ( -i )\n",
- " n_iterations = r * self.eta ** ( i )\n",
- "\n",
- " print \"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- " n_configs, n_iterations )\n",
- " \n",
- " # multi-model training\n",
- " U = self.try_params(s, n_configs, n_iterations) # what to return from function if anything?\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
- " k = int( n_configs / self.eta)\n",
- " %sql DELETE FROM mst_table_hb WHERE mst_key NOT IN (SELECT mst_key from iris_multi_model_info ORDER BY training_loss_final ASC LIMIT $k::INT);\n",
- " \n",
- " #return self.results\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_params(n):\n",
- " \n",
- " from sklearn.model_selection import ParameterSampler\n",
- " from scipy.stats.distributions import expon, uniform, lognorm\n",
- " import numpy as np\n",
- " \n",
- " # model architecture\n",
- " model_id = [1, 2]\n",
- "\n",
- " # compile params\n",
- " # loss function\n",
- " loss = ['categorical_crossentropy']\n",
- " # optimizer\n",
- " optimizer = ['Adam', 'SGD']\n",
- " # learning rate\n",
- " lr = [0.01, 0.1]\n",
- " # metrics\n",
- " metrics = ['accuracy']\n",
- "\n",
- " # fit params\n",
- " # batch size\n",
- " batch_size = [4, 8]\n",
- " # epochs\n",
- " epochs = [1]\n",
- "\n",
- " # create random param list\n",
- " param_grid = {\n",
- " 'model_id': model_id,\n",
- " 'loss': loss,\n",
- " 'optimizer': optimizer,\n",
- " 'lr': uniform(lr[0], lr[1]),\n",
- " 'metrics': metrics,\n",
- " 'batch_size': batch_size,\n",
- " 'epochs': epochs\n",
- " }\n",
- " param_list = list(ParameterSampler(param_grid, n_iter=n))\n",
- " \n",
- " import psycopg2 as p2\n",
- "\n",
- " #conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
- " #conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- " conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
- " cur = conn.cursor()\n",
- "\n",
- " %sql DROP TABLE IF EXISTS mst_table_hb, mst_table_auto_hb;\n",
- "\n",
- " %sql CREATE TABLE mst_table_hb(mst_key serial, model_id integer, compile_params varchar, fit_params varchar);\n",
- "\n",
- " for params in param_list:\n",
- "\n",
- " model_id = str(params.get(\"model_id\"))\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
- " row_content = \"(\" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
- " \n",
- " %sql INSERT INTO mst_table_hb (model_id, compile_params, fit_params) VALUES $row_content\n",
- " \n",
- " %sql DROP TABLE IF EXISTS mst_table_hb_summary;\n",
- " %sql CREATE TABLE mst_table_hb_summary (model_arch_table varchar);\n",
- " %sql INSERT INTO mst_table_hb_summary VALUES ('model_arch_library');\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "def try_params(s, n_configs, n_iterations):\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n_configs aka n = \", n_configs)\n",
- " print (\"n_iterations aka r = \", n_iterations)\n",
- " \n",
- " import psycopg2 as p2\n",
- "\n",
- " #conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
- " #conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- " conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
- " cur = conn.cursor()\n",
- "\n",
- " # multi-model fit\n",
- " # TO DO: use warm start to continue from where left off after if not 1st time thru for this s value\n",
- " %sql DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
- " %sql SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', 'iris_multi_model', 'mst_table_hb', $n_iterations::INT, 0);\n",
- " \n",
- " # save results\n",
- " %sql DROP TABLE IF EXISTS temp_results;\n",
- " %sql CREATE TABLE temp_results AS (SELECT * FROM iris_multi_model_info);\n",
- " %sql ALTER TABLE temp_results DROP COLUMN mst_key, ADD COLUMN model_arch_table TEXT, ADD COLUMN num_iterations INTEGER, ADD COLUMN start_training_time TIMESTAMP, ADD COLUMN end_training_time TIMESTAMP, ADD COLUMN s INTEGER, ADD COLUMN n INTEGER, ADD COLUMN r INTEGER;\n",
- " %sql UPDATE temp_results SET model_arch_table = (SELECT model_arch_table FROM iris_multi_model_summary), num_iterations = (SELECT num_iterations FROM iris_multi_model_summary), start_training_time = (SELECT start_training_time FROM iris_multi_model_summary), end_training_time = (SELECT end_training_time FROM iris_multi_model_summary), s = $s, n = $n_configs, r = $n_iterations;\n",
- " %sql INSERT INTO results (SELECT * FROM temp_results);\n",
- "\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "def top_k(k):\n",
- " \n",
- " print (\"k = \", k)\n",
- " %sql DELETE FROM mst_table_hb WHERE mst_key NOT IN (SELECT mst_key from iris_multi_model_info ORDER BY training_loss_final ASC LIMIT $k::INT);\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n"
- ]
- }
- ],
- "source": [
- "get_params(3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>madlib_keras_fit_multiple_model</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td></td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[('',)]"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', 'iris_multi_model', 'mst_table_hb', 3.0::INT, 0);"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "('s = ', 2)\n",
- "('n = ', 9)\n",
- "('r = ', 1.0)\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "\n",
- "*** 9 configurations x 1.0 iterations each\n",
- "('s = ', 2)\n",
- "('n_configs aka n = ', 9)\n",
- "('n_iterations aka r = ', 1.0)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "Done.\n",
- "9 rows affected.\n",
- "9 rows affected.\n",
- "6 rows affected.\n",
- "\n",
- "*** 3.0 configurations x 3.0 iterations each\n",
- "('s = ', 2)\n",
- "('n_configs aka n = ', 3.0)\n",
- "('n_iterations aka r = ', 3.0)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "2 rows affected.\n",
- "\n",
- "*** 1.0 configurations x 9.0 iterations each\n",
- "('s = ', 2)\n",
- "('n_configs aka n = ', 1.0)\n",
- "('n_iterations aka r = ', 9.0)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "('s = ', 1)\n",
- "('n = ', 3)\n",
- "('r = ', 3.0)\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "\n",
- "*** 3 configurations x 3.0 iterations each\n",
- "('s = ', 1)\n",
- "('n_configs aka n = ', 3)\n",
- "('n_iterations aka r = ', 3.0)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "2 rows affected.\n",
- "\n",
- "*** 1.0 configurations x 9.0 iterations each\n",
- "('s = ', 1)\n",
- "('n_configs aka n = ', 1.0)\n",
- "('n_iterations aka r = ', 9.0)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "('s = ', 0)\n",
- "('n = ', 3)\n",
- "('r = ', 9)\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "\n",
- "*** 3 configurations x 9.0 iterations each\n",
- "('s = ', 0)\n",
- "('n_configs aka n = ', 3)\n",
- "('n_iterations aka r = ', 9)\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "2 rows affected.\n"
- ]
- }
- ],
- "source": [
- "hp = Hyperband( get_params, try_params )\n",
- "results = hp.run()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "('s = ', 4)\n",
- "('n = ', 81)\n",
- "('r = ', 1.0)\n",
- "\n",
- "*** 81 configurations x 1.0 iterations each\n",
- "\n",
- "1 | Mon Nov 4 11:31:06 2019 | lowest loss so far: inf (run -1)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "2 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.8345 (run 1)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "3 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.6510 (run 2)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "4 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "5 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "6 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "7 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "8 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "9 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "10 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "11 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "12 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "13 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "14 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "15 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "16 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "17 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "18 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "19 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "20 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "21 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "22 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "23 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "24 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "25 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "26 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "27 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "28 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "29 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "30 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "31 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "32 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "33 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "34 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "35 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "36 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "37 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "38 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "39 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "40 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "41 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "42 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "43 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "44 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "45 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "46 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "47 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "48 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "49 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "50 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "51 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "52 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "53 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "54 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "55 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "56 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "57 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "58 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "59 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "60 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "61 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "62 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "63 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "64 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "65 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "66 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "67 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "68 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "69 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "70 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "71 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "72 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "73 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "74 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "75 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "76 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "77 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "78 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "79 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "80 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "81 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 27.0 configurations x 3.0 iterations each\n",
- "\n",
- "82 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "83 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "84 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "85 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "86 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "87 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "88 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "89 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "90 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "91 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "92 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "93 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "94 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "95 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "96 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "97 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "98 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "99 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "100 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "101 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "102 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "103 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "104 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "105 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "106 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "107 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "108 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 9.0 configurations x 9.0 iterations each\n",
- "\n",
- "109 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "110 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "111 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "112 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "113 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "114 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "115 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "116 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "117 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "118 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "119 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "120 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "121 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 3)\n",
- "('n = ', 27)\n",
- "('r = ', 3.0)\n",
- "\n",
- "*** 27 configurations x 3.0 iterations each\n",
- "\n",
- "122 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "123 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "124 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "125 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "126 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "127 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "128 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "129 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "130 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "131 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "132 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "133 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "134 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "135 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "136 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "137 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "138 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "139 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "140 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "141 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "142 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "143 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "144 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "145 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "146 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "147 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "148 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 9.0 configurations x 9.0 iterations each\n",
- "\n",
- "149 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "150 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "151 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "152 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "153 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "154 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "155 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "156 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "157 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "158 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "159 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "160 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "161 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 2)\n",
- "('n = ', 9)\n",
- "('r = ', 9.0)\n",
- "\n",
- "*** 9 configurations x 9.0 iterations each\n",
- "\n",
- "162 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "163 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "164 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "165 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "166 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "167 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "168 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "169 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "170 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "171 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "172 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "173 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "174 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 1)\n",
- "('n = ', 6)\n",
- "('r = ', 27.0)\n",
- "\n",
- "*** 6 configurations x 27.0 iterations each\n",
- "\n",
- "175 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "176 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "177 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "178 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "179 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "180 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 2.0 configurations x 81.0 iterations each\n",
- "\n",
- "181 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "182 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 0)\n",
- "('n = ', 5)\n",
- "('r = ', 81)\n",
- "\n",
- "*** 5 configurations x 81.0 iterations each\n",
- "\n",
- "183 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "184 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "185 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "186 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "187 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n"
- ]
- }
- ],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "\"bare-bones demonstration of using hyperband to tune sklearn GBT\"\n",
- "\n",
- "#from hyperband import Hyperband\n",
- "#from defs.gb import get_params, try_params\n",
- "\n",
- "hb = Hyperband( get_params, try_params )\n",
- "\n",
- "# no actual tuning, doesn't call try_params()\n",
- "results = hb.run( dry_run = True )\n",
- "\n",
- "#results = hb.run( skip_last = 1 ) # shorter run\n",
- "#results = hb.run()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'auc': 0.14932365125588232,\n",
- " 'counter': 1,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.3449042743689773,\n",
- " 'loss': 0.09612127946443949,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5297427251128467,\n",
- " 'counter': 2,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6810161167234852,\n",
- " 'loss': 0.29350431308140146,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6457699181279158,\n",
- " 'counter': 3,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.5595007428160708,\n",
- " 'loss': 0.34509736982486094,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8206491838665859,\n",
- " 'counter': 4,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7352560865167196,\n",
- " 'loss': 0.8643507964048233,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.18475486383110362,\n",
- " 'counter': 5,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8095582069640777,\n",
- " 'loss': 0.6422878527834606,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5201466775139346,\n",
- " 'counter': 6,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.061716851339827516,\n",
- " 'loss': 0.7637321166865296,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3678060389872875,\n",
- " 'counter': 7,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7032141409909735,\n",
- " 'loss': 0.9970910128616833,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.623833972507887,\n",
- " 'counter': 8,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6271070921439452,\n",
- " 'loss': 0.6818115924622632,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5347124236042501,\n",
- " 'counter': 9,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.628467546548407,\n",
- " 'loss': 0.5524470113431209,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8321190560222416,\n",
- " 'counter': 10,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.41638463730432973,\n",
- " 'loss': 0.09678521908764359,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.061062929867188864,\n",
- " 'counter': 11,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.0028834425894580518,\n",
- " 'loss': 0.4216971031578337,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9376513800154704,\n",
- " 'counter': 12,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.16826870560571294,\n",
- " 'loss': 0.5449483165104079,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5834570772207172,\n",
- " 'counter': 13,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8877445886262226,\n",
- " 'loss': 0.4698389127372775,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.21304684547405073,\n",
- " 'counter': 14,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7469669529997487,\n",
- " 'loss': 0.23330551995179538,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6763383197387115,\n",
- " 'counter': 15,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7302057912871137,\n",
- " 'loss': 0.7292961849015884,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.08100918401803436,\n",
- " 'counter': 16,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.2476607780064477,\n",
- " 'loss': 0.22020659312661672,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6772648893392582,\n",
- " 'counter': 17,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.2635059351107911,\n",
- " 'loss': 0.3548780552739327,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.38959696020655343,\n",
- " 'counter': 18,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8074850148212164,\n",
- " 'loss': 0.6831797045398854,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.28214128210394984,\n",
- " 'counter': 19,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.5586043847447024,\n",
- " 'loss': 0.5895247404509585,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5935634609284195,\n",
- " 'counter': 20,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.2524445181735028,\n",
- " 'loss': 0.5764078217427744,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.260253788134716,\n",
- " 'counter': 21,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.42072655521356217,\n",
- " 'loss': 0.2895823050412111,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.33883876178689964,\n",
- " 'counter': 22,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.39083614971508474,\n",
- " 'loss': 0.9650770442907194,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.719970588686091,\n",
- " 'counter': 23,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.9903707648425806,\n",
- " 'loss': 0.039462762474445245,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3209159736248801,\n",
- " 'counter': 24,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8592443478622364,\n",
- " 'loss': 0.7397879324981425,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7997691243085803,\n",
- " 'counter': 25,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.4922901573443462,\n",
- " 'loss': 0.18765450639699832,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.22631202543756523,\n",
- " 'counter': 26,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.28006503355111945,\n",
- " 'loss': 0.48717636277040455,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5665453352489781,\n",
- " 'counter': 27,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.035009402524677435,\n",
- " 'loss': 0.9930670090906539,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6817156539035648,\n",
- " 'counter': 28,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.15294906076180148,\n",
- " 'loss': 0.8503715476801526,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7685638961337506,\n",
- " 'counter': 29,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.011243538135467968,\n",
- " 'loss': 0.6052913272912087,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6988396342541379,\n",
- " 'counter': 30,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.5071967804132129,\n",
- " 'loss': 0.7805187244077492,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.32570758065286776,\n",
- " 'counter': 31,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.03333287764167736,\n",
- " 'loss': 0.48876586221992946,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8033965511414279,\n",
- " 'counter': 32,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.25782010930427446,\n",
- " 'loss': 0.50225218130975,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8953389854952885,\n",
- " 'counter': 33,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.12933396436879585,\n",
- " 'loss': 0.5469737022467042,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9904744922199102,\n",
- " 'counter': 34,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.27921507740970253,\n",
- " 'loss': 0.4623269363430126,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8216755293065882,\n",
- " 'counter': 35,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.3095970561001742,\n",
- " 'loss': 0.08455142081972067,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.49343941684530424,\n",
- " 'counter': 36,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.19192206286177693,\n",
- " 'loss': 0.8753463298815788,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7768566010812188,\n",
- " 'counter': 37,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8844797226134172,\n",
- " 'loss': 0.006524719143109925,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5568446688778599,\n",
- " 'counter': 38,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.03343421254704615,\n",
- " 'loss': 0.28727277111436633,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.44890709887467173,\n",
- " 'counter': 39,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6621309512898387,\n",
- " 'loss': 0.8452824397393393,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7159606495208108,\n",
- " 'counter': 40,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.03640098587868845,\n",
- " 'loss': 0.5345771594057487,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.049708369641641825,\n",
- " 'counter': 41,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.9166311014849851,\n",
- " 'loss': 0.2459478308253078,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4431795147502938,\n",
- " 'counter': 42,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.34347646704287493,\n",
- " 'loss': 0.7455012183673186,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.577438555466327,\n",
- " 'counter': 43,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6337934682570888,\n",
- " 'loss': 0.7060270787405257,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.15988021579253786,\n",
- " 'counter': 44,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.23610216907848525,\n",
- " 'loss': 0.7618289903036082,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.34729025936083047,\n",
- " 'counter': 45,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.15386208941648094,\n",
- " 'loss': 0.651297365300279,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6690515453619355,\n",
- " 'counter': 46,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.42416417342440316,\n",
- " 'loss': 0.9798304500519983,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4062833264316038,\n",
- " 'counter': 47,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.06633302273089348,\n",
- " 'loss': 0.002314197657050765,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.030070938160237537,\n",
- " 'counter': 48,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6231477920854881,\n",
- " 'loss': 0.7975859655350828,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.41159190478675756,\n",
- " 'counter': 49,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.9030149677278394,\n",
- " 'loss': 0.7962848279570938,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8900336390951261,\n",
- " 'counter': 50,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.29475568711433975,\n",
- " 'loss': 0.28702364265668745,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.2773802251144335,\n",
- " 'counter': 51,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.5928073265043763,\n",
- " 'loss': 0.1581297672397728,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8427151952466633,\n",
- " 'counter': 52,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.4486620660086762,\n",
- " 'loss': 0.3140863162642532,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5118370029322294,\n",
- " 'counter': 53,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.43671336182927545,\n",
- " 'loss': 0.406539489245766,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.341510694918731,\n",
- " 'counter': 54,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.48428189054966786,\n",
- " 'loss': 0.20791346518739418,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6966333576762844,\n",
- " 'counter': 55,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.3138884358841406,\n",
- " 'loss': 0.9104675288781803,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.22114849572838202,\n",
- " 'counter': 56,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6110657848443738,\n",
- " 'loss': 0.5707451367088578,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7029113308489365,\n",
- " 'counter': 57,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6178152359864543,\n",
- " 'loss': 0.8274994418321521,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.013018575113299069,\n",
- " 'counter': 58,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.2360837039014222,\n",
- " 'loss': 0.9016765926145995,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8665715166180401,\n",
- " 'counter': 59,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.0026488267932833764,\n",
- " 'loss': 0.059307621400578325,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7743746952811773,\n",
- " 'counter': 60,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6900012476648771,\n",
- " 'loss': 0.6666667730347293,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5946878497942433,\n",
- " 'counter': 61,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.19961327853188193,\n",
- " 'loss': 0.5141734373013741,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.30579644485953283,\n",
- " 'counter': 62,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.507198375492157,\n",
- " 'loss': 0.02053142631511551,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6590602737914744,\n",
- " 'counter': 63,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7624233062849223,\n",
- " 'loss': 0.6141459128770841,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.20986327943383043,\n",
- " 'counter': 64,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7979456085800078,\n",
- " 'loss': 0.05549406428366488,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9826418040282519,\n",
- " 'counter': 65,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.2464671883682481,\n",
- " 'loss': 0.24585107122333882,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8078805365825079,\n",
- " 'counter': 66,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.9749889676719125,\n",
- " 'loss': 0.7386885612341967,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.19149991385444298,\n",
- " 'counter': 67,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.06422864002294604,\n",
- " 'loss': 0.28302300854052853,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5850722867509154,\n",
- " 'counter': 68,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.9241418374035124,\n",
- " 'loss': 0.09780509865247677,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5372504248125591,\n",
- " 'counter': 69,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.6361881418189349,\n",
- " 'loss': 0.393617318158613,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.944692306830134,\n",
- " 'counter': 70,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.10124417675344488,\n",
- " 'loss': 0.30406871143958103,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7295273137695822,\n",
- " 'counter': 71,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.38511124116491224,\n",
- " 'loss': 0.9321688112586601,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.14396665290994226,\n",
- " 'counter': 72,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7965319124572947,\n",
- " 'loss': 0.7832759619660403,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9598140096591034,\n",
- " 'counter': 73,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.3006372299429394,\n",
- " 'loss': 0.6811306390439349,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9589866032479657,\n",
- " 'counter': 74,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8086605103459729,\n",
- " 'loss': 0.6480868740360761,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4604507685852718,\n",
- " 'counter': 75,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.934138847763478,\n",
- " 'loss': 0.46784545504478314,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.18296420228236143,\n",
- " 'counter': 76,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.8245420856995567,\n",
- " 'loss': 0.972634414534186,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7709734625367239,\n",
- " 'counter': 77,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.3683277546209054,\n",
- " 'loss': 0.3394274397919588,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.2979644109828484,\n",
- " 'counter': 78,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.4998427152313397,\n",
- " 'loss': 0.4582683874561885,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.05007082650334205,\n",
- " 'counter': 79,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.796716759796841,\n",
- " 'loss': 0.08778209355256572,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.45979067527981465,\n",
- " 'counter': 80,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.020953727973823777,\n",
- " 'loss': 0.6089566476449716,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.17842291459911253,\n",
- " 'counter': 81,\n",
- " 'iterations': 1.0,\n",
- " 'log_loss': 0.7053992418707126,\n",
- " 'loss': 0.6230410294625963,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8996483331893631,\n",
- " 'counter': 82,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.3167006641102992,\n",
- " 'loss': 0.16476508861020422,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8953395477328949,\n",
- " 'counter': 83,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.7673121965868592,\n",
- " 'loss': 0.025222690970011286,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.38032696111866104,\n",
- " 'counter': 84,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.4796548903059872,\n",
- " 'loss': 0.7458974038760335,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.032455422669963596,\n",
- " 'counter': 85,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5901866655247737,\n",
- " 'loss': 0.662110478346492,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.11585350435665698,\n",
- " 'counter': 86,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.312173138009064,\n",
- " 'loss': 0.2779682107165128,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7741754906105917,\n",
- " 'counter': 87,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.1906549889869844,\n",
- " 'loss': 0.12691435021828767,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.07746765348272833,\n",
- " 'counter': 88,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.8878745012924916,\n",
- " 'loss': 0.7001492859702206,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.30296011884094276,\n",
- " 'counter': 89,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.29444679975841614,\n",
- " 'loss': 0.7319940270061703,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.03288140548304952,\n",
- " 'counter': 90,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.8463635455425874,\n",
- " 'loss': 0.10618958497581377,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7732339697686416,\n",
- " 'counter': 91,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.2663718116578472,\n",
- " 'loss': 0.9262904221746493,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7344626741509571,\n",
- " 'counter': 92,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.7526650720024616,\n",
- " 'loss': 0.24216877112571178,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5879351331170476,\n",
- " 'counter': 93,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.314940650253717,\n",
- " 'loss': 0.34080463056337396,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6491741667128744,\n",
- " 'counter': 94,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.44334645145718066,\n",
- " 'loss': 0.570325069559337,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.006899761143807082,\n",
- " 'counter': 95,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.18368973795102816,\n",
- " 'loss': 0.8901141509344943,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4763029036511496,\n",
- " 'counter': 96,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.7347658042913988,\n",
- " 'loss': 0.7743729022882786,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.27860402471208456,\n",
- " 'counter': 97,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.004897700312819553,\n",
- " 'loss': 0.4209772643297237,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9286019249165068,\n",
- " 'counter': 98,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.6753865886577269,\n",
- " 'loss': 0.08516529949245888,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7842613678350525,\n",
- " 'counter': 99,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.18928720248351272,\n",
- " 'loss': 0.2934987986367198,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8815896482892973,\n",
- " 'counter': 100,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.20397676919438612,\n",
- " 'loss': 0.3694663542562,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.10146671272978525,\n",
- " 'counter': 101,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.3280685736687545,\n",
- " 'loss': 0.06304746700176833,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.2209283313543119,\n",
- " 'counter': 102,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.6478610703404913,\n",
- " 'loss': 0.984832243529239,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.44977769472993245,\n",
- " 'counter': 103,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.9236212659411028,\n",
- " 'loss': 0.7828276111713075,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.12553979666406057,\n",
- " 'counter': 104,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5418233128611745,\n",
- " 'loss': 0.6717846584047122,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.23642297116408884,\n",
- " 'counter': 105,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.47673450637829873,\n",
- " 'loss': 0.9821532224821047,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5741963231702492,\n",
- " 'counter': 106,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.7169544802804382,\n",
- " 'loss': 0.1969495406542292,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.11440620809960522,\n",
- " 'counter': 107,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5529263134299719,\n",
- " 'loss': 0.7257557665463978,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.03530500348935384,\n",
- " 'counter': 108,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.9110596043136135,\n",
- " 'loss': 0.8651882903450844,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5223031959393326,\n",
- " 'counter': 109,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.5318887983248776,\n",
- " 'loss': 0.8469517836145459,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7208177536119641,\n",
- " 'counter': 110,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.10922773695608279,\n",
- " 'loss': 0.001302267990096584,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3110495827894265,\n",
- " 'counter': 111,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.44760951981063224,\n",
- " 'loss': 0.46973501425171893,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3588083186787636,\n",
- " 'counter': 112,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.5532070351942855,\n",
- " 'loss': 0.8825859318997112,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7651239975390371,\n",
- " 'counter': 113,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.9155250548514041,\n",
- " 'loss': 0.025237482554793966,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9300875242084679,\n",
- " 'counter': 114,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.4773688144026147,\n",
- " 'loss': 0.06180465496742127,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.0996109981974258,\n",
- " 'counter': 115,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.6493369819432604,\n",
- " 'loss': 0.013985186777659475,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9012034688917692,\n",
- " 'counter': 116,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.232198008796352,\n",
- " 'loss': 0.9914319418596014,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7757629371895003,\n",
- " 'counter': 117,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.3683260119792202,\n",
- " 'loss': 0.7905506633542351,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6805120101046878,\n",
- " 'counter': 118,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.3020478210948526,\n",
- " 'loss': 0.31890103857000807,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.16132436724581078,\n",
- " 'counter': 119,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.9700131330760121,\n",
- " 'loss': 0.21419549725311016,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.13367618629195566,\n",
- " 'counter': 120,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.8811277401829034,\n",
- " 'loss': 0.2660610384966404,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.0656842181355819,\n",
- " 'counter': 121,\n",
- " 'iterations': 81.0,\n",
- " 'log_loss': 0.531425333497132,\n",
- " 'loss': 0.3738357336717808,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5247195757299832,\n",
- " 'counter': 122,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.31750122429590744,\n",
- " 'loss': 0.5604300117060509,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7766240710183595,\n",
- " 'counter': 123,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5206230315191374,\n",
- " 'loss': 0.11510225861587098,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.833613810628292,\n",
- " 'counter': 124,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.238903452302888,\n",
- " 'loss': 0.3097701363065719,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5101001510499669,\n",
- " 'counter': 125,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.6501778955203619,\n",
- " 'loss': 0.45666391331535117,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6289704258920873,\n",
- " 'counter': 126,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.3175705641880501,\n",
- " 'loss': 0.2244181097063077,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.049491062942634834,\n",
- " 'counter': 127,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.10125760302696518,\n",
- " 'loss': 0.8928513883669132,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.10723693854962635,\n",
- " 'counter': 128,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.38901129074802565,\n",
- " 'loss': 0.9528848169353429,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7586625887149587,\n",
- " 'counter': 129,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.15115125711208044,\n",
- " 'loss': 0.014653352772398653,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.16869688165252406,\n",
- " 'counter': 130,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.8478428792368242,\n",
- " 'loss': 0.37975147720678026,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.2028507498813994,\n",
- " 'counter': 131,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5284341093464384,\n",
- " 'loss': 0.8885195119854732,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.07688952516948655,\n",
- " 'counter': 132,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.30626688927558576,\n",
- " 'loss': 0.3094312602240511,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7739127306455509,\n",
- " 'counter': 133,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.4172768606236803,\n",
- " 'loss': 0.7508531480394048,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5764482720961361,\n",
- " 'counter': 134,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.14694952450953735,\n",
- " 'loss': 0.49088397170748677,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9788854348863874,\n",
- " 'counter': 135,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.29798370159397425,\n",
- " 'loss': 0.00032905602146438007,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8220209183525184,\n",
- " 'counter': 136,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5014649265135969,\n",
- " 'loss': 0.5790784156747228,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6259786905481609,\n",
- " 'counter': 137,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.042576172863746486,\n",
- " 'loss': 0.38933594587482434,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4648530005837833,\n",
- " 'counter': 138,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.03420911713389474,\n",
- " 'loss': 0.8188856390981347,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.12564906294227307,\n",
- " 'counter': 139,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.7730380012180292,\n",
- " 'loss': 0.8244186334698065,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8164859539155577,\n",
- " 'counter': 140,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.27487923695948524,\n",
- " 'loss': 0.021740489831387655,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3322569548647567,\n",
- " 'counter': 141,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5265198266148321,\n",
- " 'loss': 0.5320474239268533,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.43860815961769006,\n",
- " 'counter': 142,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.09297835907649554,\n",
- " 'loss': 0.6769464211974119,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9279725690148007,\n",
- " 'counter': 143,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.6635889423915476,\n",
- " 'loss': 0.7639368887378851,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.033544534103501444,\n",
- " 'counter': 144,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.915246669805891,\n",
- " 'loss': 0.7053455575611667,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5075879672736031,\n",
- " 'counter': 145,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.16942466004412282,\n",
- " 'loss': 0.17174966413087556,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.2800575533463264,\n",
- " 'counter': 146,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.31888425959168454,\n",
- " 'loss': 0.564347214563593,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8432091832234639,\n",
- " 'counter': 147,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.5428688815119704,\n",
- " 'loss': 0.10755044240923706,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3566612776049186,\n",
- " 'counter': 148,\n",
- " 'iterations': 3.0,\n",
- " 'log_loss': 0.41328249130294803,\n",
- " 'loss': 0.7755190182154043,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6392050218312009,\n",
- " 'counter': 149,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.5925961191076492,\n",
- " 'loss': 0.16762873144777568,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.04674159780230158,\n",
- " 'counter': 150,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.8516102444082616,\n",
- " 'loss': 0.8556994990400649,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.4017873704066903,\n",
- " 'counter': 151,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.5112627752272675,\n",
- " 'loss': 0.3569586561350456,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8340666173377387,\n",
- " 'counter': 152,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.9446102917052954,\n",
- " 'loss': 0.5212135566011181,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.580685685675133,\n",
- " 'counter': 153,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.9008936157670153,\n",
- " 'loss': 0.7133904463463437,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8216669219498715,\n",
- " 'counter': 154,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.642407933520737,\n",
- " 'loss': 0.25987622452527404,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.1219267252329107,\n",
- " 'counter': 155,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.6545156950387196,\n",
- " 'loss': 0.32292023554125926,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6042193527835271,\n",
- " 'counter': 156,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.40200498422921305,\n",
- " 'loss': 0.8421553561696671,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3493005856626694,\n",
- " 'counter': 157,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.35318104337634926,\n",
- " 'loss': 0.7700194868385192,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.13196247808805261,\n",
- " 'counter': 158,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.0880043096847839,\n",
- " 'loss': 0.30410196671776657,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6093129479295774,\n",
- " 'counter': 159,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.6998649820522845,\n",
- " 'loss': 0.44844223677047723,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.25017231395921147,\n",
- " 'counter': 160,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.047635552750225685,\n",
- " 'loss': 0.8848990434980698,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6335592092432603,\n",
- " 'counter': 161,\n",
- " 'iterations': 81.0,\n",
- " 'log_loss': 0.8621888727907012,\n",
- " 'loss': 0.8803427178896533,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.45753629500055115,\n",
- " 'counter': 162,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.8068030046604604,\n",
- " 'loss': 0.18943723013787606,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9969193634649506,\n",
- " 'counter': 163,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.37500158224223135,\n",
- " 'loss': 0.7688523651856489,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.357386674867637,\n",
- " 'counter': 164,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.8165131961896558,\n",
- " 'loss': 0.9693066933917974,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7535755450807958,\n",
- " 'counter': 165,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.7998447604628575,\n",
- " 'loss': 0.07652694221842082,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9577803131411142,\n",
- " 'counter': 166,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.9342701569501605,\n",
- " 'loss': 0.6280841159525683,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.1449007354101708,\n",
- " 'counter': 167,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.5471894051166079,\n",
- " 'loss': 0.06572953918283864,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.6787455152608078,\n",
- " 'counter': 168,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.8219477112352841,\n",
- " 'loss': 0.7344967690081345,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.13165584552000364,\n",
- " 'counter': 169,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.3889792245599526,\n",
- " 'loss': 0.4210657136923722,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.06201807035142759,\n",
- " 'counter': 170,\n",
- " 'iterations': 9.0,\n",
- " 'log_loss': 0.26716809728290936,\n",
- " 'loss': 0.14597089784162198,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8624386730545928,\n",
- " 'counter': 171,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.8322745409402996,\n",
- " 'loss': 0.20077430739452928,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5715381698905945,\n",
- " 'counter': 172,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.7768489012428331,\n",
- " 'loss': 0.41670892677072635,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9504349501608724,\n",
- " 'counter': 173,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.33037534912176747,\n",
- " 'loss': 0.8416009329004872,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.3599171525262599,\n",
- " 'counter': 174,\n",
- " 'iterations': 81.0,\n",
- " 'log_loss': 0.18406511309555207,\n",
- " 'loss': 0.42147703248451496,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.21820227121120184,\n",
- " 'counter': 175,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.8464401125613037,\n",
- " 'loss': 0.7221498252181305,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.38119137963197636,\n",
- " 'counter': 176,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.9479153512340984,\n",
- " 'loss': 0.29062627566721855,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.22770017353156757,\n",
- " 'counter': 177,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.4974406767661237,\n",
- " 'loss': 0.6351536093891014,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.07157251719815028,\n",
- " 'counter': 178,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.7398139747408429,\n",
- " 'loss': 0.7368347574266718,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8214789406991464,\n",
- " 'counter': 179,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.9021437479943136,\n",
- " 'loss': 0.9330381261782902,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8359682498265789,\n",
- " 'counter': 180,\n",
- " 'iterations': 27.0,\n",
- " 'log_loss': 0.4715081057970302,\n",
- " 'loss': 0.9091357807927576,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.9772179053215029,\n",
- " 'counter': 181,\n",
- " 'iterations': 81.0,\n",
- " 'log_loss': 0.06464750664889651,\n",
- " 'loss': 0.8809809023370188,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.5735130758927682,\n",
- " 'counter': 182,\n",
- " 'iterations': 81.0,\n",
- " 'log_loss': 0.3034499002530837,\n",
- " 'loss': 0.1269187021236282,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.1512467303346996,\n",
- " 'counter': 183,\n",
- " 'iterations': 81,\n",
- " 'log_loss': 0.06311948439063264,\n",
- " 'loss': 0.39857231746894706,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.8624041903201188,\n",
- " 'counter': 184,\n",
- " 'iterations': 81,\n",
- " 'log_loss': 0.4487021577153334,\n",
- " 'loss': 0.482329681305285,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7276284122996118,\n",
- " 'counter': 185,\n",
- " 'iterations': 81,\n",
- " 'log_loss': 0.6673161280963441,\n",
- " 'loss': 0.2502910767343628,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7599292676148188,\n",
- " 'counter': 186,\n",
- " 'iterations': 81,\n",
- " 'log_loss': 0.8642048547707212,\n",
- " 'loss': 0.714166849575351,\n",
- " 'params': None,\n",
- " 'seconds': 0},\n",
- " {'auc': 0.7396708566431607,\n",
- " 'counter': 187,\n",
- " 'iterations': 81,\n",
- " 'log_loss': 0.8495002542298392,\n",
- " 'loss': 0.14587931909618035,\n",
- " 'params': None,\n",
- " 'seconds': 0}]"
- ]
- },
- "execution_count": 37,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 118,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2, 'b': 0.3388081749546307, 'c': 0.704635960884642},\n",
- " {'a': 1, 'b': 0.4904175136129263, 'c': 0.8971084273807718},\n",
- " {'a': 1, 'b': 1.2386463990117793, 'c': 0.21568311690580266},\n",
- " {'a': 1, 'b': 1.91007461806631, 'c': 0.17778124867596956},\n",
- " {'a': 1, 'b': 1.2563450220231427, 'c': 0.002076412746974121}]"
- ]
- },
- "execution_count": 118,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from sklearn.model_selection import ParameterSampler\n",
- "from scipy.stats.distributions import expon, uniform, lognorm\n",
- "import numpy as np\n",
- "#rng = np.random.RandomState()\n",
- "param_grid = {'a':[1, 2], 'b': expon(), 'c': uniform()}\n",
- "#param_list = list(ParameterSampler(param_grid, n_iter=5, random_state=rng))\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=5))\n",
- "rounded_list = [dict((k, round(v, 6)) for (k, v) in d.items()) for d in param_list]\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2, 'b': 0.37954129345633403, 'c': 0.3742154014629032},\n",
- " {'a': 2, 'b': 1.2830633021262747, 'c': 0.4373122879029032},\n",
- " {'a': 1, 'b': 0.22037072550727527, 'c': 0.26397341600176616},\n",
- " {'a': 1, 'b': 0.549444485603122, 'c': 0.8317686948528791},\n",
- " {'a': 1, 'b': 1.0567787144413414, 'c': 0.9560841093558743}]"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2.0, 'b': 0.379541, 'c': 0.374215},\n",
- " {'a': 2.0, 'b': 1.283063, 'c': 0.437312},\n",
- " {'a': 1.0, 'b': 0.220371, 'c': 0.263973},\n",
- " {'a': 1.0, 'b': 0.549444, 'c': 0.831769},\n",
- " {'a': 1.0, 'b': 1.056779, 'c': 0.956084}]"
- ]
- },
- "execution_count": 34,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "rounded_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 150,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'d': 2.9713720038716116},\n",
- " {'d': 10.275052606706604},\n",
- " {'d': 4.211836333907813},\n",
- " {'d': 3.6005371688499834},\n",
- " {'d': 14.68709362771547}]"
- ]
- },
- "execution_count": 150,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "#rng = np.random.RandomState(0)\n",
- "param_grid = {'d': lognorm(1, 2, 3)}\n",
- "#param_list = list(ParameterSampler(param_grid, n_iter=5, random_state=rng))\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=5))\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 266,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.07983433464722507,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.03805362658279962,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.09043633721868387,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.02775811670911417,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.104019113296403,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.06986494800074812,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.010449656955883938,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.04915490422264339,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.05257644929029893,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.02993608422766151,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'}]"
- ]
- },
- "execution_count": 266,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# model architecture\n",
- "model_id = [1, 2]\n",
- "\n",
- "# compile params\n",
- "\n",
- "# loss function\n",
- "loss = ['categorical_crossentropy']\n",
- "# optimizer\n",
- "optimizer = ['Adam', 'SGD']\n",
- "# learning rate\n",
- "lr = [0.01, 0.1]\n",
- "# metrics\n",
- "metrics = ['accuracy']\n",
- "\n",
- "# fit params\n",
- "\n",
- "# batch size\n",
- "batch_size = [4, 8]\n",
- "# epochs\n",
- "epochs = [1]\n",
- "\n",
- "# create random param list\n",
- "param_grid = {\n",
- " 'model_id': model_id,\n",
- " 'loss': loss,\n",
- " 'optimizer': optimizer,\n",
- " 'lr': uniform(lr[0], lr[1]),\n",
- " 'metrics': metrics,\n",
- " 'batch_size': batch_size,\n",
- " 'epochs': epochs\n",
- "}\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=10))\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 212,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.03396784466820144,\n",
- " 'optimizer': 'Adam'}"
- ]
- },
- "execution_count": 212,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "param_list[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 285,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n"
- ]
- }
- ],
- "source": [
- "for params in param_list:\n",
- "# for key, value in params.items():\n",
- "# print (key, value)\n",
- "\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " print (compile_params)\n",
- " \n",
- " fit_params = \"batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\"))\n",
- " print (fit_params)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 301,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n"
- ]
- }
- ],
- "source": [
- "import psycopg2 as p2\n",
- "\n",
- "#conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
- "conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- "cur = conn.cursor()\n",
- "\n",
- "%sql DROP TABLE IF EXISTS mst_table_hb, mst_table_auto_hb;\n",
- "\n",
- "%sql CREATE TABLE mst_table_hb(mst_key serial, model_id integer, compile_params varchar, fit_params varchar);\n",
- "\n",
- "for params in param_list:\n",
- " model_id = str(params.get(\"model_id\"))\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
- " row_content = \"(\" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
- " \n",
- " %sql INSERT INTO mst_table_hb (model_id, compile_params, fit_params) VALUES $row_content"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 302,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "10 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>mst_key</th>\n",
- " <th>model_arch_id</th>\n",
- " <th>compile_params</th>\n",
- " <th>fit_params</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>2</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>3</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>4</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>5</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>6</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>7</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>8</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>9</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>10</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(1, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (2, 1, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (3, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (4, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (5, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (6, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (8, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (9, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (10, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']\", u'batch_size=8,epochs=1')]"
- ]
- },
- "execution_count": 302,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "SELECT * FROM mst_table_hb ORDER BY mst_key;"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- " \n",
- " def __init__( self, get_params_function, try_params_function ):\n",
- " self.get_params = get_params_function\n",
- " self.try_params = try_params_function\n",
- "\n",
- " self.max_iter = 81 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
- "\n",
- " self.logeta = lambda x: log( x ) / log( self.eta )\n",
- " self.s_max = int( self.logeta( self.max_iter ))\n",
- " self.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- " self.results = [] # list of dicts\n",
- " self.counter = 0\n",
- " self.best_loss = np.inf\n",
- " self.best_counter = -1\n",
- "\n",
- "\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- " for s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- " # initial number of configurations\n",
- " n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- " # initial number of iterations per config\n",
- " r = self.max_iter * self.eta ** ( -s )\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n = \", n)\n",
- " print (\"r = \", r)\n",
- "\n",
- " # n random configurations\n",
- " T = self.get_params(n) # what to return from function if anything?\n",
- " \n",
- " return\n",
- "\n",
- " for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1\n",
- "\n",
- " # Run each of the n configs for <iterations>\n",
- " # and keep best (n_configs / eta) configurations\n",
- "\n",
- " n_configs = n * self.eta ** ( -i )\n",
- " n_iterations = r * self.eta ** ( i )\n",
- "\n",
- " print \"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- " n_configs, n_iterations )\n",
- "\n",
- " val_losses = []\n",
- " early_stops = []\n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " for t in T:\n",
- "\n",
- " self.counter += 1\n",
- " print \"\\n{} | {} | lowest loss so far: {:.4f} (run {})\\n\".format(\n",
- " self.counter, ctime(), self.best_loss, self.best_counter )\n",
- "\n",
- " start_time = time()\n",
- "\n",
- " if dry_run:\n",
- " result = { 'loss': random(), 'log_loss': random(), 'auc': random()}\n",
- " else:\n",
- " result = self.try_params( n_iterations, t ) # <---\n",
- "\n",
- " assert( type( result ) == dict )\n",
- " assert( 'loss' in result )\n",
- "\n",
- " seconds = int( round( time() - start_time ))\n",
- " print \"\\n{} seconds.\".format( seconds )\n",
- "\n",
- " loss = result['loss']\n",
- " val_losses.append( loss )\n",
- "\n",
- " early_stop = result.get( 'early_stop', False )\n",
- " early_stops.append( early_stop )\n",
- "\n",
- " # keeping track of the best result so far (for display only)\n",
- " # could do it be checking results each time, but hey\n",
- " if loss < self.best_loss:\n",
- " self.best_loss = loss\n",
- " self.best_counter = self.counter\n",
- "\n",
- " result['counter'] = self.counter\n",
- " result['seconds'] = seconds\n",
- " result['params'] = t\n",
- " result['iterations'] = n_iterations\n",
- " \n",
- " self.results.append( result )\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
- " indices = np.argsort( val_losses )\n",
- " T = [ T[i] for i in indices if not early_stops[i]]\n",
- " T = T[ 0:int( n_configs / self.eta )]\n",
- " \n",
- " return self.results"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/community-artifacts/Deep-learning/automl/hyperband_v1.py b/community-artifacts/Deep-learning/automl/hyperband_v1.py
deleted file mode 100644
index 9bbf0f0..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_v1.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import numpy as np
-
-from random import random
-from math import log, ceil
-from time import time, ctime
-
-
-class Hyperband:
-
- def __init__( self, get_params_function, try_params_function ):
- self.get_params = get_params_function
- self.try_params = try_params_function
-
- self.max_iter = 81 # maximum iterations per configuration
- self.eta = 3 # defines configuration downsampling rate (default = 3)
-
- self.logeta = lambda x: log( x ) / log( self.eta )
- self.s_max = int( self.logeta( self.max_iter ))
- self.B = ( self.s_max + 1 ) * self.max_iter
-
- self.results = [] # list of dicts
- self.counter = 0
- self.best_loss = np.inf
- self.best_counter = -1
-
-
- # can be called multiple times
- def run( self, skip_last = 0, dry_run = False ):
-
- for s in reversed( range( self.s_max + 1 )):
-
- # initial number of configurations
- n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))
-
- # initial number of iterations per config
- r = self.max_iter * self.eta ** ( -s )
-
- # n random configurations
- T = [ self.get_params() for i in range( n )]
-
- for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1
-
- # Run each of the n configs for <iterations>
- # and keep best (n_configs / eta) configurations
-
- n_configs = n * self.eta ** ( -i )
- n_iterations = r * self.eta ** ( i )
-
- print "\n*** {} configurations x {:.1f} iterations each".format(
- n_configs, n_iterations )
-
- val_losses = []
- early_stops = []
-
- for t in T:
-
- self.counter += 1
- print "\n{} | {} | lowest loss so far: {:.4f} (run {})\n".format(
- self.counter, ctime(), self.best_loss, self.best_counter )
-
- start_time = time()
-
- if dry_run:
- result = { 'loss': random(), 'log_loss': random(), 'auc': random()}
- else:
- result = self.try_params( n_iterations, t ) # <---
-
- assert( type( result ) == dict )
- assert( 'loss' in result )
-
- seconds = int( round( time() - start_time ))
- print "\n{} seconds.".format( seconds )
-
- loss = result['loss']
- val_losses.append( loss )
-
- early_stop = result.get( 'early_stop', False )
- early_stops.append( early_stop )
-
- # keeping track of the best result so far (for display only)
- # could do it be checking results each time, but hey
- if loss < self.best_loss:
- self.best_loss = loss
- self.best_counter = self.counter
-
- result['counter'] = self.counter
- result['seconds'] = seconds
- result['params'] = t
- result['iterations'] = n_iterations
-
- self.results.append( result )
-
- # select a number of best configurations for the next loop
- # filter out early stops, if any
- indices = np.argsort( val_losses )
- T = [ T[i] for i in indices if not early_stops[i]]
- T = T[ 0:int( n_configs / self.eta )]
-
- return self.results
diff --git a/community-artifacts/Deep-learning/automl/hyperband_v2.ipynb b/community-artifacts/Deep-learning/automl/hyperband_v2.ipynb
deleted file mode 100644
index d1a2de6..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_v2.ipynb
+++ /dev/null
@@ -1,3043 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Hyperband\n",
- "\n",
- "Impelementation of Hyperband https://arxiv.org/pdf/1603.06560.pdf with ideas from blog post by the same authors https://homes.cs.washington.edu/~jamieson/hyperband.html"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
- ]
- }
- ],
- "source": [
- "%load_ext sql"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
- "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
- "\n",
- "# Greenplum Database 5.x on GCP - via tunnel\n",
- "%sql postgresql://gpadmin@localhost:8000/madlib\n",
- " \n",
- "# PostgreSQL local\n",
- "#%sql postgresql://fmcquillan@localhost:5432/madlib\n",
- "\n",
- "#psycopg2 connection\n",
- "import psycopg2 as p2\n",
- "#conn = p2.connect('postgresql://gpadmin@35.239.240.26:5432/madlib')\n",
- "#conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- "conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
- "cur = conn.cursor()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Pretty print run schedule"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "max_iter = 27\n",
- "eta = 3\n",
- "B = 4*max_iter = 108\n",
- " \n",
- "s=3\n",
- "n_i r_i\n",
- "------------\n",
- "27 1.0\n",
- "9.0 3.0\n",
- "3.0 9.0\n",
- "1.0 27.0\n",
- " \n",
- "s=2\n",
- "n_i r_i\n",
- "------------\n",
- "9 3.0\n",
- "3.0 9.0\n",
- "1.0 27.0\n",
- " \n",
- "s=1\n",
- "n_i r_i\n",
- "------------\n",
- "6 9.0\n",
- "2.0 27.0\n",
- " \n",
- "s=0\n",
- "n_i r_i\n",
- "------------\n",
- "4 27\n",
- " \n",
- "sum of configurations at leaf nodes across all s = 8.0\n",
- "(if have more workers than this, they may not be 100% busy)\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "from math import log, ceil\n",
- "\n",
- "#input\n",
- "max_iter = 27 # maximum iterations/epochs per configuration\n",
- "eta = 3 # defines downsampling rate (default=3)\n",
- "\n",
- "logeta = lambda x: log(x)/log(eta)\n",
- "s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
- "B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
- "\n",
- "#echo output\n",
- "print (\"max_iter = \" + str(max_iter))\n",
- "print (\"eta = \" + str(eta))\n",
- "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
- "\n",
- "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
- "\n",
- "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
- "for s in reversed(range(s_max+1)):\n",
- " \n",
- " print (\" \")\n",
- " print (\"s=\" + str(s))\n",
- " print (\"n_i r_i\")\n",
- " print (\"------------\")\n",
- " counter = 0\n",
- " \n",
- " n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
- " r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
- "\n",
- " #### Begin Finite Horizon Successive Halving with (n,r)\n",
- " #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
- " for i in range(s+1):\n",
- " # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
- " n_i = n*eta**(-i)\n",
- " r_i = r*eta**(i)\n",
- " \n",
- " print (str(n_i) + \" \" + str (r_i))\n",
- " \n",
- " # check if leaf node for this s\n",
- " if counter == s:\n",
- " sum_leaf_n_i += n_i\n",
- " counter += 1\n",
- " \n",
- " #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
- " #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
- " #### End Finite Horizon Successive Halving with (n,r)\n",
- "\n",
- "print (\" \")\n",
- "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
- "print (\"(if have more workers than this, they may not be 100% busy)\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Create tables"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "-- overall results table\n",
- "DROP TABLE IF EXISTS results;\n",
- "CREATE TABLE results ( \n",
- " model_id INTEGER, \n",
- " compile_params TEXT,\n",
- " fit_params TEXT, \n",
- " model_type TEXT, \n",
- " model_size DOUBLE PRECISION, \n",
- " metrics_elapsed_time DOUBLE PRECISION[], \n",
- " metrics_type TEXT[], \n",
- " training_metrics_final DOUBLE PRECISION, \n",
- " training_loss_final DOUBLE PRECISION, \n",
- " training_metrics DOUBLE PRECISION[], \n",
- " training_loss DOUBLE PRECISION[], \n",
- " validation_metrics_final DOUBLE PRECISION, \n",
- " validation_loss_final DOUBLE PRECISION, \n",
- " validation_metrics DOUBLE PRECISION[], \n",
- " validation_loss DOUBLE PRECISION[], \n",
- " model_arch_table TEXT, \n",
- " num_iterations INTEGER, \n",
- " start_training_time TIMESTAMP, \n",
- " end_training_time TIMESTAMP,\n",
- " s INTEGER, \n",
- " n INTEGER, \n",
- " r INTEGER,\n",
- " run_id SERIAL\n",
- " );\n",
- "\n",
- "-- model selection table\n",
- "DROP TABLE IF EXISTS mst_table_hb, mst_table_auto_hb;\n",
- "CREATE TABLE mst_table_hb (\n",
- " mst_key SERIAL, \n",
- " model_id INTEGER, \n",
- " compile_params VARCHAR, \n",
- " fit_params VARCHAR\n",
- " );\n",
- "\n",
- "-- model selection summary table\n",
- "DROP TABLE IF EXISTS mst_table_hb_summary;\n",
- "CREATE TABLE mst_table_hb_summary (model_arch_table varchar);\n",
- "INSERT INTO mst_table_hb_summary VALUES ('model_arch_library');"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Hyperband main "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- " \n",
- " def __init__( self, get_params_function, try_params_function ):\n",
- " self.get_params = get_params_function\n",
- " self.try_params = try_params_function\n",
- "\n",
- " self.max_iter = 3 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
- "\n",
- " self.logeta = lambda x: log( x ) / log( self.eta )\n",
- " self.s_max = int( self.logeta( self.max_iter ))\n",
- " self.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- " self.results = [] # list of dicts\n",
- " self.counter = 0\n",
- " self.best_loss = np.inf\n",
- " self.best_counter = -1\n",
- "\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- " for s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- " # initial number of configurations\n",
- " n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- " # initial number of iterations per config\n",
- " r = self.max_iter * self.eta ** ( -s )\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n = \", n)\n",
- " print (\"r = \", r)\n",
- "\n",
- " # n random configurations\n",
- " T = self.get_params(n) # what to return from function if anything?\n",
- " \n",
- " for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1\n",
- "\n",
- " # Run each of the n configs for <iterations>\n",
- " # and keep best (n_configs / eta) configurations\n",
- "\n",
- " n_configs = n * self.eta ** ( -i )\n",
- " n_iterations = r * self.eta ** ( i )\n",
- "\n",
- " print \"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- " n_configs, n_iterations )\n",
- " \n",
- " # multi-model training\n",
- " U = self.try_params(s, n_configs, n_iterations) # what to return from function if anything?\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
- " # drop from model selection table, model table and info table to keep all in sync\n",
- " k = int( n_configs / self.eta)\n",
- " \n",
- " %sql DELETE FROM iris_multi_model_info WHERE mst_key NOT IN (SELECT mst_key FROM iris_multi_model_info ORDER BY training_loss_final ASC LIMIT $k::INT);\n",
- " %sql DELETE FROM iris_multi_model WHERE mst_key NOT IN (SELECT mst_key FROM iris_multi_model_info);\n",
- " %sql DELETE FROM mst_table_hb WHERE mst_key NOT IN (SELECT mst_key FROM iris_multi_model_info);\n",
- "\n",
- " #return self.results\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_params(n):\n",
- " \n",
- " from sklearn.model_selection import ParameterSampler\n",
- " from scipy.stats.distributions import uniform\n",
- " import numpy as np\n",
- " \n",
- " # model architecture\n",
- " model_id = [1, 2]\n",
- "\n",
- " # compile params\n",
- " # loss function\n",
- " loss = ['categorical_crossentropy']\n",
- " # optimizer\n",
- " optimizer = ['Adam', 'SGD']\n",
- " # learning rate (sample on log scale here not in ParameterSampler)\n",
- " lr_range = [0.01, 0.1]\n",
- " lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n)\n",
- " # metrics\n",
- " metrics = ['accuracy']\n",
- "\n",
- " # fit params\n",
- " # batch size\n",
- " batch_size = [4, 8]\n",
- " # epochs\n",
- " epochs = [1]\n",
- "\n",
- " # create random param list\n",
- " param_grid = {\n",
- " 'model_id': model_id,\n",
- " 'loss': loss,\n",
- " 'optimizer': optimizer,\n",
- " 'lr': lr,\n",
- " 'metrics': metrics,\n",
- " 'batch_size': batch_size,\n",
- " 'epochs': epochs\n",
- " }\n",
- " param_list = list(ParameterSampler(param_grid, n_iter=n))\n",
- "\n",
- " for params in param_list:\n",
- "\n",
- " model_id = str(params.get(\"model_id\"))\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
- " row_content = \"(\" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
- " \n",
- " %sql INSERT INTO mst_table_hb (model_id, compile_params, fit_params) VALUES $row_content\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "def try_params(s, n_configs, n_iterations):\n",
- "\n",
- " # multi-model fit\n",
- " # TO DO: use warm start to continue from where left off after if not 1st time thru for this s value\n",
- " %sql DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;\n",
- " %sql SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed', 'iris_multi_model', 'mst_table_hb', $n_iterations::INT, 0);\n",
- " \n",
- " # save results\n",
- " %sql DROP TABLE IF EXISTS temp_results;\n",
- " %sql CREATE TABLE temp_results AS (SELECT * FROM iris_multi_model_info);\n",
- " %sql ALTER TABLE temp_results DROP COLUMN mst_key, ADD COLUMN model_arch_table TEXT, ADD COLUMN num_iterations INTEGER, ADD COLUMN start_training_time TIMESTAMP, ADD COLUMN end_training_time TIMESTAMP, ADD COLUMN s INTEGER, ADD COLUMN n INTEGER, ADD COLUMN r INTEGER;\n",
- " %sql UPDATE temp_results SET model_arch_table = (SELECT model_arch_table FROM iris_multi_model_summary), num_iterations = (SELECT num_iterations FROM iris_multi_model_summary), start_training_time = (SELECT start_training_time FROM iris_multi_model_summary), end_training_time = (SELECT end_training_time FROM iris_multi_model_summary), s = $s, n = $n_configs, r = $n_iterations;\n",
- " %sql INSERT INTO results (SELECT * FROM temp_results);\n",
- "\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "('s = ', 1)\n",
- "('n = ', 3)\n",
- "('r = ', 1.0)\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "\n",
- "*** 3 configurations x 1.0 iterations each\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "Done.\n",
- "3 rows affected.\n",
- "3 rows affected.\n",
- "2 rows affected.\n",
- "2 rows affected.\n",
- "2 rows affected.\n",
- "\n",
- "*** 1.0 configurations x 3.0 iterations each\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "('s = ', 0)\n",
- "('n = ', 2)\n",
- "('r = ', 3)\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "\n",
- "*** 2 configurations x 3.0 iterations each\n",
- "Done.\n",
- "1 rows affected.\n",
- "Done.\n",
- "2 rows affected.\n",
- "Done.\n",
- "2 rows affected.\n",
- "2 rows affected.\n",
- "2 rows affected.\n",
- "2 rows affected.\n",
- "2 rows affected.\n"
- ]
- }
- ],
- "source": [
- "hp = Hyperband( get_params, try_params )\n",
- "results = hp.run()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Plot results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib notebook\n",
- "import matplotlib.pyplot as plt\n",
- "from collections import defaultdict\n",
- "import pandas as pd\n",
- "import seaborn as sns\n",
- "sns.set_palette(sns.color_palette(\"hls\", 20))\n",
- "plt.rcParams.update({'font.size': 12})\n",
- "pd.set_option('display.max_colwidth', -1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0 rows affected.\n"
- ]
- },
- {
- "data": {
- "application/javascript": [
- "/* Put everything inside the global mpl namespace */\n",
- "window.mpl = {};\n",
- "\n",
- "\n",
- "mpl.get_websocket_type = function() {\n",
- " if (typeof(WebSocket) !== 'undefined') {\n",
- " return WebSocket;\n",
- " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
- " return MozWebSocket;\n",
- " } else {\n",
- " alert('Your browser does not have WebSocket support.' +\n",
- " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
- " 'Firefox 4 and 5 are also supported but you ' +\n",
- " 'have to enable WebSockets in about:config.');\n",
- " };\n",
- "}\n",
- "\n",
- "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
- " this.id = figure_id;\n",
- "\n",
- " this.ws = websocket;\n",
- "\n",
- " this.supports_binary = (this.ws.binaryType != undefined);\n",
- "\n",
- " if (!this.supports_binary) {\n",
- " var warnings = document.getElementById(\"mpl-warnings\");\n",
- " if (warnings) {\n",
- " warnings.style.display = 'block';\n",
- " warnings.textContent = (\n",
- " \"This browser does not support binary websocket messages. \" +\n",
- " \"Performance may be slow.\");\n",
- " }\n",
- " }\n",
- "\n",
- " this.imageObj = new Image();\n",
- "\n",
- " this.context = undefined;\n",
- " this.message = undefined;\n",
- " this.canvas = undefined;\n",
- " this.rubberband_canvas = undefined;\n",
- " this.rubberband_context = undefined;\n",
- " this.format_dropdown = undefined;\n",
- "\n",
- " this.image_mode = 'full';\n",
- "\n",
- " this.root = $('<div/>');\n",
- " this._root_extra_style(this.root)\n",
- " this.root.attr('style', 'display: inline-block');\n",
- "\n",
- " $(parent_element).append(this.root);\n",
- "\n",
- " this._init_header(this);\n",
- " this._init_canvas(this);\n",
- " this._init_toolbar(this);\n",
- "\n",
- " var fig = this;\n",
- "\n",
- " this.waiting = false;\n",
- "\n",
- " this.ws.onopen = function () {\n",
- " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
- " fig.send_message(\"send_image_mode\", {});\n",
- " if (mpl.ratio != 1) {\n",
- " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
- " }\n",
- " fig.send_message(\"refresh\", {});\n",
- " }\n",
- "\n",
- " this.imageObj.onload = function() {\n",
- " if (fig.image_mode == 'full') {\n",
- " // Full images could contain transparency (where diff images\n",
- " // almost always do), so we need to clear the canvas so that\n",
- " // there is no ghosting.\n",
- " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
- " }\n",
- " fig.context.drawImage(fig.imageObj, 0, 0);\n",
- " };\n",
- "\n",
- " this.imageObj.onunload = function() {\n",
- " fig.ws.close();\n",
- " }\n",
- "\n",
- " this.ws.onmessage = this._make_on_message_function(this);\n",
- "\n",
- " this.ondownload = ondownload;\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_header = function() {\n",
- " var titlebar = $(\n",
- " '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
- " 'ui-helper-clearfix\"/>');\n",
- " var titletext = $(\n",
- " '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
- " 'text-align: center; padding: 3px;\"/>');\n",
- " titlebar.append(titletext)\n",
- " this.root.append(titlebar);\n",
- " this.header = titletext[0];\n",
- "}\n",
- "\n",
- "\n",
- "\n",
- "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
- "\n",
- "}\n",
- "\n",
- "\n",
- "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
- "\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_canvas = function() {\n",
- " var fig = this;\n",
- "\n",
- " var canvas_div = $('<div/>');\n",
- "\n",
- " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
- "\n",
- " function canvas_keyboard_event(event) {\n",
- " return fig.key_event(event, event['data']);\n",
- " }\n",
- "\n",
- " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
- " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
- " this.canvas_div = canvas_div\n",
- " this._canvas_extra_style(canvas_div)\n",
- " this.root.append(canvas_div);\n",
- "\n",
- " var canvas = $('<canvas/>');\n",
- " canvas.addClass('mpl-canvas');\n",
- " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
- "\n",
- " this.canvas = canvas[0];\n",
- " this.context = canvas[0].getContext(\"2d\");\n",
- "\n",
- " var backingStore = this.context.backingStorePixelRatio ||\n",
- "\tthis.context.webkitBackingStorePixelRatio ||\n",
- "\tthis.context.mozBackingStorePixelRatio ||\n",
- "\tthis.context.msBackingStorePixelRatio ||\n",
- "\tthis.context.oBackingStorePixelRatio ||\n",
- "\tthis.context.backingStorePixelRatio || 1;\n",
- "\n",
- " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
- "\n",
- " var rubberband = $('<canvas/>');\n",
- " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
- "\n",
- " var pass_mouse_events = true;\n",
- "\n",
- " canvas_div.resizable({\n",
- " start: function(event, ui) {\n",
- " pass_mouse_events = false;\n",
- " },\n",
- " resize: function(event, ui) {\n",
- " fig.request_resize(ui.size.width, ui.size.height);\n",
- " },\n",
- " stop: function(event, ui) {\n",
- " pass_mouse_events = true;\n",
- " fig.request_resize(ui.size.width, ui.size.height);\n",
- " },\n",
- " });\n",
- "\n",
- " function mouse_event_fn(event) {\n",
- " if (pass_mouse_events)\n",
- " return fig.mouse_event(event, event['data']);\n",
- " }\n",
- "\n",
- " rubberband.mousedown('button_press', mouse_event_fn);\n",
- " rubberband.mouseup('button_release', mouse_event_fn);\n",
- " // Throttle sequential mouse events to 1 every 20ms.\n",
- " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
- "\n",
- " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
- " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
- "\n",
- " canvas_div.on(\"wheel\", function (event) {\n",
- " event = event.originalEvent;\n",
- " event['data'] = 'scroll'\n",
- " if (event.deltaY < 0) {\n",
- " event.step = 1;\n",
- " } else {\n",
- " event.step = -1;\n",
- " }\n",
- " mouse_event_fn(event);\n",
- " });\n",
- "\n",
- " canvas_div.append(canvas);\n",
- " canvas_div.append(rubberband);\n",
- "\n",
- " this.rubberband = rubberband;\n",
- " this.rubberband_canvas = rubberband[0];\n",
- " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
- " this.rubberband_context.strokeStyle = \"#000000\";\n",
- "\n",
- " this._resize_canvas = function(width, height) {\n",
- " // Keep the size of the canvas, canvas container, and rubber band\n",
- " // canvas in synch.\n",
- " canvas_div.css('width', width)\n",
- " canvas_div.css('height', height)\n",
- "\n",
- " canvas.attr('width', width * mpl.ratio);\n",
- " canvas.attr('height', height * mpl.ratio);\n",
- " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
- "\n",
- " rubberband.attr('width', width);\n",
- " rubberband.attr('height', height);\n",
- " }\n",
- "\n",
- " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
- " // upon first draw.\n",
- " this._resize_canvas(600, 600);\n",
- "\n",
- " // Disable right mouse context menu.\n",
- " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
- " return false;\n",
- " });\n",
- "\n",
- " function set_focus () {\n",
- " canvas.focus();\n",
- " canvas_div.focus();\n",
- " }\n",
- "\n",
- " window.setTimeout(set_focus, 100);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_toolbar = function() {\n",
- " var fig = this;\n",
- "\n",
- " var nav_element = $('<div/>')\n",
- " nav_element.attr('style', 'width: 100%');\n",
- " this.root.append(nav_element);\n",
- "\n",
- " // Define a callback function for later on.\n",
- " function toolbar_event(event) {\n",
- " return fig.toolbar_button_onclick(event['data']);\n",
- " }\n",
- " function toolbar_mouse_event(event) {\n",
- " return fig.toolbar_button_onmouseover(event['data']);\n",
- " }\n",
- "\n",
- " for(var toolbar_ind in mpl.toolbar_items) {\n",
- " var name = mpl.toolbar_items[toolbar_ind][0];\n",
- " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
- " var image = mpl.toolbar_items[toolbar_ind][2];\n",
- " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
- "\n",
- " if (!name) {\n",
- " // put a spacer in here.\n",
- " continue;\n",
- " }\n",
- " var button = $('<button/>');\n",
- " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
- " 'ui-button-icon-only');\n",
- " button.attr('role', 'button');\n",
- " button.attr('aria-disabled', 'false');\n",
- " button.click(method_name, toolbar_event);\n",
- " button.mouseover(tooltip, toolbar_mouse_event);\n",
- "\n",
- " var icon_img = $('<span/>');\n",
- " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
- " icon_img.addClass(image);\n",
- " icon_img.addClass('ui-corner-all');\n",
- "\n",
- " var tooltip_span = $('<span/>');\n",
- " tooltip_span.addClass('ui-button-text');\n",
- " tooltip_span.html(tooltip);\n",
- "\n",
- " button.append(icon_img);\n",
- " button.append(tooltip_span);\n",
- "\n",
- " nav_element.append(button);\n",
- " }\n",
- "\n",
- " var fmt_picker_span = $('<span/>');\n",
- "\n",
- " var fmt_picker = $('<select/>');\n",
- " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
- " fmt_picker_span.append(fmt_picker);\n",
- " nav_element.append(fmt_picker_span);\n",
- " this.format_dropdown = fmt_picker[0];\n",
- "\n",
- " for (var ind in mpl.extensions) {\n",
- " var fmt = mpl.extensions[ind];\n",
- " var option = $(\n",
- " '<option/>', {selected: fmt === mpl.default_extension}).html(fmt);\n",
- " fmt_picker.append(option)\n",
- " }\n",
- "\n",
- " // Add hover states to the ui-buttons\n",
- " $( \".ui-button\" ).hover(\n",
- " function() { $(this).addClass(\"ui-state-hover\");},\n",
- " function() { $(this).removeClass(\"ui-state-hover\");}\n",
- " );\n",
- "\n",
- " var status_bar = $('<span class=\"mpl-message\"/>');\n",
- " nav_element.append(status_bar);\n",
- " this.message = status_bar[0];\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
- " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
- " // which will in turn request a refresh of the image.\n",
- " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.send_message = function(type, properties) {\n",
- " properties['type'] = type;\n",
- " properties['figure_id'] = this.id;\n",
- " this.ws.send(JSON.stringify(properties));\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.send_draw_message = function() {\n",
- " if (!this.waiting) {\n",
- " this.waiting = true;\n",
- " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
- " }\n",
- "}\n",
- "\n",
- "\n",
- "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
- " var format_dropdown = fig.format_dropdown;\n",
- " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
- " fig.ondownload(fig, format);\n",
- "}\n",
- "\n",
- "\n",
- "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
- " var size = msg['size'];\n",
- " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
- " fig._resize_canvas(size[0], size[1]);\n",
- " fig.send_message(\"refresh\", {});\n",
- " };\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
- " var x0 = msg['x0'] / mpl.ratio;\n",
- " var y0 = (fig.canvas.height - msg['y0']) / mpl.ratio;\n",
- " var x1 = msg['x1'] / mpl.ratio;\n",
- " var y1 = (fig.canvas.height - msg['y1']) / mpl.ratio;\n",
- " x0 = Math.floor(x0) + 0.5;\n",
- " y0 = Math.floor(y0) + 0.5;\n",
- " x1 = Math.floor(x1) + 0.5;\n",
- " y1 = Math.floor(y1) + 0.5;\n",
- " var min_x = Math.min(x0, x1);\n",
- " var min_y = Math.min(y0, y1);\n",
- " var width = Math.abs(x1 - x0);\n",
- " var height = Math.abs(y1 - y0);\n",
- "\n",
- " fig.rubberband_context.clearRect(\n",
- " 0, 0, fig.canvas.width, fig.canvas.height);\n",
- "\n",
- " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
- " // Updates the figure title.\n",
- " fig.header.textContent = msg['label'];\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
- " var cursor = msg['cursor'];\n",
- " switch(cursor)\n",
- " {\n",
- " case 0:\n",
- " cursor = 'pointer';\n",
- " break;\n",
- " case 1:\n",
- " cursor = 'default';\n",
- " break;\n",
- " case 2:\n",
- " cursor = 'crosshair';\n",
- " break;\n",
- " case 3:\n",
- " cursor = 'move';\n",
- " break;\n",
- " }\n",
- " fig.rubberband_canvas.style.cursor = cursor;\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
- " fig.message.textContent = msg['message'];\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
- " // Request the server to send over a new figure.\n",
- " fig.send_draw_message();\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
- " fig.image_mode = msg['mode'];\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.updated_canvas_event = function() {\n",
- " // Called whenever the canvas gets updated.\n",
- " this.send_message(\"ack\", {});\n",
- "}\n",
- "\n",
- "// A function to construct a web socket function for onmessage handling.\n",
- "// Called in the figure constructor.\n",
- "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
- " return function socket_on_message(evt) {\n",
- " if (evt.data instanceof Blob) {\n",
- " /* FIXME: We get \"Resource interpreted as Image but\n",
- " * transferred with MIME type text/plain:\" errors on\n",
- " * Chrome. But how to set the MIME type? It doesn't seem\n",
- " * to be part of the websocket stream */\n",
- " evt.data.type = \"image/png\";\n",
- "\n",
- " /* Free the memory for the previous frames */\n",
- " if (fig.imageObj.src) {\n",
- " (window.URL || window.webkitURL).revokeObjectURL(\n",
- " fig.imageObj.src);\n",
- " }\n",
- "\n",
- " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
- " evt.data);\n",
- " fig.updated_canvas_event();\n",
- " fig.waiting = false;\n",
- " return;\n",
- " }\n",
- " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
- " fig.imageObj.src = evt.data;\n",
- " fig.updated_canvas_event();\n",
- " fig.waiting = false;\n",
- " return;\n",
- " }\n",
- "\n",
- " var msg = JSON.parse(evt.data);\n",
- " var msg_type = msg['type'];\n",
- "\n",
- " // Call the \"handle_{type}\" callback, which takes\n",
- " // the figure and JSON message as its only arguments.\n",
- " try {\n",
- " var callback = fig[\"handle_\" + msg_type];\n",
- " } catch (e) {\n",
- " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
- " return;\n",
- " }\n",
- "\n",
- " if (callback) {\n",
- " try {\n",
- " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
- " callback(fig, msg);\n",
- " } catch (e) {\n",
- " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
- " }\n",
- " }\n",
- " };\n",
- "}\n",
- "\n",
- "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
- "mpl.findpos = function(e) {\n",
- " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
- " var targ;\n",
- " if (!e)\n",
- " e = window.event;\n",
- " if (e.target)\n",
- " targ = e.target;\n",
- " else if (e.srcElement)\n",
- " targ = e.srcElement;\n",
- " if (targ.nodeType == 3) // defeat Safari bug\n",
- " targ = targ.parentNode;\n",
- "\n",
- " // jQuery normalizes the pageX and pageY\n",
- " // pageX,Y are the mouse positions relative to the document\n",
- " // offset() returns the position of the element relative to the document\n",
- " var x = e.pageX - $(targ).offset().left;\n",
- " var y = e.pageY - $(targ).offset().top;\n",
- "\n",
- " return {\"x\": x, \"y\": y};\n",
- "};\n",
- "\n",
- "/*\n",
- " * return a copy of an object with only non-object keys\n",
- " * we need this to avoid circular references\n",
- " * http://stackoverflow.com/a/24161582/3208463\n",
- " */\n",
- "function simpleKeys (original) {\n",
- " return Object.keys(original).reduce(function (obj, key) {\n",
- " if (typeof original[key] !== 'object')\n",
- " obj[key] = original[key]\n",
- " return obj;\n",
- " }, {});\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.mouse_event = function(event, name) {\n",
- " var canvas_pos = mpl.findpos(event)\n",
- "\n",
- " if (name === 'button_press')\n",
- " {\n",
- " this.canvas.focus();\n",
- " this.canvas_div.focus();\n",
- " }\n",
- "\n",
- " var x = canvas_pos.x * mpl.ratio;\n",
- " var y = canvas_pos.y * mpl.ratio;\n",
- "\n",
- " this.send_message(name, {x: x, y: y, button: event.button,\n",
- " step: event.step,\n",
- " guiEvent: simpleKeys(event)});\n",
- "\n",
- " /* This prevents the web browser from automatically changing to\n",
- " * the text insertion cursor when the button is pressed. We want\n",
- " * to control all of the cursor setting manually through the\n",
- " * 'cursor' event from matplotlib */\n",
- " event.preventDefault();\n",
- " return false;\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
- " // Handle any extra behaviour associated with a key event\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.key_event = function(event, name) {\n",
- "\n",
- " // Prevent repeat events\n",
- " if (name == 'key_press')\n",
- " {\n",
- " if (event.which === this._key)\n",
- " return;\n",
- " else\n",
- " this._key = event.which;\n",
- " }\n",
- " if (name == 'key_release')\n",
- " this._key = null;\n",
- "\n",
- " var value = '';\n",
- " if (event.ctrlKey && event.which != 17)\n",
- " value += \"ctrl+\";\n",
- " if (event.altKey && event.which != 18)\n",
- " value += \"alt+\";\n",
- " if (event.shiftKey && event.which != 16)\n",
- " value += \"shift+\";\n",
- "\n",
- " value += 'k';\n",
- " value += event.which.toString();\n",
- "\n",
- " this._key_event_extra(event, name);\n",
- "\n",
- " this.send_message(name, {key: value,\n",
- " guiEvent: simpleKeys(event)});\n",
- " return false;\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
- " if (name == 'download') {\n",
- " this.handle_save(this, null);\n",
- " } else {\n",
- " this.send_message(\"toolbar_button\", {name: name});\n",
- " }\n",
- "};\n",
- "\n",
- "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
- " this.message.textContent = tooltip;\n",
- "};\n",
- "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\" [...]
- "\n",
- "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
- "\n",
- "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
- " // Create a \"websocket\"-like object which calls the given IPython comm\n",
- " // object with the appropriate methods. Currently this is a non binary\n",
- " // socket, so there is still some room for performance tuning.\n",
- " var ws = {};\n",
- "\n",
- " ws.close = function() {\n",
- " comm.close()\n",
- " };\n",
- " ws.send = function(m) {\n",
- " //console.log('sending', m);\n",
- " comm.send(m);\n",
- " };\n",
- " // Register the callback with on_msg.\n",
- " comm.on_msg(function(msg) {\n",
- " //console.log('receiving', msg['content']['data'], msg);\n",
- " // Pass the mpl event to the overridden (by mpl) onmessage function.\n",
- " ws.onmessage(msg['content']['data'])\n",
- " });\n",
- " return ws;\n",
- "}\n",
- "\n",
- "mpl.mpl_figure_comm = function(comm, msg) {\n",
- " // This is the function which gets called when the mpl process\n",
- " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
- "\n",
- " var id = msg.content.data.id;\n",
- " // Get hold of the div created by the display call when the Comm\n",
- " // socket was opened in Python.\n",
- " var element = $(\"#\" + id);\n",
- " var ws_proxy = comm_websocket_adapter(comm)\n",
- "\n",
- " function ondownload(figure, format) {\n",
- " window.open(figure.imageObj.src);\n",
- " }\n",
- "\n",
- " var fig = new mpl.figure(id, ws_proxy,\n",
- " ondownload,\n",
- " element.get(0));\n",
- "\n",
- " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
- " // web socket which is closed, not our websocket->open comm proxy.\n",
- " ws_proxy.onopen();\n",
- "\n",
- " fig.parent_element = element.get(0);\n",
- " fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n",
- " if (!fig.cell_info) {\n",
- " console.error(\"Failed to find cell for figure\", id, fig);\n",
- " return;\n",
- " }\n",
- "\n",
- " var output_index = fig.cell_info[2]\n",
- " var cell = fig.cell_info[0];\n",
- "\n",
- "};\n",
- "\n",
- "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
- " var width = fig.canvas.width/mpl.ratio\n",
- " fig.root.unbind('remove')\n",
- "\n",
- " // Update the output cell to use the data from the current canvas.\n",
- " fig.push_to_output();\n",
- " var dataURL = fig.canvas.toDataURL();\n",
- " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
- " // the notebook keyboard shortcuts fail.\n",
- " IPython.keyboard_manager.enable()\n",
- " $(fig.parent_element).html('<img src=\"' + dataURL + '\" width=\"' + width + '\">');\n",
- " fig.close_ws(fig, msg);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.close_ws = function(fig, msg){\n",
- " fig.send_message('closing', msg);\n",
- " // fig.ws.close()\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
- " // Turn the data on the canvas into data in the output cell.\n",
- " var width = this.canvas.width/mpl.ratio\n",
- " var dataURL = this.canvas.toDataURL();\n",
- " this.cell_info[1]['text/html'] = '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.updated_canvas_event = function() {\n",
- " // Tell IPython that the notebook contents must change.\n",
- " IPython.notebook.set_dirty(true);\n",
- " this.send_message(\"ack\", {});\n",
- " var fig = this;\n",
- " // Wait a second, then push the new image to the DOM so\n",
- " // that it is saved nicely (might be nice to debounce this).\n",
- " setTimeout(function () { fig.push_to_output() }, 1000);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_toolbar = function() {\n",
- " var fig = this;\n",
- "\n",
- " var nav_element = $('<div/>')\n",
- " nav_element.attr('style', 'width: 100%');\n",
- " this.root.append(nav_element);\n",
- "\n",
- " // Define a callback function for later on.\n",
- " function toolbar_event(event) {\n",
- " return fig.toolbar_button_onclick(event['data']);\n",
- " }\n",
- " function toolbar_mouse_event(event) {\n",
- " return fig.toolbar_button_onmouseover(event['data']);\n",
- " }\n",
- "\n",
- " for(var toolbar_ind in mpl.toolbar_items){\n",
- " var name = mpl.toolbar_items[toolbar_ind][0];\n",
- " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
- " var image = mpl.toolbar_items[toolbar_ind][2];\n",
- " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
- "\n",
- " if (!name) { continue; };\n",
- "\n",
- " var button = $('<button class=\"btn btn-default\" href=\"#\" title=\"' + name + '\"><i class=\"fa ' + image + ' fa-lg\"></i></button>');\n",
- " button.click(method_name, toolbar_event);\n",
- " button.mouseover(tooltip, toolbar_mouse_event);\n",
- " nav_element.append(button);\n",
- " }\n",
- "\n",
- " // Add the status bar.\n",
- " var status_bar = $('<span class=\"mpl-message\" style=\"text-align:right; float: right;\"/>');\n",
- " nav_element.append(status_bar);\n",
- " this.message = status_bar[0];\n",
- "\n",
- " // Add the close button to the window.\n",
- " var buttongrp = $('<div class=\"btn-group inline pull-right\"></div>');\n",
- " var button = $('<button class=\"btn btn-mini btn-primary\" href=\"#\" title=\"Stop Interaction\"><i class=\"fa fa-power-off icon-remove icon-large\"></i></button>');\n",
- " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
- " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
- " buttongrp.append(button);\n",
- " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
- " titlebar.prepend(buttongrp);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._root_extra_style = function(el){\n",
- " var fig = this\n",
- " el.on(\"remove\", function(){\n",
- "\tfig.close_ws(fig, {});\n",
- " });\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._canvas_extra_style = function(el){\n",
- " // this is important to make the div 'focusable\n",
- " el.attr('tabindex', 0)\n",
- " // reach out to IPython and tell the keyboard manager to turn it's self\n",
- " // off when our div gets focus\n",
- "\n",
- " // location in version 3\n",
- " if (IPython.notebook.keyboard_manager) {\n",
- " IPython.notebook.keyboard_manager.register_events(el);\n",
- " }\n",
- " else {\n",
- " // location in version 2\n",
- " IPython.keyboard_manager.register_events(el);\n",
- " }\n",
- "\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
- " var manager = IPython.notebook.keyboard_manager;\n",
- " if (!manager)\n",
- " manager = IPython.keyboard_manager;\n",
- "\n",
- " // Check for shift+enter\n",
- " if (event.shiftKey && event.which == 13) {\n",
- " this.canvas_div.blur();\n",
- " event.shiftKey = false;\n",
- " // Send a \"J\" for go to next cell\n",
- " event.which = 74;\n",
- " event.keyCode = 74;\n",
- " manager.command_mode();\n",
- " manager.handle_keydown(event);\n",
- " }\n",
- "}\n",
- "\n",
- "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
- " fig.ondownload(fig, null);\n",
- "}\n",
- "\n",
- "\n",
- "mpl.find_output_cell = function(html_output) {\n",
- " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
- " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
- " // IPython event is triggered only after the cells have been serialised, which for\n",
- " // our purposes (turning an active figure into a static one), is too late.\n",
- " var cells = IPython.notebook.get_cells();\n",
- " var ncells = cells.length;\n",
- " for (var i=0; i<ncells; i++) {\n",
- " var cell = cells[i];\n",
- " if (cell.cell_type === 'code'){\n",
- " for (var j=0; j<cell.output_area.outputs.length; j++) {\n",
- " var data = cell.output_area.outputs[j];\n",
- " if (data.data) {\n",
- " // IPython >= 3 moved mimebundle to data attribute of output\n",
- " data = data.data;\n",
- " }\n",
- " if (data['text/html'] == html_output) {\n",
- " return [cell, data, j];\n",
- " }\n",
- " }\n",
- " }\n",
- " }\n",
- "}\n",
- "\n",
- "// Register the function which deals with the matplotlib target/channel.\n",
- "// The kernel may be null if the page has been refreshed.\n",
- "if (IPython.notebook.kernel != null) {\n",
- " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
- "}\n"
- ],
- "text/plain": [
- "<IPython.core.display.Javascript object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "<img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABEwAAAImCAYAAABJvh+8AAAgAElEQVR4Xu3asQ0AMAzDsPb/pztXgy9gTiA0GbnHESBAgAABAgQIECBAgAABAgQIfAKXBwECBAgQIECAAAECBAgQIECAwC9gMFEEAQIECBAgQIAAAQIECBAgQCACBhNJECBAgAABAgQIECBAgAABAgQMJhogQIAAAQIECBAgQIAAAQIECGwBHyYKIUCAAAECBAgQIECAAAECBAhEwGAiCQIECBAgQIAAAQIECBAgQICAwUQDBAgQIECAAAECBAgQIECAAIEt4MNEIQQIECBAgAABAgQIECBAgACBCBhMJEGAAAECBAgQIECAAAECBAgQMJhogAABAgQIECBAgAABAgQIECCwBXyYKIQAAQIECBAgQIAAAQIECBAgEAGDiSQIECBAgAABAgQIECBAgAA [...]
- ],
- "text/plain": [
- "<IPython.core.display.HTML object>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "ename": "KeyError",
- "evalue": "'run_id'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-9-195f08b29212>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnrows\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncols\u001b[0m\u [...]
- "\u001b[0;32m/Users/fmcquillan/anaconda/lib/python2.7/site-packages/pandas/core/frame.pyc\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1967\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_multilevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1968\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001 [...]
- "\u001b[0;32m/Users/fmcquillan/anaconda/lib/python2.7/site-packages/pandas/core/frame.pyc\u001b[0m in \u001b[0;36m_getitem_column\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 1974\u001b[0m \u001b[0;31m# get column\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1975\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_unique\u001b[0m\u001b[0;34m:\u001b[0 [...]
- "\u001b[0;32m/Users/fmcquillan/anaconda/lib/python2.7/site-packages/pandas/core/generic.pyc\u001b[0m in \u001b[0;36m_get_item_cache\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 1089\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mres\ [...]
- "\u001b[0;32m/Users/fmcquillan/anaconda/lib/python2.7/site-packages/pandas/core/internals.pyc\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, item, fastpath)\u001b[0m\n\u001b[1;32m 3209\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3210\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misnull\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3211\u001b [...]
- "\u001b[0;32m/Users/fmcquillan/anaconda/lib/python2.7/site-packages/pandas/core/index.pyc\u001b[0m in \u001b[0;36mget_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 1757\u001b[0m 'backfill or nearest lookups')\n\u001b[1;32m 1758\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_values_from_object\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m [...]
- "\u001b[0;32mpandas/index.pyx\u001b[0m in \u001b[0;36mpandas.index.IndexEngine.get_loc (pandas/index.c:3979)\u001b[0;34m()\u001b[0m\n",
- "\u001b[0;32mpandas/index.pyx\u001b[0m in \u001b[0;36mpandas.index.IndexEngine.get_loc (pandas/index.c:3843)\u001b[0;34m()\u001b[0m\n",
- "\u001b[0;32mpandas/hashtable.pyx\u001b[0m in \u001b[0;36mpandas.hashtable.PyObjectHashTable.get_item (pandas/hashtable.c:12265)\u001b[0;34m()\u001b[0m\n",
- "\u001b[0;32mpandas/hashtable.pyx\u001b[0m in \u001b[0;36mpandas.hashtable.PyObjectHashTable.get_item (pandas/hashtable.c:12216)\u001b[0;34m()\u001b[0m\n",
- "\u001b[0;31mKeyError\u001b[0m: 'run_id'"
- ]
- }
- ],
- "source": [
- "output_root_name = 'results'\n",
- "df_results = %sql SELECT * FROM $output_root_name ORDER BY run_id;\n",
- "df_results = df_results.DataFrame()\n",
- "\n",
- "fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5))\n",
- "for run_id in df_results['run_id']:\n",
- " df_output_info = %sql SELECT training_metrics,training_loss FROM $output_root_name WHERE run_id = $run_id\n",
- " df_output_info = df_output_info.DataFrame()\n",
- " training_metrics = df_output_info['training_metrics'][0]\n",
- " training_loss = df_output_info['training_loss'][0]\n",
- " X = range(len(training_metrics))\n",
- " \n",
- " ax_metric = axs[0]\n",
- " ax_loss = axs[1]\n",
- " ax_metric.set_xticks(X[::1])\n",
- " ax_metric.plot(X, training_metrics, label=run_id)\n",
- " ax_metric.set_xlabel('Iteration')\n",
- " ax_metric.set_ylabel('Metric')\n",
- " ax_metric.set_title('Training metric curve')\n",
- "\n",
- " ax_loss.set_xticks(X[::1])\n",
- " ax_loss.plot(X, training_loss, label=run_id)\n",
- " ax_loss.set_xlabel('Iteration')\n",
- " ax_loss.set_ylabel('Loss')\n",
- " ax_loss.set_title('Training loss curve')\n",
- " \n",
- "fig.legend(ncol=4)\n",
- "fig.tight_layout()\n",
- "# fig.savefig('./lc_keras_fit.png', dpi = 300)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# ------------------ Scratch ---------------------"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "best_configs = %sql SELECT ARRAY(SELECT mst_key FROM iris_multi_model_info ORDER BY training_loss_final ASC LIMIT $k::INT);\n",
- " %sql DELETE FROM mst_table_hb WHERE mst_key NOT IN $best_configs;\n",
- " %sql DELETE FROM iris_multi_model WHERE mst_key NOT IN $best_configs;\n",
- " %sql DELETE FROM iris_multi_model_info WHERE mst_key NOT IN $best_configs;"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "('s = ', 4)\n",
- "('n = ', 81)\n",
- "('r = ', 1.0)\n",
- "\n",
- "*** 81 configurations x 1.0 iterations each\n",
- "\n",
- "1 | Mon Nov 4 11:31:06 2019 | lowest loss so far: inf (run -1)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "2 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.8345 (run 1)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "3 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.6510 (run 2)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "4 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "5 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "6 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "7 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "8 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "9 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "10 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "11 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "12 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "13 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "14 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "15 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "16 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "17 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "18 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "19 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "20 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "21 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "22 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "23 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "24 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "25 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "26 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "27 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "28 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "29 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "30 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "31 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "32 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "33 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "34 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "35 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "36 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "37 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "38 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "39 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "40 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "41 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "42 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "43 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "44 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "45 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "46 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "47 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "48 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "49 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "50 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "51 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "52 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "53 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "54 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "55 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "56 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "57 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "58 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "59 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0176 (run 3)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "60 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "61 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "62 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "63 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "64 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "65 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "66 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "67 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "68 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "69 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "70 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "71 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "72 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "73 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "74 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "75 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "76 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "77 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "78 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "79 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "80 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "81 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 27.0 configurations x 3.0 iterations each\n",
- "\n",
- "82 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "83 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "84 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "85 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "86 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "87 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "88 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "89 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "90 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "91 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "92 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "93 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "94 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "95 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "96 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "97 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "98 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "99 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "100 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "101 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "102 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "103 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "104 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "105 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "106 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "107 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "108 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 9.0 configurations x 9.0 iterations each\n",
- "\n",
- "109 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "110 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "111 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "112 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "113 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "114 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "115 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "116 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "117 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "118 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "119 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "120 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "121 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 3)\n",
- "('n = ', 27)\n",
- "('r = ', 3.0)\n",
- "\n",
- "*** 27 configurations x 3.0 iterations each\n",
- "\n",
- "122 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "123 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "124 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "125 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "126 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "127 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "128 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "129 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "130 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "131 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "132 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "133 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "134 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "135 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "136 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "137 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "138 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "139 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "140 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "141 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "142 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "143 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "144 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "145 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "146 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "147 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "148 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 9.0 configurations x 9.0 iterations each\n",
- "\n",
- "149 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "150 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "151 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "152 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "153 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "154 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "155 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "156 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "157 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "158 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "159 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "160 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "161 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 2)\n",
- "('n = ', 9)\n",
- "('r = ', 9.0)\n",
- "\n",
- "*** 9 configurations x 9.0 iterations each\n",
- "\n",
- "162 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "163 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "164 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "165 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "166 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "167 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "168 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "169 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "170 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 3.0 configurations x 27.0 iterations each\n",
- "\n",
- "171 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "172 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "173 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 1.0 configurations x 81.0 iterations each\n",
- "\n",
- "174 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 1)\n",
- "('n = ', 6)\n",
- "('r = ', 27.0)\n",
- "\n",
- "*** 6 configurations x 27.0 iterations each\n",
- "\n",
- "175 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "176 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "177 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "178 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "179 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "180 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "*** 2.0 configurations x 81.0 iterations each\n",
- "\n",
- "181 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "182 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "('s = ', 0)\n",
- "('n = ', 5)\n",
- "('r = ', 81)\n",
- "\n",
- "*** 5 configurations x 81.0 iterations each\n",
- "\n",
- "183 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "184 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "185 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "186 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n",
- "\n",
- "187 | Mon Nov 4 11:31:06 2019 | lowest loss so far: 0.0156 (run 59)\n",
- "\n",
- "\n",
- "0 seconds.\n"
- ]
- }
- ],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "\"bare-bones demonstration of using hyperband to tune sklearn GBT\"\n",
- "\n",
- "#from hyperband import Hyperband\n",
- "#from defs.gb import get_params, try_params\n",
- "\n",
- "hb = Hyperband( get_params, try_params )\n",
- "\n",
- "# no actual tuning, doesn't call try_params()\n",
- "results = hb.run( dry_run = True )\n",
- "\n",
- "#results = hb.run( skip_last = 1 ) # shorter run\n",
- "#results = hb.run()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>?column?</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>[5]</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[([5],)]"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "best_configs"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 118,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2, 'b': 0.3388081749546307, 'c': 0.704635960884642},\n",
- " {'a': 1, 'b': 0.4904175136129263, 'c': 0.8971084273807718},\n",
- " {'a': 1, 'b': 1.2386463990117793, 'c': 0.21568311690580266},\n",
- " {'a': 1, 'b': 1.91007461806631, 'c': 0.17778124867596956},\n",
- " {'a': 1, 'b': 1.2563450220231427, 'c': 0.002076412746974121}]"
- ]
- },
- "execution_count": 118,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from sklearn.model_selection import ParameterSampler\n",
- "from scipy.stats.distributions import expon, uniform, lognorm\n",
- "import numpy as np\n",
- "#rng = np.random.RandomState()\n",
- "param_grid = {'a':[1, 2], 'b': expon(), 'c': uniform()}\n",
- "#param_list = list(ParameterSampler(param_grid, n_iter=5, random_state=rng))\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=5))\n",
- "rounded_list = [dict((k, round(v, 6)) for (k, v) in d.items()) for d in param_list]\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2, 'b': 0.37954129345633403, 'c': 0.3742154014629032},\n",
- " {'a': 2, 'b': 1.2830633021262747, 'c': 0.4373122879029032},\n",
- " {'a': 1, 'b': 0.22037072550727527, 'c': 0.26397341600176616},\n",
- " {'a': 1, 'b': 0.549444485603122, 'c': 0.8317686948528791},\n",
- " {'a': 1, 'b': 1.0567787144413414, 'c': 0.9560841093558743}]"
- ]
- },
- "execution_count": 33,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'a': 2.0, 'b': 0.379541, 'c': 0.374215},\n",
- " {'a': 2.0, 'b': 1.283063, 'c': 0.437312},\n",
- " {'a': 1.0, 'b': 0.220371, 'c': 0.263973},\n",
- " {'a': 1.0, 'b': 0.549444, 'c': 0.831769},\n",
- " {'a': 1.0, 'b': 1.056779, 'c': 0.956084}]"
- ]
- },
- "execution_count": 34,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "rounded_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 150,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'d': 2.9713720038716116},\n",
- " {'d': 10.275052606706604},\n",
- " {'d': 4.211836333907813},\n",
- " {'d': 3.6005371688499834},\n",
- " {'d': 14.68709362771547}]"
- ]
- },
- "execution_count": 150,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "#rng = np.random.RandomState(0)\n",
- "param_grid = {'d': lognorm(1, 2, 3)}\n",
- "#param_list = list(ParameterSampler(param_grid, n_iter=5, random_state=rng))\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=5))\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 197,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[0.20984381 1.42136262 0.81160104 0.038913 0.22006219 6.32888505\n",
- " 0.09113144]\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[{'lr': 0.22006218937865862, 'optimizer': 'Adam'},\n",
- " {'lr': 1.4213626223774578, 'optimizer': 'SGD'},\n",
- " {'lr': 0.09113143553155685, 'optimizer': 'SGD'},\n",
- " {'lr': 0.09113143553155685, 'optimizer': 'Adam'},\n",
- " {'lr': 0.038913004274499154, 'optimizer': 'Adam'},\n",
- " {'lr': 0.038913004274499154, 'optimizer': 'SGD'},\n",
- " {'lr': 6.328885051196006, 'optimizer': 'Adam'}]"
- ]
- },
- "execution_count": 197,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from sklearn.model_selection import ParameterSampler\n",
- "from scipy.stats.distributions import uniform\n",
- "import numpy as np\n",
- "\n",
- "#s = np.random.uniform(-1,1,7)\n",
- "#print (s)\n",
- " \n",
- "# optimizer\n",
- "optimizer = ['Adam', 'SGD']\n",
- "# learning rate (log scale)\n",
- "lr_range = [0.001, 10]\n",
- "lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), 7)\n",
- "print (lr)\n",
- "\n",
- "# create random param list\n",
- "param_grid = {\n",
- " 'optimizer': optimizer,\n",
- " 'lr': lr\n",
- "}\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=7))\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 266,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[{'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.07983433464722507,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.03805362658279962,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.09043633721868387,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.02775811670911417,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.104019113296403,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.06986494800074812,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.010449656955883938,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 4,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.04915490422264339,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.05257644929029893,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 1,\n",
- " 'optimizer': 'Adam'},\n",
- " {'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.02993608422766151,\n",
- " 'metrics': 'accuracy',\n",
- " 'model_id': 2,\n",
- " 'optimizer': 'SGD'}]"
- ]
- },
- "execution_count": 266,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# model architecture\n",
- "model_id = [1, 2]\n",
- "\n",
- "# compile params\n",
- "\n",
- "# loss function\n",
- "loss = ['categorical_crossentropy']\n",
- "# optimizer\n",
- "optimizer = ['Adam', 'SGD']\n",
- "# learning rate\n",
- "lr = [0.01, 0.1]\n",
- "# metrics\n",
- "metrics = ['accuracy']\n",
- "\n",
- "# fit params\n",
- "\n",
- "# batch size\n",
- "batch_size = [4, 8]\n",
- "# epochs\n",
- "epochs = [1]\n",
- "\n",
- "# create random param list\n",
- "param_grid = {\n",
- " 'model_id': model_id,\n",
- " 'loss': loss,\n",
- " 'optimizer': optimizer,\n",
- " 'lr': uniform(lr[0], lr[1]),\n",
- " 'metrics': metrics,\n",
- " 'batch_size': batch_size,\n",
- " 'epochs': epochs\n",
- "}\n",
- "param_list = list(ParameterSampler(param_grid, n_iter=10))\n",
- "param_list"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 212,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'batch_size': 8,\n",
- " 'epochs': 1,\n",
- " 'loss': 'categorical_crossentropy',\n",
- " 'lr': 0.03396784466820144,\n",
- " 'optimizer': 'Adam'}"
- ]
- },
- "execution_count": 212,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "param_list[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 285,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']$$\n",
- "batch_size=4,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n",
- "$$loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']$$\n",
- "batch_size=8,epochs=1\n"
- ]
- }
- ],
- "source": [
- "for params in param_list:\n",
- "# for key, value in params.items():\n",
- "# print (key, value)\n",
- "\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " print (compile_params)\n",
- " \n",
- " fit_params = \"batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\"))\n",
- " print (fit_params)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 301,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n"
- ]
- }
- ],
- "source": [
- "%sql DROP TABLE IF EXISTS mst_table_hb, mst_table_auto_hb;\n",
- "\n",
- "%sql CREATE TABLE mst_table_hb(mst_key serial, model_id integer, compile_params varchar, fit_params varchar);\n",
- "\n",
- "for params in param_list:\n",
- " model_id = str(params.get(\"model_id\"))\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
- " row_content = \"(\" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
- " \n",
- " %sql INSERT INTO mst_table_hb (model_id, compile_params, fit_params) VALUES $row_content"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 302,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "10 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>mst_key</th>\n",
- " <th>model_arch_id</th>\n",
- " <th>compile_params</th>\n",
- " <th>fit_params</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>2</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>3</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>4</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>5</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>6</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>7</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>8</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']</td>\n",
- " <td>batch_size=4,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>9</td>\n",
- " <td>1</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>10</td>\n",
- " <td>2</td>\n",
- " <td>loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']</td>\n",
- " <td>batch_size=8,epochs=1</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(1, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.07983433464722507)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (2, 1, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.03805362658279962)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (3, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.09043633721868387)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (4, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.02775811670911417)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (5, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.104019113296403)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (6, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.06986494800074812)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (7, 2, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.010449656955883938)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (8, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.04915490422264339)',metrics=['accuracy']\", u'batch_size=4,epochs=1'),\n",
- " (9, 1, u\"loss='categorical_crossentropy',optimizer='Adam(lr=0.05257644929029893)',metrics=['accuracy']\", u'batch_size=8,epochs=1'),\n",
- " (10, 2, u\"loss='categorical_crossentropy',optimizer='SGD(lr=0.02993608422766151)',metrics=['accuracy']\", u'batch_size=8,epochs=1')]"
- ]
- },
- "execution_count": 302,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "SELECT * FROM mst_table_hb ORDER BY mst_key;"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- " \n",
- " def __init__( self, get_params_function, try_params_function ):\n",
- " self.get_params = get_params_function\n",
- " self.try_params = try_params_function\n",
- "\n",
- " self.max_iter = 81 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
- "\n",
- " self.logeta = lambda x: log( x ) / log( self.eta )\n",
- " self.s_max = int( self.logeta( self.max_iter ))\n",
- " self.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- " self.results = [] # list of dicts\n",
- " self.counter = 0\n",
- " self.best_loss = np.inf\n",
- " self.best_counter = -1\n",
- "\n",
- "\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- " for s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- " # initial number of configurations\n",
- " n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- " # initial number of iterations per config\n",
- " r = self.max_iter * self.eta ** ( -s )\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n = \", n)\n",
- " print (\"r = \", r)\n",
- "\n",
- " # n random configurations\n",
- " T = self.get_params(n) # what to return from function if anything?\n",
- " \n",
- " return\n",
- "\n",
- " for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1\n",
- "\n",
- " # Run each of the n configs for <iterations>\n",
- " # and keep best (n_configs / eta) configurations\n",
- "\n",
- " n_configs = n * self.eta ** ( -i )\n",
- " n_iterations = r * self.eta ** ( i )\n",
- "\n",
- " print \"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- " n_configs, n_iterations )\n",
- "\n",
- " val_losses = []\n",
- " early_stops = []\n",
- " \n",
- " \n",
- " \n",
- " \n",
- "\n",
- " for t in T:\n",
- "\n",
- " self.counter += 1\n",
- " print \"\\n{} | {} | lowest loss so far: {:.4f} (run {})\\n\".format(\n",
- " self.counter, ctime(), self.best_loss, self.best_counter )\n",
- "\n",
- " start_time = time()\n",
- "\n",
- " if dry_run:\n",
- " result = { 'loss': random(), 'log_loss': random(), 'auc': random()}\n",
- " else:\n",
- " result = self.try_params( n_iterations, t ) # <---\n",
- "\n",
- " assert( type( result ) == dict )\n",
- " assert( 'loss' in result )\n",
- "\n",
- " seconds = int( round( time() - start_time ))\n",
- " print \"\\n{} seconds.\".format( seconds )\n",
- "\n",
- " loss = result['loss']\n",
- " val_losses.append( loss )\n",
- "\n",
- " early_stop = result.get( 'early_stop', False )\n",
- " early_stops.append( early_stop )\n",
- "\n",
- " # keeping track of the best result so far (for display only)\n",
- " # could do it be checking results each time, but hey\n",
- " if loss < self.best_loss:\n",
- " self.best_loss = loss\n",
- " self.best_counter = self.counter\n",
- "\n",
- " result['counter'] = self.counter\n",
- " result['seconds'] = seconds\n",
- " result['params'] = t\n",
- " result['iterations'] = n_iterations\n",
- " \n",
- " self.results.append( result )\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
- " indices = np.argsort( val_losses )\n",
- " T = [ T[i] for i in indices if not early_stops[i]]\n",
- " T = T[ 0:int( n_configs / self.eta )]\n",
- " \n",
- " return self.results"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.10"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/community-artifacts/Deep-learning/automl/hyperband_v3_mnist.ipynb b/community-artifacts/Deep-learning/automl/hyperband_v3_mnist.ipynb
deleted file mode 100644
index aa290ba..0000000
--- a/community-artifacts/Deep-learning/automl/hyperband_v3_mnist.ipynb
+++ /dev/null
@@ -1,2928 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Hyperband using MNIST\n",
- "\n",
- "Model architecture based on https://keras.io/examples/mnist_transfer_cnn/ \n",
- "\n",
- "To load images into tables we use the script called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning which uses the Python Imaging Library so supports multiple formats http://www.pythonware.com/products/pil/\n",
- "\n",
- "## Table of contents\n",
- "<a href=\"#import_libraries\">1. Import libraries</a>\n",
- "\n",
- "<a href=\"#load_and_prepare_data\">2. Load and prepare data</a>\n",
- "\n",
- "<a href=\"#image_preproc\">3. Call image preprocessor</a>\n",
- "\n",
- "<a href=\"#define_and_load_model\">4. Define and load model architecture</a>\n",
- "\n",
- "<a href=\"#hyperband\">5. Hyperband</a>\n",
- "\n",
- "<a href=\"#plot\">6. Plot results</a>"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
- " \"You should import from traitlets.config instead.\", ShimWarning)\n",
- "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
- " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
- ]
- }
- ],
- "source": [
- "%load_ext sql"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Greenplum Database 5.x on GCP (PM demo machine) - direct external IP access\n",
- "#%sql postgresql://gpadmin@34.67.65.96:5432/madlib\n",
- "\n",
- "# Greenplum Database 5.x on GCP - via tunnel\n",
- "%sql postgresql://gpadmin@localhost:8000/madlib\n",
- " \n",
- "# PostgreSQL local\n",
- "#%sql postgresql://fmcquillan@localhost:5432/madlib\n",
- "\n",
- "# psycopg2 connection\n",
- "import psycopg2 as p2\n",
- "#conn = p2.connect('postgresql://fmcquillan@localhost:5432/madlib')\n",
- "conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')\n",
- "cur = conn.cursor()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>version</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(u'MADlib version: 1.17-dev, git revision: rel/v1.16-47-g5a1717e, cmake configuration time: Tue Nov 19 01:02:39 UTC 2019, build type: release, build system: Linux-3.10.0-957.27.2.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5',)]"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%sql select madlib.version();\n",
- "#%sql select version();"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"import_libraries\"></a>\n",
- "# 1. Import libraries\n",
- "From https://keras.io/examples/mnist_transfer_cnn/ import libraries and define some params"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Using TensorFlow backend.\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Couldn't import dot_parser, loading of dot files will not be possible.\n"
- ]
- }
- ],
- "source": [
- "from __future__ import print_function\n",
- "\n",
- "import datetime\n",
- "import keras\n",
- "from keras.datasets import mnist\n",
- "from keras.models import Sequential\n",
- "from keras.layers import Dense, Dropout, Activation, Flatten\n",
- "from keras.layers import Conv2D, MaxPooling2D\n",
- "from keras import backend as K\n",
- "\n",
- "now = datetime.datetime.now\n",
- "\n",
- "#batch_size = 128\n",
- "num_classes = 10\n",
- "#epochs = 5\n",
- "\n",
- "# input image dimensions\n",
- "img_rows, img_cols = 28, 28\n",
- "# number of convolutional filters to use\n",
- "filters = 32\n",
- "# size of pooling area for max pooling\n",
- "pool_size = 2\n",
- "# convolution kernel size\n",
- "kernel_size = 3\n",
- "\n",
- "if K.image_data_format() == 'channels_first':\n",
- " input_shape = (1, img_rows, img_cols)\n",
- "else:\n",
- " input_shape = (img_rows, img_cols, 1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Others needed in this workbook"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "import numpy as np"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"load_and_prepare_data\"></a>\n",
- "# 2. Load and prepare data\n",
- "\n",
- "First load MNIST data from Keras, consisting of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(10000, 28, 28)\n",
- "(10000, 28, 28, 1)\n"
- ]
- }
- ],
- "source": [
- "# the data, split between train and test sets\n",
- "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
- "\n",
- "# reshape to match model architecture\n",
- "print(x_test.shape)\n",
- "x_train = x_train.reshape(len(x_train), *input_shape)\n",
- "x_test = x_test.reshape(len(x_test), *input_shape)\n",
- "print(x_test.shape)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Load datasets into tables using image loader scripts called <em>madlib_image_loader.py</em> located at https://github.com/apache/madlib-site/tree/asf-site/community-artifacts/Deep-learning"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [],
- "source": [
- "# MADlib tools directory\n",
- "import sys\n",
- "import os\n",
- "madlib_site_dir = '/Users/fmcquillan/Documents/Product/MADlib/Demos/data'\n",
- "sys.path.append(madlib_site_dir)\n",
- "\n",
- "# Import image loader module\n",
- "from madlib_image_loader import ImageLoader, DbCredentials"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Specify database credentials, for connecting to db\n",
- "db_creds = DbCredentials(user='gpadmin',\n",
- " host='localhost',\n",
- " port='8000',\n",
- " password='')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Initialize ImageLoader (increase num_workers to run faster)\n",
- "iloader = ImageLoader(num_workers=5, db_creds=db_creds)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "MainProcess: Connected to madlib db.\n",
- "Executing: CREATE TABLE train_mnist (id SERIAL, x REAL[], y TEXT)\n",
- "CREATE TABLE\n",
- "Created table train_mnist in madlib db\n",
- "Spawning 5 workers...\n",
- "Initializing PoolWorker-11 [pid 34068]\n",
- "PoolWorker-11: Created temporary directory /tmp/madlib_RbuQlbqxI5\n",
- "Initializing PoolWorker-12 [pid 34069]\n",
- "PoolWorker-12: Created temporary directory /tmp/madlib_tEyH9GMFGV\n",
- "Initializing PoolWorker-13 [pid 34070]\n",
- "PoolWorker-13: Created temporary directory /tmp/madlib_TyYs4viAVD\n",
- "Initializing PoolWorker-14 [pid 34071]\n",
- "Initializing PoolWorker-15 [pid 34072]\n",
- "PoolWorker-14: Created temporary directory /tmp/madlib_KTwnncRsaq\n",
- "PoolWorker-15: Created temporary directory /tmp/madlib_jtG9zAC8HU\n",
- "PoolWorker-11: Connected to madlib db.\n",
- "PoolWorker-13: Connected to madlib db.\n",
- "PoolWorker-14: Connected to madlib db.\n",
- "PoolWorker-12: Connected to madlib db.\n",
- "PoolWorker-15: Connected to madlib db.\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0000.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0000.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0000.tmp\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0000.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0000.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0001.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0001.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0001.tmp\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0001.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0001.tmp\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0002.tmp\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0002.tmp\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0002.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0002.tmp\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0002.tmp\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0003.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0003.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0003.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0003.tmp\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0003.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0004.tmp\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0004.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0004.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0004.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0004.tmp\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0005.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0005.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0005.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0005.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0005.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0006.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0006.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0006.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0006.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0006.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0007.tmp\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0007.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0007.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0007.tmp\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0007.tmp\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0008.tmp\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0008.tmp\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0008.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0008.tmp\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0008.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0009.tmp\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0009.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0009.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0009.tmp\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0009.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0010.tmp\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0010.tmp\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0010.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0010.tmp\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0010.tmp\n",
- "PoolWorker-13: Wrote 1000 images to /tmp/madlib_TyYs4viAVD/train_mnist0011.tmp\n",
- "PoolWorker-11: Wrote 1000 images to /tmp/madlib_RbuQlbqxI5/train_mnist0011.tmp\n",
- "PoolWorker-14: Wrote 1000 images to /tmp/madlib_KTwnncRsaq/train_mnist0011.tmp\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "PoolWorker-12: Wrote 1000 images to /tmp/madlib_tEyH9GMFGV/train_mnist0011.tmp\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-13: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Loaded 1000 images into train_mnist\n",
- "PoolWorker-14: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Wrote 1000 images to /tmp/madlib_jtG9zAC8HU/train_mnist0011.tmp\n",
- "PoolWorker-12: Loaded 1000 images into train_mnist\n",
- "PoolWorker-15: Loaded 1000 images into train_mnist\n",
- "PoolWorker-11: Removed temporary directory /tmp/madlib_RbuQlbqxI5\n",
- "PoolWorker-15: Removed temporary directory /tmp/madlib_jtG9zAC8HU\n",
- "PoolWorker-12: Removed temporary directory /tmp/madlib_tEyH9GMFGV\n",
- "PoolWorker-13: Removed temporary directory /tmp/madlib_TyYs4viAVD\n",
- "PoolWorker-14: Removed temporary directory /tmp/madlib_KTwnncRsaq\n",
- "Done! Loaded 60000 images in 45.7068669796s\n",
- "5 workers terminated.\n",
- "MainProcess: Connected to madlib db.\n",
- "Executing: CREATE TABLE test_mnist (id SERIAL, x REAL[], y TEXT)\n",
- "CREATE TABLE\n",
- "Created table test_mnist in madlib db\n",
- "Spawning 5 workers...\n",
- "Initializing PoolWorker-16 [pid 34074]\n",
- "PoolWorker-16: Created temporary directory /tmp/madlib_MjwU1yRoMW\n",
- "Initializing PoolWorker-17 [pid 34075]\n",
- "PoolWorker-17: Created temporary directory /tmp/madlib_kTezv88uWu\n",
- "Initializing PoolWorker-18 [pid 34076]\n",
- "PoolWorker-18: Created temporary directory /tmp/madlib_TFIofbewK1\n",
- "Initializing PoolWorker-19 [pid 34077]\n",
- "PoolWorker-19: Created temporary directory /tmp/madlib_QUIRxlckvj\n",
- "PoolWorker-20: Created temporary directory /tmp/madlib_Eii5YFUzCZ\n",
- "Initializing PoolWorker-20 [pid 34078]\n",
- "PoolWorker-17: Connected to madlib db.\n",
- "PoolWorker-18: Connected to madlib db.\n",
- "PoolWorker-19: Connected to madlib db.\n",
- "PoolWorker-16: Connected to madlib db.\n",
- "PoolWorker-20: Connected to madlib db.\n",
- "PoolWorker-18: Wrote 1000 images to /tmp/madlib_TFIofbewK1/test_mnist0000.tmp\n",
- "PoolWorker-19: Wrote 1000 images to /tmp/madlib_QUIRxlckvj/test_mnist0000.tmp\n",
- "PoolWorker-17: Wrote 1000 images to /tmp/madlib_kTezv88uWu/test_mnist0000.tmp\n",
- "PoolWorker-16: Wrote 1000 images to /tmp/madlib_MjwU1yRoMW/test_mnist0000.tmp\n",
- "PoolWorker-20: Wrote 1000 images to /tmp/madlib_Eii5YFUzCZ/test_mnist0000.tmp\n",
- "PoolWorker-18: Loaded 1000 images into test_mnist\n",
- "PoolWorker-17: Loaded 1000 images into test_mnist\n",
- "PoolWorker-19: Loaded 1000 images into test_mnist\n",
- "PoolWorker-18: Wrote 1000 images to /tmp/madlib_TFIofbewK1/test_mnist0001.tmp\n",
- "PoolWorker-16: Loaded 1000 images into test_mnist\n",
- "PoolWorker-20: Loaded 1000 images into test_mnist\n",
- "PoolWorker-18: Loaded 1000 images into test_mnist\n",
- "PoolWorker-19: Wrote 1000 images to /tmp/madlib_QUIRxlckvj/test_mnist0001.tmp\n",
- "PoolWorker-17: Wrote 1000 images to /tmp/madlib_kTezv88uWu/test_mnist0001.tmp\n",
- "PoolWorker-16: Wrote 1000 images to /tmp/madlib_MjwU1yRoMW/test_mnist0001.tmp\n",
- "PoolWorker-20: Wrote 1000 images to /tmp/madlib_Eii5YFUzCZ/test_mnist0001.tmp\n",
- "PoolWorker-19: Loaded 1000 images into test_mnist\n",
- "PoolWorker-17: Loaded 1000 images into test_mnist\n",
- "PoolWorker-16: Loaded 1000 images into test_mnist\n",
- "PoolWorker-20: Loaded 1000 images into test_mnist\n",
- "PoolWorker-16: Removed temporary directory /tmp/madlib_MjwU1yRoMW\n",
- "PoolWorker-19: Removed temporary directory /tmp/madlib_QUIRxlckvj\n",
- "PoolWorker-17: Removed temporary directory /tmp/madlib_kTezv88uWu\n",
- "PoolWorker-20: Removed temporary directory /tmp/madlib_Eii5YFUzCZ\n",
- "PoolWorker-18: Removed temporary directory /tmp/madlib_TFIofbewK1\n",
- "Done! Loaded 10000 images in 6.80017995834s\n",
- "5 workers terminated.\n"
- ]
- }
- ],
- "source": [
- "# Drop tables\n",
- "%sql DROP TABLE IF EXISTS train_mnist, test_mnist\n",
- "\n",
- "# Save images to temporary directories and load into database\n",
- "iloader.load_dataset_from_np(x_train, y_train, 'train_mnist', append=False)\n",
- "iloader.load_dataset_from_np(x_test, y_test, 'test_mnist', append=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"image_preproc\"></a>\n",
- "# 3. Call image preprocessor\n",
- "\n",
- "Transforms from one image per row to multiple images per row for batch optimization. Also normalizes and one-hot encodes.\n",
- "\n",
- "Training dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>source_table</th>\n",
- " <th>output_table</th>\n",
- " <th>dependent_varname</th>\n",
- " <th>independent_varname</th>\n",
- " <th>dependent_vartype</th>\n",
- " <th>class_values</th>\n",
- " <th>buffer_size</th>\n",
- " <th>normalizing_const</th>\n",
- " <th>num_classes</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>train_mnist</td>\n",
- " <td>train_mnist_packed</td>\n",
- " <td>y</td>\n",
- " <td>x</td>\n",
- " <td>text</td>\n",
- " <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
- " <td>1000</td>\n",
- " <td>255.0</td>\n",
- " <td>10</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(u'train_mnist', u'train_mnist_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 1000, 255.0, 10)]"
- ]
- },
- "execution_count": 37,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "DROP TABLE IF EXISTS train_mnist_packed, train_mnist_packed_summary;\n",
- "\n",
- "SELECT madlib.training_preprocessor_dl('train_mnist', -- Source table\n",
- " 'train_mnist_packed', -- Output table\n",
- " 'y', -- Dependent variable\n",
- " 'x', -- Independent variable\n",
- " 1000, -- Buffer size\n",
- " 255 -- Normalizing constant\n",
- " );\n",
- "\n",
- "SELECT * FROM train_mnist_packed_summary;"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Test dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 39,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "1 rows affected.\n",
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>source_table</th>\n",
- " <th>output_table</th>\n",
- " <th>dependent_varname</th>\n",
- " <th>independent_varname</th>\n",
- " <th>dependent_vartype</th>\n",
- " <th>class_values</th>\n",
- " <th>buffer_size</th>\n",
- " <th>normalizing_const</th>\n",
- " <th>num_classes</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>test_mnist</td>\n",
- " <td>test_mnist_packed</td>\n",
- " <td>y</td>\n",
- " <td>x</td>\n",
- " <td>text</td>\n",
- " <td>[u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9']</td>\n",
- " <td>5000</td>\n",
- " <td>255.0</td>\n",
- " <td>10</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(u'test_mnist', u'test_mnist_packed', u'y', u'x', u'text', [u'0', u'1', u'2', u'3', u'4', u'5', u'6', u'7', u'8', u'9'], 5000, 255.0, 10)]"
- ]
- },
- "execution_count": 39,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "DROP TABLE IF EXISTS test_mnist_packed, test_mnist_packed_summary;\n",
- "\n",
- "SELECT madlib.validation_preprocessor_dl('test_mnist', -- Source table\n",
- " 'test_mnist_packed', -- Output table\n",
- " 'y', -- Dependent variable\n",
- " 'x', -- Independent variable\n",
- " 'train_mnist_packed' -- Training preproc table\n",
- " );\n",
- "\n",
- "SELECT * FROM test_mnist_packed_summary;"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"define_and_load_model\"></a>\n",
- "# 4. Define and load model architecture\n",
- "\n",
- "Model with feature and classification layers trainable"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "_________________________________________________________________\n",
- "Layer (type) Output Shape Param # \n",
- "=================================================================\n",
- "conv2d_1 (Conv2D) (None, 26, 26, 32) 320 \n",
- "_________________________________________________________________\n",
- "activation_1 (Activation) (None, 26, 26, 32) 0 \n",
- "_________________________________________________________________\n",
- "conv2d_2 (Conv2D) (None, 24, 24, 32) 9248 \n",
- "_________________________________________________________________\n",
- "activation_2 (Activation) (None, 24, 24, 32) 0 \n",
- "_________________________________________________________________\n",
- "max_pooling2d_1 (MaxPooling2 (None, 12, 12, 32) 0 \n",
- "_________________________________________________________________\n",
- "dropout_1 (Dropout) (None, 12, 12, 32) 0 \n",
- "_________________________________________________________________\n",
- "flatten_1 (Flatten) (None, 4608) 0 \n",
- "_________________________________________________________________\n",
- "dense_1 (Dense) (None, 128) 589952 \n",
- "_________________________________________________________________\n",
- "activation_3 (Activation) (None, 128) 0 \n",
- "_________________________________________________________________\n",
- "dropout_2 (Dropout) (None, 128) 0 \n",
- "_________________________________________________________________\n",
- "dense_2 (Dense) (None, 10) 1290 \n",
- "_________________________________________________________________\n",
- "activation_4 (Activation) (None, 10) 0 \n",
- "=================================================================\n",
- "Total params: 600,810\n",
- "Trainable params: 600,810\n",
- "Non-trainable params: 0\n",
- "_________________________________________________________________\n"
- ]
- }
- ],
- "source": [
- "# define two groups of layers: feature (convolutions) and classification (dense)\n",
- "feature_layers = [\n",
- " Conv2D(filters, kernel_size,\n",
- " padding='valid',\n",
- " input_shape=input_shape),\n",
- " Activation('relu'),\n",
- " Conv2D(filters, kernel_size),\n",
- " Activation('relu'),\n",
- " MaxPooling2D(pool_size=pool_size),\n",
- " Dropout(0.25),\n",
- " Flatten(),\n",
- "]\n",
- "\n",
- "classification_layers = [\n",
- " Dense(128),\n",
- " Activation('relu'),\n",
- " Dropout(0.5),\n",
- " Dense(num_classes),\n",
- " Activation('softmax')\n",
- "]\n",
- "\n",
- "# create complete model\n",
- "model = Sequential(feature_layers + classification_layers)\n",
- "\n",
- "model.summary()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Load into model architecture table using psycopg2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "<table>\n",
- " <tr>\n",
- " <th>model_id</th>\n",
- " <th>name</th>\n",
- " </tr>\n",
- " <tr>\n",
- " <td>1</td>\n",
- " <td>feature + classification layers trainable</td>\n",
- " </tr>\n",
- "</table>"
- ],
- "text/plain": [
- "[(1, u'feature + classification layers trainable')]"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%sql DROP TABLE IF EXISTS model_arch_table_mnist;\n",
- "query = \"SELECT madlib.load_keras_model('model_arch_table_mnist', %s, NULL, %s)\"\n",
- "cur.execute(query,[model.to_json(), \"feature + classification layers trainable\"])\n",
- "conn.commit()\n",
- "\n",
- "# check model loaded OK\n",
- "%sql SELECT model_id, name FROM model_arch_table_mnist;"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"hyperband\"></a>\n",
- "# 5. Hyperband"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Create tables"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "Done.\n",
- "1 rows affected.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[]"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "%%sql\n",
- "-- overall results table\n",
- "DROP TABLE IF EXISTS results_mnist;\n",
- "CREATE TABLE results_mnist ( \n",
- " model_id INTEGER, \n",
- " compile_params TEXT,\n",
- " fit_params TEXT, \n",
- " model_type TEXT, \n",
- " model_size DOUBLE PRECISION, \n",
- " metrics_elapsed_time DOUBLE PRECISION[], \n",
- " metrics_type TEXT[], \n",
- " training_metrics_final DOUBLE PRECISION, \n",
- " training_loss_final DOUBLE PRECISION, \n",
- " training_metrics DOUBLE PRECISION[], \n",
- " training_loss DOUBLE PRECISION[], \n",
- " validation_metrics_final DOUBLE PRECISION, \n",
- " validation_loss_final DOUBLE PRECISION, \n",
- " validation_metrics DOUBLE PRECISION[], \n",
- " validation_loss DOUBLE PRECISION[], \n",
- " model_arch_table TEXT, \n",
- " num_iterations INTEGER, \n",
- " start_training_time TIMESTAMP, \n",
- " end_training_time TIMESTAMP,\n",
- " s INTEGER, \n",
- " n INTEGER, \n",
- " r INTEGER,\n",
- " run_id SERIAL\n",
- " );\n",
- "\n",
- "-- model selection table\n",
- "DROP TABLE IF EXISTS mst_table_hb_mnist;\n",
- "CREATE TABLE mst_table_hb_mnist (\n",
- " mst_key SERIAL, \n",
- " model_id INTEGER, \n",
- " compile_params VARCHAR, \n",
- " fit_params VARCHAR\n",
- " );\n",
- "\n",
- "-- model selection summary table\n",
- "DROP TABLE IF EXISTS mst_table_hb_mnist_summary;\n",
- "CREATE TABLE mst_table_hb_mnist_summary (model_arch_table VARCHAR);\n",
- "INSERT INTO mst_table_hb_mnist_summary VALUES ('model_arch_table_mnist');"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Table names"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "results_table = 'results_mnist'\n",
- "\n",
- "output_table = 'mnist_multi_model'\n",
- "output_table_info = '_'.join([output_table, 'info'])\n",
- "output_table_summary = '_'.join([output_table, 'summary'])\n",
- "\n",
- "mst_table = 'mst_table_hb_mnist'\n",
- "mst_table_summary = '_'.join([mst_table, 'summary'])\n",
- "\n",
- "model_arch_table = 'model_arch_library_mnist'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Pretty print run schedule"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "max_iter = 9\n",
- "eta = 3\n",
- "B = 3*max_iter = 27\n",
- " \n",
- "s=2\n",
- "n_i r_i\n",
- "------------\n",
- "9 1.0\n",
- "3.0 3.0\n",
- "1.0 9.0\n",
- " \n",
- "s=1\n",
- "n_i r_i\n",
- "------------\n",
- "3 3.0\n",
- "1.0 9.0\n",
- " \n",
- "s=0\n",
- "n_i r_i\n",
- "------------\n",
- "3 9\n",
- " \n",
- "sum of configurations at leaf nodes across all s = 5.0\n",
- "(if have more workers than this, they may not be 100% busy)\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "from math import log, ceil\n",
- "\n",
- "#input\n",
- "max_iter = 9 # maximum iterations/epochs per configuration\n",
- "eta = 3 # defines downsampling rate (default=3)\n",
- "\n",
- "logeta = lambda x: log(x)/log(eta)\n",
- "s_max = int(logeta(max_iter)) # number of unique executions of Successive Halving (minus one)\n",
- "B = (s_max+1)*max_iter # total number of iterations (without reuse) per execution of Succesive Halving (n,r)\n",
- "\n",
- "#echo output\n",
- "print (\"max_iter = \" + str(max_iter))\n",
- "print (\"eta = \" + str(eta))\n",
- "print (\"B = \" + str(s_max+1) + \"*max_iter = \" + str(B))\n",
- "\n",
- "sum_leaf_n_i = 0 # count configurations at leaf nodes across all s\n",
- "\n",
- "#### Begin Finite Horizon Hyperband outlerloop. Repeat indefinitely.\n",
- "for s in reversed(range(s_max+1)):\n",
- " \n",
- " print (\" \")\n",
- " print (\"s=\" + str(s))\n",
- " print (\"n_i r_i\")\n",
- " print (\"------------\")\n",
- " counter = 0\n",
- " \n",
- " n = int(ceil(int(B/max_iter/(s+1))*eta**s)) # initial number of configurations\n",
- " r = max_iter*eta**(-s) # initial number of iterations to run configurations for\n",
- "\n",
- " #### Begin Finite Horizon Successive Halving with (n,r)\n",
- " #T = [ get_random_hyperparameter_configuration() for i in range(n) ] \n",
- " for i in range(s+1):\n",
- " # Run each of the n_i configs for r_i iterations and keep best n_i/eta\n",
- " n_i = n*eta**(-i)\n",
- " r_i = r*eta**(i)\n",
- " \n",
- " print (str(n_i) + \" \" + str (r_i))\n",
- " \n",
- " # check if leaf node for this s\n",
- " if counter == s:\n",
- " sum_leaf_n_i += n_i\n",
- " counter += 1\n",
- " \n",
- " #val_losses = [ run_then_return_val_loss(num_iters=r_i,hyperparameters=t) for t in T ]\n",
- " #T = [ T[i] for i in argsort(val_losses)[0:int( n_i/eta )] ]\n",
- " #### End Finite Horizon Successive Halving with (n,r)\n",
- "\n",
- "print (\" \")\n",
- "print (\"sum of configurations at leaf nodes across all s = \" + str(sum_leaf_n_i))\n",
- "print (\"(if have more workers than this, they may not be 100% busy)\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Hyperband "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "\n",
- "from random import random\n",
- "from math import log, ceil\n",
- "from time import time, ctime\n",
- "\n",
- "\n",
- "class Hyperband:\n",
- " \n",
- " def __init__( self, get_params_function, try_params_function ):\n",
- " self.get_params = get_params_function\n",
- " self.try_params = try_params_function\n",
- "\n",
- " self.max_iter = 9 # maximum iterations per configuration\n",
- " self.eta = 3 # defines configuration downsampling rate (default = 3)\n",
- "\n",
- " self.logeta = lambda x: log( x ) / log( self.eta )\n",
- " self.s_max = int( self.logeta( self.max_iter ))\n",
- " self.B = ( self.s_max + 1 ) * self.max_iter\n",
- "\n",
- " self.results = [] # list of dicts\n",
- " self.counter = 0\n",
- " self.best_loss = np.inf\n",
- " self.best_counter = -1\n",
- "\n",
- " # can be called multiple times\n",
- " def run( self, skip_last = 0, dry_run = False ):\n",
- "\n",
- " for s in reversed( range( self.s_max + 1 )):\n",
- " \n",
- " # initial number of configurations\n",
- " n = int( ceil( self.B / self.max_iter / ( s + 1 ) * self.eta ** s ))\n",
- "\n",
- " # initial number of iterations per config\n",
- " r = self.max_iter * self.eta ** ( -s )\n",
- " \n",
- " print (\"s = \", s)\n",
- " print (\"n = \", n)\n",
- " print (\"r = \", r)\n",
- "\n",
- " # n random configurations\n",
- " T = self.get_params(n) # what to return from function if anything?\n",
- " \n",
- " for i in range(( s + 1 ) - int( skip_last )): # changed from s + 1\n",
- "\n",
- " # Run each of the n configs for <iterations>\n",
- " # and keep best (n_configs / eta) configurations\n",
- "\n",
- " n_configs = n * self.eta ** ( -i )\n",
- " n_iterations = r * self.eta ** ( i )\n",
- "\n",
- " print (\"\\n*** {} configurations x {:.1f} iterations each\".format(\n",
- " n_configs, n_iterations ))\n",
- " \n",
- " # multi-model training\n",
- " U = self.try_params(s, n_configs, n_iterations) # what to return from function if anything?\n",
- "\n",
- " # select a number of best configurations for the next loop\n",
- " # filter out early stops, if any\n",
- " # drop from model selection table\n",
- " k = int( n_configs / self.eta)\n",
- " \n",
- " # prune mst_table for next try\n",
- " %sql DELETE FROM $mst_table WHERE mst_key NOT IN (SELECT mst_key FROM $output_table_info ORDER BY validation_loss_final ASC LIMIT $k::INT);\n",
- "\n",
- " #return self.results\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_params(n):\n",
- " \n",
- " from sklearn.model_selection import ParameterSampler\n",
- " from scipy.stats.distributions import uniform\n",
- " import numpy as np\n",
- " \n",
- " # model architecture\n",
- " model_id = [1]\n",
- "\n",
- " # compile params\n",
- " # loss function\n",
- " loss = ['categorical_crossentropy']\n",
- " # optimizer\n",
- " optimizer = ['Adam', 'SGD']\n",
- " # learning rate (sample on log scale here not in ParameterSampler)\n",
- " lr_range = [0.001, 0.1]\n",
- " lr = 10**np.random.uniform(np.log10(lr_range[0]), np.log10(lr_range[1]), n)\n",
- " # metrics\n",
- " metrics = ['accuracy']\n",
- "\n",
- " # fit params\n",
- " # batch size\n",
- " batch_size = [64, 128]\n",
- " # epochs\n",
- " epochs = [1]\n",
- "\n",
- " # create random param list\n",
- " param_grid = {\n",
- " 'model_id': model_id,\n",
- " 'loss': loss,\n",
- " 'optimizer': optimizer,\n",
- " 'lr': lr,\n",
- " 'metrics': metrics,\n",
- " 'batch_size': batch_size,\n",
- " 'epochs': epochs\n",
- " }\n",
- " param_list = list(ParameterSampler(param_grid, n_iter=n))\n",
- " \n",
- " for params in param_list:\n",
- "\n",
- " model_id = str(params.get(\"model_id\"))\n",
- " compile_params = \"$$loss='\" + str(params.get(\"loss\")) + \"',optimizer='\" + str(params.get(\"optimizer\")) + \"(lr=\" + str(params.get(\"lr\")) + \")',metrics=['\" + str(params.get(\"metrics\")) + \"']$$\" \n",
- " fit_params = \"$$batch_size=\" + str(params.get(\"batch_size\")) + \",epochs=\" + str(params.get(\"epochs\")) + \"$$\" \n",
- " row_content = \"(\" + model_id + \", \" + compile_params + \", \" + fit_params + \");\"\n",
- " \n",
- " %sql INSERT INTO $mst_table (model_id, compile_params, fit_params) VALUES $row_content\n",
- " \n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "def try_params(s, n_configs, n_iterations):\n",
- " \n",
- " # multi-model fit\n",
- " # TO DO: use warm start to continue from where left off after if not 1st time thru for this s value\n",
- " %sql DROP TABLE IF EXISTS $output_table, $output_table_summary, $output_table_info;\n",
- " \n",
- " # passing vars as madlib args does not work???\n",
- " #%sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', $output_table, $mst_table, $n_iterations::INT, 0);\n",
- " %sql SELECT madlib.madlib_keras_fit_multiple_model('train_mnist_packed', 'mnist_multi_model', 'mst_table_hb_mnist', $n_iterations::INT, 0, 'test_mnist_packed');\n",
- " \n",
- " # save results\n",
- " %sql DROP TABLE IF EXISTS temp_results;\n",
- " %sql CREATE TABLE temp_results AS (SELECT * FROM $output_table_info);\n",
- " %sql ALTER TABLE temp_results DROP COLUMN mst_key, ADD COLUMN model_arch_table TEXT, ADD COLUMN num_iterations INTEGER, ADD COLUMN start_training_time TIMESTAMP, ADD COLUMN end_training_time TIMESTAMP, ADD COLUMN s INTEGER, ADD COLUMN n INTEGER, ADD COLUMN r INTEGER;\n",
- " %sql UPDATE temp_results SET model_arch_table = (SELECT model_arch_table FROM $output_table_summary), num_iterations = (SELECT num_iterations FROM iris_multi_model_summary), start_training_time = (SELECT start_training_time FROM iris_multi_model_summary), end_training_time = (SELECT end_training_time FROM iris_multi_model_summary), s = $s, n = $n_configs, r = $n_iterations;\n",
- " %sql INSERT INTO $results_table (SELECT * FROM temp_results);\n",
- "\n",
- " return"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "s = 2\n",
- "n = 9\n",
- "r = 1.0\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "1 rows affected.\n",
- "\n",
- "*** 9 configurations x 1.0 iterations each\n",
- "Done.\n"
- ]
- }
- ],
- "source": [
- "hp = Hyperband( get_params, try_params )\n",
- "results = hp.run()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "<a id=\"plot\"></a>\n",
- "# 6. Plot results"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 62,
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib notebook\n",
- "import matplotlib.pyplot as plt\n",
- "from collections import defaultdict\n",
- "import pandas as pd\n",
- "import seaborn as sns\n",
- "sns.set_palette(sns.color_palette(\"hls\", 20))\n",
- "plt.rcParams.update({'font.size': 12})\n",
- "pd.set_option('display.max_colwidth', -1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 68,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "7 rows affected.\n"
- ]
- },
- {
- "data": {
- "application/javascript": [
- "/* Put everything inside the global mpl namespace */\n",
- "window.mpl = {};\n",
- "\n",
- "\n",
- "mpl.get_websocket_type = function() {\n",
- " if (typeof(WebSocket) !== 'undefined') {\n",
- " return WebSocket;\n",
- " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
- " return MozWebSocket;\n",
- " } else {\n",
- " alert('Your browser does not have WebSocket support.' +\n",
- " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
- " 'Firefox 4 and 5 are also supported but you ' +\n",
- " 'have to enable WebSockets in about:config.');\n",
- " };\n",
- "}\n",
- "\n",
- "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
- " this.id = figure_id;\n",
- "\n",
- " this.ws = websocket;\n",
- "\n",
- " this.supports_binary = (this.ws.binaryType != undefined);\n",
- "\n",
- " if (!this.supports_binary) {\n",
- " var warnings = document.getElementById(\"mpl-warnings\");\n",
- " if (warnings) {\n",
- " warnings.style.display = 'block';\n",
- " warnings.textContent = (\n",
- " \"This browser does not support binary websocket messages. \" +\n",
- " \"Performance may be slow.\");\n",
- " }\n",
- " }\n",
- "\n",
- " this.imageObj = new Image();\n",
- "\n",
- " this.context = undefined;\n",
- " this.message = undefined;\n",
- " this.canvas = undefined;\n",
- " this.rubberband_canvas = undefined;\n",
- " this.rubberband_context = undefined;\n",
- " this.format_dropdown = undefined;\n",
- "\n",
- " this.image_mode = 'full';\n",
- "\n",
- " this.root = $('<div/>');\n",
- " this._root_extra_style(this.root)\n",
- " this.root.attr('style', 'display: inline-block');\n",
- "\n",
- " $(parent_element).append(this.root);\n",
- "\n",
- " this._init_header(this);\n",
- " this._init_canvas(this);\n",
- " this._init_toolbar(this);\n",
- "\n",
- " var fig = this;\n",
- "\n",
- " this.waiting = false;\n",
- "\n",
- " this.ws.onopen = function () {\n",
- " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
- " fig.send_message(\"send_image_mode\", {});\n",
- " if (mpl.ratio != 1) {\n",
- " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n",
- " }\n",
- " fig.send_message(\"refresh\", {});\n",
- " }\n",
- "\n",
- " this.imageObj.onload = function() {\n",
- " if (fig.image_mode == 'full') {\n",
- " // Full images could contain transparency (where diff images\n",
- " // almost always do), so we need to clear the canvas so that\n",
- " // there is no ghosting.\n",
- " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
- " }\n",
- " fig.context.drawImage(fig.imageObj, 0, 0);\n",
- " };\n",
- "\n",
- " this.imageObj.onunload = function() {\n",
- " fig.ws.close();\n",
- " }\n",
- "\n",
- " this.ws.onmessage = this._make_on_message_function(this);\n",
- "\n",
- " this.ondownload = ondownload;\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_header = function() {\n",
- " var titlebar = $(\n",
- " '<div class=\"ui-dialog-titlebar ui-widget-header ui-corner-all ' +\n",
- " 'ui-helper-clearfix\"/>');\n",
- " var titletext = $(\n",
- " '<div class=\"ui-dialog-title\" style=\"width: 100%; ' +\n",
- " 'text-align: center; padding: 3px;\"/>');\n",
- " titlebar.append(titletext)\n",
- " this.root.append(titlebar);\n",
- " this.header = titletext[0];\n",
- "}\n",
- "\n",
- "\n",
- "\n",
- "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
- "\n",
- "}\n",
- "\n",
- "\n",
- "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
- "\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_canvas = function() {\n",
- " var fig = this;\n",
- "\n",
- " var canvas_div = $('<div/>');\n",
- "\n",
- " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
- "\n",
- " function canvas_keyboard_event(event) {\n",
- " return fig.key_event(event, event['data']);\n",
- " }\n",
- "\n",
- " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
- " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
- " this.canvas_div = canvas_div\n",
- " this._canvas_extra_style(canvas_div)\n",
- " this.root.append(canvas_div);\n",
- "\n",
- " var canvas = $('<canvas/>');\n",
- " canvas.addClass('mpl-canvas');\n",
- " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
- "\n",
- " this.canvas = canvas[0];\n",
- " this.context = canvas[0].getContext(\"2d\");\n",
- "\n",
- " var backingStore = this.context.backingStorePixelRatio ||\n",
- "\tthis.context.webkitBackingStorePixelRatio ||\n",
- "\tthis.context.mozBackingStorePixelRatio ||\n",
- "\tthis.context.msBackingStorePixelRatio ||\n",
- "\tthis.context.oBackingStorePixelRatio ||\n",
- "\tthis.context.backingStorePixelRatio || 1;\n",
- "\n",
- " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n",
- "\n",
- " var rubberband = $('<canvas/>');\n",
- " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
- "\n",
- " var pass_mouse_events = true;\n",
- "\n",
- " canvas_div.resizable({\n",
- " start: function(event, ui) {\n",
- " pass_mouse_events = false;\n",
- " },\n",
- " resize: function(event, ui) {\n",
- " fig.request_resize(ui.size.width, ui.size.height);\n",
- " },\n",
- " stop: function(event, ui) {\n",
- " pass_mouse_events = true;\n",
- " fig.request_resize(ui.size.width, ui.size.height);\n",
- " },\n",
- " });\n",
- "\n",
- " function mouse_event_fn(event) {\n",
- " if (pass_mouse_events)\n",
- " return fig.mouse_event(event, event['data']);\n",
- " }\n",
- "\n",
- " rubberband.mousedown('button_press', mouse_event_fn);\n",
- " rubberband.mouseup('button_release', mouse_event_fn);\n",
- " // Throttle sequential mouse events to 1 every 20ms.\n",
- " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
- "\n",
- " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
- " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
- "\n",
- " canvas_div.on(\"wheel\", function (event) {\n",
- " event = event.originalEvent;\n",
- " event['data'] = 'scroll'\n",
- " if (event.deltaY < 0) {\n",
- " event.step = 1;\n",
- " } else {\n",
- " event.step = -1;\n",
- " }\n",
- " mouse_event_fn(event);\n",
- " });\n",
- "\n",
- " canvas_div.append(canvas);\n",
- " canvas_div.append(rubberband);\n",
- "\n",
- " this.rubberband = rubberband;\n",
- " this.rubberband_canvas = rubberband[0];\n",
- " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
- " this.rubberband_context.strokeStyle = \"#000000\";\n",
- "\n",
- " this._resize_canvas = function(width, height) {\n",
- " // Keep the size of the canvas, canvas container, and rubber band\n",
- " // canvas in synch.\n",
- " canvas_div.css('width', width)\n",
- " canvas_div.css('height', height)\n",
- "\n",
- " canvas.attr('width', width * mpl.ratio);\n",
- " canvas.attr('height', height * mpl.ratio);\n",
- " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n",
- "\n",
- " rubberband.attr('width', width);\n",
- " rubberband.attr('height', height);\n",
- " }\n",
- "\n",
- " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
- " // upon first draw.\n",
- " this._resize_canvas(600, 600);\n",
- "\n",
- " // Disable right mouse context menu.\n",
- " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
- " return false;\n",
- " });\n",
- "\n",
- " function set_focus () {\n",
- " canvas.focus();\n",
- " canvas_div.focus();\n",
- " }\n",
- "\n",
- " window.setTimeout(set_focus, 100);\n",
- "}\n",
- "\n",
- "mpl.figure.prototype._init_toolbar = function() {\n",
- " var fig = this;\n",
- "\n",
- " var nav_element = $('<div/>')\n",
- " nav_element.attr('style', 'width: 100%');\n",
- " this.root.append(nav_element);\n",
- "\n",
- " // Define a callback function for later on.\n",
- " function toolbar_event(event) {\n",
- " return fig.toolbar_button_onclick(event['data']);\n",
- " }\n",
- " function toolbar_mouse_event(event) {\n",
- " return fig.toolbar_button_onmouseover(event['data']);\n",
- " }\n",
- "\n",
- " for(var toolbar_ind in mpl.toolbar_items) {\n",
- " var name = mpl.toolbar_items[toolbar_ind][0];\n",
- " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
- " var image = mpl.toolbar_items[toolbar_ind][2];\n",
- " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
- "\n",
- " if (!name) {\n",
- " // put a spacer in here.\n",
- " continue;\n",
- " }\n",
- " var button = $('<button/>');\n",
- " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
- " 'ui-button-icon-only');\n",
- " button.attr('role', 'button');\n",
- " button.attr('aria-disabled', 'false');\n",
- " button.click(method_name, toolbar_event);\n",
- " button.mouseover(tooltip, toolbar_mouse_event);\n",
- "\n",
- " var icon_img = $('<span/>');\n",
- " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
- " icon_img.addClass(image);\n",
- " icon_img.addClass('ui-corner-all');\n",
... 1445 lines suppressed ...