You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/04 03:06:56 UTC

spark git commit: [SPARK-7329] [MLLIB] simplify ParamGridBuilder impl

Repository: spark
Updated Branches:
  refs/heads/master 9e25b09f8 -> 1ffa8cb91


[SPARK-7329] [MLLIB] simplify ParamGridBuilder impl

as suggested by justinuang on #5601.

Author: Xiangrui Meng <me...@databricks.com>

Closes #5873 from mengxr/SPARK-7329 and squashes the following commits:

d08f9cf [Xiangrui Meng] simplify tests
b7a7b9b [Xiangrui Meng] simplify grid build


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ffa8cb9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ffa8cb9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ffa8cb9

Branch: refs/heads/master
Commit: 1ffa8cb91f8badf12a8aa190dc25920715a00db7
Parents: 9e25b09
Author: Xiangrui Meng <me...@databricks.com>
Authored: Sun May 3 18:06:48 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Sun May 3 18:06:48 2015 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tuning.py | 28 +++++++++-------------------
 1 file changed, 9 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1ffa8cb9/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a383bd0..1773ab5 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 #
 
+import itertools
+
 __all__ = ['ParamGridBuilder']
 
 
@@ -37,14 +39,10 @@ class ParamGridBuilder(object):
 {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
 {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
 {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
-    >>> fail_count = 0
-    >>> for e in expected:
-    ...     if e not in output:
-    ...         fail_count += 1
-    >>> if len(expected) != len(output):
-    ...     fail_count += 1
-    >>> fail_count
-    0
+    >>> len(output) == len(expected)
+    True
+    >>> all([m in expected for m in output])
+    True
     """
 
     def __init__(self):
@@ -76,17 +74,9 @@ class ParamGridBuilder(object):
         Builds and returns all combinations of parameters specified
         by the param grid.
         """
-        param_maps = [{}]
-        for (param, values) in self._param_grid.items():
-            new_param_maps = []
-            for value in values:
-                for old_map in param_maps:
-                    copied_map = old_map.copy()
-                    copied_map[param] = value
-                    new_param_maps.append(copied_map)
-            param_maps = new_param_maps
-
-        return param_maps
+        keys = self._param_grid.keys()
+        grid_values = self._param_grid.values()
+        return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org