You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2014/09/22 14:28:48 UTC
[06/60] Rewrite the Scala API as (somewhat) thin Layer on Java API
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
new file mode 100644
index 0000000..05f9917
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/coGroupDataSet.scala
@@ -0,0 +1,230 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala
+
+import org.apache.commons.lang3.Validate
+import org.apache.flink.api.common.InvalidProgramException
+import org.apache.flink.api.common.functions.{RichCoGroupFunction, CoGroupFunction}
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.java.operators._
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
+import org.apache.flink.api.java.{DataSet => JavaDataSet}
+import org.apache.flink.api.scala.typeutils.{ScalaTupleSerializer, ScalaTupleTypeInfo}
+import org.apache.flink.types.TypeInformation
+import org.apache.flink.util.Collector
+
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
+
+
+/**
+ * A specific [[DataSet]] that results from a `coGroup` operation. The result of a default coGroup is
+ * a tuple containing two arrays of values from the two sides of the coGroup. The result of the
+ * coGroup can be changed by specifying a custom coGroup function using the `apply` method or by
+ * providing a [[RichCoGroupFunction]].
+ *
+ * Example:
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val coGroupResult = left.coGroup(right).where(0, 2).isEqualTo(0, 1) {
+ * (left, right) => new MyCoGroupResult(left.min, right.max)
+ * }
+ * }}}
+ *
+ * Or, using key selector functions with tuple data types:
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val coGroupResult = left.coGroup(right).where({_._1}).isEqualTo({_._1) {
+ * (left, right) => new MyCoGroupResult(left.max, right.min)
+ * }
+ * }}}
+ *
+ * @tparam T Type of the left input of the coGroup.
+ * @tparam O Type of the right input of the coGroup.
+ */
+trait CoGroupDataSet[T, O] extends DataSet[(Array[T], Array[O])] {
+
+ /**
+ * Creates a new [[DataSet]] where the result for each pair of co-grouped element lists is the
+ * result of the given function. You can either return an element or choose to return [[None]],
+ * which allows implementing a filter directly in the coGroup function.
+ */
+ def apply[R: TypeInformation: ClassTag](
+ fun: (TraversableOnce[T], TraversableOnce[O]) => Option[R]): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] where the result for each pair of co-grouped element lists is the
+ * result of the given function. The function can output zero or more elements using the
+ * [[Collector]] which will form the result.
+ */
+ def apply[R: TypeInformation: ClassTag](
+ fun: (TraversableOnce[T], TraversableOnce[O], Collector[R]) => Unit): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] by passing each pair of co-grouped element lists to the given
+ * function. The function can output zero or more elements using the [[Collector]] which will form
+ * the result.
+ *
+ * A [[RichCoGroupFunction]] can be used to access the
+ * broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
+ */
+ def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R]
+}
+
+/**
+ * Private implementation for [[CoGroupDataSet]] to keep the implementation details, i.e. the
+ * parameters of the constructor, hidden.
+ */
+private[flink] class CoGroupDataSetImpl[T, O](
+ coGroupOperator: CoGroupOperator[T, O, (Array[T], Array[O])],
+ thisSet: JavaDataSet[T],
+ otherSet: JavaDataSet[O],
+ thisKeys: Keys[T],
+ otherKeys: Keys[O]) extends DataSet(coGroupOperator) with CoGroupDataSet[T, O] {
+
+ def apply[R: TypeInformation: ClassTag](
+ fun: (TraversableOnce[T], TraversableOnce[O]) => Option[R]): DataSet[R] = {
+ Validate.notNull(fun, "CoGroup function must not be null.")
+ val coGrouper = new CoGroupFunction[T, O, R] {
+ def coGroup(left: java.lang.Iterable[T], right: java.lang.Iterable[O], out: Collector[R]) = {
+ fun(left.iterator.asScala, right.iterator.asScala) map { out.collect(_) }
+ }
+ }
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, coGrouper, implicitly[TypeInformation[R]])
+ wrap(coGroupOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](
+ fun: (TraversableOnce[T], TraversableOnce[O], Collector[R]) => Unit): DataSet[R] = {
+ Validate.notNull(fun, "CoGroup function must not be null.")
+ val coGrouper = new CoGroupFunction[T, O, R] {
+ def coGroup(left: java.lang.Iterable[T], right: java.lang.Iterable[O], out: Collector[R]) = {
+ fun(left.iterator.asScala, right.iterator.asScala, out)
+ }
+ }
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, coGrouper, implicitly[TypeInformation[R]])
+ wrap(coGroupOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](joiner: CoGroupFunction[T, O, R]): DataSet[R] = {
+ Validate.notNull(joiner, "CoGroup function must not be null.")
+ val coGroupOperator = new CoGroupOperator[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, joiner, implicitly[TypeInformation[R]])
+ wrap(coGroupOperator)
+ }
+}
+
+/**
+ * An unfinished coGroup operation that results from [[DataSet.coGroup()]] The keys for the left and
+ * right side must be specified using first `where` and then `isEqualTo`. For example:
+ *
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val joinResult = left.coGroup(right).where(...).isEqualTo(...)
+ * }}}
+ * @tparam T The type of the left input of the coGroup.
+ * @tparam O The type of the right input of the coGroup.
+ */
+trait UnfinishedCoGroupOperation[T, O]
+ extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]]
+
+/**
+ * Private implementation for [[UnfinishedCoGroupOperation]] to keep the implementation details,
+ * i.e. the parameters of the constructor, hidden.
+ */
+private[flink] class UnfinishedCoGroupOperationImpl[T: ClassTag, O: ClassTag](
+ leftSet: JavaDataSet[T],
+ rightSet: JavaDataSet[O])
+ extends UnfinishedKeyPairOperation[T, O, CoGroupDataSet[T, O]](leftSet, rightSet)
+ with UnfinishedCoGroupOperation[T, O] {
+
+ private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]) = {
+ val coGrouper = new CoGroupFunction[T, O, (Array[T], Array[O])] {
+ def coGroup(
+ left: java.lang.Iterable[T],
+ right: java.lang.Iterable[O],
+ out: Collector[(Array[T], Array[O])]) = {
+ val leftResult = Array[Any](left.asScala.toSeq: _*).asInstanceOf[Array[T]]
+ val rightResult = Array[Any](right.asScala.toSeq: _*).asInstanceOf[Array[O]]
+
+ out.collect((leftResult, rightResult))
+ }
+ }
+
+ // We have to use this hack, for some reason classOf[Array[T]] does not work.
+ // Maybe because ObjectArrayTypeInfo does not accept the Scala Array as an array class.
+ val leftArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[T](0).getClass, leftSet.getType)
+ val rightArrayType = ObjectArrayTypeInfo.getInfoFor(new Array[O](0).getClass, rightSet.getType)
+
+ val returnType = new ScalaTupleTypeInfo[(Array[T], Array[O])](
+ classOf[(Array[T], Array[O])], Seq(leftArrayType, rightArrayType)) {
+
+ override def createSerializer: TypeSerializer[(Array[T], Array[O])] = {
+ val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
+ for (i <- 0 until getArity()) {
+ fieldSerializers(i) = types(i).createSerializer
+ }
+
+ new ScalaTupleSerializer[(Array[T], Array[O])](
+ classOf[(Array[T], Array[O])],
+ fieldSerializers) {
+ override def createInstance(fields: Array[AnyRef]) = {
+ (fields(0).asInstanceOf[Array[T]], fields(1).asInstanceOf[Array[O]])
+ }
+ }
+ }
+ }
+ val coGroupOperator = new CoGroupOperator[T, O, (Array[T], Array[O])](
+ leftSet, rightSet, leftKey, rightKey, coGrouper, returnType)
+
+ // sanity check solution set key mismatches
+ leftSet match {
+ case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
+ leftKey match {
+ case keyFields: Keys.FieldPositionKeys[_] =>
+ val positions: Array[Int] = keyFields.computeLogicalKeyPositions
+ solutionSet.checkJoinKeyFields(positions)
+ case _ =>
+ throw new InvalidProgramException("Currently, the solution set may only be joined " +
+ "with " +
+ "using tuple field positions.")
+ }
+ case _ =>
+ }
+ rightSet match {
+ case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
+ rightKey match {
+ case keyFields: Keys.FieldPositionKeys[_] =>
+ val positions: Array[Int] = keyFields.computeLogicalKeyPositions
+ solutionSet.checkJoinKeyFields(positions)
+ case _ =>
+ throw new InvalidProgramException("Currently, the solution set may only be joined " +
+ "with " +
+ "using tuple field positions.")
+ }
+ case _ =>
+ }
+
+ new CoGroupDataSetImpl(coGroupOperator, leftSet, rightSet, leftKey, rightKey)
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Counter.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Counter.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Counter.scala
index 5a53a85..d538188 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Counter.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Counter.scala
@@ -19,7 +19,7 @@
package org.apache.flink.api.scala.codegen
-class Counter {
+private[flink] class Counter {
private var value: Int = 0
def next: Int = {
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/DeserializeMethodGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/DeserializeMethodGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/DeserializeMethodGen.scala
deleted file mode 100644
index 1362d3f..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/DeserializeMethodGen.scala
+++ /dev/null
@@ -1,261 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.reflect.macros.Context
-
-trait DeserializeMethodGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with TreeGen[C] with SerializerGen[C] with Loggers[C] =>
- import c.universe._
-
- protected def mkDeserialize(desc: UDTDescriptor, listImpls: Map[Int, Type]): List[Tree] = {
-
-// val rootRecyclingOn = mkMethod("deserializeRecyclingOn", Flag.OVERRIDE | Flag.FINAL, List(("record", typeOf[org.apache.flink.pact.common.`type`.Record])), desc.tpe, {
- val rootRecyclingOn = mkMethod("deserializeRecyclingOn", Flag.FINAL, List(("record", typeOf[org.apache.flink.types.Record])), desc.tpe, {
- val env = GenEnvironment(listImpls, "flat" + desc.id, false, true, true, true)
- mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
- })
-
-// val rootRecyclingOff = mkMethod("deserializeRecyclingOff", Flag.OVERRIDE | Flag.FINAL, List(("record", typeOf[org.apache.flink.pact.common.`type`.Record])), desc.tpe, {
- val rootRecyclingOff = mkMethod("deserializeRecyclingOff", Flag.FINAL, List(("record", typeOf[org.apache.flink.types.Record])), desc.tpe, {
- val env = GenEnvironment(listImpls, "flat" + desc.id, false, false, true, true)
- mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
- })
-
- val aux = desc.getRecursiveRefs map { desc =>
- mkMethod("deserialize" + desc.id, Flag.PRIVATE | Flag.FINAL, List(("record", typeOf[org.apache.flink.types.Record])), desc.tpe, {
- val env = GenEnvironment(listImpls, "boxed" + desc.id, true, false, false, true)
- mkSingle(genDeserialize(desc, Ident("record"), env, Map()))
- })
- }
-
- rootRecyclingOn +: rootRecyclingOff +: aux.toList
- }
-
- private def genDeserialize(desc: UDTDescriptor, source: Tree, env: GenEnvironment, scope: Map[Int, (String, Type)]): Seq[Tree] = desc match {
-
- case PactValueDescriptor(id, tpe) => {
- val chk = env.mkChkIdx(id)
- val get = env.mkGetField(id, source, tpe)
-
- Seq(mkIf(chk, get, mkNull))
- }
-
- case PrimitiveDescriptor(id, _, default, _) => {
- val chk = env.mkChkIdx(id)
- val des = env.mkGetFieldInto(id, source)
- val get = env.mkGetValue(id)
-
- Seq(mkIf(chk, Block(List(des), get), default))
- }
-
- case BoxedPrimitiveDescriptor(id, tpe, _, _, box, _) => {
- val des = env.mkGetFieldInto(id, source)
- val chk = mkAnd(env.mkChkIdx(id), des)
- val get = box(env.mkGetValue(id))
-
- Seq(mkIf(chk, get, mkNull))
- }
-
- case list @ ListDescriptor(id, tpe, _, elem) => {
- val chk = mkAnd(env.mkChkIdx(id), env.mkNotIsNull(id, source))
-
- val (init, pactList) = env.reentrant match {
-
- // This is a bit conservative, but avoids runtime checks
- // and/or even more specialized deserialize() methods to
- // track whether it's safe to reuse the list variable.
- case true => {
- val listTpe = env.listImpls(id)
- val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
- (list, Ident("list" + id: TermName))
- }
-
- case false => {
- val clear = Apply(Select(env.mkSelectWrapper(id), "clear"), List())
- (clear, env.mkSelectWrapper(id))
- }
- }
-
- // val buildTpe = appliedType(builderClass.tpe, List(elem.tpe, tpe))
- // val build = mkVal(env.methodSym, "b" + id, 0, false, buildTpe) { _ => Apply(Select(cbf(), "apply"), List()) }
-// val userList = mkVal("b" + id, NoFlags, false, tpe, New(TypeTree(tpe), List(List())))
- val buildTpe = mkBuilderOf(elem.tpe, tpe)
- val cbf = c.inferImplicitValue(mkCanBuildFromOf(tpe, elem.tpe, tpe))
- val build = mkVal("b" + id, NoFlags, false, buildTpe, Apply(Select(cbf, "apply": TermName), List()))
- val des = env.mkGetFieldInto(id, source, pactList)
- val body = genDeserializeList(elem, pactList, Ident("b" + id: TermName), env.copy(allowRecycling = false, chkNull = true), scope)
- val stats = init +: des +: build +: body
-
- Seq(mkIf(chk, Block(stats.init.toList, stats.last), mkNull))
- }
-
- // we have a mutable UDT and the context allows recycling
- case CaseClassDescriptor(_, tpe, true, _, getters) if env.allowRecycling => {
-
- val fields = getters filterNot { _.isBaseField } map {
- case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
- mkSingle(genDeserialize(desc, source, env, scope))
- }), desc.tpe, "v" + desc.id)
- }
-
- val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
-
- val stats = fields map { _._2 }
-
- val setterStats = getters map {
- case FieldAccessor(_, setter, fTpe, _, fDesc) => {
- val (name, tpe) = newScope(fDesc.id)
- val castVal = maybeMkAsInstanceOf(Ident(name: TermName))(c.WeakTypeTag(tpe), c.WeakTypeTag(fTpe))
- env.mkCallSetMutableField(desc.id, setter, castVal)
- }
- }
-
- val ret = env.mkSelectMutableUdtInst(desc.id)
-
- (stats ++ setterStats) :+ ret
- }
-
- case CaseClassDescriptor(_, tpe, _, _, getters) => {
-
- val fields = getters filterNot { _.isBaseField } map {
- case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
- mkSingle(genDeserialize(desc, source, env, scope))
- }), desc.tpe, "v" + desc.id)
- }
-
- val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
-
- val stats = fields map { _._2 }
-
- val args = getters map {
- case FieldAccessor(_, _, fTpe, _, fDesc) => {
- val (name, tpe) = newScope(fDesc.id)
- maybeMkAsInstanceOf(Ident(name: TermName))(c.WeakTypeTag(tpe), c.WeakTypeTag(fTpe))
- }
- }
-
- val ret = New(TypeTree(tpe), List(args.toList))
-
- stats :+ ret
- }
-
- case BaseClassDescriptor(_, tpe, Seq(tagField, baseFields @ _*), subTypes) => {
-
- val fields = baseFields map {
- case FieldAccessor(_, _, _, _, desc) => (desc.id, mkVal("v" + desc.id, NoFlags, false, desc.tpe, {
- val special = desc match {
- case d @ PrimitiveDescriptor(id, _, _, _) if id == tagField.desc.id => d.copy(default = Literal(Constant(-1)))
- case _ => desc
- }
- mkSingle(genDeserialize(desc, source, env, scope))
- }), desc.tpe, "v" + desc.id)
- }
-
- val newScope = scope ++ (fields map { case (id, tree, tpe, name) => id -> (name, tpe) })
-
- val stats = fields map { _._2 }
-
- val cases = subTypes.zipWithIndex.toList map {
- case (dSubType, i) => {
- val code = mkSingle(genDeserialize(dSubType, source, env, newScope))
- val pat = Bind("tag": TermName, Literal(Constant(i)))
- CaseDef(pat, EmptyTree, code)
- }
- }
-
- val chk = env.mkChkIdx(tagField.desc.id)
- val des = env.mkGetFieldInto(tagField.desc.id, source)
- val get = env.mkGetValue(tagField.desc.id)
- Seq(mkIf(chk, Block(stats.toList :+ des, Match(get, cases)), mkNull))
- }
-
- case RecursiveDescriptor(id, tpe, refId) => {
- val chk = mkAnd(env.mkChkIdx(id), env.mkNotIsNull(id, source))
- val rec = mkVal("record" + id, NoFlags, false, typeOf[org.apache.flink.types.Record], New(TypeTree(typeOf[org.apache.flink.types.Record]), List(List())))
- val get = env.mkGetFieldInto(id, source, Ident("record" + id: TermName))
- val des = env.mkCallDeserialize(refId, Ident("record" + id: TermName))
-
- Seq(mkIf(chk, Block(List(rec, get), des), mkNull))
- }
-
- case _ => Seq(mkNull)
- }
-
- private def genDeserializeList(elem: UDTDescriptor, source: Tree, target: Tree, env: GenEnvironment, scope: Map[Int, (String, Type)]): Seq[Tree] = {
-
- val size = mkVal("size", NoFlags, false, definitions.IntTpe, Apply(Select(source, "size"), List()))
- val sizeHint = Apply(Select(target, "sizeHint"), List(Ident("size": TermName)))
- val i = mkVar("i", NoFlags, false, definitions.IntTpe, mkZero)
-
- val loop = mkWhile(Apply(Select(Ident("i": TermName), "$less"), List(Ident("size": TermName)))) {
-
- val item = mkVal("item", NoFlags, false, getListElemWrapperType(elem, env), Apply(Select(source, "get"), List(Ident("i": TermName))))
-
- val (stats, value) = elem match {
-
- case PrimitiveDescriptor(_, _, _, wrapper) => (Seq(), env.mkGetValue(Ident("item": TermName)))
-
- case BoxedPrimitiveDescriptor(_, _, _, wrapper, box, _) => (Seq(), box(env.mkGetValue(Ident("item": TermName))))
-
- case PactValueDescriptor(_, tpe) => (Seq(), Ident("item": TermName))
-
- case ListDescriptor(id, tpe, _, innerElem) => {
-
- // val buildTpe = appliedType(builderClass.tpe, List(innerElem.tpe, tpe))
- // val build = mkVal(env.methodSym, "b" + id, 0, false, buildTpe) { _ => Apply(Select(cbf(), "apply"), List()) }
- val buildTpe = mkBuilderOf(innerElem.tpe, tpe)
- val cbf = c.inferImplicitValue(mkCanBuildFromOf(tpe, innerElem.tpe, tpe))
- val build = mkVal("b" + id, NoFlags, false, buildTpe, Apply(Select(cbf, "apply": TermName), List()))
- val body = mkVal("v" + id, NoFlags, false, elem.tpe,
- mkSingle(genDeserializeList(innerElem, Ident("item": TermName), Ident("b" + id: TermName), env, scope)))
- (Seq(build, body), Ident("v" + id: TermName))
- }
-
- case RecursiveDescriptor(id, tpe, refId) => (Seq(), env.mkCallDeserialize(refId, Ident("item": TermName)))
-
- case _ => {
- val body = genDeserialize(elem, Ident("item": TermName), env.copy(idxPrefix = "boxed" + elem.id, chkIndex = false, chkNull = false), scope)
- val v = mkVal("v" + elem.id, NoFlags, false, elem.tpe, mkSingle(body))
- (Seq(v), Ident("v" + elem.id: TermName))
- }
- }
-
- val chk = env.mkChkNotNull(Ident("item": TermName), elem.tpe)
- val add = Apply(Select(target, "$plus$eq"), List(value))
- val addNull = Apply(Select(target, "$plus$eq"), List(mkNull))
- val inc = Assign(Ident("i": TermName), Apply(Select(Ident("i": TermName), "$plus"), List(mkOne)))
-
- Block(List(item, mkIf(chk, mkSingle(stats :+ add), addNull)), inc)
- }
-
- val get = Apply(Select(target, "result"), List())
-
- Seq(size, sizeHint, i, loop, get)
- }
-
-
- private def getListElemWrapperType(desc: UDTDescriptor, env: GenEnvironment): Type = desc match {
- case PrimitiveDescriptor(_, _, _, wrapper) => wrapper
- case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, _) => wrapper
- case PactValueDescriptor(_, tpe) => tpe
- case ListDescriptor(id, _, _, _) => env.listImpls(id)
- case _ => typeOf[org.apache.flink.types.Record]
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Logger.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Logger.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Logger.scala
deleted file mode 100644
index c1f98ae..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Logger.scala
+++ /dev/null
@@ -1,118 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.reflect.macros.Context
-
-trait Loggers[C <: Context] { this: MacroContextHolder[C] =>
- import c.universe._
-
- abstract sealed class LogLevel extends Ordered[LogLevel] {
- protected[Loggers] val toInt: Int
- override def compare(that: LogLevel) = this.toInt.compare(that.toInt)
- }
-
- object LogLevel {
- def unapply(name: String): Option[LogLevel] = name match {
- case "error" | "Error" => Some(Error)
- case "warn" | "Warn" => Some(Warn)
- case "debug" | "Debug" => Some(Debug)
- case "inspect" | "Inspect" => Some(Inspect)
- case _ => None
- }
- case object Error extends LogLevel { override val toInt = 1 }
- case object Warn extends LogLevel { override val toInt = 2 }
- case object Debug extends LogLevel { override val toInt = 3 }
- case object Inspect extends LogLevel { override val toInt = 4 }
- }
-
- object logger { var level: LogLevel = LogLevel.Warn }
- private val counter = new Counter
-
- trait Logger {
-
- abstract sealed class Severity {
- protected val toInt: Int
- protected def reportInner(msg: String, pos: Position)
-
- protected def formatMsg(msg: String) = msg
-
- def isEnabled = this.toInt <= logger.level.toInt
-
- def report(msg: String) = {
- if (isEnabled) {
- reportInner(formatMsg(msg), c.enclosingPosition)
- }
- }
- }
-
- case object Error extends Severity {
- override val toInt = LogLevel.Error.toInt
- override def reportInner(msg: String, pos: Position) = c.error(pos, msg)
- }
-
- case object Warn extends Severity {
- override val toInt = LogLevel.Warn.toInt
- override def reportInner(msg: String, pos: Position) = c.warning(pos, msg)
- }
-
- case object Debug extends Severity {
- override val toInt = LogLevel.Debug.toInt
- override def reportInner(msg: String, pos: Position) = c.info(pos, msg, true)
- }
-
- def getMsgAndStackLine(e: Throwable) = {
- val lines = e.getStackTrace.map(_.toString)
- val relevant = lines filter { _.contains("org.apache.flink") }
- val stackLine = relevant.headOption getOrElse e.getStackTrace.toString
- e.getMessage() + " @ " + stackLine
- }
-
- def posString(pos: Position): String = pos match {
- case NoPosition => "?:?"
- case _ => pos.line + ":" + pos.column
- }
-
- def safely(default: => Tree, inspect: Boolean)(onError: Throwable => String)(block: => Tree): Tree = {
- try {
- block
- } catch {
- case e:Throwable => {
- Error.report(onError(e));
- val ret = default
- ret
- }
- }
- }
-
- def verbosely[T](obs: T => String)(block: => T): T = {
- val ret = block
- Debug.report(obs(ret))
- ret
- }
-
- def maybeVerbosely[T](guard: T => Boolean)(obs: T => String)(block: => T): T = {
- val ret = block
- if (guard(ret)) Debug.report(obs(ret))
- ret
- }
- }
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/MacroContextHolder.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/MacroContextHolder.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/MacroContextHolder.scala
index effc27b..4ce4922 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/MacroContextHolder.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/MacroContextHolder.scala
@@ -16,22 +16,16 @@
* limitations under the License.
*/
-
package org.apache.flink.api.scala.codegen
import scala.reflect.macros.Context
-class MacroContextHolder[C <: Context](val c: C)
+private[flink] class MacroContextHolder[C <: Context](val c: C)
-object MacroContextHolder {
+private[flink] object MacroContextHolder {
def newMacroHelper[C <: Context](c: C) = new MacroContextHolder[c.type](c)
- with Loggers[c.type]
- with UDTDescriptors[c.type]
- with UDTAnalyzer[c.type]
+ with TypeDescriptors[c.type]
+ with TypeAnalyzer[c.type]
with TreeGen[c.type]
- with SerializerGen[c.type]
- with SerializeMethodGen[c.type]
- with DeserializeMethodGen[c.type]
- with UDTGen[c.type]
- with SelectionExtractor[c.type]
+ with TypeInformationGen[c.type]
}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SelectionExtractor.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SelectionExtractor.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SelectionExtractor.scala
deleted file mode 100644
index 37abd27..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SelectionExtractor.scala
+++ /dev/null
@@ -1,184 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.reflect.macros.Context
-import scala.Option.option2Iterable
-
-import org.apache.flink.api.scala.analysis.FieldSelector
-
-trait SelectionExtractor[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with Loggers[C] with TreeGen[C] =>
- import c.universe._
-
- def getSelector[T: c.WeakTypeTag, R: c.WeakTypeTag](fun: c.Expr[T => R]): Expr[List[Int]] =
- (new SelectionExtractorInstance with Logger).extract(fun)
-
- class SelectionExtractorInstance { this: Logger =>
- def extract[T: c.WeakTypeTag, R: c.WeakTypeTag](fun: c.Expr[T => R]): Expr[List[Int]] = {
- val result = getSelector(fun.tree) match {
- case Left(errs) => Left(errs.toList)
- case Right(sels) => getUDTDescriptor(weakTypeOf[T]) match {
- case UnsupportedDescriptor(id, tpe, errs) => Left(errs.toList)
- case desc: UDTDescriptor => chkSelectors(desc, sels) match {
- case Nil => Right(desc, sels map { _.tail })
- case errs => Left(errs)
- }
- }
- }
-
- result match {
- case Left(errs) => {
- errs foreach { err => c.error(c.enclosingPosition, s"Error analyzing FieldSelector ${show(fun.tree)}: " + err) }
- reify { throw new RuntimeException("Invalid key selector."); }
- }
- case Right((udtDesc, sels)) => {
- val descs: List[Option[UDTDescriptor]] = sels flatMap { sel: List[String] => udtDesc.select(sel) }
- descs foreach { desc => desc map { desc => if (!desc.canBeKey) c.error(c.enclosingPosition, "Type " + desc.tpe + " cannot be key.") } }
- val ids = descs map { _ map { _.id } }
- ids forall { _.isDefined } match {
- case false => {
- c.error(c.enclosingPosition, s"Could not determine ids of key fields: ${ids}")
- reify { throw new RuntimeException("Invalid key selector."); }
- }
- case true => {
- val generatedIds = ids map { _.get } map { id => Literal(Constant(id: Int)) }
- val generatedList = mkList(generatedIds)
- reify {
- val list = c.Expr[List[Int]](generatedList).splice
- list
- }
- }
- }
- }
- }
-
- }
-
- private def getSelector(tree: Tree): Either[List[String], List[List[String]]] = tree match {
-
- case Function(List(p), body) => getSelector(body, Map(p.symbol -> Nil)) match {
- case err @ Left(_) => err
- case Right(sels) => Right(sels map { sel => p.name.toString +: sel })
- }
-
- case _ => Left(List("expected lambda expression literal but found " + show(tree)))
- }
-
- private def getSelector(tree: Tree, roots: Map[Symbol, List[String]]): Either[List[String], List[List[String]]] = tree match {
-
- case SimpleMatch(body, bindings) => getSelector(body, roots ++ bindings)
-
- case Match(_, List(CaseDef(pat, EmptyTree, _))) => Left(List("case pattern is too complex"))
- case Match(_, List(CaseDef(_, guard, _))) => Left(List("case pattern is guarded"))
- case Match(_, _ :: _ :: _) => Left(List("match contains more than one case"))
-
- case TupleCtor(args) => {
-
- val (errs, sels) = args.map(arg => getSelector(arg, roots)).partition(_.isLeft)
-
- errs match {
- case Nil => Right(sels.map(_.right.get).flatten)
- case _ => Left(errs.map(_.left.get).flatten)
- }
- }
-
- case Apply(tpt@TypeApply(_, _), _) => Left(List("constructor call on non-tuple type " + tpt.tpe))
-
- case Ident(name) => roots.get(tree.symbol) match {
- case Some(sel) => Right(List(sel))
- case None => Left(List("unexpected identifier " + name))
- }
-
- case Select(src, member) => getSelector(src, roots) match {
- case err @ Left(_) => err
- case Right(List(sel)) => Right(List(sel :+ member.toString))
- case _ => Left(List("unsupported selection"))
- }
-
- case _ => Left(List("unsupported construct of kind " + showRaw(tree)))
-
- }
-
- private object SimpleMatch {
-
- def unapply(tree: Tree): Option[(Tree, Map[Symbol, List[String]])] = tree match {
-
- case Match(arg, List(cd @ CaseDef(CasePattern(bindings), EmptyTree, body))) => Some((body, bindings))
- case _ => None
- }
-
- private object CasePattern {
-
- def unapply(tree: Tree): Option[Map[Symbol, List[String]]] = tree match {
-
- case Apply(MethodTypeTree(params), binds) => {
- val exprs = params.zip(binds) map {
- case (p, CasePattern(inners)) => Some(inners map { case (sym, path) => (sym, p.name.toString +: path) })
- case _ => None
- }
- if (exprs.forall(_.isDefined)) {
- Some(exprs.flatten.flatten.toMap)
- }
- else
- None
- }
-
- case Ident(_) | Bind(_, Ident(_)) => Some(Map(tree.symbol -> Nil))
- case Bind(_, CasePattern(inners)) => Some(inners + (tree.symbol -> Nil))
- case _ => None
- }
- }
-
- private object MethodTypeTree {
- def unapply(tree: Tree): Option[List[Symbol]] = tree match {
- case _: TypeTree => tree.tpe match {
- case MethodType(params, _) => Some(params)
- case _ => None
- }
- case _ => None
- }
- }
- }
-
- private object TupleCtor {
-
- def unapply(tree: Tree): Option[List[Tree]] = tree match {
- case Apply(tpt@TypeApply(_, _), args) if isTupleTpe(tpt.tpe) => Some(args)
- case _ => None
- }
-
- private def isTupleTpe(tpe: Type): Boolean = definitions.TupleClass.contains(tpe.typeSymbol)
- }
- }
-
- protected def chkSelectors(udt: UDTDescriptor, sels: List[List[String]]): List[String] = {
- sels flatMap { sel => chkSelector(udt, sel.head, sel.tail) }
- }
-
- protected def chkSelector(udt: UDTDescriptor, path: String, sel: List[String]): Option[String] = (udt, sel) match {
- case (_, Nil) if udt.isPrimitiveProduct => None
- case (_, Nil) => Some(path + ": " + udt.tpe + " is not a primitive or product of primitives")
- case (_, field :: rest) => udt.select(field) match {
- case None => Some("member " + field + " is not a case accessor of " + path + ": " + udt.tpe)
- case Some(udt) => chkSelector(udt, path + "." + field, rest)
- }
- }
-
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializeMethodGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializeMethodGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializeMethodGen.scala
deleted file mode 100644
index 6424ea1..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializeMethodGen.scala
+++ /dev/null
@@ -1,226 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.reflect.macros.Context
-
-trait SerializeMethodGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with TreeGen[C] with SerializerGen[C] with Loggers[C] =>
- import c.universe._
-
- protected def mkSerialize(desc: UDTDescriptor, listImpls: Map[Int, Type]): List[Tree] = {
-
-// val root = mkMethod("serialize", Flag.OVERRIDE | Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[org.apache.flink.pact.common.`type`.Record])), definitions.UnitTpe, {
- val root = mkMethod("serialize", Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[org.apache.flink.types.Record])), definitions.UnitTpe, {
- val env = GenEnvironment(listImpls, "flat" + desc.id, false, true, true, true)
- val stats = genSerialize(desc, Ident("item": TermName), Ident("record": TermName), env)
- Block(stats.toList, mkUnit)
- })
-
- val aux = desc.getRecursiveRefs map { desc =>
- mkMethod("serialize" + desc.id, Flag.PRIVATE | Flag.FINAL, List(("item", desc.tpe), ("record", typeOf[org.apache.flink.types.Record])), definitions.UnitTpe, {
- val env = GenEnvironment(listImpls, "boxed" + desc.id, true, false, false, true)
- val stats = genSerialize(desc, Ident("item": TermName), Ident("record": TermName), env)
- Block(stats.toList, mkUnit)
- })
- }
-
- root +: aux.toList
- }
-
- private def genSerialize(desc: UDTDescriptor, source: Tree, target: Tree, env: GenEnvironment): Seq[Tree] = desc match {
-
- case PactValueDescriptor(id, _) => {
- val chk = env.mkChkIdx(id)
- val set = env.mkSetField(id, target, source)
-
- Seq(mkIf(chk, set))
- }
-
- case PrimitiveDescriptor(id, _, _, _) => {
- val chk = env.mkChkIdx(id)
- val ser = env.mkSetValue(id, source)
- val set = env.mkSetField(id, target)
-
- Seq(mkIf(chk, Block(List(ser), set)))
- }
-
- case BoxedPrimitiveDescriptor(id, tpe, _, _, _, unbox) => {
- val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
- val ser = env.mkSetValue(id, unbox(source))
- val set = env.mkSetField(id, target)
-
- Seq(mkIf(chk, Block(List(ser), set)))
- }
-
- case desc @ ListDescriptor(id, tpe, iter, elem) => {
- val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
-
- val upd = desc.getInnermostElem match {
- case _: RecursiveDescriptor => Some(Apply(Select(target, "updateBinaryRepresenation"), List()))
- case _ => None
- }
-
- val (init, list) = env.reentrant match {
-
- // This is a bit conservative, but avoids runtime checks
- // and/or even more specialized serialize() methods to
- // track whether it's safe to reuse the list variable.
- case true => {
- val listTpe = env.listImpls(id)
- val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
- (list, Ident("list" + id: TermName))
- }
-
- case false => {
- val clear = Apply(Select(env.mkSelectWrapper(id), "clear"), List())
- (clear, env.mkSelectWrapper(id))
- }
- }
-
- val body = genSerializeList(elem, iter(source), list, env.copy(chkNull = true))
- val set = env.mkSetField(id, target, list)
- val stats = (init +: body) :+ set
-
- val updStats = upd ++ stats
- Seq(mkIf(chk, Block(updStats.init.toList, updStats.last)))
- }
-
- case CaseClassDescriptor(_, tpe, _, _, getters) => {
- val chk = env.mkChkNotNull(source, tpe)
- val stats = getters filterNot { _.isBaseField } flatMap { case FieldAccessor(sym, _, _, _, desc) => genSerialize(desc, Select(source, sym), target, env.copy(chkNull = true)) }
-
- stats match {
- case Nil => Seq()
- case _ => Seq(mkIf(chk, mkSingle(stats)))
- }
- }
-
- case BaseClassDescriptor(id, tpe, Seq(tagField, baseFields @ _*), subTypes) => {
- val chk = env.mkChkNotNull(source, tpe)
- val fields = baseFields flatMap { (f => genSerialize(f.desc, Select(source, f.getter), target, env.copy(chkNull = true))) }
- val cases = subTypes.zipWithIndex.toList map {
- case (dSubType, i) => {
-
- val pat = Bind("inst": TermName, Typed(Ident("_"), TypeTree(dSubType.tpe)))
- val cast = None
- val inst = Ident("inst": TermName)
- // val (pat, cast, inst) = {
- // val erasedTpe = mkErasedType(env.methodSym, dSubType.tpe)
- //
- // if (erasedTpe =:= dSubType.tpe) {
- //
- // val pat = Bind(newTermName("inst"), Typed(Ident("_"), TypeTree(dSubType.tpe)))
- // (pat, None, Ident(newTermName("inst")))
- //
- // } else {
- //
- // // This avoids type erasure warnings in the generated pattern match
- // val pat = Bind(newTermName("erasedInst"), Typed(Ident("_"), TypeTree(erasedTpe)))
- // val cast = mkVal("inst", NoFlags, false, dSubType.tpe, mkAsInstanceOf(Ident("erasedInst"))(c.WeakTypeTag(dSubType.tpe)))
- // val inst = Ident(cast.symbol)
- // (pat, Some(cast), inst)
- // }
- // }
-
- val tag = genSerialize(tagField.desc, c.literal(i).tree, target, env.copy(chkNull = false))
- val code = genSerialize(dSubType, inst, target, env.copy(chkNull = false))
- val body = (cast.toSeq ++ tag ++ code) :+ mkUnit
-
- CaseDef(pat, EmptyTree, Block(body.init.toList, body.last))
- }
- }
-
- Seq(mkIf(chk, Block(fields.toList,Match(source, cases))))
- }
-
- case RecursiveDescriptor(id, tpe, refId) => {
- // Important: recursive types introduce re-entrant calls to serialize()
-
- val chk = mkAnd(env.mkChkIdx(id), env.mkChkNotNull(source, tpe))
-
- // Persist the outer record prior to recursing, since the call
- // is going to reuse all the PactPrimitive wrappers that were
- // needed *before* the recursion.
- val updTgt = Apply(Select(target, "updateBinaryRepresenation"), List())
-
- val rec = mkVal("record" + id, NoFlags, false, typeOf[org.apache.flink.types.Record], New(TypeTree(typeOf[org.apache.flink.types.Record]), List(List())))
- val ser = env.mkCallSerialize(refId, source, Ident("record" + id: TermName))
-
- // Persist the new inner record after recursing, since the
- // current call is going to reuse all the PactPrimitive
- // wrappers that are needed *after* the recursion.
- val updRec = Apply(Select(Ident("record" + id: TermName), "updateBinaryRepresenation"), List())
-
- val set = env.mkSetField(id, target, Ident("record" + id: TermName))
-
- Seq(mkIf(chk, Block(List(updTgt, rec, ser, updRec), set)))
- }
- }
-
- private def genSerializeList(elem: UDTDescriptor, iter: Tree, target: Tree, env: GenEnvironment): Seq[Tree] = {
-
- val it = mkVal("it", NoFlags, false, mkIteratorOf(elem.tpe), iter)
-
- val loop = mkWhile(Select(Ident("it": TermName), "hasNext")) {
-
- val item = mkVal("item", NoFlags, false, elem.tpe, Select(Ident("it": TermName), "next"))
-
- val (stats, value) = elem match {
-
- case PrimitiveDescriptor(_, _, _, wrapper) => (Seq(), New(TypeTree(wrapper), List(List(Ident("item": TermName)))))
-
- case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, unbox) => (Seq(), New(TypeTree(wrapper), List(List(unbox(Ident("item": TermName))))))
-
- case PactValueDescriptor(_, tpe) => (Seq(), Ident("item": TermName))
-
- case ListDescriptor(id, _, iter, innerElem) => {
- val listTpe = env.listImpls(id)
- val list = mkVal("list" + id, NoFlags, false, listTpe, New(TypeTree(listTpe), List(List())))
- val body = genSerializeList(innerElem, iter(Ident("item": TermName)), Ident("list" + id: TermName), env)
- (list +: body, Ident("list" + id: TermName))
- }
-
- case RecursiveDescriptor(id, tpe, refId) => {
- val rec = mkVal("record" + id, NoFlags, false, typeOf[org.apache.flink.types.Record], New(TypeTree(typeOf[org.apache.flink.types.Record]), List(List())))
- val ser = env.mkCallSerialize(refId, Ident("item": TermName), Ident("record" + id: TermName))
- val updRec = Apply(Select(Ident("record" + id: TermName), "updateBinaryRepresenation"), List())
-
- (Seq(rec, ser, updRec), Ident("record" + id: TermName))
- }
-
- case _ => {
- val rec = mkVal("record", NoFlags, false, typeOf[org.apache.flink.types.Record], New(TypeTree(typeOf[org.apache.flink.types.Record]), List(List())))
- val ser = genSerialize(elem, Ident("item": TermName), Ident("record": TermName), env.copy(idxPrefix = "boxed" + elem.id, chkIndex = false, chkNull = false))
- val upd = Apply(Select(Ident("record": TermName), "updateBinaryRepresenation"), List())
- ((rec +: ser) :+ upd, Ident("record": TermName))
- }
- }
-
- val chk = env.mkChkNotNull(Ident("item": TermName), elem.tpe)
- val add = Apply(Select(target, "add"), List(value))
- val addNull = Apply(Select(target, "add"), List(mkNull))
-
- Block(List(item), mkIf(chk, mkSingle(stats :+ add), addNull))
- }
-
- Seq(it, loop)
- }
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializerGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializerGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializerGen.scala
deleted file mode 100644
index 9e3a31d..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/SerializerGen.scala
+++ /dev/null
@@ -1,328 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.language.postfixOps
-import scala.reflect.macros.Context
-
-import org.apache.flink.api.scala.analysis.UDTSerializer
-
-import org.apache.flink.types.Record
-
-
-trait SerializerGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with TreeGen[C] with SerializeMethodGen[C] with DeserializeMethodGen[C] with Loggers[C] =>
- import c.universe._
-
- def mkUdtSerializerClass[T: c.WeakTypeTag](name: String = "", creatorName: String = "createSerializer"): (ClassDef, Tree) = {
- val desc = getUDTDescriptor(weakTypeOf[T])
-
- desc match {
- case UnsupportedDescriptor(_, _, errs) => {
- val errorString = errs.mkString("\n")
- c.abort(c.enclosingPosition, s"Error analyzing UDT ${weakTypeOf[T]}: $errorString")
- }
- case _ =>
- }
-
- val serName = newTypeName("UDTSerializerImpl" + name)
- val ser = mkClass(serName, Flag.FINAL, List(weakTypeOf[UDTSerializer[T]]), {
-
- val (listImpls, listImplTypes) = mkListImplClasses(desc)
-
- val indexMapIter = Select(Ident("indexMap": TermName), "iterator": TermName)
- val (fields1, inits1) = mkIndexes(desc.id, getIndexFields(desc).toList, false, indexMapIter)
- val (fields2, inits2) = mkBoxedIndexes(desc)
-
- val fields = fields1 ++ fields2
- val init = inits1 ++ inits2 match {
- case Nil => Nil
- case inits => List(mkMethod("init", Flag.OVERRIDE | Flag.FINAL, List(), definitions.UnitTpe, Block(inits, mkUnit)))
- }
-
- val (wrapperFields, wrappers) = mkPactWrappers(desc, listImplTypes)
-
- val mutableUdts = desc.flatten.toList flatMap {
- case cc @ CaseClassDescriptor(_, _, true, _, _) => Some(cc)
- case _ => None
- } distinct
-
- val mutableUdtInsts = mutableUdts map { u => mkMutableUdtInst(u) }
-
- val helpers = listImpls ++ fields ++ wrapperFields ++ mutableUdtInsts ++ init
- val ctor = mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(("indexMap", typeOf[Array[Int]])), NoType, {
- Block(List(mkSuperCall(List(Ident(newTermName("indexMap"))))), mkUnit)
- })
-
-// val methods = List(ctor)// ++ List(mkGetFieldIndex(desc)) //++ mkSerialize(desc, listImplTypes) ++ mkDeserialize(desc, listImplTypes)
- val methods = List(ctor) ++ mkSerialize(desc, listImplTypes) ++ mkDeserialize(desc, listImplTypes)
-
- helpers ++ methods
- })
-
-
- val (_, serTpe) = typeCheck(ser)
-
- val createSerializer = mkMethod(creatorName, Flag.OVERRIDE, List(("indexMap", typeOf[Array[Int]])), NoType, {
- Block(List(), mkCtorCall(serTpe, List(Ident(newTermName("indexMap")))))
- })
- (ser, createSerializer)
- }
-
- private def mkListImplClass[T <: org.apache.flink.types.Value: c.WeakTypeTag]: (Tree, Type) = {
- val listImplName = c.fresh[TypeName]("PactListImpl")
- val tpe = weakTypeOf[org.apache.flink.types.ListValue[T]]
-
- val listDef = mkClass(listImplName, Flag.FINAL, List(tpe), {
- List(mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(), NoType, Block(List(mkSuperCall()), mkUnit)))
- })
-
- typeCheck(listDef)
- }
-
- def mkListImplClasses(desc: UDTDescriptor): (List[Tree], Map[Int, Type]) = {
- desc match {
- case ListDescriptor(id, _, _, elem: ListDescriptor) => {
- val (defs, tpes) = mkListImplClasses(elem)
- val (listDef, listTpe) = mkListImplClass(c.WeakTypeTag(tpes(elem.id)))
- (defs :+ listDef, tpes + (id -> listTpe))
- }
- case ListDescriptor(id, _, _, elem: PrimitiveDescriptor) => {
- val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(elem.wrapper))
- (List(classDef), Map(id -> tpe))
- }
- case ListDescriptor(id, _, _, elem: BoxedPrimitiveDescriptor) => {
- val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(elem.wrapper))
- (List(classDef), Map(id -> tpe))
- }
- case ListDescriptor(id, _, _, elem: PactValueDescriptor) => {
- val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(elem.tpe))
- (List(classDef), Map(id -> tpe))
- }
- case ListDescriptor(id, _, _, elem) => {
- val (classDefs, tpes) = mkListImplClasses(elem)
- val (classDef, tpe) = mkListImplClass(c.WeakTypeTag(typeOf[org.apache.flink.types.Record]))
- (classDefs :+ classDef, tpes + (id -> tpe))
- }
- case BaseClassDescriptor(_, _, getters, subTypes) => {
- val (defs, tpes) = getters.foldLeft((List[Tree](), Map[Int, Type]())) { (result, f) =>
- val (defs, tpes) = result
- val (newDefs, newTpes) = mkListImplClasses(f.desc)
- (defs ++ newDefs, tpes ++ newTpes)
- }
- val (subDefs, subTpes) = subTypes.foldLeft((List[Tree](), Map[Int, Type]())) { (result, s) =>
- val (defs, tpes) = result
- val (innerDefs, innerTpes) = mkListImplClasses(s)
- (defs ++ innerDefs, tpes ++ innerTpes)
- }
- (defs ++ subDefs, tpes ++ subTpes)
- }
- case CaseClassDescriptor(_, _, _, _, getters) => {
- getters.foldLeft((List[Tree](), Map[Int, Type]())) { (result, f) =>
- val (defs, tpes) = result
- val (newDefs, newTpes) = mkListImplClasses(f.desc)
- (defs ++ newDefs, tpes ++ newTpes)
- }
- }
- case _ => {
- (List[Tree](), Map[Int, Type]())
- }
- }
- }
-
- private def mkIndexes(descId: Int, descFields: List[UDTDescriptor], boxed: Boolean, indexMapIter: Tree): (List[Tree], List[Tree]) = {
-
- val prefix = (if (boxed) "boxed" else "flat") + descId
- val iterName = prefix + "Iter"
- val iter = mkVal(iterName, Flag.PRIVATE, true, mkIteratorOf(definitions.IntTpe), indexMapIter)
-
- val fieldsAndInits = descFields map {
- case d => {
- val next = Apply(Select(Ident(iterName: TermName), "next": TermName), Nil)
- val idxField = mkVal(prefix + "Idx" + d.id, Flag.PRIVATE, false, definitions.IntTpe, next)
-
- (List(idxField), Nil)
- }
- }
-
- val (fields, inits) = fieldsAndInits.unzip
- (iter +: fields.flatten, inits.flatten)
- }
-
- protected def getIndexFields(desc: UDTDescriptor): Seq[UDTDescriptor] = desc match {
- // Flatten product types
- case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getIndexFields(f.desc) }
- // TODO: Rather than laying out subclass fields sequentially, just reserve enough fields for the largest subclass.
- // This is tricky because subclasses can contain opaque descriptors, so we don't know how many fields we need until runtime.
- case BaseClassDescriptor(id, _, getters, subTypes) => (getters flatMap { f => getIndexFields(f.desc) }) ++ (subTypes flatMap getIndexFields)
- case _ => Seq(desc)
- }
-
- private def mkBoxedIndexes(desc: UDTDescriptor): (List[Tree], List[Tree]) = {
-
- def getBoxedDescriptors(d: UDTDescriptor): Seq[UDTDescriptor] = d match {
- case ListDescriptor(_, _, _, elem: BaseClassDescriptor) => elem +: getBoxedDescriptors(elem)
- case ListDescriptor(_, _, _, elem: CaseClassDescriptor) => elem +: getBoxedDescriptors(elem)
- case ListDescriptor(_, _, _, elem) => getBoxedDescriptors(elem)
- case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getBoxedDescriptors(f.desc) }
- case BaseClassDescriptor(id, _, getters, subTypes) => (getters flatMap { f => getBoxedDescriptors(f.desc) }) ++ (subTypes flatMap getBoxedDescriptors)
- case RecursiveDescriptor(_, _, refId) => desc.findById(refId).map(_.mkRoot).toSeq
- case _ => Seq()
- }
-
- val fieldsAndInits = getBoxedDescriptors(desc).distinct.toList flatMap { d =>
- // the way this is done here is a relic from the support of OpaqueDescriptors
- // there it was not just mkOne but actual differing numbers of fields
- // retrieved from the opaque UDT descriptors
- getIndexFields(d).toList match {
- case Nil => None
- case fields => {
- val widths = fields map {
- case _ => mkOne
- }
- val sum = widths.reduce { (s, i) => Apply(Select(s, "$plus": TermName), List(i)) }
- val range = Apply(Select(Ident("scala": TermName), "Range": TermName), List(mkZero, sum))
- Some(mkIndexes(d.id, fields, true, Select(range, "iterator": TermName)))
- }
- }
- }
-
- val (fields, inits) = fieldsAndInits.unzip
- (fields.flatten, inits.flatten)
- }
-
- private def mkPactWrappers(desc: UDTDescriptor, listImpls: Map[Int, Type]): (List[Tree], List[(Int, Type)]) = {
-
- def getFieldTypes(desc: UDTDescriptor): Seq[(Int, Type)] = desc match {
- case PrimitiveDescriptor(id, _, _, wrapper) => Seq((id, wrapper))
- case BoxedPrimitiveDescriptor(id, _, _, wrapper, _, _) => Seq((id, wrapper))
- case d @ ListDescriptor(id, _, _, elem) => {
- val listField = (id, listImpls(id))
- val elemFields = d.getInnermostElem match {
- case elem: CaseClassDescriptor => getFieldTypes(elem)
- case elem: BaseClassDescriptor => getFieldTypes(elem)
- case _ => Seq()
- }
- listField +: elemFields
- }
- case CaseClassDescriptor(_, _, _, _, getters) => getters filterNot { _.isBaseField } flatMap { f => getFieldTypes(f.desc) }
- case BaseClassDescriptor(_, _, getters, subTypes) => (getters flatMap { f => getFieldTypes(f.desc) }) ++ (subTypes flatMap getFieldTypes)
- case _ => Seq()
- }
-
- getFieldTypes(desc) toList match {
- case Nil => (Nil, Nil)
- case types =>
- val fields = types map { case (id, tpe) => mkVar("w" + id, Flag.PRIVATE, true, tpe, mkCtorCall(tpe, List())) }
- (fields, types)
- }
- }
-
- private def mkMutableUdtInst(desc: CaseClassDescriptor): Tree = {
- val args = desc.getters map {
- case FieldAccessor(_, _, fTpe, _, _) => {
- mkDefault(fTpe)
- }
- }
-
- val ctor = mkCtorCall(desc.tpe, args.toList)
- mkVar("mutableUdtInst" + desc.id, Flag.PRIVATE, true, desc.tpe, ctor)
- }
-
- private def mkGetFieldIndex(desc: UDTDescriptor): Tree = {
-
- val env = GenEnvironment(Map(), "flat" + desc.id, false, true, true, true)
-
- def mkCases(desc: UDTDescriptor, path: Seq[String]): Seq[(Seq[String], Tree)] = desc match {
-
- case PrimitiveDescriptor(id, _, _, _) => Seq((path, mkList(List(env.mkSelectIdx(id)))))
- case BoxedPrimitiveDescriptor(id, _, _, _, _, _) => Seq((path, mkList(List(env.mkSelectIdx(id)))))
-
- case BaseClassDescriptor(_, _, Seq(tag, getters @ _*), _) => {
- val tagCase = Seq((path :+ "getClass", mkList(List(env.mkSelectIdx(tag.desc.id)))))
- val fieldCases = getters flatMap { f => mkCases(f.desc, path :+ f.getter.name.toString) }
- tagCase ++ fieldCases
- }
-
- case CaseClassDescriptor(_, _, _, _, getters) => {
- def fieldCases = getters flatMap { f => mkCases(f.desc, path :+ f.getter.name.toString) }
- val allFieldsCase = desc match {
- case _ if desc.isPrimitiveProduct => {
- val nonRest = fieldCases filter { case (p, _) => p.size == path.size + 1 } map { _._2 }
- Seq((path, nonRest.reduceLeft((z, f) => Apply(Select(z, "$plus$plus"), List(f)))))
- }
- case _ => Seq()
- }
- allFieldsCase ++ fieldCases
- }
- case _ => Seq()
- }
-
- def mkPat(path: Seq[String]): Tree = {
-
- val seqUnapply = TypeApply(mkSelect("scala", "collection", "Seq", "unapplySeq"), List(TypeTree(typeOf[String])))
- val fun = Apply(seqUnapply, List(Ident(nme.WILDCARD)))
-
- val args = path map {
- case null => Bind(newTermName("rest"), Star(Ident(newTermName("_"))))
- case s => Literal(Constant(s))
- }
-
- UnApply(fun, args.toList)
- }
-
- mkMethod("getFieldIndex", Flag.FINAL, List(("selection", mkSeqOf(typeOf[String]))), mkListOf(typeOf[Int]), {
-// mkMethod("getFieldIndex", Flag.OVERRIDE | Flag.FINAL, List(("selection", mkSeqOf(typeOf[String]))), mkListOf(typeOf[Int]), {
- val cases = mkCases(desc, Seq()) map { case (path, idxs) => CaseDef(mkPat(path), EmptyTree, idxs) }
-// val errCase = CaseDef(Ident("_"), EmptyTree, Apply(Ident(newTermName("println")), List(Ident("selection"))))
- val errCase = CaseDef(Ident("_"), EmptyTree, (reify {throw new RuntimeException("Invalid selection")}).tree )
-// Match(Ident("selection"), (cases :+ errCase).toList)
- Match(Ident("selection"), List(errCase))
- })
- }
-
- protected case class GenEnvironment(listImpls: Map[Int, Type], idxPrefix: String, reentrant: Boolean, allowRecycling: Boolean, chkIndex: Boolean, chkNull: Boolean) {
- private def isNullable(tpe: Type) = typeOf[Null] <:< tpe && tpe <:< typeOf[AnyRef]
-
- def mkChkNotNull(source: Tree, tpe: Type): Tree = if (isNullable(tpe) && chkNull) Apply(Select(source, "$bang$eq": TermName), List(mkNull)) else EmptyTree
- def mkChkIdx(fieldId: Int): Tree = if (chkIndex) Apply(Select(mkSelectIdx(fieldId), "$greater$eq": TermName), List(mkZero)) else EmptyTree
-
- def mkSelectIdx(fieldId: Int): Tree = Ident(newTermName(idxPrefix + "Idx" + fieldId))
- def mkSelectSerializer(fieldId: Int): Tree = Ident(newTermName(idxPrefix + "Ser" + fieldId))
- def mkSelectWrapper(fieldId: Int): Tree = Ident(newTermName("w" + fieldId))
- def mkSelectMutableUdtInst(udtId: Int): Tree = Ident(newTermName("mutableUdtInst" + udtId))
-
- def mkCallSetMutableField(udtId: Int, setter: Symbol, source: Tree): Tree = Apply(Select(mkSelectMutableUdtInst(udtId), setter), List(source))
- def mkCallSerialize(refId: Int, source: Tree, target: Tree): Tree = Apply(Ident(newTermName("serialize" + refId)), List(source, target))
- def mkCallDeserialize(refId: Int, source: Tree): Tree = Apply(Ident(newTermName("deserialize" + refId)), List(source))
-
- def mkSetField(fieldId: Int, record: Tree): Tree = mkSetField(fieldId, record, mkSelectWrapper(fieldId))
- def mkSetField(fieldId: Int, record: Tree, wrapper: Tree): Tree = Apply(Select(record, "setField": TermName), List(mkSelectIdx(fieldId), wrapper))
- def mkGetField(fieldId: Int, record: Tree, tpe: Type): Tree = Apply(Select(record, "getField": TermName), List(mkSelectIdx(fieldId), Literal(Constant(tpe))))
- def mkGetFieldInto(fieldId: Int, record: Tree): Tree = mkGetFieldInto(fieldId, record, mkSelectWrapper(fieldId))
- def mkGetFieldInto(fieldId: Int, record: Tree, wrapper: Tree): Tree = Apply(Select(record, "getFieldInto": TermName), List(mkSelectIdx(fieldId), wrapper))
-
- def mkSetValue(fieldId: Int, value: Tree): Tree = mkSetValue(mkSelectWrapper(fieldId), value)
- def mkSetValue(wrapper: Tree, value: Tree): Tree = Apply(Select(wrapper, "setValue": TermName), List(value))
- def mkGetValue(fieldId: Int): Tree = mkGetValue(mkSelectWrapper(fieldId))
- def mkGetValue(wrapper: Tree): Tree = Apply(Select(wrapper, "getValue": TermName), List())
-
- def mkNotIsNull(fieldId: Int, record: Tree): Tree = Select(Apply(Select(record, "isNull": TermName), List(mkSelectIdx(fieldId))), "unary_$bang": TermName)
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TreeGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TreeGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TreeGen.scala
index 29bf6ed..89454d5 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TreeGen.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TreeGen.scala
@@ -23,11 +23,10 @@ import scala.language.implicitConversions
import scala.reflect.macros.Context
-trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with Loggers[C] =>
+private[flink] trait TreeGen[C <: Context] { this: MacroContextHolder[C] with TypeDescriptors[C] =>
import c.universe._
def mkDefault(tpe: Type): Tree = {
- import definitions._
tpe match {
case definitions.BooleanTpe => Literal(Constant(false))
case definitions.ByteTpe => Literal(Constant(0: Byte))
@@ -47,7 +46,8 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
def mkZero = reify( 0 ).tree
def mkOne = reify( 1 ).tree
- def mkAsInstanceOf[T: c.WeakTypeTag](source: Tree): Tree = reify(c.Expr(source).splice.asInstanceOf[T]).tree
+ def mkAsInstanceOf[T: c.WeakTypeTag](source: Tree): Tree =
+ reify(c.Expr(source).splice.asInstanceOf[T]).tree
def maybeMkAsInstanceOf[S: c.WeakTypeTag, T: c.WeakTypeTag](source: Tree): Tree = {
if (weakTypeOf[S] <:< weakTypeOf[T])
@@ -57,14 +57,25 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
}
// def mkIdent(target: Symbol): Tree = Ident(target) setType target.tpe
- def mkSelect(rootModule: String, path: String*): Tree = mkSelect(Ident(newTermName(rootModule)), path: _*)
- def mkSelect(source: Tree, path: String*): Tree = path.foldLeft(source) { (ret, item) => Select(ret, newTermName(item)) }
- def mkSelectSyms(source: Tree, path: Symbol*): Tree = path.foldLeft(source) { (ret, item) => Select(ret, item) }
+ def mkSelect(rootModule: String, path: String*): Tree =
+ mkSelect(Ident(newTermName(rootModule)), path: _*)
+
+ def mkSelect(source: Tree, path: String*): Tree =
+ path.foldLeft(source) { (ret, item) => Select(ret, newTermName(item)) }
+
+ def mkSelectSyms(source: Tree, path: Symbol*): Tree =
+ path.foldLeft(source) { (ret, item) => Select(ret, item) }
def mkCall(root: Tree, path: String*)(args: List[Tree]) = Apply(mkSelect(root, path: _*), args)
- def mkSeq(items: List[Tree]): Tree = Apply(mkSelect("scala", "collection", "Seq", "apply"), items)
- def mkList(items: List[Tree]): Tree = Apply(mkSelect("scala", "collection", "immutable", "List", "apply"), items)
+ def mkSeq(items: List[Tree]): Tree =
+ Apply(mkSelect("scala", "collection", "Seq", "apply"), items)
+
+ def mkList(items: List[Tree]): Tree =
+ Apply(mkSelect("scala", "collection", "immutable", "List", "apply"), items)
+
+ def mkMap(items: List[Tree]): Tree =
+ Apply(mkSelect("scala", "collection", "immutable", "Map", "apply"), items)
def mkVal(name: String, flags: FlagSet, transient: Boolean, valTpe: Type, value: Tree): Tree = {
ValDef(Modifiers(flags), newTermName(name), TypeTree(valTpe), value)
@@ -81,7 +92,11 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
List(valDef, defDef)
}
- def mkVarAndLazyGetter(name: String, flags: FlagSet, valTpe: Type, value: Tree): (Tree, Tree) = {
+ def mkVarAndLazyGetter(
+ name: String,
+ flags: FlagSet,
+ valTpe: Type,
+ value: Tree): (Tree, Tree) = {
val fieldName = name + " "
val field = mkVar(fieldName, NoFlags, false, valTpe, mkNull)
val fieldSel = Ident(newTermName(fieldName))
@@ -117,20 +132,35 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
}
}
- def mkMethod(name: String, flags: FlagSet, args: List[(String, Type)], ret: Type, impl: Tree): Tree = {
- val valParams = args map { case (name, tpe) => ValDef(Modifiers(Flag.PARAM), newTermName(name), TypeTree(tpe), EmptyTree) }
+ def mkMethod(
+ name: String,
+ flags: FlagSet,
+ args: List[(String, Type)],
+ ret: Type,
+ impl: Tree): Tree = {
+ val valParams = args map { case (name, tpe) =>
+ ValDef(Modifiers(Flag.PARAM), newTermName(name), TypeTree(tpe), EmptyTree)
+ }
DefDef(Modifiers(flags), newTermName(name), Nil, List(valParams), TypeTree(ret), impl)
}
- def mkClass(name: TypeName, flags: FlagSet, parents: List[Type], members: List[Tree]): ClassDef = {
+ def mkClass(
+ name: TypeName,
+ flags: FlagSet,
+ parents: List[Type],
+ members: List[Tree]): ClassDef = {
val parentTypeTrees = parents map { TypeTree(_) }
val selfType = ValDef(Modifiers(), nme.WILDCARD, TypeTree(NoType), EmptyTree)
ClassDef(Modifiers(flags), name, Nil, Template(parentTypeTrees, selfType, members))
}
- def mkThrow(tpe: Type, msg: Tree): Tree = Throw(Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), List(msg)))
+ def mkThrow(tpe: Type, msg: Tree): Tree =
+ Throw(Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), List(msg)))
+
// def mkThrow(tpe: Type, msg: Tree): Tree = Throw(New(TypeTree(tpe)), List(List(msg))))
+
def mkThrow(tpe: Type, msg: String): Tree = mkThrow(tpe, c.literal(msg).tree)
+
def mkThrow(msg: String): Tree = mkThrow(typeOf[java.lang.RuntimeException], msg)
implicit def tree2Ops[T <: Tree](tree: T) = new {
@@ -165,17 +195,24 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
}
def mkBuilderOf(elemTpe: Type, listTpe: Type) = {
- def makeIt[ElemTpe: c.WeakTypeTag, ListTpe: c.WeakTypeTag] = weakTypeOf[scala.collection.mutable.Builder[ElemTpe, ListTpe]]
+ def makeIt[ElemTpe: c.WeakTypeTag, ListTpe: c.WeakTypeTag] =
+ weakTypeOf[scala.collection.mutable.Builder[ElemTpe, ListTpe]]
+
makeIt(c.WeakTypeTag(elemTpe), c.WeakTypeTag(listTpe))
}
def mkCanBuildFromOf(fromTpe: Type, elemTpe: Type, toTpe: Type) = {
- def makeIt[From: c.WeakTypeTag, Elem: c.WeakTypeTag, To: c.WeakTypeTag] = weakTypeOf[scala.collection.generic.CanBuildFrom[From, Elem, To]]
+ def makeIt[From: c.WeakTypeTag, Elem: c.WeakTypeTag, To: c.WeakTypeTag] =
+ weakTypeOf[scala.collection.generic.CanBuildFrom[From, Elem, To]]
+
makeIt(c.WeakTypeTag(fromTpe), c.WeakTypeTag(elemTpe), c.WeakTypeTag(toTpe))
}
- def mkCtorCall(tpe: Type, args: List[Tree]) = Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), args)
- def mkSuperCall(args: List[Tree] = List()) = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), args)
+ def mkCtorCall(tpe: Type, args: List[Tree]) =
+ Apply(Select(New(TypeTree(tpe)), nme.CONSTRUCTOR), args)
+
+ def mkSuperCall(args: List[Tree] = List()) =
+ Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), args)
def mkWhile(cond: Tree)(body: Tree): Tree = {
val lblName = c.fresh[TermName]("while")
@@ -202,7 +239,8 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
def extractOneInputUdf(fun: Tree) = {
val (paramName, udfBody) = fun match {
case Function(List(param), body) => (param.name.toString, body)
- case _ => c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
+ case _ =>
+ c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
}
val uncheckedUdfBody = c.resetAllAttrs(udfBody)
(paramName, uncheckedUdfBody)
@@ -210,8 +248,10 @@ trait TreeGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C]
def extractTwoInputUdf(fun: Tree) = {
val (param1Name, param2Name, udfBody) = fun match {
- case Function(List(param1, param2), body) => (param1.name.toString, param2.name.toString, body)
- case _ => c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
+ case Function(List(param1, param2), body) =>
+ (param1.name.toString, param2.name.toString, body)
+ case _ =>
+ c.abort(c.enclosingPosition, "Could not extract user defined function, got: " + show(fun))
}
val uncheckedUdfBody = c.resetAllAttrs(udfBody)
(param1Name, param2Name, uncheckedUdfBody)
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala
new file mode 100644
index 0000000..7b1675d
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeAnalyzer.scala
@@ -0,0 +1,382 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.api.scala.codegen
+
+import scala.Option.option2Iterable
+import scala.collection.GenTraversableOnce
+import scala.collection.mutable
+import scala.reflect.macros.Context
+import scala.util.DynamicVariable
+
+import org.apache.flink.types.BooleanValue
+import org.apache.flink.types.ByteValue
+import org.apache.flink.types.CharValue
+import org.apache.flink.types.DoubleValue
+import org.apache.flink.types.FloatValue
+import org.apache.flink.types.IntValue
+import org.apache.flink.types.StringValue
+import org.apache.flink.types.LongValue
+import org.apache.flink.types.ShortValue
+
+
+private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
+ with TypeDescriptors[C] =>
+
+ import c.universe._
+
+ // This value is controlled by the udtRecycling compiler option
+ var enableMutableUDTs = false
+
+ private val mutableTypes = mutable.Set[Type]()
+
+ def getUDTDescriptor(tpe: Type): UDTDescriptor = new UDTAnalyzerInstance().analyze(tpe)
+
+ private def typeArgs(tpe: Type) = tpe match { case TypeRef(_, _, args) => args }
+
+ private class UDTAnalyzerInstance {
+
+ private val cache = new UDTAnalyzerCache()
+
+ def analyze(tpe: Type): UDTDescriptor = {
+
+ cache.getOrElseUpdate(tpe) { id =>
+ tpe match {
+ case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper)
+ case BoxedPrimitiveType(default, wrapper, box, unbox) =>
+ BoxedPrimitiveDescriptor(id, tpe, default, wrapper, box, unbox)
+ case ListType(elemTpe, iter) => analyzeList(id, tpe, elemTpe, iter)
+ case CaseClassType() => analyzeCaseClass(id, tpe)
+ case BaseClassType() => analyzeClassHierarchy(id, tpe)
+ case ValueType() => ValueDescriptor(id, tpe)
+ case WritableType() => WritableDescriptor(id, tpe)
+ case _ => GenericClassDescriptor(id, tpe)
+ }
+ }
+ }
+
+ private def analyzeList(
+ id: Int,
+ tpe: Type,
+ elemTpe: Type,
+ iter: Tree => Tree): UDTDescriptor = analyze(elemTpe) match {
+ case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
+ case desc => ListDescriptor(id, tpe, iter, desc)
+ }
+
+ private def analyzeClassHierarchy(id: Int, tpe: Type): UDTDescriptor = {
+
+ val tagField = {
+ val (intTpe, intDefault, intWrapper) = PrimitiveType.intPrimitive
+ FieldAccessor(
+ NoSymbol,
+ NoSymbol,
+ NullaryMethodType(intTpe),
+ isBaseField = true,
+ PrimitiveDescriptor(cache.newId, intTpe, intDefault, intWrapper))
+ }
+
+ val subTypes = tpe.typeSymbol.asClass.knownDirectSubclasses.toList flatMap { d =>
+
+ val dTpe =
+ {
+ val tArgs = (tpe.typeSymbol.asClass.typeParams, typeArgs(tpe)).zipped.toMap
+ val dArgs = d.asClass.typeParams map { dp =>
+ val tArg = tArgs.keySet.find { tp =>
+ dp == tp.typeSignature.asSeenFrom(d.typeSignature, tpe.typeSymbol).typeSymbol
+ }
+ tArg map { tArgs(_) } getOrElse dp.typeSignature
+ }
+
+ appliedType(d.asType.toType, dArgs)
+ }
+
+ if (dTpe <:< tpe)
+ Some(analyze(dTpe))
+ else
+ None
+ }
+
+ val errors = subTypes flatMap { _.findByType[UnsupportedDescriptor] }
+
+ errors match {
+ case _ :: _ =>
+ val errorMessage = errors flatMap {
+ case UnsupportedDescriptor(_, subType, errs) =>
+ errs map { err => "Subtype " + subType + " - " + err }
+ }
+ UnsupportedDescriptor(id, tpe, errorMessage)
+
+ case Nil if subTypes.isEmpty =>
+ UnsupportedDescriptor(id, tpe, Seq("No instantiable subtypes found for base class"))
+ case Nil =>
+ val (tParams, _) = tpe.typeSymbol.asClass.typeParams.zip(typeArgs(tpe)).unzip
+ val baseMembers =
+ tpe.members filter { f => f.isMethod } filter { f => f.asMethod.isSetter } map {
+ f => (f, f.asMethod.setter, f.asMethod.returnType)
+ }
+
+ val subMembers = subTypes map {
+ case BaseClassDescriptor(_, _, getters, _) => getters
+ case CaseClassDescriptor(_, _, _, _, getters) => getters
+ case _ => Seq()
+ }
+
+ val baseFields = baseMembers flatMap {
+ case (bGetter, bSetter, bTpe) =>
+ val accessors = subMembers map {
+ _ find { sf =>
+ sf.getter.name == bGetter.name &&
+ sf.tpe.termSymbol.asMethod.returnType <:< bTpe.termSymbol.asMethod.returnType
+ }
+ }
+ accessors.forall { _.isDefined } match {
+ case true =>
+ Some(
+ FieldAccessor(
+ bGetter, bSetter, bTpe, isBaseField = true, analyze(bTpe.termSymbol.asMethod.returnType)))
+ case false => None
+ }
+ }
+
+ def wireBaseFields(desc: UDTDescriptor): UDTDescriptor = {
+
+ def updateField(field: FieldAccessor) = {
+ baseFields find { bf => bf.getter.name == field.getter.name } match {
+ case Some(FieldAccessor(_, _, _, _, fieldDesc)) =>
+ field.copy(isBaseField = true, desc = fieldDesc)
+ case None => field
+ }
+ }
+
+ desc match {
+ case desc @ BaseClassDescriptor(_, _, getters, baseSubTypes) =>
+ desc.copy(getters = getters map updateField, subTypes = baseSubTypes map wireBaseFields)
+ case desc @ CaseClassDescriptor(_, _, _, _, getters) =>
+ desc.copy(getters = getters map updateField)
+ case _ => desc
+ }
+ }
+
+ BaseClassDescriptor(id, tpe, tagField +: baseFields.toSeq, subTypes map wireBaseFields)
+ }
+
+ }
+
+ private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = {
+
+ tpe.baseClasses exists { bc => !(bc == tpe.typeSymbol) && bc.asClass.isCaseClass } match {
+
+ case true =>
+ UnsupportedDescriptor(id, tpe, Seq("Case-to-case inheritance is not supported."))
+
+ case false =>
+
+ val ctors = tpe.declarations collect {
+ case m: MethodSymbol if m.isPrimaryConstructor => m
+ }
+
+ ctors match {
+ case c1 :: c2 :: _ =>
+ UnsupportedDescriptor(
+ id,
+ tpe,
+ Seq("Multiple constructors found, this is not supported."))
+ case ctor :: Nil =>
+ val caseFields = ctor.paramss.flatten.map {
+ sym =>
+ {
+ val methodSym = tpe.member(sym.name).asMethod
+ val getter = methodSym.getter
+ val setter = methodSym.setter
+ val returnType = methodSym.returnType.asSeenFrom(tpe, tpe.typeSymbol)
+ (getter, setter, returnType)
+ }
+ }
+ val fields = caseFields map {
+ case (fgetter, fsetter, fTpe) =>
+ FieldAccessor(fgetter, fsetter, fTpe, isBaseField = false, analyze(fTpe))
+ }
+ val mutable = enableMutableUDTs && (fields forall { f => f.setter != NoSymbol })
+ if (mutable) {
+ mutableTypes.add(tpe)
+ }
+ fields filter { _.desc.isInstanceOf[UnsupportedDescriptor] } match {
+ case errs @ _ :: _ =>
+ val msgs = errs flatMap { f =>
+ (f: @unchecked) match {
+ case FieldAccessor(fgetter, _, _, _, UnsupportedDescriptor(_, fTpe, errors)) =>
+ errors map { err => "Field " + fgetter.name + ": " + fTpe + " - " + err }
+ }
+ }
+ UnsupportedDescriptor(id, tpe, msgs)
+
+ case Nil => CaseClassDescriptor(id, tpe, mutable, ctor, fields.toSeq)
+ }
+ }
+ }
+ }
+
+ private object PrimitiveType {
+ def intPrimitive: (Type, Literal, Type) = {
+ val (d, w) = primitives(definitions.IntClass)
+ (definitions.IntTpe, d, w)
+ }
+
+ def unapply(tpe: Type): Option[(Literal, Type)] = primitives.get(tpe.typeSymbol)
+ }
+
+ private object BoxedPrimitiveType {
+ def unapply(tpe: Type): Option[(Literal, Type, Tree => Tree, Tree => Tree)] =
+ boxedPrimitives.get(tpe.typeSymbol)
+ }
+
+ private object ListType {
+
+ def unapply(tpe: Type): Option[(Type, Tree => Tree)] = tpe match {
+
+ case ArrayType(elemTpe) =>
+ val iter = { source: Tree =>
+ Select(source, newTermName("iterator"))
+ }
+ Some(elemTpe, iter)
+
+ case TraversableType(elemTpe) =>
+ val iter = { source: Tree => Select(source, newTermName("toIterator")) }
+ Some(elemTpe, iter)
+
+ case _ => None
+ }
+
+ private object ArrayType {
+ def unapply(tpe: Type): Option[Type] = tpe match {
+ case TypeRef(_, _, elemTpe :: Nil) if tpe <:< typeOf[Array[_]] => Some(elemTpe)
+ case _ => None
+ }
+ }
+
+ private object TraversableType {
+ def unapply(tpe: Type): Option[Type] = tpe match {
+ case _ if tpe <:< typeOf[GenTraversableOnce[_]] =>
+ // val abstrElemTpe = genTraversableOnceClass.typeConstructor.typeParams.head.tpe
+ // val elemTpe = abstrElemTpe.asSeenFrom(tpe, genTraversableOnceClass)
+ // Some(elemTpe)
+ // TODO make sure this works as it should
+ tpe match {
+ case TypeRef(_, _, elemTpe :: Nil) => Some(elemTpe.asSeenFrom(tpe, tpe.typeSymbol))
+ }
+
+ case _ => None
+ }
+ }
+ }
+
+ private object CaseClassType {
+ def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
+ }
+
+ private object BaseClassType {
+ def unapply(tpe: Type): Boolean =
+ tpe.typeSymbol.asClass.isAbstractClass && tpe.typeSymbol.asClass.isSealed
+ }
+
+ private object ValueType {
+ def unapply(tpe: Type): Boolean =
+ tpe.typeSymbol.asClass.baseClasses exists {
+ s => s.fullName == "org.apache.flink.types.Value"
+ }
+ }
+
+ private object WritableType {
+ def unapply(tpe: Type): Boolean =
+ tpe.typeSymbol.asClass.baseClasses exists {
+ s => s.fullName == "org.apache.hadoop.io.Writable"
+ }
+ }
+
+ private class UDTAnalyzerCache {
+
+ private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map())
+ private val idGen = new Counter
+
+ def newId = idGen.next
+
+ def getOrElseUpdate(tpe: Type)(orElse: Int => UDTDescriptor): UDTDescriptor = {
+
+ val id = idGen.next
+ val cache = caches.value
+
+ cache.get(tpe) map { _.copy(id = id) } getOrElse {
+ val ref = RecursiveDescriptor(id, tpe, id)
+ caches.withValue(cache + (tpe -> ref)) {
+ orElse(id)
+ }
+ }
+ }
+ }
+ }
+
+ lazy val primitives = Map[Symbol, (Literal, Type)](
+ definitions.BooleanClass -> (Literal(Constant(false)), typeOf[BooleanValue]),
+ definitions.ByteClass -> (Literal(Constant(0: Byte)), typeOf[ByteValue]),
+ definitions.CharClass -> (Literal(Constant(0: Char)), typeOf[CharValue]),
+ definitions.DoubleClass -> (Literal(Constant(0: Double)), typeOf[DoubleValue]),
+ definitions.FloatClass -> (Literal(Constant(0: Float)), typeOf[FloatValue]),
+ definitions.IntClass -> (Literal(Constant(0: Int)), typeOf[IntValue]),
+ definitions.LongClass -> (Literal(Constant(0: Long)), typeOf[LongValue]),
+ definitions.ShortClass -> (Literal(Constant(0: Short)), typeOf[ShortValue]),
+ definitions.StringClass -> (Literal(Constant(null: String)), typeOf[StringValue]))
+
+ lazy val boxedPrimitives = {
+
+ def getBoxInfo(prim: Symbol, primName: String, boxName: String) = {
+ val (default, wrapper) = primitives(prim)
+ val box = { t: Tree =>
+ Apply(
+ Select(
+ Select(Ident(newTermName("scala")), newTermName("Predef")),
+ newTermName(primName + "2" + boxName)),
+ List(t))
+ }
+ val unbox = { t: Tree =>
+ Apply(
+ Select(
+ Select(Ident(newTermName("scala")), newTermName("Predef")),
+ newTermName(boxName + "2" + primName)),
+ List(t))
+ }
+ (default, wrapper, box, unbox)
+ }
+
+ Map(
+ typeOf[java.lang.Boolean].typeSymbol ->
+ getBoxInfo(definitions.BooleanClass, "boolean", "Boolean"),
+ typeOf[java.lang.Byte].typeSymbol -> getBoxInfo(definitions.ByteClass, "byte", "Byte"),
+ typeOf[java.lang.Character].typeSymbol ->
+ getBoxInfo(definitions.CharClass, "char", "Character"),
+ typeOf[java.lang.Double].typeSymbol ->
+ getBoxInfo(definitions.DoubleClass, "double", "Double"),
+ typeOf[java.lang.Float].typeSymbol -> getBoxInfo(definitions.FloatClass, "float", "Float"),
+ typeOf[java.lang.Integer].typeSymbol -> getBoxInfo(definitions.IntClass, "int", "Integer"),
+ typeOf[java.lang.Long].typeSymbol -> getBoxInfo(definitions.LongClass, "long", "Long"),
+ typeOf[java.lang.Short].typeSymbol -> getBoxInfo(definitions.ShortClass, "short", "Short"))
+ }
+
+}
+