You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@predictionio.apache.org by do...@apache.org on 2016/07/18 20:17:36 UTC
[05/34] incubator-predictionio git commit: rename all except examples
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala b/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala
deleted file mode 100644
index 2e3eadd..0000000
--- a/e2/src/test/scala/io/prediction/e2/engine/CategoricalNaiveBayesTest.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.prediction.e2.engine
-
-import io.prediction.e2.fixture.{NaiveBayesFixture, SharedSparkContext}
-import org.scalatest.{Matchers, FlatSpec}
-
-import scala.language.reflectiveCalls
-
-class CategoricalNaiveBayesTest extends FlatSpec with Matchers
-with SharedSparkContext with NaiveBayesFixture {
- val Tolerance = .0001
- val labeledPoints = fruit.labeledPoints
-
- "Model" should "have log priors and log likelihoods" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- model.priors(fruit.Banana) should be(-.7885 +- Tolerance)
- model.priors(fruit.Orange) should be(-1.7047 +- Tolerance)
- model.priors(fruit.OtherFruit) should be(-1.0116 +- Tolerance)
-
- model.likelihoods(fruit.Banana)(0)(fruit.Long) should
- be(-.2231 +- Tolerance)
- model.likelihoods(fruit.Banana)(0)(fruit.NotLong) should
- be(-1.6094 +- Tolerance)
- model.likelihoods(fruit.Banana)(1)(fruit.Sweet) should
- be(-.2231 +- Tolerance)
- model.likelihoods(fruit.Banana)(1)(fruit.NotSweet) should
- be(-1.6094 +- Tolerance)
- model.likelihoods(fruit.Banana)(2)(fruit.Yellow) should
- be(-.2231 +- Tolerance)
- model.likelihoods(fruit.Banana)(2)(fruit.NotYellow) should
- be(-1.6094 +- Tolerance)
-
- model.likelihoods(fruit.Orange)(0) should not contain key(fruit.Long)
- model.likelihoods(fruit.Orange)(0)(fruit.NotLong) should be(0.0)
- model.likelihoods(fruit.Orange)(1)(fruit.Sweet) should
- be(-.6931 +- Tolerance)
- model.likelihoods(fruit.Orange)(1)(fruit.NotSweet) should
- be(-.6931 +- Tolerance)
- model.likelihoods(fruit.Orange)(2)(fruit.NotYellow) should be(0.0)
- model.likelihoods(fruit.Orange)(2) should not contain key(fruit.Yellow)
-
- model.likelihoods(fruit.OtherFruit)(0)(fruit.Long) should
- be(-.6931 +- Tolerance)
- model.likelihoods(fruit.OtherFruit)(0)(fruit.NotLong) should
- be(-.6931 +- Tolerance)
- model.likelihoods(fruit.OtherFruit)(1)(fruit.Sweet) should
- be(-.2877 +- Tolerance)
- model.likelihoods(fruit.OtherFruit)(1)(fruit.NotSweet) should
- be(-1.3863 +- Tolerance)
- model.likelihoods(fruit.OtherFruit)(2)(fruit.Yellow) should
- be(-1.3863 +- Tolerance)
- model.likelihoods(fruit.OtherFruit)(2)(fruit.NotYellow) should
- be(-.2877 +- Tolerance)
- }
-
- "Model's log score" should "be the log score of the given point" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- val score = model.logScore(LabeledPoint(
- fruit.Banana,
- Array(fruit.Long, fruit.NotSweet, fruit.NotYellow))
- )
-
- score should not be None
- score.get should be(-4.2304 +- Tolerance)
- }
-
- it should "be negative infinity for a point with a non-existing feature" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- val score = model.logScore(LabeledPoint(
- fruit.Banana,
- Array(fruit.Long, fruit.NotSweet, "Not Exist"))
- )
-
- score should not be None
- score.get should be(Double.NegativeInfinity)
- }
-
- it should "be none for a point with a non-existing label" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- val score = model.logScore(LabeledPoint(
- "Not Exist",
- Array(fruit.Long, fruit.NotSweet, fruit.Yellow))
- )
-
- score should be(None)
- }
-
- it should "use the provided default likelihood function" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- val score = model.logScore(
- LabeledPoint(
- fruit.Banana,
- Array(fruit.Long, fruit.NotSweet, "Not Exist")
- ),
- ls => ls.min - math.log(2)
- )
-
- score should not be None
- score.get should be(-4.9236 +- Tolerance)
- }
-
- "Model predict" should "return the correct label" in {
- val labeledPointsRdd = sc.parallelize(labeledPoints)
- val model = CategoricalNaiveBayes.train(labeledPointsRdd)
-
- val label = model.predict(Array(fruit.Long, fruit.Sweet, fruit.Yellow))
- label should be(fruit.Banana)
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala b/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala
deleted file mode 100644
index a33a30a..0000000
--- a/e2/src/test/scala/io/prediction/e2/engine/MarkovChainTest.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-package io.prediction.e2.engine
-
-import io.prediction.e2.fixture.{MarkovChainFixture, SharedSparkContext}
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix
-import org.scalatest.{FlatSpec, Matchers}
-
-import scala.language.reflectiveCalls
-
-class MarkovChainTest extends FlatSpec with Matchers with SharedSparkContext
-with MarkovChainFixture {
-
- "Markov chain training" should "produce a model" in {
- val matrix =
- new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
- val model = MarkovChain.train(matrix, 2)
-
- model.n should be(2)
- model.transitionVectors.collect() should contain theSameElementsAs Seq(
- (0, Vectors.sparse(2, Array(0, 1), Array(0.3, 0.7))),
- (1, Vectors.sparse(2, Array(0, 1), Array(0.5, 0.5)))
- )
- }
-
- it should "contains probabilities of the top N only" in {
- val matrix =
- new CoordinateMatrix(sc.parallelize(fiveByFiveMatrix.matrixEntries))
- val model = MarkovChain.train(matrix, 2)
-
- model.n should be(2)
- (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4)))
- model.transitionVectors.collect() should contain theSameElementsAs Seq(
- (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))),
- (1, Vectors.sparse(5, Array(2, 4), Array(9.0 / 25, 8.0 / 25))),
- (2, Vectors.sparse(5, Array(1, 4), Array(10.0 / 28, 10.0 / 28))),
- (3, Vectors.sparse(5, Array(3, 4), Array(3.0 / 9, 4.0 / 9))),
- (4, Vectors.sparse(5, Array(3, 4), Array(8.0 / 25, 0.4)))
- )
- }
-
- "Model predict" should "calculate the probablities of new states" in {
- val matrix =
- new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
- val model = MarkovChain.train(matrix, 2)
- val nextState = model.predict(Seq(0.4, 0.6))
-
- nextState should contain theSameElementsInOrderAs Seq(0.42, 0.58)
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala b/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala
deleted file mode 100644
index ead51b2..0000000
--- a/e2/src/test/scala/io/prediction/e2/evaluation/CrossValidationTest.scala
+++ /dev/null
@@ -1,111 +0,0 @@
-package io.prediction.e2.evaluation
-
-import org.scalatest.{Matchers, Inspectors, FlatSpec}
-import org.apache.spark.rdd.RDD
-import io.prediction.e2.fixture.SharedSparkContext
-import io.prediction.e2.engine.LabeledPoint
-
-object CrossValidationTest {
- case class TrainingData(labeledPoints: Seq[LabeledPoint])
- case class Query(features: Array[String])
- case class ActualResult(label: String)
-
- case class EmptyEvaluationParams()
-
- def toTrainingData(labeledPoints: RDD[LabeledPoint]) = TrainingData(labeledPoints.collect().toSeq)
- def toQuery(labeledPoint: LabeledPoint) = Query(labeledPoint.features)
- def toActualResult(labeledPoint: LabeledPoint) = ActualResult(labeledPoint.label)
-
-}
-
-
-class CrossValidationTest extends FlatSpec with Matchers with Inspectors
-with SharedSparkContext{
-
-
- val Label1 = "l1"
- val Label2 = "l2"
- val Label3 = "l3"
- val Label4 = "l4"
- val Attribute1 = "a1"
- val NotAttribute1 = "na1"
- val Attribute2 = "a2"
- val NotAttribute2 = "na2"
-
- val labeledPoints = Seq(
- LabeledPoint(Label1, Array(Attribute1, Attribute2)),
- LabeledPoint(Label2, Array(NotAttribute1, Attribute2)),
- LabeledPoint(Label3, Array(Attribute1, NotAttribute2)),
- LabeledPoint(Label4, Array(NotAttribute1, NotAttribute2))
- )
-
- val dataCount = labeledPoints.size
- val evalKs = (1 to dataCount)
- val emptyParams = new CrossValidationTest.EmptyEvaluationParams()
- type Fold = (
- CrossValidationTest.TrainingData,
- CrossValidationTest.EmptyEvaluationParams,
- RDD[(CrossValidationTest.Query, CrossValidationTest.ActualResult)])
-
- def toTestTrain(dataSplit: Fold): (Seq[LabeledPoint], Seq[LabeledPoint]) = {
- val trainingData = dataSplit._1.labeledPoints
- val queryActual = dataSplit._3
- val testingData = queryActual.map { case (query, actual) =>
- LabeledPoint(actual.label, query.features)
- }
- (trainingData, testingData.collect().toSeq)
- }
-
- def splitData(k: Int, labeledPointsRDD: RDD[LabeledPoint]): Seq[Fold] = {
- CommonHelperFunctions.splitData[
- LabeledPoint,
- CrossValidationTest.TrainingData,
- CrossValidationTest.EmptyEvaluationParams,
- CrossValidationTest.Query,
- CrossValidationTest.ActualResult](
- k,
- labeledPointsRDD,
- emptyParams,
- CrossValidationTest.toTrainingData,
- CrossValidationTest.toQuery,
- CrossValidationTest.toActualResult)
- }
-
- "Fold count" should "equal evalK" in {
- val labeledPointsRDD = sc.parallelize(labeledPoints)
- val lengths = evalKs.map(k => splitData(k, labeledPointsRDD).length)
- lengths should be(evalKs)
- }
-
- "Testing data size" should "be within 1 of total / evalK" in {
- val labeledPointsRDD = sc.parallelize(labeledPoints)
- val splits = evalKs.map(k => k -> splitData(k, labeledPointsRDD))
- val diffs = splits.map { case (k, folds) =>
- folds.map(fold => fold._3.count() - dataCount / k)
- }
- forAll(diffs) {foldDiffs => foldDiffs.max should be <= 1L}
- diffs.map(folds => folds.sum) should be(evalKs.map(k => dataCount % k))
- }
-
- "Training + testing" should "equal original dataset" in {
- val labeledPointsRDD = sc.parallelize(labeledPoints)
- forAll(evalKs) {k =>
- val split = splitData(k, labeledPointsRDD)
- forAll(split) {fold =>
- val(training, testing) = toTestTrain(fold)
- (training ++ testing).toSet should be(labeledPoints.toSet)
- }
- }
- }
-
- "Training and testing" should "be disjoint" in {
- val labeledPointsRDD = sc.parallelize(labeledPoints)
- forAll(evalKs) { k =>
- val split = splitData(k, labeledPointsRDD)
- forAll(split) { fold =>
- val (training, testing) = toTestTrain(fold)
- training.toSet.intersect(testing.toSet) should be('empty)
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala
deleted file mode 100644
index 56ebbd8..0000000
--- a/e2/src/test/scala/io/prediction/e2/fixture/BinaryVectorizerFixture.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.e2.fixture
-
-import scala.collection.immutable.HashMap
-import scala.collection.immutable.HashSet
-import org.apache.spark.mllib.linalg.Vector
-
-trait BinaryVectorizerFixture {
-
- def base = {
- new {
- val maps : Seq[HashMap[String, String]] = Seq(
- HashMap("food" -> "orange", "music" -> "rock", "hobby" -> "scala"),
- HashMap("food" -> "orange", "music" -> "pop", "hobby" ->"running"),
- HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar"),
- HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar")
- )
-
- val properties = HashSet("food", "hobby")
- }
- }
-
-
- def testArrays = {
- new {
- // Test case for checking food value not listed in base.maps, and
- // property not in properties.
- val one = Array(("food", "burger"), ("music", "rock"), ("hobby", "scala"))
-
- // Test case for making sure indices are preserved.
- val twoA = Array(("food", "orange"), ("hobby", "scala"))
- val twoB = Array(("food", "banana"), ("hobby", "scala"))
- val twoC = Array(("hobby", "guitar"))
- }
- }
-
- def vecSum (vec1 : Vector, vec2 : Vector) : Array[Double] = {
- (0 until vec1.size).map(
- k => vec1(k) + vec2(k)
- ).toArray
- }
-
-}
-
-
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala
deleted file mode 100644
index e47d49e..0000000
--- a/e2/src/test/scala/io/prediction/e2/fixture/MarkovChainFixture.scala
+++ /dev/null
@@ -1,39 +0,0 @@
-package io.prediction.e2.fixture
-
-import org.apache.spark.mllib.linalg.distributed.MatrixEntry
-
-trait MarkovChainFixture {
- def twoByTwoMatrix = {
- new {
- val matrixEntries = Seq(
- MatrixEntry(0, 0, 3),
- MatrixEntry(0, 1, 7),
- MatrixEntry(1, 0, 10),
- MatrixEntry(1, 1, 10)
- )
- }
- }
-
- def fiveByFiveMatrix = {
- new {
- val matrixEntries = Seq(
- MatrixEntry(0, 1, 12),
- MatrixEntry(0, 2, 8),
- MatrixEntry(1, 0, 3),
- MatrixEntry(1, 1, 3),
- MatrixEntry(1, 2, 9),
- MatrixEntry(1, 3, 2),
- MatrixEntry(1, 4, 8),
- MatrixEntry(2, 1, 10),
- MatrixEntry(2, 2, 8),
- MatrixEntry(2, 4, 10),
- MatrixEntry(3, 0, 2),
- MatrixEntry(3, 3, 3),
- MatrixEntry(3, 4, 4),
- MatrixEntry(4, 1, 7),
- MatrixEntry(4, 3, 8),
- MatrixEntry(4, 4, 10)
- )
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala b/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala
deleted file mode 100644
index 97dd663..0000000
--- a/e2/src/test/scala/io/prediction/e2/fixture/NaiveBayesFixture.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.prediction.e2.fixture
-
-import io.prediction.e2.engine.LabeledPoint
-
-trait NaiveBayesFixture {
-
- def fruit = {
- new {
- val Banana = "Banana"
- val Orange = "Orange"
- val OtherFruit = "Other Fruit"
- val NotLong = "Not Long"
- val Long = "Long"
- val NotSweet = "Not Sweet"
- val Sweet = "Sweet"
- val NotYellow = "Not Yellow"
- val Yellow = "Yellow"
-
- val labeledPoints = Seq(
- LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
- LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
- LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
- LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
- LabeledPoint(Banana, Array(NotLong, NotSweet, NotYellow)),
- LabeledPoint(Orange, Array(NotLong, Sweet, NotYellow)),
- LabeledPoint(Orange, Array(NotLong, NotSweet, NotYellow)),
- LabeledPoint(OtherFruit, Array(Long, Sweet, NotYellow)),
- LabeledPoint(OtherFruit, Array(NotLong, Sweet, NotYellow)),
- LabeledPoint(OtherFruit, Array(Long, Sweet, Yellow)),
- LabeledPoint(OtherFruit, Array(NotLong, NotSweet, NotYellow))
- )
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala b/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala
deleted file mode 100644
index 74dd814..0000000
--- a/e2/src/test/scala/io/prediction/e2/fixture/SharedSparkContext.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.prediction.e2.fixture
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.scalatest.{BeforeAndAfterAll, Suite}
-
-trait SharedSparkContext extends BeforeAndAfterAll {
- self: Suite =>
- @transient private var _sc: SparkContext = _
-
- def sc: SparkContext = _sc
-
- var conf = new SparkConf(false)
-
- override def beforeAll() {
- _sc = new SparkContext("local", "test", conf)
- super.beforeAll()
- }
-
- override def afterAll() {
- LocalSparkContext.stop(_sc)
-
- _sc = null
- super.afterAll()
- }
-}
-
-object LocalSparkContext {
- def stop(sc: SparkContext) {
- if (sc != null) {
- sc.stop()
- }
- // To avoid Akka rebinding to the same port, since it doesn't unbind
- // immediately on shutdown
- System.clearProperty("spark.driver.port")
- }
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala
new file mode 100644
index 0000000..576b8c6
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/BinaryVectorizerTest.scala
@@ -0,0 +1,56 @@
+/** Copyright 2015 TappingStone, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.predictionio.e2.engine
+
+import org.apache.predictionio.e2.fixture.BinaryVectorizerFixture
+import org.apache.predictionio.e2.fixture.SharedSparkContext
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.rdd.RDD
+import org.scalatest.FlatSpec
+import org.scalatest.Matchers
+import scala.collection.immutable.HashMap
+
+
+import scala.language.reflectiveCalls
+
+class BinaryVectorizerTest extends FlatSpec with Matchers with SharedSparkContext
+with BinaryVectorizerFixture{
+
+ "toBinary" should "produce the following summed values:" in {
+ val testCase = BinaryVectorizer(sc.parallelize(base.maps), base.properties)
+ val vectorTwoA = testCase.toBinary(testArrays.twoA)
+ val vectorTwoB = testCase.toBinary(testArrays.twoB)
+
+
+ // Make sure vectors produced are the same size.
+ vectorTwoA.size should be (vectorTwoB.size)
+
+ // // Test case for checking food value not listed in base.maps.
+ testCase.toBinary(testArrays.one).toArray.sum should be (1.0)
+
+ // Test cases for making sure indices are preserved.
+ val sumOne = vecSum(vectorTwoA, vectorTwoB)
+
+ exactly (1, sumOne) should be (2.0)
+ exactly (2,sumOne) should be (0.0)
+ exactly (2, sumOne) should be (1.0)
+
+ val sumTwo = vecSum(Vectors.dense(sumOne), testCase.toBinary(testArrays.twoC))
+
+ exactly (3, sumTwo) should be (1.0)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala
new file mode 100644
index 0000000..4373d7d
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/CategoricalNaiveBayesTest.scala
@@ -0,0 +1,132 @@
+/** Copyright 2015 TappingStone, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.predictionio.e2.engine
+
+import org.apache.predictionio.e2.fixture.{NaiveBayesFixture, SharedSparkContext}
+import org.scalatest.{Matchers, FlatSpec}
+
+import scala.language.reflectiveCalls
+
+class CategoricalNaiveBayesTest extends FlatSpec with Matchers
+with SharedSparkContext with NaiveBayesFixture {
+ val Tolerance = .0001
+ val labeledPoints = fruit.labeledPoints
+
+ "Model" should "have log priors and log likelihoods" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ model.priors(fruit.Banana) should be(-.7885 +- Tolerance)
+ model.priors(fruit.Orange) should be(-1.7047 +- Tolerance)
+ model.priors(fruit.OtherFruit) should be(-1.0116 +- Tolerance)
+
+ model.likelihoods(fruit.Banana)(0)(fruit.Long) should
+ be(-.2231 +- Tolerance)
+ model.likelihoods(fruit.Banana)(0)(fruit.NotLong) should
+ be(-1.6094 +- Tolerance)
+ model.likelihoods(fruit.Banana)(1)(fruit.Sweet) should
+ be(-.2231 +- Tolerance)
+ model.likelihoods(fruit.Banana)(1)(fruit.NotSweet) should
+ be(-1.6094 +- Tolerance)
+ model.likelihoods(fruit.Banana)(2)(fruit.Yellow) should
+ be(-.2231 +- Tolerance)
+ model.likelihoods(fruit.Banana)(2)(fruit.NotYellow) should
+ be(-1.6094 +- Tolerance)
+
+ model.likelihoods(fruit.Orange)(0) should not contain key(fruit.Long)
+ model.likelihoods(fruit.Orange)(0)(fruit.NotLong) should be(0.0)
+ model.likelihoods(fruit.Orange)(1)(fruit.Sweet) should
+ be(-.6931 +- Tolerance)
+ model.likelihoods(fruit.Orange)(1)(fruit.NotSweet) should
+ be(-.6931 +- Tolerance)
+ model.likelihoods(fruit.Orange)(2)(fruit.NotYellow) should be(0.0)
+ model.likelihoods(fruit.Orange)(2) should not contain key(fruit.Yellow)
+
+ model.likelihoods(fruit.OtherFruit)(0)(fruit.Long) should
+ be(-.6931 +- Tolerance)
+ model.likelihoods(fruit.OtherFruit)(0)(fruit.NotLong) should
+ be(-.6931 +- Tolerance)
+ model.likelihoods(fruit.OtherFruit)(1)(fruit.Sweet) should
+ be(-.2877 +- Tolerance)
+ model.likelihoods(fruit.OtherFruit)(1)(fruit.NotSweet) should
+ be(-1.3863 +- Tolerance)
+ model.likelihoods(fruit.OtherFruit)(2)(fruit.Yellow) should
+ be(-1.3863 +- Tolerance)
+ model.likelihoods(fruit.OtherFruit)(2)(fruit.NotYellow) should
+ be(-.2877 +- Tolerance)
+ }
+
+ "Model's log score" should "be the log score of the given point" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ val score = model.logScore(LabeledPoint(
+ fruit.Banana,
+ Array(fruit.Long, fruit.NotSweet, fruit.NotYellow))
+ )
+
+ score should not be None
+ score.get should be(-4.2304 +- Tolerance)
+ }
+
+ it should "be negative infinity for a point with a non-existing feature" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ val score = model.logScore(LabeledPoint(
+ fruit.Banana,
+ Array(fruit.Long, fruit.NotSweet, "Not Exist"))
+ )
+
+ score should not be None
+ score.get should be(Double.NegativeInfinity)
+ }
+
+ it should "be none for a point with a non-existing label" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ val score = model.logScore(LabeledPoint(
+ "Not Exist",
+ Array(fruit.Long, fruit.NotSweet, fruit.Yellow))
+ )
+
+ score should be(None)
+ }
+
+ it should "use the provided default likelihood function" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ val score = model.logScore(
+ LabeledPoint(
+ fruit.Banana,
+ Array(fruit.Long, fruit.NotSweet, "Not Exist")
+ ),
+ ls => ls.min - math.log(2)
+ )
+
+ score should not be None
+ score.get should be(-4.9236 +- Tolerance)
+ }
+
+ "Model predict" should "return the correct label" in {
+ val labeledPointsRdd = sc.parallelize(labeledPoints)
+ val model = CategoricalNaiveBayes.train(labeledPointsRdd)
+
+ val label = model.predict(Array(fruit.Long, fruit.Sweet, fruit.Yellow))
+ label should be(fruit.Banana)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala
new file mode 100644
index 0000000..137095a
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/engine/MarkovChainTest.scala
@@ -0,0 +1,49 @@
+package org.apache.predictionio.e2.engine
+
+import org.apache.predictionio.e2.fixture.{MarkovChainFixture, SharedSparkContext}
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix
+import org.scalatest.{FlatSpec, Matchers}
+
+import scala.language.reflectiveCalls
+
+class MarkovChainTest extends FlatSpec with Matchers with SharedSparkContext
+with MarkovChainFixture {
+
+ "Markov chain training" should "produce a model" in {
+ val matrix =
+ new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
+ val model = MarkovChain.train(matrix, 2)
+
+ model.n should be(2)
+ model.transitionVectors.collect() should contain theSameElementsAs Seq(
+ (0, Vectors.sparse(2, Array(0, 1), Array(0.3, 0.7))),
+ (1, Vectors.sparse(2, Array(0, 1), Array(0.5, 0.5)))
+ )
+ }
+
+ it should "contains probabilities of the top N only" in {
+ val matrix =
+ new CoordinateMatrix(sc.parallelize(fiveByFiveMatrix.matrixEntries))
+ val model = MarkovChain.train(matrix, 2)
+
+ model.n should be(2)
+ (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4)))
+ model.transitionVectors.collect() should contain theSameElementsAs Seq(
+ (0, Vectors.sparse(5, Array(1, 2), Array(.6, .4))),
+ (1, Vectors.sparse(5, Array(2, 4), Array(9.0 / 25, 8.0 / 25))),
+ (2, Vectors.sparse(5, Array(1, 4), Array(10.0 / 28, 10.0 / 28))),
+ (3, Vectors.sparse(5, Array(3, 4), Array(3.0 / 9, 4.0 / 9))),
+ (4, Vectors.sparse(5, Array(3, 4), Array(8.0 / 25, 0.4)))
+ )
+ }
+
+ "Model predict" should "calculate the probablities of new states" in {
+ val matrix =
+ new CoordinateMatrix(sc.parallelize(twoByTwoMatrix.matrixEntries))
+ val model = MarkovChain.train(matrix, 2)
+ val nextState = model.predict(Seq(0.4, 0.6))
+
+ nextState should contain theSameElementsInOrderAs Seq(0.42, 0.58)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala b/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala
new file mode 100644
index 0000000..d15b927
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/evaluation/CrossValidationTest.scala
@@ -0,0 +1,111 @@
+package org.apache.predictionio.e2.evaluation
+
+import org.scalatest.{Matchers, Inspectors, FlatSpec}
+import org.apache.spark.rdd.RDD
+import org.apache.predictionio.e2.fixture.SharedSparkContext
+import org.apache.predictionio.e2.engine.LabeledPoint
+
+object CrossValidationTest {
+ case class TrainingData(labeledPoints: Seq[LabeledPoint])
+ case class Query(features: Array[String])
+ case class ActualResult(label: String)
+
+ case class EmptyEvaluationParams()
+
+ def toTrainingData(labeledPoints: RDD[LabeledPoint]) = TrainingData(labeledPoints.collect().toSeq)
+ def toQuery(labeledPoint: LabeledPoint) = Query(labeledPoint.features)
+ def toActualResult(labeledPoint: LabeledPoint) = ActualResult(labeledPoint.label)
+
+}
+
+
+class CrossValidationTest extends FlatSpec with Matchers with Inspectors
+with SharedSparkContext{
+
+
+ val Label1 = "l1"
+ val Label2 = "l2"
+ val Label3 = "l3"
+ val Label4 = "l4"
+ val Attribute1 = "a1"
+ val NotAttribute1 = "na1"
+ val Attribute2 = "a2"
+ val NotAttribute2 = "na2"
+
+ val labeledPoints = Seq(
+ LabeledPoint(Label1, Array(Attribute1, Attribute2)),
+ LabeledPoint(Label2, Array(NotAttribute1, Attribute2)),
+ LabeledPoint(Label3, Array(Attribute1, NotAttribute2)),
+ LabeledPoint(Label4, Array(NotAttribute1, NotAttribute2))
+ )
+
+ val dataCount = labeledPoints.size
+ val evalKs = (1 to dataCount)
+ val emptyParams = new CrossValidationTest.EmptyEvaluationParams()
+ type Fold = (
+ CrossValidationTest.TrainingData,
+ CrossValidationTest.EmptyEvaluationParams,
+ RDD[(CrossValidationTest.Query, CrossValidationTest.ActualResult)])
+
+ def toTestTrain(dataSplit: Fold): (Seq[LabeledPoint], Seq[LabeledPoint]) = {
+ val trainingData = dataSplit._1.labeledPoints
+ val queryActual = dataSplit._3
+ val testingData = queryActual.map { case (query, actual) =>
+ LabeledPoint(actual.label, query.features)
+ }
+ (trainingData, testingData.collect().toSeq)
+ }
+
+ def splitData(k: Int, labeledPointsRDD: RDD[LabeledPoint]): Seq[Fold] = {
+ CommonHelperFunctions.splitData[
+ LabeledPoint,
+ CrossValidationTest.TrainingData,
+ CrossValidationTest.EmptyEvaluationParams,
+ CrossValidationTest.Query,
+ CrossValidationTest.ActualResult](
+ k,
+ labeledPointsRDD,
+ emptyParams,
+ CrossValidationTest.toTrainingData,
+ CrossValidationTest.toQuery,
+ CrossValidationTest.toActualResult)
+ }
+
+ "Fold count" should "equal evalK" in {
+ val labeledPointsRDD = sc.parallelize(labeledPoints)
+ val lengths = evalKs.map(k => splitData(k, labeledPointsRDD).length)
+ lengths should be(evalKs)
+ }
+
+ "Testing data size" should "be within 1 of total / evalK" in {
+ val labeledPointsRDD = sc.parallelize(labeledPoints)
+ val splits = evalKs.map(k => k -> splitData(k, labeledPointsRDD))
+ val diffs = splits.map { case (k, folds) =>
+ folds.map(fold => fold._3.count() - dataCount / k)
+ }
+ forAll(diffs) {foldDiffs => foldDiffs.max should be <= 1L}
+ diffs.map(folds => folds.sum) should be(evalKs.map(k => dataCount % k))
+ }
+
+ "Training + testing" should "equal original dataset" in {
+ val labeledPointsRDD = sc.parallelize(labeledPoints)
+ forAll(evalKs) {k =>
+ val split = splitData(k, labeledPointsRDD)
+ forAll(split) {fold =>
+ val(training, testing) = toTestTrain(fold)
+ (training ++ testing).toSet should be(labeledPoints.toSet)
+ }
+ }
+ }
+
+ "Training and testing" should "be disjoint" in {
+ val labeledPointsRDD = sc.parallelize(labeledPoints)
+ forAll(evalKs) { k =>
+ val split = splitData(k, labeledPointsRDD)
+ forAll(split) { fold =>
+ val (training, testing) = toTestTrain(fold)
+ training.toSet.intersect(testing.toSet) should be('empty)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala
new file mode 100644
index 0000000..76d8db3
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/BinaryVectorizerFixture.scala
@@ -0,0 +1,59 @@
+/** Copyright 2015 TappingStone, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.predictionio.e2.fixture
+
+import scala.collection.immutable.HashMap
+import scala.collection.immutable.HashSet
+import org.apache.spark.mllib.linalg.Vector
+
+trait BinaryVectorizerFixture {
+
+ def base = {
+ new {
+ val maps : Seq[HashMap[String, String]] = Seq(
+ HashMap("food" -> "orange", "music" -> "rock", "hobby" -> "scala"),
+ HashMap("food" -> "orange", "music" -> "pop", "hobby" ->"running"),
+ HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar"),
+ HashMap("food" -> "banana", "music" -> "rock", "hobby" -> "guitar")
+ )
+
+ val properties = HashSet("food", "hobby")
+ }
+ }
+
+
+ def testArrays = {
+ new {
+ // Test case for checking food value not listed in base.maps, and
+ // property not in properties.
+ val one = Array(("food", "burger"), ("music", "rock"), ("hobby", "scala"))
+
+ // Test case for making sure indices are preserved.
+ val twoA = Array(("food", "orange"), ("hobby", "scala"))
+ val twoB = Array(("food", "banana"), ("hobby", "scala"))
+ val twoC = Array(("hobby", "guitar"))
+ }
+ }
+
+ def vecSum (vec1 : Vector, vec2 : Vector) : Array[Double] = {
+ (0 until vec1.size).map(
+ k => vec1(k) + vec2(k)
+ ).toArray
+ }
+
+}
+
+
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala
new file mode 100644
index 0000000..a214be0
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/MarkovChainFixture.scala
@@ -0,0 +1,39 @@
+package org.apache.predictionio.e2.fixture
+
+import org.apache.spark.mllib.linalg.distributed.MatrixEntry
+
+trait MarkovChainFixture {
+ def twoByTwoMatrix = {
+ new {
+ val matrixEntries = Seq(
+ MatrixEntry(0, 0, 3),
+ MatrixEntry(0, 1, 7),
+ MatrixEntry(1, 0, 10),
+ MatrixEntry(1, 1, 10)
+ )
+ }
+ }
+
+ def fiveByFiveMatrix = {
+ new {
+ val matrixEntries = Seq(
+ MatrixEntry(0, 1, 12),
+ MatrixEntry(0, 2, 8),
+ MatrixEntry(1, 0, 3),
+ MatrixEntry(1, 1, 3),
+ MatrixEntry(1, 2, 9),
+ MatrixEntry(1, 3, 2),
+ MatrixEntry(1, 4, 8),
+ MatrixEntry(2, 1, 10),
+ MatrixEntry(2, 2, 8),
+ MatrixEntry(2, 4, 10),
+ MatrixEntry(3, 0, 2),
+ MatrixEntry(3, 3, 3),
+ MatrixEntry(3, 4, 4),
+ MatrixEntry(4, 1, 7),
+ MatrixEntry(4, 3, 8),
+ MatrixEntry(4, 4, 10)
+ )
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala
new file mode 100644
index 0000000..483f366
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/NaiveBayesFixture.scala
@@ -0,0 +1,48 @@
+/** Copyright 2015 TappingStone, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.predictionio.e2.fixture
+
+import org.apache.predictionio.e2.engine.LabeledPoint
+
+trait NaiveBayesFixture {
+
+ def fruit = {
+ new {
+ val Banana = "Banana"
+ val Orange = "Orange"
+ val OtherFruit = "Other Fruit"
+ val NotLong = "Not Long"
+ val Long = "Long"
+ val NotSweet = "Not Sweet"
+ val Sweet = "Sweet"
+ val NotYellow = "Not Yellow"
+ val Yellow = "Yellow"
+
+ val labeledPoints = Seq(
+ LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
+ LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
+ LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
+ LabeledPoint(Banana, Array(Long, Sweet, Yellow)),
+ LabeledPoint(Banana, Array(NotLong, NotSweet, NotYellow)),
+ LabeledPoint(Orange, Array(NotLong, Sweet, NotYellow)),
+ LabeledPoint(Orange, Array(NotLong, NotSweet, NotYellow)),
+ LabeledPoint(OtherFruit, Array(Long, Sweet, NotYellow)),
+ LabeledPoint(OtherFruit, Array(NotLong, Sweet, NotYellow)),
+ LabeledPoint(OtherFruit, Array(Long, Sweet, Yellow)),
+ LabeledPoint(OtherFruit, Array(NotLong, NotSweet, NotYellow))
+ )
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala
----------------------------------------------------------------------
diff --git a/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala b/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala
new file mode 100644
index 0000000..d0d762e
--- /dev/null
+++ b/e2/src/test/scala/org/apache/predictionio/e2/fixture/SharedSparkContext.scala
@@ -0,0 +1,51 @@
+/** Copyright 2015 TappingStone, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.predictionio.e2.fixture
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+trait SharedSparkContext extends BeforeAndAfterAll {
+ self: Suite =>
+ @transient private var _sc: SparkContext = _
+
+ def sc: SparkContext = _sc
+
+ var conf = new SparkConf(false)
+
+ override def beforeAll() {
+ _sc = new SparkContext("local", "test", conf)
+ super.beforeAll()
+ }
+
+ override def afterAll() {
+ LocalSparkContext.stop(_sc)
+
+ _sc = null
+ super.afterAll()
+ }
+}
+
+object LocalSparkContext {
+ def stop(sc: SparkContext) {
+ if (sc != null) {
+ sc.stop()
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind
+ // immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala b/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala
deleted file mode 100644
index 74324c9..0000000
--- a/tools/src/main/scala/io/prediction/tools/RegisterEngine.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools
-
-import java.io.File
-
-import grizzled.slf4j.Logging
-import io.prediction.data.storage.EngineManifest
-import io.prediction.data.storage.EngineManifestSerializer
-import io.prediction.data.storage.Storage
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-import org.json4s._
-import org.json4s.native.Serialization.read
-
-import scala.io.Source
-
-object RegisterEngine extends Logging {
- val engineManifests = Storage.getMetaDataEngineManifests
- implicit val formats = DefaultFormats + new EngineManifestSerializer
-
- def registerEngine(
- jsonManifest: File,
- engineFiles: Seq[File],
- copyLocal: Boolean = false): Unit = {
- val jsonString = try {
- Source.fromFile(jsonManifest).mkString
- } catch {
- case e: java.io.FileNotFoundException =>
- error(s"Engine manifest file not found: ${e.getMessage}. Aborting.")
- sys.exit(1)
- }
- val engineManifest = read[EngineManifest](jsonString)
-
- info(s"Registering engine ${engineManifest.id} ${engineManifest.version}")
- engineManifests.update(
- engineManifest.copy(files = engineFiles.map(_.toURI.toString)), true)
- }
-
- def unregisterEngine(jsonManifest: File): Unit = {
- val jsonString = try {
- Source.fromFile(jsonManifest).mkString
- } catch {
- case e: java.io.FileNotFoundException =>
- error(s"Engine manifest file not found: ${e.getMessage}. Aborting.")
- sys.exit(1)
- }
- val fileEngineManifest = read[EngineManifest](jsonString)
- val engineManifest = engineManifests.get(
- fileEngineManifest.id,
- fileEngineManifest.version)
-
- engineManifest map { em =>
- val conf = new Configuration
- val fs = FileSystem.get(conf)
-
- em.files foreach { f =>
- val path = new Path(f)
- info(s"Removing ${f}")
- fs.delete(path, false)
- }
-
- engineManifests.delete(em.id, em.version)
- info(s"Unregistered engine ${em.id} ${em.version}")
- } getOrElse {
- error(s"${fileEngineManifest.id} ${fileEngineManifest.version} is not " +
- "registered.")
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RunServer.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/RunServer.scala b/tools/src/main/scala/io/prediction/tools/RunServer.scala
deleted file mode 100644
index eb65e87..0000000
--- a/tools/src/main/scala/io/prediction/tools/RunServer.scala
+++ /dev/null
@@ -1,178 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools
-
-import java.io.File
-import java.net.URI
-
-import grizzled.slf4j.Logging
-import io.prediction.data.storage.EngineManifest
-import io.prediction.tools.console.ConsoleArgs
-import io.prediction.workflow.WorkflowUtils
-
-import scala.sys.process._
-
-object RunServer extends Logging {
- def runServer(
- ca: ConsoleArgs,
- core: File,
- em: EngineManifest,
- engineInstanceId: String): Int = {
- val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv =>
- s"${kv._1}=${kv._2}"
- ).mkString(",")
-
- val sparkHome = ca.common.sparkHome.getOrElse(
- sys.env.getOrElse("SPARK_HOME", "."))
-
- val extraFiles = WorkflowUtils.thirdPartyConfFiles
-
- val driverClassPathIndex =
- ca.common.sparkPassThrough.indexOf("--driver-class-path")
- val driverClassPathPrefix =
- if (driverClassPathIndex != -1) {
- Seq(ca.common.sparkPassThrough(driverClassPathIndex + 1))
- } else {
- Seq()
- }
- val extraClasspaths =
- driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths
-
- val deployModeIndex =
- ca.common.sparkPassThrough.indexOf("--deploy-mode")
- val deployMode = if (deployModeIndex != -1) {
- ca.common.sparkPassThrough(deployModeIndex + 1)
- } else {
- "client"
- }
-
- val mainJar =
- if (ca.build.uberJar) {
- if (deployMode == "cluster") {
- em.files.filter(_.startsWith("hdfs")).head
- } else {
- em.files.filterNot(_.startsWith("hdfs")).head
- }
- } else {
- if (deployMode == "cluster") {
- em.files.filter(_.contains("pio-assembly")).head
- } else {
- core.getCanonicalPath
- }
- }
-
- val jarFiles = (em.files ++ Option(new File(ca.common.pioHome.get, "plugins")
- .listFiles()).getOrElse(Array.empty[File]).map(_.getAbsolutePath)).mkString(",")
-
- val sparkSubmit =
- Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) ++
- ca.common.sparkPassThrough ++
- Seq(
- "--class",
- "io.prediction.workflow.CreateServer",
- "--name",
- s"PredictionIO Engine Instance: ${engineInstanceId}") ++
- (if (!ca.build.uberJar) {
- Seq("--jars", jarFiles)
- } else Seq()) ++
- (if (extraFiles.size > 0) {
- Seq("--files", extraFiles.mkString(","))
- } else {
- Seq()
- }) ++
- (if (extraClasspaths.size > 0) {
- Seq("--driver-class-path", extraClasspaths.mkString(":"))
- } else {
- Seq()
- }) ++
- (if (ca.common.sparkKryo) {
- Seq(
- "--conf",
- "spark.serializer=org.apache.spark.serializer.KryoSerializer")
- } else {
- Seq()
- }) ++
- Seq(
- mainJar,
- "--engineInstanceId",
- engineInstanceId,
- "--ip",
- ca.deploy.ip,
- "--port",
- ca.deploy.port.toString,
- "--event-server-ip",
- ca.eventServer.ip,
- "--event-server-port",
- ca.eventServer.port.toString) ++
- (if (ca.accessKey.accessKey != "") {
- Seq("--accesskey", ca.accessKey.accessKey)
- } else {
- Seq()
- }) ++
- (if (ca.eventServer.enabled) Seq("--feedback") else Seq()) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++
- (if (ca.common.verbose) Seq("--verbose") else Seq()) ++
- ca.deploy.logUrl.map(x => Seq("--log-url", x)).getOrElse(Seq()) ++
- ca.deploy.logPrefix.map(x => Seq("--log-prefix", x)).getOrElse(Seq()) ++
- Seq("--json-extractor", ca.common.jsonExtractor.toString)
-
- info(s"Submission command: ${sparkSubmit.mkString(" ")}")
-
- val proc =
- Process(sparkSubmit, None, "CLASSPATH" -> "", "SPARK_YARN_USER_ENV" -> pioEnvVars).run()
- Runtime.getRuntime.addShutdownHook(new Thread(new Runnable {
- def run(): Unit = {
- proc.destroy()
- }
- }))
- proc.exitValue()
- }
-
- def newRunServer(
- ca: ConsoleArgs,
- em: EngineManifest,
- engineInstanceId: String): Int = {
- val jarFiles = em.files.map(new URI(_)) ++
- Option(new File(ca.common.pioHome.get, "plugins").listFiles())
- .getOrElse(Array.empty[File]).map(_.toURI)
- val args = Seq(
- "--engineInstanceId",
- engineInstanceId,
- "--engine-variant",
- ca.common.variantJson.toURI.toString,
- "--ip",
- ca.deploy.ip,
- "--port",
- ca.deploy.port.toString,
- "--event-server-ip",
- ca.eventServer.ip,
- "--event-server-port",
- ca.eventServer.port.toString) ++
- (if (ca.accessKey.accessKey != "") {
- Seq("--accesskey", ca.accessKey.accessKey)
- } else {
- Nil
- }) ++
- (if (ca.eventServer.enabled) Seq("--feedback") else Nil) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Nil) ++
- (if (ca.common.verbose) Seq("--verbose") else Nil) ++
- ca.deploy.logUrl.map(x => Seq("--log-url", x)).getOrElse(Nil) ++
- ca.deploy.logPrefix.map(x => Seq("--log-prefix", x)).getOrElse(Nil) ++
- Seq("--json-extractor", ca.common.jsonExtractor.toString)
-
- Runner.runOnSpark("io.prediction.workflow.CreateServer", args, ca, jarFiles)
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala b/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala
deleted file mode 100644
index b18690e..0000000
--- a/tools/src/main/scala/io/prediction/tools/RunWorkflow.scala
+++ /dev/null
@@ -1,212 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools
-
-import java.io.File
-import java.net.URI
-
-import grizzled.slf4j.Logging
-import io.prediction.data.storage.EngineManifest
-import io.prediction.tools.console.ConsoleArgs
-import io.prediction.workflow.WorkflowUtils
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
-import scala.sys.process._
-
-object RunWorkflow extends Logging {
- def runWorkflow(
- ca: ConsoleArgs,
- core: File,
- em: EngineManifest,
- variantJson: File): Int = {
- // Collect and serialize PIO_* environmental variables
- val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv =>
- s"${kv._1}=${kv._2}"
- ).mkString(",")
-
- val sparkHome = ca.common.sparkHome.getOrElse(
- sys.env.getOrElse("SPARK_HOME", "."))
-
- val hadoopConf = new Configuration
- val hdfs = FileSystem.get(hadoopConf)
-
- val driverClassPathIndex =
- ca.common.sparkPassThrough.indexOf("--driver-class-path")
- val driverClassPathPrefix =
- if (driverClassPathIndex != -1) {
- Seq(ca.common.sparkPassThrough(driverClassPathIndex + 1))
- } else {
- Seq()
- }
- val extraClasspaths =
- driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths
-
- val deployModeIndex =
- ca.common.sparkPassThrough.indexOf("--deploy-mode")
- val deployMode = if (deployModeIndex != -1) {
- ca.common.sparkPassThrough(deployModeIndex + 1)
- } else {
- "client"
- }
-
- val extraFiles = WorkflowUtils.thirdPartyConfFiles
-
- val mainJar =
- if (ca.build.uberJar) {
- if (deployMode == "cluster") {
- em.files.filter(_.startsWith("hdfs")).head
- } else {
- em.files.filterNot(_.startsWith("hdfs")).head
- }
- } else {
- if (deployMode == "cluster") {
- em.files.filter(_.contains("pio-assembly")).head
- } else {
- core.getCanonicalPath
- }
- }
-
- val workMode =
- ca.common.evaluation.map(_ => "Evaluation").getOrElse("Training")
-
- val engineLocation = Seq(
- sys.env("PIO_FS_ENGINESDIR"),
- em.id,
- em.version)
-
- if (deployMode == "cluster") {
- val dstPath = new Path(engineLocation.mkString(Path.SEPARATOR))
- info("Cluster deploy mode detected. Trying to copy " +
- s"${variantJson.getCanonicalPath} to " +
- s"${hdfs.makeQualified(dstPath).toString}.")
- hdfs.copyFromLocalFile(new Path(variantJson.toURI), dstPath)
- }
-
- val sparkSubmit =
- Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator)) ++
- ca.common.sparkPassThrough ++
- Seq(
- "--class",
- "io.prediction.workflow.CreateWorkflow",
- "--name",
- s"PredictionIO $workMode: ${em.id} ${em.version} (${ca.common.batch})") ++
- (if (!ca.build.uberJar) {
- Seq("--jars", em.files.mkString(","))
- } else Seq()) ++
- (if (extraFiles.size > 0) {
- Seq("--files", extraFiles.mkString(","))
- } else {
- Seq()
- }) ++
- (if (extraClasspaths.size > 0) {
- Seq("--driver-class-path", extraClasspaths.mkString(":"))
- } else {
- Seq()
- }) ++
- (if (ca.common.sparkKryo) {
- Seq(
- "--conf",
- "spark.serializer=org.apache.spark.serializer.KryoSerializer")
- } else {
- Seq()
- }) ++
- Seq(
- mainJar,
- "--env",
- pioEnvVars,
- "--engine-id",
- em.id,
- "--engine-version",
- em.version,
- "--engine-variant",
- if (deployMode == "cluster") {
- hdfs.makeQualified(new Path(
- (engineLocation :+ variantJson.getName).mkString(Path.SEPARATOR))).
- toString
- } else {
- variantJson.getCanonicalPath
- },
- "--verbosity",
- ca.common.verbosity.toString) ++
- ca.common.engineFactory.map(
- x => Seq("--engine-factory", x)).getOrElse(Seq()) ++
- ca.common.engineParamsKey.map(
- x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++
- (if (deployMode == "cluster") Seq("--deploy-mode", "cluster") else Seq()) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++
- (if (ca.common.verbose) Seq("--verbose") else Seq()) ++
- (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++
- (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++
- (if (ca.common.stopAfterPrepare) {
- Seq("--stop-after-prepare")
- } else {
- Seq()
- }) ++
- ca.common.evaluation.map(x => Seq("--evaluation-class", x)).
- getOrElse(Seq()) ++
- // If engineParamsGenerator is specified, it overrides the evaluation.
- ca.common.engineParamsGenerator.orElse(ca.common.evaluation)
- .map(x => Seq("--engine-params-generator-class", x))
- .getOrElse(Seq()) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++
- Seq("--json-extractor", ca.common.jsonExtractor.toString)
-
- info(s"Submission command: ${sparkSubmit.mkString(" ")}")
- Process(sparkSubmit, None, "CLASSPATH" -> "", "SPARK_YARN_USER_ENV" -> pioEnvVars).!
- }
-
- def newRunWorkflow(ca: ConsoleArgs, em: EngineManifest): Int = {
- val jarFiles = em.files.map(new URI(_))
- val args = Seq(
- "--engine-id",
- em.id,
- "--engine-version",
- em.version,
- "--engine-variant",
- ca.common.variantJson.toURI.toString,
- "--verbosity",
- ca.common.verbosity.toString) ++
- ca.common.engineFactory.map(
- x => Seq("--engine-factory", x)).getOrElse(Seq()) ++
- ca.common.engineParamsKey.map(
- x => Seq("--engine-params-key", x)).getOrElse(Seq()) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++
- (if (ca.common.verbose) Seq("--verbose") else Seq()) ++
- (if (ca.common.skipSanityCheck) Seq("--skip-sanity-check") else Seq()) ++
- (if (ca.common.stopAfterRead) Seq("--stop-after-read") else Seq()) ++
- (if (ca.common.stopAfterPrepare) {
- Seq("--stop-after-prepare")
- } else {
- Seq()
- }) ++
- ca.common.evaluation.map(x => Seq("--evaluation-class", x)).
- getOrElse(Seq()) ++
- // If engineParamsGenerator is specified, it overrides the evaluation.
- ca.common.engineParamsGenerator.orElse(ca.common.evaluation)
- .map(x => Seq("--engine-params-generator-class", x))
- .getOrElse(Seq()) ++
- (if (ca.common.batch != "") Seq("--batch", ca.common.batch) else Seq()) ++
- Seq("--json-extractor", ca.common.jsonExtractor.toString)
-
- Runner.runOnSpark(
- "io.prediction.workflow.CreateWorkflow",
- args,
- ca,
- jarFiles)
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/Runner.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/Runner.scala b/tools/src/main/scala/io/prediction/tools/Runner.scala
deleted file mode 100644
index 3156660..0000000
--- a/tools/src/main/scala/io/prediction/tools/Runner.scala
+++ /dev/null
@@ -1,211 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools
-
-import java.io.File
-import java.net.URI
-
-import grizzled.slf4j.Logging
-import io.prediction.tools.console.ConsoleArgs
-import io.prediction.workflow.WorkflowUtils
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
-import scala.sys.process._
-
-object Runner extends Logging {
- def envStringToMap(env: String): Map[String, String] =
- env.split(',').flatMap(p =>
- p.split('=') match {
- case Array(k, v) => List(k -> v)
- case _ => Nil
- }
- ).toMap
-
- def argumentValue(arguments: Seq[String], argumentName: String): Option[String] = {
- val argumentIndex = arguments.indexOf(argumentName)
- try {
- arguments(argumentIndex) // just to make it error out if index is -1
- Some(arguments(argumentIndex + 1))
- } catch {
- case e: IndexOutOfBoundsException => None
- }
- }
-
- def handleScratchFile(
- fileSystem: Option[FileSystem],
- uri: Option[URI],
- localFile: File): String = {
- val localFilePath = localFile.getCanonicalPath
- (fileSystem, uri) match {
- case (Some(fs), Some(u)) =>
- val dest = fs.makeQualified(Path.mergePaths(
- new Path(u),
- new Path(localFilePath)))
- info(s"Copying $localFile to ${dest.toString}")
- fs.copyFromLocalFile(new Path(localFilePath), dest)
- dest.toUri.toString
- case _ => localFile.toURI.toString
- }
- }
-
- def cleanup(fs: Option[FileSystem], uri: Option[URI]): Unit = {
- (fs, uri) match {
- case (Some(f), Some(u)) =>
- f.close()
- case _ => Unit
- }
- }
-
- def detectFilePaths(
- fileSystem: Option[FileSystem],
- uri: Option[URI],
- args: Seq[String]): Seq[String] = {
- args map { arg =>
- val f = try {
- new File(new URI(arg))
- } catch {
- case e: Throwable => new File(arg)
- }
- if (f.exists()) {
- handleScratchFile(fileSystem, uri, f)
- } else {
- arg
- }
- }
- }
-
- def runOnSpark(
- className: String,
- classArgs: Seq[String],
- ca: ConsoleArgs,
- extraJars: Seq[URI]): Int = {
- // Return error for unsupported cases
- val deployMode =
- argumentValue(ca.common.sparkPassThrough, "--deploy-mode").getOrElse("client")
- val master =
- argumentValue(ca.common.sparkPassThrough, "--master").getOrElse("local")
-
- (ca.common.scratchUri, deployMode, master) match {
- case (Some(u), "client", m) if m != "yarn-cluster" =>
- error("--scratch-uri cannot be set when deploy mode is client")
- return 1
- case (_, "cluster", m) if m.startsWith("spark://") =>
- error("Using cluster deploy mode with Spark standalone cluster is not supported")
- return 1
- case _ => Unit
- }
-
- // Initialize HDFS API for scratch URI
- val fs = ca.common.scratchUri map { uri =>
- FileSystem.get(uri, new Configuration())
- }
-
- // Collect and serialize PIO_* environmental variables
- val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_")).map(kv =>
- s"${kv._1}=${kv._2}"
- ).mkString(",")
-
- // Location of Spark
- val sparkHome = ca.common.sparkHome.getOrElse(
- sys.env.getOrElse("SPARK_HOME", "."))
-
- // Local path to PredictionIO assembly JAR
- val mainJar = handleScratchFile(
- fs,
- ca.common.scratchUri,
- console.Console.coreAssembly(ca.common.pioHome.get))
-
- // Extra JARs that are needed by the driver
- val driverClassPathPrefix =
- argumentValue(ca.common.sparkPassThrough, "--driver-class-path") map { v =>
- Seq(v)
- } getOrElse {
- Nil
- }
-
- val extraClasspaths =
- driverClassPathPrefix ++ WorkflowUtils.thirdPartyClasspaths
-
- // Extra files that are needed to be passed to --files
- val extraFiles = WorkflowUtils.thirdPartyConfFiles map { f =>
- handleScratchFile(fs, ca.common.scratchUri, new File(f))
- }
-
- val deployedJars = extraJars map { j =>
- handleScratchFile(fs, ca.common.scratchUri, new File(j))
- }
-
- val sparkSubmitCommand =
- Seq(Seq(sparkHome, "bin", "spark-submit").mkString(File.separator))
-
- val sparkSubmitJars = if (extraJars.nonEmpty) {
- Seq("--jars", deployedJars.map(_.toString).mkString(","))
- } else {
- Nil
- }
-
- val sparkSubmitFiles = if (extraFiles.nonEmpty) {
- Seq("--files", extraFiles.mkString(","))
- } else {
- Nil
- }
-
- val sparkSubmitExtraClasspaths = if (extraClasspaths.nonEmpty) {
- Seq("--driver-class-path", extraClasspaths.mkString(":"))
- } else {
- Nil
- }
-
- val sparkSubmitKryo = if (ca.common.sparkKryo) {
- Seq(
- "--conf",
- "spark.serializer=org.apache.spark.serializer.KryoSerializer")
- } else {
- Nil
- }
-
- val verbose = if (ca.common.verbose) Seq("--verbose") else Nil
-
- val sparkSubmit = Seq(
- sparkSubmitCommand,
- ca.common.sparkPassThrough,
- Seq("--class", className),
- sparkSubmitJars,
- sparkSubmitFiles,
- sparkSubmitExtraClasspaths,
- sparkSubmitKryo,
- Seq(mainJar),
- detectFilePaths(fs, ca.common.scratchUri, classArgs),
- Seq("--env", pioEnvVars),
- verbose).flatten.filter(_ != "")
- info(s"Submission command: ${sparkSubmit.mkString(" ")}")
- val proc = Process(
- sparkSubmit,
- None,
- "CLASSPATH" -> "",
- "SPARK_YARN_USER_ENV" -> pioEnvVars).run()
- Runtime.getRuntime.addShutdownHook(new Thread(new Runnable {
- def run(): Unit = {
- cleanup(fs, ca.common.scratchUri)
- proc.destroy()
- }
- }))
- cleanup(fs, ca.common.scratchUri)
- proc.exitValue()
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala b/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala
deleted file mode 100644
index c5ec913..0000000
--- a/tools/src/main/scala/io/prediction/tools/admin/AdminAPI.scala
+++ /dev/null
@@ -1,156 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools.admin
-
-import akka.actor.{Actor, ActorSystem, Props}
-import akka.event.Logging
-import akka.io.IO
-import akka.util.Timeout
-import io.prediction.data.api.StartServer
-import io.prediction.data.storage.Storage
-import org.json4s.{Formats, DefaultFormats}
-
-import java.util.concurrent.TimeUnit
-
-import spray.can.Http
-import spray.http.{MediaTypes, StatusCodes}
-import spray.httpx.Json4sSupport
-import spray.routing._
-
-import scala.concurrent.ExecutionContext
-
-class AdminServiceActor(val commandClient: CommandClient)
- extends HttpServiceActor {
-
- object Json4sProtocol extends Json4sSupport {
- implicit def json4sFormats: Formats = DefaultFormats
- }
-
- import Json4sProtocol._
-
- val log = Logging(context.system, this)
-
- // we use the enclosing ActorContext's or ActorSystem's dispatcher for our
- // Futures
- implicit def executionContext: ExecutionContext = actorRefFactory.dispatcher
- implicit val timeout: Timeout = Timeout(5, TimeUnit.SECONDS)
-
- // for better message response
- val rejectionHandler = RejectionHandler {
- case MalformedRequestContentRejection(msg, _) :: _ =>
- complete(StatusCodes.BadRequest, Map("message" -> msg))
- case MissingQueryParamRejection(msg) :: _ =>
- complete(StatusCodes.NotFound,
- Map("message" -> s"missing required query parameter ${msg}."))
- case AuthenticationFailedRejection(cause, challengeHeaders) :: _ =>
- complete(StatusCodes.Unauthorized, challengeHeaders,
- Map("message" -> s"Invalid accessKey."))
- }
-
- val jsonPath = """(.+)\.json$""".r
-
- val route: Route =
- pathSingleSlash {
- get {
- respondWithMediaType(MediaTypes.`application/json`) {
- complete(Map("status" -> "alive"))
- }
- }
- } ~
- path("cmd" / "app" / Segment / "data") {
- appName => {
- delete {
- respondWithMediaType(MediaTypes.`application/json`) {
- complete(commandClient.futureAppDataDelete(appName))
- }
- }
- }
- } ~
- path("cmd" / "app" / Segment) {
- appName => {
- delete {
- respondWithMediaType(MediaTypes.`application/json`) {
- complete(commandClient.futureAppDelete(appName))
- }
- }
- }
- } ~
- path("cmd" / "app") {
- get {
- respondWithMediaType(MediaTypes.`application/json`) {
- complete(commandClient.futureAppList())
- }
- } ~
- post {
- entity(as[AppRequest]) {
- appArgs => respondWithMediaType(MediaTypes.`application/json`) {
- complete(commandClient.futureAppNew(appArgs))
- }
- }
- }
- }
- def receive: Actor.Receive = runRoute(route)
-}
-
-class AdminServerActor(val commandClient: CommandClient) extends Actor {
- val log = Logging(context.system, this)
- val child = context.actorOf(
- Props(classOf[AdminServiceActor], commandClient),
- "AdminServiceActor")
-
- implicit val system = context.system
-
- def receive: PartialFunction[Any, Unit] = {
- case StartServer(host, portNum) => {
- IO(Http) ! Http.Bind(child, interface = host, port = portNum)
-
- }
- case m: Http.Bound => log.info("Bound received. AdminServer is ready.")
- case m: Http.CommandFailed => log.error("Command failed.")
- case _ => log.error("Unknown message.")
- }
-}
-
-case class AdminServerConfig(
- ip: String = "localhost",
- port: Int = 7071
-)
-
-object AdminServer {
- def createAdminServer(config: AdminServerConfig): Unit = {
- implicit val system = ActorSystem("AdminServerSystem")
-
- val commandClient = new CommandClient(
- appClient = Storage.getMetaDataApps,
- accessKeyClient = Storage.getMetaDataAccessKeys,
- eventClient = Storage.getLEvents()
- )
-
- val serverActor = system.actorOf(
- Props(classOf[AdminServerActor], commandClient),
- "AdminServerActor")
- serverActor ! StartServer(config.ip, config.port)
- system.awaitTermination
- }
-}
-
-object AdminRun {
- def main (args: Array[String]) {
- AdminServer.createAdminServer(AdminServerConfig(
- ip = "localhost",
- port = 7071))
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala b/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala
deleted file mode 100644
index 924b6f0..0000000
--- a/tools/src/main/scala/io/prediction/tools/admin/CommandClient.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools.admin
-
-import io.prediction.data.storage._
-
-import scala.concurrent.{ExecutionContext, Future}
-
-abstract class BaseResponse()
-
-case class GeneralResponse(
- status: Int = 0,
- message: String = ""
-) extends BaseResponse()
-
-case class AppRequest(
- id: Int = 0,
- name: String = "",
- description: String = ""
-)
-
-case class TrainRequest(
- enginePath: String = ""
-)
-case class AppResponse(
- id: Int = 0,
- name: String = "",
- keys: Seq[AccessKey]
-) extends BaseResponse()
-
-case class AppNewResponse(
- status: Int = 0,
- message: String = "",
- id: Int = 0,
- name: String = "",
- key: String
-) extends BaseResponse()
-
-case class AppListResponse(
- status: Int = 0,
- message: String = "",
- apps: Seq[AppResponse]
-) extends BaseResponse()
-
-class CommandClient(
- val appClient: Apps,
- val accessKeyClient: AccessKeys,
- val eventClient: LEvents
-) {
-
- def futureAppNew(req: AppRequest)(implicit ec: ExecutionContext): Future[BaseResponse] = Future {
- val response = appClient.getByName(req.name) map { app =>
- GeneralResponse(0, s"App ${req.name} already exists. Aborting.")
- } getOrElse {
- appClient.get(req.id) map {
- app2 =>
- GeneralResponse(0,
- s"App ID ${app2.id} already exists and maps to the app '${app2.name}'. " +
- "Aborting.")
- } getOrElse {
- val appid = appClient.insert(App(
- id = Option(req.id).getOrElse(0),
- name = req.name,
- description = Option(req.description)))
- appid map { id =>
- val dbInit = eventClient.init(id)
- val r = if (dbInit) {
- val accessKey = AccessKey(
- key = "",
- appid = id,
- events = Seq())
- val accessKey2 = accessKeyClient.insert(AccessKey(
- key = "",
- appid = id,
- events = Seq()))
- accessKey2 map { k =>
- new AppNewResponse(1,"App created successfully.",id, req.name, k)
- } getOrElse {
- GeneralResponse(0, s"Unable to create new access key.")
- }
- } else {
- GeneralResponse(0, s"Unable to initialize Event Store for this app ID: ${id}.")
- }
- r
- } getOrElse {
- GeneralResponse(0, s"Unable to create new app.")
- }
- }
- }
- response
- }
-
- def futureAppList()(implicit ec: ExecutionContext): Future[AppListResponse] = Future {
- val apps = appClient.getAll().sortBy(_.name)
- val appsRes = apps.map {
- app => {
- new AppResponse(app.id, app.name, accessKeyClient.getByAppid(app.id))
- }
- }
- new AppListResponse(1, "Successful retrieved app list.", appsRes)
- }
-
- def futureAppDataDelete(appName: String)
- (implicit ec: ExecutionContext): Future[GeneralResponse] = Future {
- val response = appClient.getByName(appName) map { app =>
- val data = if (eventClient.remove(app.id)) {
- GeneralResponse(1, s"Removed Event Store for this app ID: ${app.id}")
- } else {
- GeneralResponse(0, s"Error removing Event Store for this app.")
- }
-
- val dbInit = eventClient.init(app.id)
- val data2 = if (dbInit) {
- GeneralResponse(1, s"Initialized Event Store for this app ID: ${app.id}.")
- } else {
- GeneralResponse(0, s"Unable to initialize Event Store for this appId:" +
- s" ${app.id}.")
- }
- GeneralResponse(data.status * data2.status, data.message + data2.message)
- } getOrElse {
- GeneralResponse(0, s"App ${appName} does not exist.")
- }
- response
- }
-
- def futureAppDelete(appName: String)
- (implicit ec: ExecutionContext): Future[GeneralResponse] = Future {
-
- val response = appClient.getByName(appName) map { app =>
- val data = if (eventClient.remove(app.id)) {
- Storage.getMetaDataApps.delete(app.id)
- GeneralResponse(1, s"App successfully deleted")
- } else {
- GeneralResponse(0, s"Error removing Event Store for app ${app.name}.");
- }
- data
- } getOrElse {
- GeneralResponse(0, s"App ${appName} does not exist.")
- }
- response
- }
-
- def futureTrain(req: TrainRequest)
- (implicit ec: ExecutionContext): Future[GeneralResponse] = Future {
- null
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/admin/README.md
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/admin/README.md b/tools/src/main/scala/io/prediction/tools/admin/README.md
deleted file mode 100644
index 475a3de..0000000
--- a/tools/src/main/scala/io/prediction/tools/admin/README.md
+++ /dev/null
@@ -1,161 +0,0 @@
-## Admin API (under development)
-
-### Start Admin HTTP Server without bin/pio (for development)
-
-NOTE: elasticsearch and hbase should be running first.
-
-```
-$ sbt/sbt "tools/compile"
-$ set -a
-$ source conf/pio-env.sh
-$ set +a
-$ sbt/sbt "tools/run-main io.prediction.tools.admin.AdminRun"
-```
-
-### Unit test (Very minimal)
-
-```
-$ set -a
-$ source conf/pio-env.sh
-$ set +a
-$ sbt/sbt "tools/test-only io.prediction.tools.admin.AdminAPISpec"
-```
-
-### Start with pio command adminserver
-
-```
-$ pio adminserver
-```
-
-Admin Server url defaults to `http://localhost:7071`
-
-The host and port can be specified by using the 'ip' and 'port' parameters
-
-```
-$ pio adminserver --ip 127.0.0.1 --port 7080
-```
-
-### Current Supported Commands
-
-#### Check status
-
-```
-$ curl -i http://localhost:7071/
-
-{"status":"alive"}
-```
-
-#### Get list of apps
-
-```
-$ curl -i -X GET http://localhost:7071/cmd/app
-
-{"status":1,"message":"Successful retrieved app list.","apps":[{"id":12,"name":"scratch","keys":[{"key":"gtPgVMIr3uthus1QJWFBcIjNf6d1SNuhaOWQAgdLbOBP1eRWMNIJWl6SkHgI1OoN","appid":12,"events":[]}]},{"id":17,"name":"test-ecommercerec","keys":[{"key":"zPkr6sBwQoBwBjVHK2hsF9u26L38ARSe19QzkdYentuomCtYSuH0vXP5fq7advo4","appid":17,"events":[]}]}]}
-```
-
-#### Create a new app
-
-```
-$ curl -i -X POST http://localhost:7071/cmd/app \
--H "Content-Type: application/json" \
--d '{ "name" : "my_new_app" }'
-
-{"status":1,"message":"App created successfully.","id":19,"name":"my_new_app","keys":[{"key":"","appid":19,"events":[]}]}
-```
-
-#### Delete data of app
-
-```
-$ curl -i -X DELETE http://localhost:7071/cmd/app/my_new_app/data
-```
-
-#### Delete app
-
-```
-$ curl -i -X DELETE http://localhost:7071/cmd/app/my_new_app
-
-{"status":1,"message":"App successfully deleted"}
-```
-
-
-## API Doc (To be updated)
-
-### app list:
-GET http://localhost:7071/cmd/app
-
-OK Response:
-{
- \u201cstatus\u201d: <STATUS>,
- \u201cmessage\u201d: <MESSAGE>,
- \u201capps\u201d : [
- { \u201cname': \u201c<APP_NAME>\u201d,
- \u201cid': <APP_ID>,
- \u201caccessKey' : \u201c<ACCESS_KEY>\u201d },
- { \u201cname': \u201c<APP_NAME>\u201d,
- \u201cid': <APP_ID>,
- \u201caccessKey' : \u201c<ACCESS_KEY>\u201d }, ... ]
-}
-
-Error Response:
-{\u201cstatus\u201d: <STATUS>, \u201cmessage\u201d : \u201c<MESSAGE>\u201d}
-
-### app new
-POST http://localhost:7071/cmd/app
-Request Body:
-{ name\u201d: \u201c<APP_NAME>\u201d, // required
- \u201cid\u201d: <APP_ID>, // optional
- \u201cdescription\u201d: \u201c<DESCRIPTION>\u201d } // optional
-
-OK Response:
-{ \u201cstatus\u201d: <STATUS>,
- \u201cmessage\u201d: <MESSAGE>,
- \u201capp\u201d : {
- \u201cname\u201d: \u201c<APP_NAME>\u201d,
- \u201cid\u201d: <APP_ID>,
- \u201caccessKey\u201d : \u201c<ACCESS_KEY>\u201d }
-}
-
-Error Response:
-{ \u201cstatus\u201d: <STATUS>, \u201cmessage\u201d : \u201c<MESSAGE>\u201d}
-
-### app delete
-DELETE http://localhost:7071/cmd/app/{appName}
-
-OK Response:
-{ "status": <STATUS>, "message" : \u201c<MESSAGE>\u201d}
-
-Error Response:
-{ \u201cstatus\u201d: <STATUS>, \u201cmessage\u201d : \u201c<MESSAGE>\u201d}
-
-### app data-delete
-DELETE http://localhost:7071/cmd/app/{appName}/data
-
-OK Response:
-{ "status": <STATUS>, "message" : \u201c<MESSAGE>\u201d}
-
-Error Response:
-{ \u201cstatus\u201d: <STATUS>, \u201cmessage\u201d : \u201c<MESSAGE>\u201d }
-
-
-### train TBD
-
-#### Training request:
-POST http://localhost:7071/cmd/train
-Request body: TBD
-
-OK Response: TBD
-
-Error Response: TBD
-
-#### Get training status:
-GET http://localhost:7071/cmd/train/{engineInstanceId}
-
-OK Response: TBD
-INIT
-TRAINING
-DONE
-ERROR
-
-Error Response: TBD
-
-### deploy TBD
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala b/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala
deleted file mode 100644
index 85955e8..0000000
--- a/tools/src/main/scala/io/prediction/tools/console/AccessKey.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/** Copyright 2015 TappingStone, Inc.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.prediction.tools.console
-
-import io.prediction.data.storage
-
-import grizzled.slf4j.Logging
-
-case class AccessKeyArgs(
- accessKey: String = "",
- events: Seq[String] = Seq())
-
-object AccessKey extends Logging {
- def create(ca: ConsoleArgs): Int = {
- val apps = storage.Storage.getMetaDataApps
- apps.getByName(ca.app.name) map { app =>
- val accessKeys = storage.Storage.getMetaDataAccessKeys
- val accessKey = accessKeys.insert(storage.AccessKey(
- key = ca.accessKey.accessKey,
- appid = app.id,
- events = ca.accessKey.events))
- accessKey map { k =>
- info(s"Created new access key: ${k}")
- 0
- } getOrElse {
- error(s"Unable to create new access key.")
- 1
- }
- } getOrElse {
- error(s"App ${ca.app.name} does not exist. Aborting.")
- 1
- }
- }
-
- def list(ca: ConsoleArgs): Int = {
- val keys =
- if (ca.app.name == "") {
- storage.Storage.getMetaDataAccessKeys.getAll
- } else {
- val apps = storage.Storage.getMetaDataApps
- apps.getByName(ca.app.name) map { app =>
- storage.Storage.getMetaDataAccessKeys.getByAppid(app.id)
- } getOrElse {
- error(s"App ${ca.app.name} does not exist. Aborting.")
- return 1
- }
- }
- val title = "Access Key(s)"
- info(f"$title%64s | App ID | Allowed Event(s)")
- keys.sortBy(k => k.appid) foreach { k =>
- val events =
- if (k.events.size > 0) k.events.sorted.mkString(",") else "(all)"
- info(f"${k.key}%64s | ${k.appid}%6d | $events%s")
- }
- info(s"Finished listing ${keys.size} access key(s).")
- 0
- }
-
- def delete(ca: ConsoleArgs): Int = {
- try {
- storage.Storage.getMetaDataAccessKeys.delete(ca.accessKey.accessKey)
- info(s"Deleted access key ${ca.accessKey.accessKey}.")
- 0
- } catch {
- case e: Exception =>
- error(s"Error deleting access key ${ca.accessKey.accessKey}.", e)
- 1
- }
- }
-}