You are viewing a plain text version of this content. The canonical link for it is here.
Posted to user@spark.apache.org by David Leifker <dl...@gmail.com> on 2017/03/14 13:07:10 UTC
[MLlib] Multiple estimators for cross validation
I am hoping to open a discussion around the cross validation in mllib. I
found that I often wanted to evaluate multiple estimators/pipelines (with
different algorithms) or the same estimator with different parameter grids.
The CrossValidator and TrainValidationSplit only allow a single estimator
and parameter grid.
I played around with the idea a bit after looking at jira and other PRs to
see if someone else had already done something. I didn't come across
anything so I put some code together to at least solve my use case. It is
backwards compatible at an api level and has the ability to read the
previous serialized version.
I am considering opening a pull request, however I am interested in what
folks here think. This would be my first contribution.
The general idea is the ability to do this and be able to select the best
model.
// Configure an ML pipeline using nb.
val nb = new NaiveBayes()
val pipeline1 = new Pipeline("p1").setStages(Array(tokenizer,
hashingTF, nb))
val paramGrid1 = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 100))
.build()
// Configure an ML pipeline using lr.
val lr = new LogisticRegression().setMaxIter(10)
val pipeline2 = new Pipeline("p2").setStages(Array(tokenizer,
hashingTF, lr))
val paramGrid2 = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 100))
.build()
// Configure an ML pipeline using nb bernoulli (4 stages)
val binarizer = new Binarizer()
.setInputCol(hashingTF.getOutputCol)
.setOutputCol("binary_features")
val nb2 = new NaiveBayes()
.setModelType("bernoulli")
.setFeaturesCol(binarizer.getOutputCol)
val pipeline3 = new Pipeline("p3").setStages(Array(tokenizer,
hashingTF, binarizer, nb2))
val paramGrid3 = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 100))
.build()
// cross validate with both pipelines
val cv = new CrossValidator()
.setEstimators(Array(pipeline1, pipeline2, pipeline3))
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorsParamMaps(Array(paramGrid1, paramGrid2, paramGrid3))