You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2014/09/22 14:29:15 UTC
[33/60] git commit: [scala] Add Field Name Key Support for Scala Case
Classes
[scala] Add Field Name Key Support for Scala Case Classes
This does not change the runtime behavious. The key field names are
mapped to tuple indices at pre-flight time.
Also extends tests to cover the new feature.
Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/83debdb3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/83debdb3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/83debdb3
Branch: refs/heads/master
Commit: 83debdb3cb5b34bd1f7f395f061bd3be876db061
Parents: 299cef7
Author: Aljoscha Krettek <al...@gmail.com>
Authored: Fri Sep 12 11:38:20 2014 +0200
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Mon Sep 22 09:59:58 2014 +0200
----------------------------------------------------------------------
.../scala/graph/EnumTrianglesBasic.scala | 7 +-
.../examples/scala/graph/EnumTrianglesOpt.scala | 383 ++++++++++---------
.../examples/scala/graph/PageRankBasic.scala | 313 +++++++--------
.../examples/scala/misc/PiEstimation.scala | 23 +-
.../org/apache/flink/api/scala/DataSet.scala | 81 +++-
.../apache/flink/api/scala/GroupedDataSet.scala | 118 +++---
.../apache/flink/api/scala/coGroupDataSet.scala | 28 +-
.../api/scala/codegen/TypeInformationGen.scala | 5 +-
.../apache/flink/api/scala/crossDataSet.scala | 2 +-
.../apache/flink/api/scala/joinDataSet.scala | 14 +-
.../org/apache/flink/api/scala/package.scala | 15 +-
.../scala/typeutils/ScalaTupleTypeInfo.scala | 15 +-
.../api/scala/unfinishedKeyPairOperation.scala | 54 ++-
.../scala/operators/AggregateOperatorTest.scala | 40 ++
.../scala/operators/CoGroupOperatorTest.scala | 65 ++++
.../scala/operators/DistinctOperatorTest.scala | 53 ++-
.../api/scala/operators/GroupingTest.scala | 51 ++-
.../api/scala/operators/JoinOperatorTest.scala | 90 ++++-
18 files changed, 900 insertions(+), 457 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesBasic.scala
----------------------------------------------------------------------
diff --git a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesBasic.scala b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesBasic.scala
index 8672b2c..0bf01b4 100644
--- a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesBasic.scala
+++ b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesBasic.scala
@@ -85,9 +85,9 @@ object EnumTrianglesBasic {
val triangles = edgesById
// build triads
- .groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new TriadBuilder())
+ .groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new TriadBuilder())
// filter triads
- .join(edgesById).where(1,2).equalTo(0,1) { (t, _) => Some(t) }
+ .join(edgesById).where("v2", "v3").equalTo("v1", "v2") { (t, _) => Some(t) }
// emit result
if (fileOutput) {
@@ -163,8 +163,7 @@ object EnumTrianglesBasic {
private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
if (fileOutput) {
- env.readCsvFile[(Int, Int)](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1)).
- map { x => new Edge(x._1, x._2) }
+ env.readCsvFile[Edge](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1))
} else {
val edges = EnumTrianglesData.EDGES.map{ case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int]) }
env.fromCollection(edges)
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesOpt.scala
----------------------------------------------------------------------
diff --git a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesOpt.scala b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesOpt.scala
index 65c8b3e..e198986 100644
--- a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesOpt.scala
+++ b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/EnumTrianglesOpt.scala
@@ -30,206 +30,215 @@ import scala.collection.mutable.MutableList
/**
* Triangle enumeration is a pre-processing step to find closely connected parts in graphs.
* A triangle consists of three edges that connect three vertices with each other.
- *
- * <p>
- * The basic algorithm works as follows:
+ *
+ * The basic algorithm works as follows:
* It groups all edges that share a common vertex and builds triads, i.e., triples of vertices
* that are connected by two edges. Finally, all triads are filtered for which no third edge exists
* that closes the triangle.
- *
- * <p>
- * For a group of <i>n</i> edges that share a common vertex, the number of built triads is quadratic <i>((n*(n-1))/2)</i>.
- * Therefore, an optimization of the algorithm is to group edges on the vertex with the smaller output degree to
- * reduce the number of triads.
+ *
+ * For a group of ''i'' edges that share a common vertex, the number of built triads is quadratic
+ * ''(n*(n-1))/2)''. Therefore, an optimization of the algorithm is to group edges on the vertex
+ * with the smaller output degree to reduce the number of triads.
* This implementation extends the basic algorithm by computing output degrees of edge vertices and
* grouping on edges on the vertex with the smaller degree.
- *
- * <p>
+ *
* Input files are plain text files and must be formatted as follows:
- * <ul>
- * <li>Edges are represented as pairs for vertex IDs which are separated by space
- * characters. Edges are separated by new-line characters.<br>
- * For example <code>"1 2\n2 12\n1 12\n42 63\n"</code> gives four (undirected) edges (1)-(2), (2)-(12), (1)-(12), and (42)-(63)
- * that include a triangle
- * </ul>
+ *
+ * - Edges are represented as pairs for vertex IDs which are separated by space
+ * characters. Edges are separated by new-line characters.
+ * For example `"1 2\n2 12\n1 12\n42 63\n"` gives four (undirected) edges (1)-(2), (2)-(12),
+ * (1)-(12), and (42)-(63) that include a triangle
+ *
* <pre>
* (1)
* / \
* (2)-(12)
* </pre>
- *
- * Usage: <code>EnumTriangleOpt <edge path> <result path></code><br>
- * If no parameters are provided, the program is run with default data from {@link EnumTrianglesData}.
- *
- * <p>
+ *
+ * Usage:
+ * {{{
+ * EnumTriangleOpt <edge path> <result path>
+ * }}}
+ *
+ * If no parameters are provided, the program is run with default data from
+ * [[org.apache.flink.example.java.graph.util.EnumTrianglesData]].
+ *
* This example shows how to use:
- * <ul>
- * <li>Custom Java objects which extend Tuple
- * <li>Group Sorting
- * </ul>
- *
+ *
+ * - Custom Java objects which extend Tuple
+ * - Group Sorting
+ *
*/
object EnumTrianglesOpt {
-
- def main(args: Array[String]) {
- if (!parseParameters(args)) {
- return
- }
-
- // set up execution environment
- val env = ExecutionEnvironment.getExecutionEnvironment
-
- // read input data
- val edges = getEdgeDataSet(env)
-
- val edgesWithDegrees = edges
- // duplicate and switch edges
- .flatMap( e => Array(e, Edge(e.v2, e.v1)) )
- // add degree of first vertex
- .groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new DegreeCounter())
- // join degrees of vertices
- .groupBy(0,2).reduce( (e1, e2) => if(e1.d2 == 0)
- new EdgeWithDegrees(e1.v1, e1.d1, e1.v2, e2.d2)
- else
- new EdgeWithDegrees(e1.v1, e2.d1, e1.v2, e1.d2)
- )
-
- // project edges by degrees, vertex with smaller degree comes first
- val edgesByDegree = edgesWithDegrees
- .map(e => if (e.d1 < e.d2) Edge(e.v1, e.v2) else Edge(e.v2, e.v1) )
- // project edges by Id, vertex with smaller Id comes first
- val edgesById = edgesByDegree
- .map(e => if (e.v1 < e.v2) e else Edge(e.v2, e.v1) )
-
- val triangles = edgesByDegree
- // build triads
- .groupBy(0).sortGroup(1, Order.ASCENDING).reduceGroup(new TriadBuilder())
- // filter triads
- .join(edgesById).where(1,2).equalTo(0,1) { (t, _) => Some(t) }
-
- // emit result
- if (fileOutput) {
- triangles.writeAsCsv(outputPath, "\n", " ")
- } else {
- triangles.print()
- }
-
- // execute program
- env.execute("TriangleEnumeration Example")
- }
-
- // *************************************************************************
- // USER DATA TYPES
- // *************************************************************************
-
- case class Edge(v1: Int, v2: Int) extends Serializable
- case class Triad(v1: Int, v2: Int, v3: Int) extends Serializable
- case class EdgeWithDegrees(v1: Int, d1: Int, v2: Int, d2: Int) extends Serializable
-
-
- // *************************************************************************
- // USER FUNCTIONS
- // *************************************************************************
-
- /**
- * Counts the number of edges that share a common vertex.
- * Emits one edge for each input edge with a degree annotation for the shared vertex.
- * For each emitted edge, the first vertex is the vertex with the smaller id.
- */
- class DegreeCounter extends GroupReduceFunction[Edge, EdgeWithDegrees] {
-
- val vertices = MutableList[Integer]()
- var groupVertex = 0
-
- override def reduce(edges: java.lang.Iterable[Edge], out: Collector[EdgeWithDegrees]) = {
-
- // empty vertex list
- vertices.clear
-
- // collect all vertices
- for(e <- edges.asScala) {
- groupVertex = e.v1
- if(!vertices.contains(e.v2) && e.v1 != e.v2) {
- vertices += e.v2
- }
- }
-
- // count vertices to obtain degree of groupVertex
- val degree = vertices.length
-
- // create and emit edges with degrees
- for(v <- vertices) {
- if (v < groupVertex) {
- out.collect(new EdgeWithDegrees(v, 0, groupVertex, degree))
- } else {
- out.collect(new EdgeWithDegrees(groupVertex, degree, v, 0))
- }
- }
- }
- }
-
- /**
- * Builds triads (triples of vertices) from pairs of edges that share a vertex.
- * The first vertex of a triad is the shared vertex, the second and third vertex are ordered by vertexId.
- * Assumes that input edges share the first vertex and are in ascending order of the second vertex.
- */
- class TriadBuilder extends GroupReduceFunction[Edge, Triad] {
-
- val vertices = MutableList[Integer]()
-
- override def reduce(edges: java.lang.Iterable[Edge], out: Collector[Triad]) = {
-
- // clear vertex list
- vertices.clear
-
- // build and emit triads
- for(e <- edges.asScala) {
-
- // combine vertex with all previously read vertices
- for(v <- vertices) {
- out.collect(Triad(e.v1, v, e.v2))
- }
- vertices += e.v2
- }
- }
- }
-
- // *************************************************************************
- // UTIL METHODS
- // *************************************************************************
-
- private def parseParameters(args: Array[String]): Boolean = {
- if (args.length > 0) {
- fileOutput = true
- if (args.length == 2) {
- edgePath = args(0)
- outputPath = args(1)
- } else {
- System.err.println("Usage: EnumTriangleBasic <edge path> <result path>")
- false
- }
- } else {
- System.out.println("Executing Enum Triangles Basic example with built-in default data.");
- System.out.println(" Provide parameters to read input data from files.");
- System.out.println(" See the documentation for the correct format of input files.");
- System.out.println(" Usage: EnumTriangleBasic <edge path> <result path>");
- }
- true
- }
-
- private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
- if (fileOutput) {
- env.readCsvFile[(Int, Int)](edgePath, fieldDelimiter = ' ', includedFields = Array(0, 1)).
- map { x => new Edge(x._1, x._2) }
- } else {
- val edges = EnumTrianglesData.EDGES.map{ case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int]) }
- env.fromCollection(edges)
- }
- }
-
-
- private var fileOutput: Boolean = false
- private var edgePath: String = null
- private var outputPath: String = null
+
+ def main(args: Array[String]) {
+ if (!parseParameters(args)) {
+ return
+ }
+
+ // set up execution environment
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ // read input data
+ val edges = getEdgeDataSet(env)
+
+ val edgesWithDegrees = edges
+ // duplicate and switch edges
+ .flatMap(e => Seq(e, Edge(e.v2, e.v1)))
+ // add degree of first vertex
+ .groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new DegreeCounter())
+ // join degrees of vertices
+ .groupBy("v1", "v2").reduce {
+ (e1, e2) =>
+ if (e1.d2 == 0) {
+ new EdgeWithDegrees(e1.v1, e1.d1, e1.v2, e2.d2)
+ } else {
+ new EdgeWithDegrees(e1.v1, e2.d1, e1.v2, e1.d2)
+ }
+ }
+
+ // project edges by degrees, vertex with smaller degree comes first
+ val edgesByDegree = edgesWithDegrees
+ .map(e => if (e.d1 <= e.d2) Edge(e.v1, e.v2) else Edge(e.v2, e.v1))
+ // project edges by Id, vertex with smaller Id comes first
+ val edgesById = edgesByDegree
+ .map(e => if (e.v1 < e.v2) e else Edge(e.v2, e.v1))
+
+ val triangles = edgesByDegree
+ // build triads
+ .groupBy("v1").sortGroup("v2", Order.ASCENDING).reduceGroup(new TriadBuilder())
+ // filter triads
+ .join(edgesById).where("v2", "v3").equalTo("v1", "v2") { (t, _) => Some(t)}
+
+ // emit result
+ if (fileOutput) {
+ triangles.writeAsCsv(outputPath, "\n", ",")
+ } else {
+ triangles.print()
+ }
+
+ // execute program
+ env.execute("TriangleEnumeration Example")
+ }
+
+ // *************************************************************************
+ // USER DATA TYPES
+ // *************************************************************************
+
+ case class Edge(v1: Int, v2: Int) extends Serializable
+
+ case class Triad(v1: Int, v2: Int, v3: Int) extends Serializable
+
+ case class EdgeWithDegrees(v1: Int, d1: Int, v2: Int, d2: Int) extends Serializable
+
+
+ // *************************************************************************
+ // USER FUNCTIONS
+ // *************************************************************************
+
+ /**
+ * Counts the number of edges that share a common vertex.
+ * Emits one edge for each input edge with a degree annotation for the shared vertex.
+ * For each emitted edge, the first vertex is the vertex with the smaller id.
+ */
+ class DegreeCounter extends GroupReduceFunction[Edge, EdgeWithDegrees] {
+
+ val vertices = mutable.MutableList[Integer]()
+ var groupVertex = 0
+
+ override def reduce(edges: java.lang.Iterable[Edge], out: Collector[EdgeWithDegrees]) = {
+
+ // empty vertex list
+ vertices.clear()
+
+ // collect all vertices
+ for (e <- edges.asScala) {
+ groupVertex = e.v1
+ if (!vertices.contains(e.v2) && e.v1 != e.v2) {
+ vertices += e.v2
+ }
+ }
+
+ // count vertices to obtain degree of groupVertex
+ val degree = vertices.length
+
+ // create and emit edges with degrees
+ for (v <- vertices) {
+ if (v < groupVertex) {
+ out.collect(new EdgeWithDegrees(v, 0, groupVertex, degree))
+ } else {
+ out.collect(new EdgeWithDegrees(groupVertex, degree, v, 0))
+ }
+ }
+ }
+ }
+
+ /**
+ * Builds triads (triples of vertices) from pairs of edges that share a vertex.
+ * The first vertex of a triad is the shared vertex, the second and third vertex are ordered by
+ * vertexId.
+ * Assumes that input edges share the first vertex and are in ascending order of the second
+ * vertex.
+ */
+ class TriadBuilder extends GroupReduceFunction[Edge, Triad] {
+
+ val vertices = mutable.MutableList[Integer]()
+
+ override def reduce(edges: java.lang.Iterable[Edge], out: Collector[Triad]) = {
+
+ // clear vertex list
+ vertices.clear()
+
+ // build and emit triads
+ for (e <- edges.asScala) {
+ // combine vertex with all previously read vertices
+ for (v <- vertices) {
+ out.collect(Triad(e.v1, v, e.v2))
+ }
+ vertices += e.v2
+ }
+ }
+ }
+
+ // *************************************************************************
+ // UTIL METHODS
+ // *************************************************************************
+
+ private def parseParameters(args: Array[String]): Boolean = {
+ if (args.length > 0) {
+ fileOutput = true
+ if (args.length == 2) {
+ edgePath = args(0)
+ outputPath = args(1)
+ } else {
+ System.err.println("Usage: EnumTriangleOpt <edge path> <result path>")
+ false
+ }
+ } else {
+ System.out.println("Executing Enum Triangles Optimized example with built-in default data.")
+ System.out.println(" Provide parameters to read input data from files.")
+ System.out.println(" See the documentation for the correct format of input files.")
+ System.out.println(" Usage: EnumTriangleBasic <edge path> <result path>")
+ }
+ true
+ }
+
+ private def getEdgeDataSet(env: ExecutionEnvironment): DataSet[Edge] = {
+ if (fileOutput) {
+ env.readCsvFile[Edge](
+ edgePath,
+ fieldDelimiter = ' ',
+ includedFields = Array(0, 1))
+ } else {
+ val edges = EnumTrianglesData.EDGES.map {
+ case Array(v1, v2) => new Edge(v1.asInstanceOf[Int], v2.asInstanceOf[Int])}
+ env.fromCollection(edges)
+ }
+ }
+
+
+ private var fileOutput: Boolean = false
+ private var edgePath: String = null
+ private var outputPath: String = null
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/PageRankBasic.scala
----------------------------------------------------------------------
diff --git a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/PageRankBasic.scala b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/PageRankBasic.scala
index e24727c..28a0e48 100644
--- a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/PageRankBasic.scala
+++ b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/graph/PageRankBasic.scala
@@ -17,169 +17,184 @@
*/
package org.apache.flink.examples.scala.graph
-import scala.collection.JavaConverters._
import org.apache.flink.api.scala._
import org.apache.flink.example.java.graph.util.PageRankData
-import org.apache.flink.util.Collector
import org.apache.flink.api.java.aggregation.Aggregations.SUM
+import org.apache.flink.util.Collector
/**
* A basic implementation of the Page Rank algorithm using a bulk iteration.
*
- * <p>
- * This implementation requires a set of pages and a set of directed links as input and works as follows. <br>
- * In each iteration, the rank of every page is evenly distributed to all pages it points to.
- * Each page collects the partial ranks of all pages that point to it, sums them up, and applies a dampening factor to the sum.
- * The result is the new rank of the page. A new iteration is started with the new ranks of all pages.
- * This implementation terminates after a fixed number of iterations.<br>
- * This is the Wikipedia entry for the <a href="http://en.wikipedia.org/wiki/Page_rank">Page Rank algorithm</a>.
+ * This implementation requires a set of pages and a set of directed links as input and works as
+ * follows.
+ *
+ * In each iteration, the rank of every page is evenly distributed to all pages it points to. Each
+ * page collects the partial ranks of all pages that point to it, sums them up, and applies a
+ * dampening factor to the sum. The result is the new rank of the page. A new iteration is started
+ * with the new ranks of all pages. This implementation terminates after a fixed number of
+ * iterations. This is the Wikipedia entry for the
+ * [[http://en.wikipedia.org/wiki/Page_rank Page Rank algorithm]]
*
- * <p>
* Input files are plain text files and must be formatted as follows:
- * <ul>
- * <li>Pages represented as an (long) ID separated by new-line characters.<br>
- * For example <code>"1\n2\n12\n42\n63\n"</code> gives five pages with IDs 1, 2, 12, 42, and 63.
- * <li>Links are represented as pairs of page IDs which are separated by space
- * characters. Links are separated by new-line characters.<br>
- * For example <code>"1 2\n2 12\n1 12\n42 63\n"</code> gives four (directed) links (1)->(2), (2)->(12), (1)->(12), and (42)->(63).<br>
- * For this simple implementation it is required that each page has at least one incoming and one outgoing link (a page can point to itself).
- * </ul>
- *
- * <p>
- * Usage: <code>PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations></code><br>
- * If no parameters are provided, the program is run with default data from {@link PageRankData} and 10 iterations.
+ *
+ * - Pages represented as an (long) ID separated by new-line characters.
+ * For example `"1\n2\n12\n42\n63\n"` gives five pages with IDs 1, 2, 12, 42, and 63.
+ * - Links are represented as pairs of page IDs which are separated by space characters. Links
+ * are separated by new-line characters.
+ * For example `"1 2\n2 12\n1 12\n42 63\n"` gives four (directed) links (1)->(2), (2)->(12),
+ * (1)->(12), and (42)->(63). For this simple implementation it is required that each page has
+ * at least one incoming and one outgoing link (a page can point to itself).
+ *
+ * Usage:
+ * {{{
+ * PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>
+ * }}}
+ *
+ * If no parameters are provided, the program is run with default data from
+ * [[org.apache.flink.example.java.graph.util.PageRankData]] and 10 iterations.
*
- * <p>
* This example shows how to use:
- * <ul>
- * <li>Bulk Iterations
- * <li>Default Join
- * <li>Configure user-defined functions using constructor parameters.
- * </ul>
+ *
+ * - Bulk Iterations
+ * - Default Join
+ * - Configure user-defined functions using constructor parameters.
*
*/
object PageRankBasic {
-
- private final val DAMPENING_FACTOR: Double = 0.85;
- private final val EPSILON: Double = 0.0001;
-
- def main(args: Array[String]) {
- if (!parseParameters(args)) {
- return
- }
-
- // set up execution environment
- val env = ExecutionEnvironment.getExecutionEnvironment
-
- // read input data
- val pages = getPagesDataSet(env)
- val links = getLinksDataSet(env)
-
- // assign initial ranks to pages
- val pagesWithRanks = pages.map(p => Page(p, (1.0/numPages)))
-
- // build adjacency list from link input
- val adjacencyLists = links
- // initialize lists
- .map( e => AdjacencyList(e.sourceId, Array[java.lang.Long](e.targetId) ))
- // concatenate lists
- .groupBy(0).reduce( (l1, l2) => AdjacencyList(l1.sourceId, l1.targetIds ++ l2.targetIds))
-
- // start iteration
- val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
- currentRanks =>
- val newRanks = currentRanks
- // distribute ranks to target pages
- .join(adjacencyLists).where(0).equalTo(0)
- .flatMap { x => for(targetId <- x._2.targetIds) yield Page(targetId, (x._1.rank / x._2.targetIds.length))}
- // collect ranks and sum them up
- .groupBy(0).aggregate(SUM, 1)
- // apply dampening factor
- .map { p => Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages) ) }
-
- // terminate if no rank update was significant
- val termination = currentRanks
- .join(newRanks).where(0).equalTo(0)
- // check for significant update
- .filter( x => math.abs(x._1.rank - x._2.rank) > EPSILON )
-
- (newRanks, termination)
- }
-
- val result = finalRanks;
-
- // emit result
- if (fileOutput) {
- result.writeAsCsv(outputPath, "\n", " ")
- } else {
- result.print()
- }
-
- // execute program
- env.execute("Basic PageRank Example")
- }
-
- // *************************************************************************
- // USER TYPES
- // *************************************************************************
-
- case class Link(sourceId: Long, targetId: Long)
- case class Page(pageId: java.lang.Long, rank: Double)
- case class AdjacencyList(sourceId: java.lang.Long, targetIds: Array[java.lang.Long])
-
- // *************************************************************************
- // UTIL METHODS
- // *************************************************************************
-
- private def parseParameters(args: Array[String]): Boolean = {
- if (args.length > 0) {
- fileOutput = true
- if (args.length == 5) {
- pagesInputPath = args(0)
- linksInputPath = args(1)
- outputPath = args(2)
- numPages = args(3).toLong
- maxIterations = args(4).toInt
- } else {
- System.err.println("Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>");
- false
- }
- } else {
- System.out.println("Executing PageRank Basic example with default parameters and built-in default data.");
- System.out.println(" Provide parameters to read input data from files.");
- System.out.println(" See the documentation for the correct format of input files.");
- System.out.println(" Usage: PageRankBasic <pages path> <links path> <output path> <num pages> <num iterations>");
-
- numPages = PageRankData.getNumberOfPages();
- }
- true
- }
-
- private def getPagesDataSet(env: ExecutionEnvironment): DataSet[Long] = {
- if(fileOutput) {
- env.readCsvFile[Tuple1[Long]](pagesInputPath, fieldDelimiter = ' ', lineDelimiter = "\n")
- .map(x => x._1)
- } else {
- env.fromCollection(Seq.range(1, PageRankData.getNumberOfPages()+1))
- }
- }
-
- private def getLinksDataSet(env: ExecutionEnvironment): DataSet[Link] = {
- if (fileOutput) {
- env.readCsvFile[(Long, Long)](linksInputPath, fieldDelimiter = ' ', includedFields = Array(0, 1))
- .map { x => Link(x._1, x._2) }
- } else {
- val edges = PageRankData.EDGES.map{ case Array(v1, v2) => Link(v1.asInstanceOf[Long], v2.asInstanceOf[Long]) }
- env.fromCollection(edges)
- }
- }
-
- private var fileOutput: Boolean = false
- private var pagesInputPath: String = null
- private var linksInputPath: String = null
- private var outputPath: String = null
- private var numPages: Long = 0;
- private var maxIterations: Int = 10;
+
+ private final val DAMPENING_FACTOR: Double = 0.85
+ private final val EPSILON: Double = 0.0001
+
+ def main(args: Array[String]) {
+ if (!parseParameters(args)) {
+ return
+ }
+
+ // set up execution environment
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ // read input data
+ val pages = getPagesDataSet(env)
+ val links = getLinksDataSet(env)
+
+ // assign initial ranks to pages
+ val pagesWithRanks = pages.map(p => Page(p, 1.0 / numPages))
+
+ // build adjacency list from link input
+ val adjacencyLists = links
+ // initialize lists
+ .map(e => AdjacencyList(e.sourceId, Array(e.targetId)))
+ // concatenate lists
+ .groupBy("sourceId").reduce((l1, l2) => AdjacencyList(l1.sourceId, l1.targetIds ++ l2.targetIds))
+
+ // start iteration
+ val finalRanks = pagesWithRanks.iterateWithTermination(maxIterations) {
+ currentRanks =>
+ val newRanks = currentRanks
+ // distribute ranks to target pages
+ .join(adjacencyLists).where("pageId").equalTo("sourceId") {
+ (page, adjacent, out: Collector[Page]) =>
+ for (targetId <- adjacent.targetIds) {
+ out.collect(Page(targetId, page.rank / adjacent.targetIds.length))
+ }
+ }
+ // collect ranks and sum them up
+ .groupBy("pageId").aggregate(SUM, "rank")
+ // apply dampening factor
+ .map { p =>
+ Page(p.pageId, (p.rank * DAMPENING_FACTOR) + ((1 - DAMPENING_FACTOR) / numPages))
+ }
+
+ // terminate if no rank update was significant
+ val termination = currentRanks.join(newRanks).where("pageId").equalTo("pageId") {
+ (current, next) =>
+ // check for significant update
+ if (math.abs(current.rank - next.rank) > EPSILON) Some(1) else None
+ }
+
+ (newRanks, termination)
+ }
+
+ val result = finalRanks
+
+ // emit result
+ if (fileOutput) {
+ result.writeAsCsv(outputPath, "\n", " ")
+ } else {
+ result.print()
+ }
+
+ // execute program
+ env.execute("Basic PageRank Example")
+ }
+
+ // *************************************************************************
+ // USER TYPES
+ // *************************************************************************
+
+ case class Link(sourceId: Long, targetId: Long)
+
+ case class Page(pageId: Long, rank: Double)
+
+ case class AdjacencyList(sourceId: Long, targetIds: Array[Long])
+
+ // *************************************************************************
+ // UTIL METHODS
+ // *************************************************************************
+
+ private def parseParameters(args: Array[String]): Boolean = {
+ if (args.length > 0) {
+ fileOutput = true
+ if (args.length == 5) {
+ pagesInputPath = args(0)
+ linksInputPath = args(1)
+ outputPath = args(2)
+ numPages = args(3).toLong
+ maxIterations = args(4).toInt
+ } else {
+ System.err.println("Usage: PageRankBasic <pages path> <links path> <output path> <num " +
+ "pages> <num iterations>")
+ false
+ }
+ } else {
+ System.out.println("Executing PageRank Basic example with default parameters and built-in " +
+ "default data.")
+ System.out.println(" Provide parameters to read input data from files.")
+ System.out.println(" See the documentation for the correct format of input files.")
+ System.out.println(" Usage: PageRankBasic <pages path> <links path> <output path> <num " +
+ "pages> <num iterations>")
+
+ numPages = PageRankData.getNumberOfPages
+ }
+ true
+ }
+
+ private def getPagesDataSet(env: ExecutionEnvironment): DataSet[Long] = {
+ if (fileOutput) {
+ env.readCsvFile[Tuple1[Long]](pagesInputPath, fieldDelimiter = ' ', lineDelimiter = "\n")
+ .map(x => x._1)
+ } else {
+ env.generateSequence(1, 15)
+ }
+ }
+
+ private def getLinksDataSet(env: ExecutionEnvironment): DataSet[Link] = {
+ if (fileOutput) {
+ env.readCsvFile[Link](linksInputPath, fieldDelimiter = ' ',
+ includedFields = Array(0, 1))
+ } else {
+ val edges = PageRankData.EDGES.map { case Array(v1, v2) => Link(v1.asInstanceOf[Long],
+ v2.asInstanceOf[Long])}
+ env.fromCollection(edges)
+ }
+ }
+
+ private var fileOutput: Boolean = false
+ private var pagesInputPath: String = null
+ private var linksInputPath: String = null
+ private var outputPath: String = null
+ private var numPages: Long = 0
+ private var maxIterations: Int = 10
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/misc/PiEstimation.scala
----------------------------------------------------------------------
diff --git a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/misc/PiEstimation.scala b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/misc/PiEstimation.scala
index d702f61..bb66b10 100644
--- a/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/misc/PiEstimation.scala
+++ b/flink-examples/flink-scala-examples/src/main/scala/org/apache/flink/examples/scala/misc/PiEstimation.scala
@@ -26,22 +26,23 @@ object PiEstimation {
val numSamples: Long = if (args.length > 0) args(0).toLong else 1000000
- val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val env = ExecutionEnvironment.getExecutionEnvironment
// count how many of the samples would randomly fall into
- // the unit circle
+ // the upper right quadrant of the unit circle
val count =
env.generateSequence(1, numSamples)
- .map (sample => {
- val x = Math.random()
- val y = Math.random()
- if (x * x + y * y < 1) 1L else 0L
- })
- .reduce(_+_)
-
- // the ratio of the unit circle surface to 4 times the unit square is pi
+ .map { sample =>
+ val x = Math.random()
+ val y = Math.random()
+ if (x * x + y * y < 1) 1L else 0L
+ }
+ .reduce(_+_)
+
+ // ratio of samples in upper right quadrant vs total samples gives surface of upper
+ // right quadrant, times 4 gives surface of whole unit circle, i.e. PI
val pi = count
- .map (_ * 4.0 / numSamples)
+ .map ( _ * 4.0 / numSamples)
println("We estimate Pi to be:")
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index 041f269..8f14c0a 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -30,6 +30,7 @@ import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator}
+import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.core.fs.{FileSystem, Path}
import org.apache.flink.types.TypeInformation
@@ -376,6 +377,25 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
}
/**
+ * Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
+ * function. Since this is not a keyed DataSet the aggregation will be performed on the whole
+ * collection of elements.
+ *
+ * This only works on CaseClass DataSets.
+ */
+ def aggregate(agg: Aggregations, field: String): DataSet[T] = {
+ val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
+
+ set match {
+ case aggregation: ScalaAggregateOperator[T] =>
+ aggregation.and(agg, fieldIndex)
+ wrap(aggregation)
+
+ case _ => wrap(new ScalaAggregateOperator[T](set, agg, fieldIndex))
+ }
+ }
+
+ /**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int) = {
@@ -397,6 +417,27 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
}
/**
+ * Syntactic sugar for [[aggregate]] with `SUM`
+ */
+ def sum(field: String) = {
+ aggregate(Aggregations.SUM, field)
+ }
+
+ /**
+ * Syntactic sugar for [[aggregate]] with `MAX`
+ */
+ def max(field: String) = {
+ aggregate(Aggregations.MAX, field)
+ }
+
+ /**
+ * Syntactic sugar for [[aggregate]] with `MIN`
+ */
+ def min(field: String) = {
+ aggregate(Aggregations.MIN, field)
+ }
+
+ /**
* Creates a new [[DataSet]] by merging the elements of this DataSet using an associative reduce
* function.
*/
@@ -486,7 +527,7 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on only the specified tuple fields.
*
- * This only works if this DataSet contains Tuples.
+ * This only works on tuple DataSets.
*/
def distinct(fields: Int*): DataSet[T] = {
wrap(new DistinctOperator[T](
@@ -496,6 +537,19 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
/**
* Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
+ * two elements are distinct or not is made based on only the specified fields.
+ *
+ * This only works on CaseClass DataSets
+ */
+ def distinct(firstField: String, otherFields: String*): DataSet[T] = {
+ val fieldIndices = fieldNames2Indices(set.getType, firstField +: otherFields.toArray)
+ wrap(new DistinctOperator[T](
+ set,
+ new Keys.FieldPositionKeys[T](fieldIndices, set.getType, true)))
+ }
+
+ /**
+ * Creates a new DataSet containing the distinct elements of this DataSet. The decision whether
* two elements are distinct or not is made based on all tuple fields.
*
* This only works if this DataSet contains Tuples.
@@ -539,6 +593,23 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
new Keys.FieldPositionKeys[T](fields.toArray, set.getType,false))
}
+ /**
+ * Creates a [[GroupedDataSet]] which provides operations on groups of elements. Elements are
+ * grouped based on the given fields.
+ *
+ * This will not create a new DataSet, it will just attach the field names which will be
+ * used for grouping when executing a grouped operation.
+ *
+ * This only works on CaseClass DataSets.
+ */
+ def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
+ val fieldIndices = fieldNames2Indices(set.getType, firstField +: otherFields.toArray)
+
+ new GroupedDataSetImpl[T](
+ set,
+ new Keys.FieldPositionKeys[T](fieldIndices, set.getType,false))
+ }
+
// public UnsortedGrouping<T> groupBy(String... fields) {
// new UnsortedGrouping<T>(this, new Keys.ExpressionKeys<T>(fields, getType()));
// }
@@ -587,21 +658,21 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* }}}
*/
def join[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
- new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.OPTIMIZER_CHOOSES)
+ new UnfinishedJoinOperationImpl(this, other, JoinHint.OPTIMIZER_CHOOSES)
/**
* Special [[join]] operation for explicitly telling the system that the right side is assumed
* to be a lot smaller than the left side of the join.
*/
def joinWithTiny[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
- new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.BROADCAST_HASH_SECOND)
+ new UnfinishedJoinOperationImpl(this, other, JoinHint.BROADCAST_HASH_SECOND)
/**
* Special [[join]] operation for explicitly telling the system that the left side is assumed
* to be a lot smaller than the right side of the join.
*/
def joinWithHuge[O](other: DataSet[O]): UnfinishedJoinOperation[T, O] =
- new UnfinishedJoinOperationImpl(this.set, other.set, JoinHint.BROADCAST_HASH_FIRST)
+ new UnfinishedJoinOperationImpl(this, other, JoinHint.BROADCAST_HASH_FIRST)
// --------------------------------------------------------------------------------------------
// Co-Group
@@ -641,7 +712,7 @@ class DataSet[T: ClassTag](private[flink] val set: JavaDataSet[T]) {
* }}}
*/
def coGroup[O: ClassTag](other: DataSet[O]): UnfinishedCoGroupOperation[T, O] =
- new UnfinishedCoGroupOperationImpl(this.set, other.set)
+ new UnfinishedCoGroupOperationImpl(this, other)
// --------------------------------------------------------------------------------------------
// Cross
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
index dfd5cf0..a7ca821 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala
@@ -19,6 +19,7 @@ package org.apache.flink.api.scala
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala.operators.ScalaAggregateOperator
+import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import scala.collection.JavaConverters._
@@ -45,11 +46,21 @@ trait GroupedDataSet[T] {
/**
* Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
- * use one of the group-at-a-time, i.e. `reduceGroup`
+ * use one of the group-at-a-time, i.e. `reduceGroup`.
+ *
+ * This only works on Tuple DataSets.
*/
def sortGroup(field: Int, order: Order): GroupedDataSet[T]
/**
+ * Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
+ * use one of the group-at-a-time, i.e. `reduceGroup`.
+ *
+ * This only works on CaseClass DataSets.
+ */
+ def sortGroup(field: String, order: Order): GroupedDataSet[T]
+
+ /**
* Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
* function. Since this is a keyed DataSet the aggregation will be performed on groups of
* tuples with the same key.
@@ -59,6 +70,15 @@ trait GroupedDataSet[T] {
def aggregate(agg: Aggregations, field: Int): DataSet[T]
/**
+ * Creates a new [[DataSet]] by aggregating the specified field using the given aggregation
+ * function. Since this is a keyed DataSet the aggregation will be performed on groups of
+ * elements with the same key.
+ *
+ * This only works on CaseClass DataSets.
+ */
+ def aggregate(agg: Aggregations, field: String): DataSet[T]
+
+ /**
* Syntactic sugar for [[aggregate]] with `SUM`
*/
def sum(field: Int): DataSet[T]
@@ -74,6 +94,21 @@ trait GroupedDataSet[T] {
def min(field: Int): DataSet[T]
/**
+ * Syntactic sugar for [[aggregate]] with `SUM`
+ */
+ def sum(field: String): DataSet[T]
+
+ /**
+ * Syntactic sugar for [[aggregate]] with `MAX`
+ */
+ def max(field: String): DataSet[T]
+
+ /**
+ * Syntactic sugar for [[aggregate]] with `MIN`
+ */
+ def min(field: String): DataSet[T]
+
+ /**
* Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
* using an associative reduce function.
*/
@@ -124,14 +159,10 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
private val groupSortKeyPositions = mutable.MutableList[Int]()
private val groupSortOrders = mutable.MutableList[Order]()
- /**
- * Adds a secondary sort key to this [[GroupedDataSet]]. This will only have an effect if you
- * use one of the group-at-a-time, i.e. `reduceGroup`
- */
def sortGroup(field: Int, order: Order): GroupedDataSet[T] = {
if (!set.getType.isTupleType) {
throw new InvalidProgramException("Specifying order keys via field positions is only valid " +
- "for tuple data types")
+ "for tuple data types.")
}
if (field >= set.getType.getArity) {
throw new IllegalArgumentException("Order key out of tuple bounds.")
@@ -141,10 +172,14 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
this
}
- /**
- * Creates a [[SortedGrouping]] if any secondary sort fields were specified. Otherwise, just
- * create an [[UnsortedGrouping]].
- */
+ def sortGroup(field: String, order: Order): GroupedDataSet[T] = {
+ val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
+
+ groupSortKeyPositions += fieldIndex
+ groupSortOrders += order
+ this
+ }
+
private def maybeCreateSortedGrouping(): Grouping[T] = {
if (groupSortKeyPositions.length > 0) {
val grouping = new SortedGrouping[T](set, keys, groupSortKeyPositions(0), groupSortOrders(0))
@@ -161,13 +196,18 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
/** Convenience methods for creating the [[UnsortedGrouping]] */
private def createUnsortedGrouping(): Grouping[T] = new UnsortedGrouping[T](set, keys)
- /**
- * Creates a new [[DataSet]] by aggregating the specified tuple field using the given aggregation
- * function. Since this is a keyed DataSet the aggregation will be performed on groups of
- * tuples with the same key.
- *
- * This only works on Tuple DataSets.
- */
+ def aggregate(agg: Aggregations, field: String): DataSet[T] = {
+ val fieldIndex = fieldNames2Indices(set.getType, Array(field))(0)
+
+ set match {
+ case aggregation: ScalaAggregateOperator[T] =>
+ aggregation.and(agg, fieldIndex)
+ wrap(aggregation)
+
+ case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, fieldIndex))
+ }
+ }
+
def aggregate(agg: Aggregations, field: Int): DataSet[T] = set match {
case aggregation: ScalaAggregateOperator[T] =>
aggregation.and(agg, field)
@@ -176,31 +216,30 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
case _ => wrap(new ScalaAggregateOperator[T](createUnsortedGrouping(), agg, field))
}
- /**
- * Syntactic sugar for [[aggregate]] with `SUM`
- */
def sum(field: Int): DataSet[T] = {
aggregate(Aggregations.SUM, field)
}
- /**
- * Syntactic sugar for [[aggregate]] with `MAX`
- */
def max(field: Int): DataSet[T] = {
aggregate(Aggregations.MAX, field)
}
- /**
- * Syntactic sugar for [[aggregate]] with `MIN`
- */
def min(field: Int): DataSet[T] = {
aggregate(Aggregations.MIN, field)
}
- /**
- * Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
- * using an associative reduce function.
- */
+ def sum(field: String): DataSet[T] = {
+ aggregate(Aggregations.SUM, field)
+ }
+
+ def max(field: String): DataSet[T] = {
+ aggregate(Aggregations.MAX, field)
+ }
+
+ def min(field: String): DataSet[T] = {
+ aggregate(Aggregations.MIN, field)
+ }
+
def reduce(fun: (T, T) => T): DataSet[T] = {
Validate.notNull(fun, "Reduce function must not be null.")
val reducer = new ReduceFunction[T] {
@@ -211,20 +250,11 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
- /**
- * Creates a new [[DataSet]] by merging the elements of each group (elements with the same key)
- * using an associative reduce function.
- */
def reduce(reducer: ReduceFunction[T]): DataSet[T] = {
Validate.notNull(reducer, "Reduce function must not be null.")
wrap(new ReduceOperator[T](createUnsortedGrouping(), reducer))
}
- /**
- * Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
- * of elements to the group reduce function. The function must output one element. The
- * concatenation of those will form the resulting [[DataSet]].
- */
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T]) => R): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
@@ -238,11 +268,6 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
- /**
- * Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
- * of elements to the group reduce function. The function can output zero or more elements using
- * the [[Collector]]. The concatenation of the emitted values will form the resulting [[DataSet]].
- */
def reduceGroup[R: TypeInformation: ClassTag](
fun: (TraversableOnce[T], Collector[R]) => Unit): DataSet[R] = {
Validate.notNull(fun, "Group reduce function must not be null.")
@@ -256,11 +281,6 @@ private[flink] class GroupedDataSetImpl[T: ClassTag](
implicitly[TypeInformation[R]], reducer))
}
- /**
- * Creates a new [[DataSet]] by passing for each group (elements with the same key) the list
- * of elements to the [[GroupReduceFunction]]. The function can output zero or more elements. The
- * concatenation of the emitted values will form the resulting [[DataSet]].
- */
def reduceGroup[R: TypeInformation: ClassTag](reducer: GroupReduceFunction[T, R]): DataSet[R] = {
Validate.notNull(reducer, "GroupReduce function must not be null.")
wrap(
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
index 05f9917..f936b43 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
@@ -94,8 +94,8 @@ trait CoGroupDataSet[T, O] extends DataSet[(Array[T], Array[O])] {
*/
private[flink] class CoGroupDataSetImpl[T, O](
coGroupOperator: CoGroupOperator[T, O, (Array[T], Array[O])],
- thisSet: JavaDataSet[T],
- otherSet: JavaDataSet[O],
+ thisSet: DataSet[T],
+ otherSet: DataSet[O],
thisKeys: Keys[T],
otherKeys: Keys[O]) extends DataSet(coGroupOperator) with CoGroupDataSet[T, O] {
@@ -107,7 +107,7 @@ private[flink] class CoGroupDataSetImpl[T, O](
fun(left.iterator.asScala, right.iterator.asScala) map { out.collect(_) }
}
}
- val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
@@ -120,14 +120,14 @@ private[flink] class CoGroupDataSetImpl[T, O](
fun(left.iterator.asScala, right.iterator.asScala, out)
}
}
- val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, coGrouper, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R] = {
Validate.notNull(joiner, "CoGroup function must not be null.")
- val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet.set, otherSet.set, thisKeys,
otherKeys, joiner, implicitly[TypeInformation[R]])
wrap(coGroupOperator)
}
@@ -153,8 +153,8 @@ trait UnfinishedCoGroupOperation[T, O]
* i.e. the parameters of the constructor, hidden.
*/
private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
- leftSet: JavaDataSet[T],
- rightSet: JavaDataSet[O])
+ leftSet: DataSet[T],
+ rightSet: DataSet[O])
extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]](leftSet, rightSet)
with UnfinishedCoGroupOperation[T, O] {
@@ -173,11 +173,13 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
// We have to use this hack, for some reason classOf[Array[T]] does not work.
// Maybe because ObjectArrayTypeInfo does not accept the Scala Array as an array class.
- val leftArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.getType)
- val rightArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.getType)
+ val leftArrayType =
+ ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.set.getType)
+ val rightArrayType =
+ ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.set.getType)
val returnType = new ScalaTupleTypeInfo[(Array[T], Array[O])](
- classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType)) {
+ classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(Array[T], Array[O])] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
@@ -195,10 +197,10 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
}
}
val coGroupOperator = new CoGroupOperator[T, O, (Array[T], Array[O])](
- leftSet, rightSet, leftKey, rightKey, coGrouper, returnType)
+ leftSet.set, rightSet.set, leftKey, rightKey, coGrouper, returnType)
// sanity check solution set key mismatches
- leftSet match {
+ leftSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
@@ -211,7 +213,7 @@ private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
}
case _ =>
}
- rightSet match {
+ rightSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
index 248c396..6ad1f74 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
@@ -68,8 +68,11 @@ private[flink] trait TypeInformationGen[C <: Context] {
}
val fieldsExpr = c.Expr[Seq[TypeInformation[_]]](mkList(fields))
val instance = mkCreateTupleInstance[T](desc)(c.WeakTypeTag(desc.tpe))
+
+ val fieldNames = desc.getters map { f => Literal(Constant(f.getter.name.toString)) } toList
+ val fieldNamesExpr = c.Expr[Seq[String]](mkSeq(fieldNames))
reify {
- new ScalaTupleTypeInfo[T](tpeClazz.splice, fieldsExpr.splice) {
+ new ScalaTupleTypeInfo[T](tpeClazz.splice, fieldsExpr.splice, fieldNamesExpr.splice) {
override def createSerializer: TypeSerializer[T] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
for (i <- 0 until getArity) {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
index 5218745..2db2ff6 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
@@ -110,7 +110,7 @@ private[flink] object CrossDataSetImpl {
}
}
val returnType = new ScalaTupleTypeInfo[(T, O)](
- classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
+ classOf[(T, O)], Seq(leftSet.getType, rightSet.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
index 8d24ee1..a2b09e1 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
@@ -168,8 +168,8 @@ trait UnfinishedJoinOperation[T, O] extends UnfinishedKeyPairOperation[T, O, Joi
* i.e. the parameters of the constructor, hidden.
*/
private[flink] class UnfinishedJoinOperationImpl[T, O](
- leftSet: JavaDataSet[T],
- rightSet: JavaDataSet[O],
+ leftSet: DataSet[T],
+ rightSet: DataSet[O],
joinHint: JoinHint)
extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]](leftSet, rightSet)
with UnfinishedJoinOperation[T, O] {
@@ -181,7 +181,7 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
}
val returnType = new ScalaTupleTypeInfo[(T, O)](
- classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
+ classOf[(T, O)], Seq(leftSet.set.getType, rightSet.set.getType), Array("_1", "_2")) {
override def createSerializer: TypeSerializer[(T, O)] = {
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
@@ -197,10 +197,10 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
}
val joinOperator = new EquiJoin[T, O, (T, O)](
- leftSet, rightSet, leftKey, rightKey, joiner, returnType, joinHint)
+ leftSet.set, rightSet.set, leftKey, rightKey, joiner, returnType, joinHint)
// sanity check solution set key mismatches
- leftSet match {
+ leftSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
leftKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
@@ -213,7 +213,7 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
}
case _ =>
}
- rightSet match {
+ rightSet.set match {
case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
rightKey match {
case keyFields: Keys.FieldPositionKeys[_] =>
@@ -227,6 +227,6 @@ private[flink] class UnfinishedJoinOperationImpl[T, O](
case _ =>
}
- new JoinDataSetImpl(joinOperator, leftSet, rightSet, leftKey, rightKey)
+ new JoinDataSetImpl(joinOperator, leftSet.set, rightSet.set, leftKey, rightKey)
}
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/package.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/package.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/package.scala
index 405158f..c63c991 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/package.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/package.scala
@@ -21,7 +21,7 @@ package org.apache.flink.api
import _root_.scala.reflect.ClassTag
import language.experimental.macros
import org.apache.flink.types.TypeInformation
-import org.apache.flink.api.scala.typeutils.TypeUtils
+import org.apache.flink.api.scala.typeutils.{ScalaTupleTypeInfo, TypeUtils}
import org.apache.flink.api.java.{DataSet => JavaDataSet}
package object scala {
@@ -31,4 +31,17 @@ package object scala {
// We need to wrap Java DataSet because we need the scala operations
private[flink] def wrap[R: ClassTag](set: JavaDataSet[R]) = new DataSet[R](set)
+
+ private[flink] def fieldNames2Indices(
+ typeInfo: TypeInformation[_],
+ fields: Array[String]): Array[Int] = {
+ typeInfo match {
+ case ti: ScalaTupleTypeInfo[_] =>
+ ti.getFieldIndices(fields)
+
+ case _ =>
+ throw new UnsupportedOperationException("Specifying fields by name is only" +
+ "supported on Case Classes (for now).")
+ }
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaTupleTypeInfo.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaTupleTypeInfo.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaTupleTypeInfo.scala
index 069e03b..c191e81 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaTupleTypeInfo.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/ScalaTupleTypeInfo.scala
@@ -18,8 +18,7 @@
package org.apache.flink.api.scala.typeutils
-import org.apache.flink.api.java.tuple.Tuple
-import org.apache.flink.api.java.typeutils.{TupleTypeInfo, AtomicType, TupleTypeInfoBase}
+import org.apache.flink.api.java.typeutils.{AtomicType, TupleTypeInfoBase}
import org.apache.flink.types.TypeInformation
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
@@ -29,7 +28,8 @@ import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
*/
abstract class ScalaTupleTypeInfo[T <: Product](
tupleClass: Class[T],
- fieldTypes: Seq[TypeInformation[_]])
+ fieldTypes: Seq[TypeInformation[_]],
+ val fieldNames: Seq[String])
extends TupleTypeInfoBase[T](tupleClass, fieldTypes: _*) {
def createComparator(logicalKeyFields: Array[Int], orders: Array[Boolean]): TypeComparator[T] = {
@@ -76,5 +76,14 @@ abstract class ScalaTupleTypeInfo[T <: Product](
new ScalaTupleComparator[T](logicalKeyFields, fieldComparators, fieldSerializers)
}
+ def getFieldIndices(fields: Array[String]): Array[Int] = {
+ val result = fields map { x => fieldNames.indexOf(x) }
+ if (result.contains(-1)) {
+ throw new IllegalArgumentException("Fields '" + fields.mkString(", ") + "' are not valid for" +
+ " " + tupleClass + " with fields '" + fieldNames.mkString(", ") + "'.")
+ }
+ result
+ }
+
override def toString = "Scala " + super.toString
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala
index f8cbb62..198388e 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/unfinishedKeyPairOperation.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.java.{DataSet => JavaDataSet}
import org.apache.flink.api.java.functions.KeySelector
import org.apache.flink.api.java.operators.Keys
import org.apache.flink.api.java.operators.Keys.FieldPositionKeys
+import org.apache.flink.api.scala.typeutils.ScalaTupleTypeInfo
import org.apache.flink.types.TypeInformation
/**
@@ -43,27 +44,42 @@ import org.apache.flink.types.TypeInformation
* @tparam R The type of the resulting Operation.
*/
private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
- private[flink] val leftSet: JavaDataSet[T],
- private[flink] val rightSet: JavaDataSet[O]) {
+ private[flink] val leftSet: DataSet[T],
+ private[flink] val rightSet: DataSet[O]) {
private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]): R
/**
* Specify the key fields for the left side of the key based operation. This returns
- * a [[HalfUnfinishedKeyPairOperation]] on which `isEqualTo` must be called to specify the
+ * a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*
- * This only works on a Tuple [[DataSet]].
+ * This only works on Tuple [[DataSet]].
*/
def where(leftKeys: Int*) = {
- val leftKey = new FieldPositionKeys[T](leftKeys.toArray, leftSet.getType)
+ val leftKey = new FieldPositionKeys[T](leftKeys.toArray, leftSet.set.getType)
+ new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
+ }
+
+ /**
+ * Specify the key fields for the left side of the key based operation. This returns
+ * a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
+ * key for the right side. The result after specifying the right side key is the finished
+ * operation.
+ *
+ * This only works on a CaseClass [[DataSet]].
+ */
+ def where(firstLeftField: String, otherLeftFields: String*) = {
+ val fieldIndices = fieldNames2Indices(leftSet.set.getType, firstLeftField +: otherLeftFields.toArray)
+
+ val leftKey = new FieldPositionKeys[T](fieldIndices, leftSet.set.getType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
}
/**
* Specify the key selector function for the left side of the key based operation. This returns
- * a [[HalfUnfinishedKeyPairOperation]] on which `isEqualTo` must be called to specify the
+ * a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*/
@@ -72,7 +88,7 @@ private[flink] abstract class UnfinishedKeyPairOperation[T, O, R](
val keyExtractor = new KeySelector[T, K] {
def getKey(in: T) = fun(in)
}
- val leftKey = new Keys.SelectorFunctionKeys[T, K](keyExtractor, leftSet.getType, keyType)
+ val leftKey = new Keys.SelectorFunctionKeys[T, K](keyExtractor, leftSet.set.getType, keyType)
new HalfUnfinishedKeyPairOperation[T, O, R](this, leftKey)
}
}
@@ -87,7 +103,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
* This only works on a Tuple [[DataSet]].
*/
def equalTo(rightKeys: Int*): R = {
- val rightKey = new FieldPositionKeys[O](rightKeys.toArray, unfinished.rightSet.getType)
+ val rightKey = new FieldPositionKeys[O](rightKeys.toArray, unfinished.rightSet.set.getType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
@@ -96,6 +112,26 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
}
/**
+ * Specify the key fields for the right side of the key based operation. This returns
+ * the finished operation.
+ *
+ * This only works on a CaseClass [[DataSet]].
+ */
+ def equalTo(firstRightField: String, otherRightFields: String*): R = {
+ val fieldIndices = fieldNames2Indices(
+ unfinished.rightSet.set.getType,
+ firstRightField +: otherRightFields.toArray)
+
+ val rightKey = new FieldPositionKeys[O](fieldIndices, unfinished.rightSet.set.getType)
+ if (!leftKey.areCompatibale(rightKey)) {
+ throw new InvalidProgramException("The types of the key fields do not match. Left: " +
+ leftKey + " Right: " + rightKey)
+ }
+ unfinished.finish(leftKey, rightKey)
+
+ }
+
+ /**
* Specify the key selector function for the right side of the key based operation. This returns
* the finished operation.
*/
@@ -105,7 +141,7 @@ private[flink] class HalfUnfinishedKeyPairOperation[T, O, R](
def getKey(in: O) = fun(in)
}
val rightKey =
- new Keys.SelectorFunctionKeys[O, K](keyExtractor, unfinished.rightSet.getType, keyType)
+ new Keys.SelectorFunctionKeys[O, K](keyExtractor, unfinished.rightSet.set.getType, keyType)
if (!leftKey.areCompatibale(rightKey)) {
throw new InvalidProgramException("The types of the key fields do not match. Left: " +
leftKey + " Right: " + rightKey)
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateOperatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateOperatorTest.scala
index ae9fe22..ea35138 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateOperatorTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/AggregateOperatorTest.scala
@@ -66,14 +66,54 @@ class AggregateOperatorTest {
}
@Test
+ def testFieldNamesAggregate(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val tupleDs = env.fromCollection(emptyTupleData)
+
+ // should work
+ try {
+ tupleDs.aggregate(Aggregations.SUM, "_2")
+ } catch {
+ case e: Exception => Assert.fail()
+ }
+
+ // should not work: invalid field
+ try {
+ tupleDs.aggregate(Aggregations.SUM, "foo")
+ Assert.fail()
+ } catch {
+ case iae: IllegalArgumentException =>
+ case e: Exception => Assert.fail()
+ }
+
+ val longDs = env.fromCollection(emptyLongData)
+
+ // should not work: not applied to tuple DataSet
+ try {
+ longDs.aggregate(Aggregations.MIN, "_1")
+ Assert.fail()
+ } catch {
+ case uoe: InvalidProgramException =>
+ case uoe: UnsupportedOperationException =>
+ case e: Exception => Assert.fail()
+ }
+ }
+
+ @Test
def testAggregationTypes(): Unit = {
try {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
+ // should work: multiple aggregates
tupleDs.aggregate(Aggregations.SUM, 0).aggregate(Aggregations.MIN, 4)
+
+ // should work: nested aggregates
tupleDs.aggregate(Aggregations.MIN, 2).aggregate(Aggregations.SUM, 1)
+
+ // should not work: average on string
try {
tupleDs.aggregate(Aggregations.SUM, 2)
Assert.fail()
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala
index 3608a50..2b459ab 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/CoGroupOperatorTest.scala
@@ -94,6 +94,71 @@ class CoGroupOperatorTest {
ds1.coGroup(ds2).where(5).equalTo(0)
}
+ @Test
+ def testCoGroupKeyFieldNames1(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // Should work
+ try {
+ ds1.coGroup(ds2).where("_1").equalTo("_1")
+ }
+ catch {
+ case e: Exception => Assert.fail()
+ }
+ }
+
+ @Test(expected = classOf[InvalidProgramException])
+ def testCoGroupKeyFieldNames2(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // Should not work, incompatible key types
+ ds1.coGroup(ds2).where("_1").equalTo("_3")
+ }
+
+ @Test(expected = classOf[InvalidProgramException])
+ def testCoGroupKeyFieldNames3(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // Should not work, incompatible number of key fields
+ ds1.coGroup(ds2).where("_1", "_2").equalTo("_3")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testCoGroupKeyFieldNames4(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // Should not work, invalid field name
+ ds1.coGroup(ds2).where("_6").equalTo("_1")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testCoGroupKeyFieldNames5(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // Should not work, invalid field name
+ ds1.coGroup(ds2).where("_1").equalTo("bar")
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testCoGroupKeyFieldNames6(): Unit = {
+ val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(customTypeData)
+
+ // Should not work, field position key on custom data type
+ ds1.coGroup(ds2).where("_3").equalTo("_1")
+ }
+
@Ignore
@Test
def testCoGroupKeyExpressions1(): Unit = {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala
index 693cd87..6a6e7b3 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/DistinctOperatorTest.scala
@@ -30,7 +30,7 @@ class DistinctOperatorTest {
private val emptyLongData = Array[Long]()
@Test
- def testDistinctByKeyFields1(): Unit = {
+ def testDistinctByKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -44,7 +44,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testDistinctByKeyFields2(): Unit = {
+ def testDistinctByKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
@@ -53,7 +53,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testDistinctByKeyFields3(): Unit = {
+ def testDistinctByKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
@@ -62,7 +62,7 @@ class DistinctOperatorTest {
}
@Test
- def testDistinctByKeyFields4(): Unit = {
+ def testDistinctByKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -71,7 +71,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testDistinctByKeyFields5(): Unit = {
+ def testDistinctByKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
@@ -80,7 +80,7 @@ class DistinctOperatorTest {
}
@Test(expected = classOf[IllegalArgumentException])
- def testDistinctByKeyFields6(): Unit = {
+ def testDistinctByKeyIndices6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -89,6 +89,47 @@ class DistinctOperatorTest {
}
@Test
+ def testDistinctByKeyFields1(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tupleDs = env.fromCollection(emptyTupleData)
+
+ // Should work
+ try {
+ tupleDs.distinct("_1")
+ }
+ catch {
+ case e: Exception => Assert.fail()
+ }
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testDistinctByKeyFields2(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val longDs = env.fromCollection(emptyLongData)
+
+ // should not work: distinct on basic type
+ longDs.distinct("_1")
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testDistinctByKeyFields3(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val customDs = env.fromCollection(customTypeData)
+
+ // should not work: field key on custom type
+ customDs.distinct("_1")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testDistinctByKeyFields4(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tupleDs = env.fromCollection(emptyTupleData)
+
+ // should not work, invalid field
+ tupleDs.distinct("foo")
+ }
+
+ @Test
def testDistinctByKeySelector1(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
try {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala
index 841fd52..affb007 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/GroupingTest.scala
@@ -33,7 +33,7 @@ class GroupingTest {
private val emptyLongData = Array[Long]()
@Test
- def testGroupByKeyFields1(): Unit = {
+ def testGroupByKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -47,7 +47,7 @@ class GroupingTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testGroupByKeyFields2(): Unit = {
+ def testGroupByKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
@@ -56,7 +56,7 @@ class GroupingTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testGroupByKeyFields3(): Unit = {
+ def testGroupByKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val customDs = env.fromCollection(customTypeData)
@@ -65,7 +65,7 @@ class GroupingTest {
}
@Test(expected = classOf[IllegalArgumentException])
- def testGroupByKeyFields4(): Unit = {
+ def testGroupByKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -74,7 +74,7 @@ class GroupingTest {
}
@Test(expected = classOf[IllegalArgumentException])
- def testGroupByKeyFields5(): Unit = {
+ def testGroupByKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val tupleDs = env.fromCollection(emptyTupleData)
@@ -82,6 +82,47 @@ class GroupingTest {
tupleDs.groupBy(-1)
}
+ @Test
+ def testGroupByKeyFields1(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tupleDs = env.fromCollection(emptyTupleData)
+
+ // should work
+ try {
+ tupleDs.groupBy("_1")
+ }
+ catch {
+ case e: Exception => Assert.fail()
+ }
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testGroupByKeyFields2(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val longDs = env.fromCollection(emptyLongData)
+
+ // should not work, grouping on basic type
+ longDs.groupBy("_1")
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testGroupByKeyFields3(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val customDs = env.fromCollection(customTypeData)
+
+ // should not work, field key on custom type
+ customDs.groupBy("_1")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testGroupByKeyFields4(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val tupleDs = env.fromCollection(emptyTupleData)
+
+ // should not work, invalid field
+ tupleDs.groupBy("foo")
+ }
+
@Ignore
@Test
def testGroupByKeyExpressions1(): Unit = {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/83debdb3/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala
index fff9857..4240b6c 100644
--- a/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala
+++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/operators/JoinOperatorTest.scala
@@ -31,10 +31,12 @@ class JoinOperatorTest {
private val emptyLongData = Array[Long]()
@Test
- def testJoinKeyFields1(): Unit = {
+ def testJoinKeyIndices1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
+
+ // should work
try {
ds1.join(ds2).where(0).equalTo(0)
}
@@ -44,45 +46,121 @@ class JoinOperatorTest {
}
@Test(expected = classOf[InvalidProgramException])
- def testJoinKeyFields2(): Unit = {
+ def testJoinKeyIndices2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, incompatible key types
ds1.join(ds2).where(0).equalTo(2)
}
@Test(expected = classOf[InvalidProgramException])
- def testJoinKeyFields3(): Unit = {
+ def testJoinKeyIndices3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, non-matching number of key indices
ds1.join(ds2).where(0, 1).equalTo(2)
}
@Test(expected = classOf[IllegalArgumentException])
- def testJoinKeyFields4(): Unit = {
+ def testJoinKeyIndices4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, index out of range
ds1.join(ds2).where(5).equalTo(0)
}
@Test(expected = classOf[IllegalArgumentException])
- def testJoinKeyFields5(): Unit = {
+ def testJoinKeyIndices5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, negative position
ds1.join(ds2).where(-1).equalTo(-1)
}
@Test(expected = classOf[IllegalArgumentException])
- def testJoinKeyFields6(): Unit = {
+ def testJoinKeyIndices6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(customTypeData)
+
+ // should not work, key index on custom type
ds1.join(ds2).where(5).equalTo(0)
}
+ @Test
+ def testJoinKeyFields1(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // should work
+ try {
+ ds1.join(ds2).where("_1").equalTo("_1")
+ }
+ catch {
+ case e: Exception => Assert.fail()
+ }
+ }
+
+ @Test(expected = classOf[InvalidProgramException])
+ def testJoinKeyFields2(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, incompatible field types
+ ds1.join(ds2).where("_1").equalTo("_3")
+ }
+
+ @Test(expected = classOf[InvalidProgramException])
+ def testJoinKeyFields3(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, non-matching number of key indices
+
+ ds1.join(ds2).where("_1", "_2").equalTo("_3")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testJoinKeyFields4(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, non-existent key
+ ds1.join(ds2).where("foo").equalTo("_1")
+ }
+
+ @Test(expected = classOf[IllegalArgumentException])
+ def testJoinKeyFields5(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(emptyTupleData)
+
+ // should not work, non-matching number of key indices
+ ds1.join(ds2).where("_1").equalTo("bar")
+ }
+
+ @Test(expected = classOf[UnsupportedOperationException])
+ def testJoinKeyFields6(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val ds1 = env.fromCollection(emptyTupleData)
+ val ds2 = env.fromCollection(customTypeData)
+
+ // should not work, field key on custom type
+ ds1.join(ds2).where("_2").equalTo("_1")
+ }
+
@Ignore
@Test
def testJoinKeyExpressions1(): Unit = {