You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2014/09/22 14:28:47 UTC
[05/60] Rewrite the Scala API as (somewhat) thin Layer on Java API
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala
new file mode 100644
index 0000000..bbb9e73
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeDescriptors.scala
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.api.scala.codegen
+
+import scala.language.postfixOps
+
+import scala.reflect.macros.Context
+import scala.reflect.classTag
+import scala.reflect.ClassTag
+import scala.Option.option2Iterable
+
+// These are only used internally while analyzing Scala types in TypeAnalyzer and TypeInformationGen
+
+private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C] =>
+ import c.universe._
+
+ abstract sealed class UDTDescriptor {
+ val id: Int
+ val tpe: Type
+ val isPrimitiveProduct: Boolean = false
+
+ def canBeKey: Boolean
+
+ def mkRoot: UDTDescriptor = this
+
+ def flatten: Seq[UDTDescriptor]
+ def getters: Seq[FieldAccessor] = Seq()
+
+ def select(member: String): Option[UDTDescriptor] =
+ getters find { _.getter.name.toString == member } map { _.desc }
+
+ def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
+ case Nil => Seq(Some(this))
+ case head :: tail => getters find { _.getter.name.toString == head } match {
+ case None => Seq(None)
+ case Some(d : FieldAccessor) => d.desc.select(tail)
+ }
+ }
+
+ def findById(id: Int): Option[UDTDescriptor] = flatten.find { _.id == id }
+
+ def findByType[T <: UDTDescriptor: ClassTag]: Seq[T] = {
+ val clazz = classTag[T].runtimeClass
+ flatten filter { item => clazz.isAssignableFrom(item.getClass) } map { _.asInstanceOf[T] }
+ }
+
+ def getRecursiveRefs: Seq[UDTDescriptor] =
+ findByType[RecursiveDescriptor].flatMap { rd => findById(rd.refId) }.map { _.mkRoot }.distinct
+ }
+
+ case class GenericClassDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
+ override def flatten = Seq(this)
+
+ def canBeKey = false
+ }
+
+ case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor {
+ override def flatten = Seq(this)
+
+ def canBeKey = tpe <:< typeOf[Comparable[_]]
+ }
+
+ case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type)
+ extends UDTDescriptor {
+ override val isPrimitiveProduct = true
+ override def flatten = Seq(this)
+ override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
+ }
+
+ case class BoxedPrimitiveDescriptor(
+ id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree)
+ extends UDTDescriptor {
+
+ override val isPrimitiveProduct = true
+ override def flatten = Seq(this)
+ override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
+
+ override def hashCode() = (id, tpe, default, wrapper, "BoxedPrimitiveDescriptor").hashCode()
+ override def equals(that: Any) = that match {
+ case BoxedPrimitiveDescriptor(thatId, thatTpe, thatDefault, thatWrapper, _, _) =>
+ (id, tpe, default, wrapper).equals(thatId, thatTpe, thatDefault, thatWrapper)
+ case _ => false
+ }
+ }
+
+ case class ListDescriptor(id: Int, tpe: Type, iter: Tree => Tree, elem: UDTDescriptor)
+ extends UDTDescriptor {
+ override def canBeKey = false
+ override def flatten = this +: elem.flatten
+
+ def getInnermostElem: UDTDescriptor = elem match {
+ case list: ListDescriptor => list.getInnermostElem
+ case _ => elem
+ }
+
+ override def hashCode() = (id, tpe, elem).hashCode()
+ override def equals(that: Any) = that match {
+ case that @ ListDescriptor(thatId, thatTpe, _, thatElem) =>
+ (id, tpe, elem).equals((thatId, thatTpe, thatElem))
+ case _ => false
+ }
+ }
+
+ case class BaseClassDescriptor(
+ id: Int, tpe: Type, override val getters: Seq[FieldAccessor], subTypes: Seq[UDTDescriptor])
+ extends UDTDescriptor {
+
+ override def flatten = this +: ((getters flatMap { _.desc.flatten }) ++ (subTypes flatMap { _.flatten }))
+ override def canBeKey = flatten forall { f => f.canBeKey }
+
+ override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
+ case Nil => getters flatMap { g => g.desc.select(Nil) }
+ case head :: tail => getters find { _.getter.name.toString == head } match {
+ case None => Seq(None)
+ case Some(d : FieldAccessor) => d.desc.select(tail)
+ }
+ }
+ }
+
+ case class CaseClassDescriptor(
+ id: Int, tpe: Type, mutable: Boolean, ctor: Symbol, override val getters: Seq[FieldAccessor])
+ extends UDTDescriptor {
+
+ override val isPrimitiveProduct = getters.nonEmpty && getters.forall(_.desc.isPrimitiveProduct)
+
+ override def mkRoot = this.copy(getters = getters map { _.copy(isBaseField = false) })
+ override def flatten = this +: (getters flatMap { _.desc.flatten })
+
+ override def canBeKey = flatten forall { f => f.canBeKey }
+
+ // Hack: ignore the ctorTpe, since two Type instances representing
+ // the same ctor function type don't appear to be considered equal.
+ // Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
+ override def hashCode = (id, tpe, ctor, getters).hashCode
+ override def equals(that: Any) = that match {
+ case CaseClassDescriptor(thatId, thatTpe, thatMutable, thatCtor, thatGetters) =>
+ (id, tpe, mutable, ctor, getters).equals(thatId, thatTpe, thatMutable, thatCtor, thatGetters)
+ case _ => false
+ }
+
+ override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
+ case Nil => getters flatMap { g => g.desc.select(Nil) }
+ case head :: tail => getters find { _.getter.name.toString == head } match {
+ case None => Seq(None)
+ case Some(d : FieldAccessor) => d.desc.select(tail)
+ }
+ }
+ }
+
+ case class FieldAccessor(getter: Symbol, setter: Symbol, tpe: Type, isBaseField: Boolean, desc: UDTDescriptor)
+
+ case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor {
+ override def flatten = Seq(this)
+ override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
+ }
+
+ case class ValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
+ override val isPrimitiveProduct = true
+ override def flatten = Seq(this)
+ override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
+ }
+
+ case class WritableDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
+ override val isPrimitiveProduct = true
+ override def flatten = Seq(this)
+ override def canBeKey = tpe <:< typeOf[org.apache.hadoop.io.WritableComparable[_]]
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
new file mode 100644
index 0000000..248c396
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala
@@ -0,0 +1,184 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.api.scala.codegen
+
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.java.typeutils._
+import org.apache.flink.api.scala.typeutils.{ScalaTupleSerializer, ScalaTupleTypeInfo}
+import org.apache.flink.types.{Value, TypeInformation}
+import org.apache.hadoop.io.Writable
+
+import scala.reflect.macros.Context
+
+private[flink] trait TypeInformationGen[C <: Context] {
+ this: MacroContextHolder[C]
+ with TypeDescriptors[C]
+ with TypeAnalyzer[C]
+ with TreeGen[C] =>
+
+ import c.universe._
+
+ // This is for external calling by TypeUtils.createTypeInfo
+ def mkTypeInfo[T: c.WeakTypeTag]: c.Expr[TypeInformation[T]] = {
+ val desc = getUDTDescriptor(weakTypeOf[T])
+ val result: c.Expr[TypeInformation[T]] = mkTypeInfo(desc)(c.WeakTypeTag(desc.tpe))
+ result
+ }
+
+ // We have this for internal use so that we can use it to recursively generate a tree of
+ // TypeInformation from a tree of UDTDescriptor
+ def mkTypeInfo[T: c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = desc match {
+ case cc@CaseClassDescriptor(_, tpe, _, _, _) =>
+ mkTupleTypeInfo(cc)(c.WeakTypeTag(tpe).asInstanceOf[c.WeakTypeTag[Product]])
+ .asInstanceOf[c.Expr[TypeInformation[T]]]
+ case p : PrimitiveDescriptor => mkPrimitiveTypeInfo(p.tpe)
+ case p : BoxedPrimitiveDescriptor => mkPrimitiveTypeInfo(p.tpe)
+ case l : ListDescriptor if l.tpe <:< typeOf[Array[_]] => mkListTypeInfo(l)
+ case v : ValueDescriptor =>
+ mkValueTypeInfo(v)(c.WeakTypeTag(v.tpe).asInstanceOf[c.WeakTypeTag[Value]])
+ .asInstanceOf[c.Expr[TypeInformation[T]]]
+ case d : WritableDescriptor =>
+ mkWritableTypeInfo(d)(c.WeakTypeTag(d.tpe).asInstanceOf[c.WeakTypeTag[Writable]])
+ .asInstanceOf[c.Expr[TypeInformation[T]]]
+ case d => mkGenericTypeInfo(d)
+ }
+
+ def mkTupleTypeInfo[T <: Product : c.WeakTypeTag](
+ desc: CaseClassDescriptor): c.Expr[TypeInformation[T]] = {
+ val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
+ val fields = desc.getters.toList map { field =>
+ mkTypeInfo(field.desc)(c.WeakTypeTag(field.tpe)).tree
+ }
+ val fieldsExpr = c.Expr[Seq[TypeInformation[_]]](mkList(fields))
+ val instance = mkCreateTupleInstance[T](desc)(c.WeakTypeTag(desc.tpe))
+ reify {
+ new ScalaTupleTypeInfo[T](tpeClazz.splice, fieldsExpr.splice) {
+ override def createSerializer: TypeSerializer[T] = {
+ val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
+ for (i <- 0 until getArity) {
+ fieldSerializers(i) = types(i).createSerializer
+ }
+
+ new ScalaTupleSerializer[T](tupleType, fieldSerializers) {
+ override def createInstance(fields: Array[AnyRef]): T = {
+ instance.splice
+ }
+ }
+ }
+ }
+ }
+ }
+
+ def mkListTypeInfo[T: c.WeakTypeTag](desc: ListDescriptor): c.Expr[TypeInformation[T]] = {
+ val arrayClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
+ val elementClazz = c.Expr[Class[T]](Literal(Constant(desc.elem.tpe)))
+ val elementTypeInfo = mkTypeInfo(desc.elem)
+ desc.elem match {
+ // special case for string, which in scala is a primitive, but not in java
+ case p: PrimitiveDescriptor if p.tpe <:< typeOf[String] =>
+ reify {
+ BasicArrayTypeInfo.getInfoFor(arrayClazz.splice)
+ }
+ case p: PrimitiveDescriptor =>
+ reify {
+ PrimitiveArrayTypeInfo.getInfoFor(arrayClazz.splice)
+ }
+ case bp: BoxedPrimitiveDescriptor =>
+ reify {
+ BasicArrayTypeInfo.getInfoFor(arrayClazz.splice)
+ }
+ case _ =>
+ reify {
+ ObjectArrayTypeInfo.getInfoFor(
+ arrayClazz.splice,
+ elementTypeInfo.splice).asInstanceOf[TypeInformation[T]]
+ }
+ }
+ }
+
+ def mkValueTypeInfo[T <: Value : c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = {
+ val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
+ reify {
+ new ValueTypeInfo[T](tpeClazz.splice)
+ }
+ }
+
+ def mkWritableTypeInfo[T <: Writable : c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = {
+ val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
+ reify {
+ new WritableTypeInfo[T](tpeClazz.splice)
+ }
+ }
+
+ def mkGenericTypeInfo[T: c.WeakTypeTag](desc: UDTDescriptor): c.Expr[TypeInformation[T]] = {
+ val tpeClazz = c.Expr[Class[T]](Literal(Constant(desc.tpe)))
+ reify {
+ TypeExtractor.createTypeInfo(tpeClazz.splice).asInstanceOf[TypeInformation[T]]
+ }
+ }
+
+ def mkPrimitiveTypeInfo[T: c.WeakTypeTag](tpe: Type): c.Expr[TypeInformation[T]] = {
+ val tpeClazz = c.Expr[Class[T]](Literal(Constant(tpe)))
+ reify {
+ BasicTypeInfo.getInfoFor(tpeClazz.splice)
+ }
+ }
+
+ def mkCreateTupleInstance[T: c.WeakTypeTag](desc: CaseClassDescriptor): c.Expr[T] = {
+ val fields = desc.getters.zipWithIndex.map { case (field, i) =>
+ val call = mkCall(Ident(newTermName("fields")), "apply")(List(Literal(Constant(i))))
+ mkAsInstanceOf(call)(c.WeakTypeTag(field.tpe))
+ }
+ val result = Apply(Select(New(TypeTree(desc.tpe)), nme.CONSTRUCTOR), fields.toList)
+ c.Expr[T](result)
+ }
+
+// def mkCaseClassTypeInfo[T: c.WeakTypeTag](desc: CaseClassDescriptor): c.Expr[TypeInformation[T]] = {
+// val tpeClazz = c.Expr[Class[_]](Literal(Constant(desc.tpe)))
+// val caseFields = mkCaseFields(desc)
+// reify {
+// new ScalaTupleTypeInfo[T] {
+// def createSerializer: TypeSerializer[T] = {
+// null
+// }
+//
+// val fields: Map[String, TypeInformation[_]] = caseFields.splice
+// val clazz = tpeClazz.splice
+// }
+// }
+// }
+//
+// private def mkCaseFields(desc: UDTDescriptor): c.Expr[Map[String, TypeInformation[_]]] = {
+// val fields = getFields("_root_", desc).toList map { case (fieldName, fieldDesc) =>
+// val nameTree = c.Expr(Literal(Constant(fieldName)))
+// val fieldTypeInfo = mkTypeInfo(fieldDesc)(c.WeakTypeTag(fieldDesc.tpe))
+// reify { (nameTree.splice, fieldTypeInfo.splice) }.tree
+// }
+//
+// c.Expr(mkMap(fields))
+// }
+//
+// protected def getFields(name: String, desc: UDTDescriptor): Seq[(String, UDTDescriptor)] = desc match {
+// // Flatten product types
+// case CaseClassDescriptor(_, _, _, _, getters) =>
+// getters filterNot { _.isBaseField } flatMap { f => getFields(name + "." + f.getter.name, f.desc) }
+// case _ => Seq((name, desc))
+// }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTAnalyzer.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTAnalyzer.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTAnalyzer.scala
deleted file mode 100644
index 2dad277..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTAnalyzer.scala
+++ /dev/null
@@ -1,344 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.Option.option2Iterable
-import scala.collection.GenTraversableOnce
-import scala.collection.mutable
-import scala.reflect.macros.Context
-import scala.util.DynamicVariable
-
-import org.apache.flink.types.BooleanValue
-import org.apache.flink.types.ByteValue
-import org.apache.flink.types.CharValue
-import org.apache.flink.types.DoubleValue
-import org.apache.flink.types.FloatValue
-import org.apache.flink.types.IntValue
-import org.apache.flink.types.StringValue
-import org.apache.flink.types.LongValue
-import org.apache.flink.types.ShortValue
-
-
-trait UDTAnalyzer[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with Loggers[C] =>
- import c.universe._
-
- // This value is controlled by the udtRecycling compiler option
- var enableMutableUDTs = false
-
- private val mutableTypes = mutable.Set[Type]()
-
- def getUDTDescriptor(tpe: Type): UDTDescriptor = (new UDTAnalyzerInstance with Logger).analyze(tpe)
-
- private def normTpe(tpe: Type): Type = {
- // TODO Figure out what the heck this does
- // val currentThis = ThisType(localTyper.context.enclClass.owner)
- // currentThis.baseClasses.foldLeft(tpe map { _.dealias }) { (tpe, base) => tpe.substThis(base, currentThis) }
- tpe
- }
-
- private def typeArgs(tpe: Type) = tpe match { case TypeRef(_, _, args) => args }
-
- private class UDTAnalyzerInstance { this: Logger =>
-
- private val cache = new UDTAnalyzerCache()
-
- def analyze(tpe: Type): UDTDescriptor = {
-
- val normed = normTpe(tpe)
-
- cache.getOrElseUpdate(normed) { id =>
- normed match {
- case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, normed, default, wrapper)
- case BoxedPrimitiveType(default, wrapper, box, unbox) => BoxedPrimitiveDescriptor(id, normed, default, wrapper, box, unbox)
- case ListType(elemTpe, iter) => analyzeList(id, normed, elemTpe, iter)
- case CaseClassType() => analyzeCaseClass(id, normed)
- case BaseClassType() => analyzeClassHierarchy(id, normed)
- case PactValueType() => PactValueDescriptor(id, normed)
- case _ => UnsupportedDescriptor(id, normed, Seq("Unsupported type " + normed))
- }
- }
- }
-
- private def analyzeList(id: Int, tpe: Type, elemTpe: Type, iter: Tree => Tree): UDTDescriptor = analyze(elemTpe) match {
- case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
- case desc => ListDescriptor(id, tpe, iter, desc)
- }
-
- private def analyzeClassHierarchy(id: Int, tpe: Type): UDTDescriptor = {
-
- val tagField = {
- val (intTpe, intDefault, intWrapper) = PrimitiveType.intPrimitive
- FieldAccessor(NoSymbol, NoSymbol, NullaryMethodType(intTpe), true, PrimitiveDescriptor(cache.newId, intTpe, intDefault, intWrapper))
- }
-
-// c.info(c.enclosingPosition, "KNOWN SUBCLASSES: " + tpe.typeSymbol.asClass.knownDirectSubclasses.toList, true)
-
- val subTypes = tpe.typeSymbol.asClass.knownDirectSubclasses.toList flatMap { d =>
-
- val dTpe = // verbosely[Type] { dTpe => d.tpe + " <: " + tpe + " instantiated as " + dTpe + " (" + (if (dTpe <:< tpe) "Valid" else "Invalid") + " subtype)" }
- {
- val tArgs = (tpe.typeSymbol.asClass.typeParams, typeArgs(tpe)).zipped.toMap
- val dArgs = d.asClass.typeParams map { dp =>
- val tArg = tArgs.keySet.find { tp => dp == tp.typeSignature.asSeenFrom(d.typeSignature, tpe.typeSymbol).typeSymbol }
- tArg map { tArgs(_) } getOrElse dp.typeSignature
- }
-
- normTpe(appliedType(d.asType.toType, dArgs))
- }
-// c.info(c.enclosingPosition, "dTpe: " + dTpe, true)
-
- if (dTpe <:< tpe)
- Some(analyze(dTpe))
- else
- None
- }
-
-// c.info(c.enclosingPosition, c.enclosingRun.units.size + " SUBTYPES: " + subTypes, true)
-
- val errors = subTypes flatMap { _.findByType[UnsupportedDescriptor] }
-
-// c.info(c.enclosingPosition, "ERROS: " + errors, true)
-
- errors match {
- case _ :: _ => UnsupportedDescriptor(id, tpe, errors flatMap { case UnsupportedDescriptor(_, subType, errs) => errs map { err => "Subtype " + subType + " - " + err } })
- case Nil if subTypes.isEmpty => UnsupportedDescriptor(id, tpe, Seq("No instantiable subtypes found for base class"))
- case Nil => {
-
- val (tParams, tArgs) = tpe.typeSymbol.asClass.typeParams.zip(typeArgs(tpe)).unzip
- val baseMembers = tpe.members filter { f => f.isMethod } filter { f => f.asMethod.isSetter } map {
- f => (f, f.asMethod.setter, normTpe(f.asMethod.returnType))
- }
-
- val subMembers = subTypes map {
- case BaseClassDescriptor(_, _, getters, _) => getters
- case CaseClassDescriptor(_, _, _, _, getters) => getters
- case _ => Seq()
- }
-
- val baseFields = baseMembers flatMap {
- case (bGetter, bSetter, bTpe) => {
- val accessors = subMembers map {
- _ find { sf =>
- sf.getter.name == bGetter.name && sf.tpe.termSymbol.asMethod.returnType <:< bTpe.termSymbol.asMethod.returnType
- }
- }
- accessors.forall { _.isDefined } match {
- case true => Some(FieldAccessor(bGetter, bSetter, bTpe, true, analyze(bTpe.termSymbol.asMethod.returnType)))
- case false => None
- }
- }
- }
-
- def wireBaseFields(desc: UDTDescriptor): UDTDescriptor = {
-
- def updateField(field: FieldAccessor) = {
- baseFields find { bf => bf.getter.name == field.getter.name } match {
- case Some(FieldAccessor(_, _, _, _, desc)) => field.copy(isBaseField = true, desc = desc)
- case None => field
- }
- }
-
- desc match {
- case desc @ BaseClassDescriptor(_, _, getters, subTypes) => desc.copy(getters = getters map updateField, subTypes = subTypes map wireBaseFields)
- case desc @ CaseClassDescriptor(_, _, _, _, getters) => desc.copy(getters = getters map updateField)
- case _ => desc
- }
- }
-
- //Debug.report("BaseClass " + tpe + " has shared fields: " + (baseFields.map { m => m.sym.name + ": " + m.tpe }))
- BaseClassDescriptor(id, tpe, tagField +: (baseFields.toSeq), subTypes map wireBaseFields)
- }
- }
-
- }
-
- private def analyzeCaseClass(id: Int, tpe: Type): UDTDescriptor = {
-
- tpe.baseClasses exists { bc => !(bc == tpe.typeSymbol) && bc.asClass.isCaseClass } match {
-
- case true => UnsupportedDescriptor(id, tpe, Seq("Case-to-case inheritance is not supported."))
-
- case false => {
-
- val ctors = tpe.declarations collect {
- case m: MethodSymbol if m.isPrimaryConstructor => m
- }
-
- ctors match {
- case c1 :: c2 :: _ => UnsupportedDescriptor(id, tpe, Seq("Multiple constructors found, this is not supported."))
- case c :: Nil => {
- val caseFields = c.paramss.flatten.map {
- sym =>
- {
- val methodSym = tpe.member(sym.name).asMethod
- (methodSym.getter, methodSym.setter, methodSym.returnType.asSeenFrom(tpe, tpe.typeSymbol))
- }
- }
- val fields = caseFields map {
- case (fgetter, fsetter, fTpe) => FieldAccessor(fgetter, fsetter, fTpe, false, analyze(fTpe))
- }
- val mutable = maybeVerbosely[Boolean](m => m && mutableTypes.add(tpe))(_ => "Detected recyclable type: " + tpe) {
- enableMutableUDTs && (fields forall { f => f.setter != NoSymbol })
- }
- fields filter { _.desc.isInstanceOf[UnsupportedDescriptor] } match {
- case errs @ _ :: _ => {
- val msgs = errs flatMap { f =>
- (f: @unchecked) match {
- case FieldAccessor(fgetter, _, _, _, UnsupportedDescriptor(_, fTpe, errors)) => errors map { err => "Field " + fgetter.name + ": " + fTpe + " - " + err }
- }
- }
- UnsupportedDescriptor(id, tpe, msgs)
- }
- case Nil => CaseClassDescriptor(id, tpe, mutable, c, fields.toSeq)
- }
- }
- }
- }
- }
- }
-
- private object PrimitiveType {
-
- def intPrimitive: (Type, Literal, Type) = {
- val (d, w) = primitives(definitions.IntClass)
- (definitions.IntTpe, d, w)
- }
-
- def unapply(tpe: Type): Option[(Literal, Type)] = primitives.get(tpe.typeSymbol)
- }
-
- private object BoxedPrimitiveType {
-
- def unapply(tpe: Type): Option[(Literal, Type, Tree => Tree, Tree => Tree)] = boxedPrimitives.get(tpe.typeSymbol)
- }
-
- private object ListType {
-
- def unapply(tpe: Type): Option[(Type, Tree => Tree)] = tpe match {
-
- case ArrayType(elemTpe) => {
- val iter = { source: Tree =>
- Select(source, "iterator": TermName)
- }
- Some(elemTpe, iter)
- }
-
- case TraversableType(elemTpe) => {
- val iter = { source: Tree => Select(source, "toIterator": TermName) }
- Some(elemTpe, iter)
- }
-
- case _ => None
- }
-
- private object ArrayType {
- def unapply(tpe: Type): Option[Type] = tpe match {
- case TypeRef(_, _, elemTpe :: Nil) if tpe <:< typeOf[Array[_]] => Some(elemTpe)
- case _ => None
- }
- }
-
- private object TraversableType {
- def unapply(tpe: Type): Option[Type] = tpe match {
- case _ if tpe <:< typeOf[GenTraversableOnce[_]] => {
- // val abstrElemTpe = genTraversableOnceClass.typeConstructor.typeParams.head.tpe
- // val elemTpe = abstrElemTpe.asSeenFrom(tpe, genTraversableOnceClass)
- // Some(elemTpe)
- // TODO make sure this shit works as it should
- tpe match {
- case TypeRef(_, _, elemTpe :: Nil) => Some(elemTpe.asSeenFrom(tpe, tpe.typeSymbol))
- }
- }
- case _ => None
- }
- }
- }
-
- private object CaseClassType {
- def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
- }
-
- private object BaseClassType {
- def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isAbstractClass && tpe.typeSymbol.asClass.isSealed
- }
-
- private object PactValueType {
- def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.baseClasses exists { s => s.fullName == "org.apache.flink.types.Value" }
- }
-
- private class UDTAnalyzerCache {
-
- private val caches = new DynamicVariable[Map[Type, RecursiveDescriptor]](Map())
- private val idGen = new Counter
-
- def newId = idGen.next
-
- def getOrElseUpdate(tpe: Type)(orElse: Int => UDTDescriptor): UDTDescriptor = {
-
- val id = idGen.next
- val cache = caches.value
-
- cache.get(tpe) map { _.copy(id = id) } getOrElse {
- val ref = RecursiveDescriptor(id, tpe, id)
- caches.withValue(cache + (tpe -> ref)) {
- orElse(id)
- }
- }
- }
- }
- }
-
- lazy val primitives = Map[Symbol, (Literal, Type)](
- definitions.BooleanClass -> (Literal(Constant(false)), typeOf[BooleanValue]),
- definitions.ByteClass -> (Literal(Constant(0: Byte)), typeOf[ByteValue]),
- definitions.CharClass -> (Literal(Constant(0: Char)), typeOf[CharValue]),
- definitions.DoubleClass -> (Literal(Constant(0: Double)), typeOf[DoubleValue]),
- definitions.FloatClass -> (Literal(Constant(0: Float)), typeOf[FloatValue]),
- definitions.IntClass -> (Literal(Constant(0: Int)), typeOf[IntValue]),
- definitions.LongClass -> (Literal(Constant(0: Long)), typeOf[LongValue]),
- definitions.ShortClass -> (Literal(Constant(0: Short)), typeOf[ShortValue]),
- definitions.StringClass -> (Literal(Constant(null: String)), typeOf[StringValue]))
-
- lazy val boxedPrimitives = {
-
- def getBoxInfo(prim: Symbol, primName: String, boxName: String) = {
- val (default, wrapper) = primitives(prim)
- val box = { t: Tree =>
- Apply(Select(Select(Ident(newTermName("scala")), newTermName("Predef")), newTermName(primName + "2" + boxName)), List(t))
- }
- val unbox = { t: Tree =>
- Apply(Select(Select(Ident(newTermName("scala")), newTermName("Predef")), newTermName(boxName + "2" + primName)), List(t))
- }
- (default, wrapper, box, unbox)
- }
-
- Map(
- typeOf[java.lang.Boolean].typeSymbol -> getBoxInfo(definitions.BooleanClass, "boolean", "Boolean"),
- typeOf[java.lang.Byte].typeSymbol -> getBoxInfo(definitions.ByteClass, "byte", "Byte"),
- typeOf[java.lang.Character].typeSymbol -> getBoxInfo(definitions.CharClass, "char", "Character"),
- typeOf[java.lang.Double].typeSymbol -> getBoxInfo(definitions.DoubleClass, "double", "Double"),
- typeOf[java.lang.Float].typeSymbol -> getBoxInfo(definitions.FloatClass, "float", "Float"),
- typeOf[java.lang.Integer].typeSymbol -> getBoxInfo(definitions.IntClass, "int", "Integer"),
- typeOf[java.lang.Long].typeSymbol -> getBoxInfo(definitions.LongClass, "long", "Long"),
- typeOf[java.lang.Short].typeSymbol -> getBoxInfo(definitions.ShortClass, "short", "Short"))
- }
-
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTDescriptors.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTDescriptors.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTDescriptors.scala
deleted file mode 100644
index e57e7bb..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTDescriptors.scala
+++ /dev/null
@@ -1,158 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.language.postfixOps
-
-import scala.reflect.macros.Context
-import scala.reflect.classTag
-import scala.reflect.ClassTag
-import scala.Option.option2Iterable
-
-trait UDTDescriptors[C <: Context] { this: MacroContextHolder[C] =>
- import c.universe._
-
- abstract sealed class UDTDescriptor {
- val id: Int
- val tpe: Type
- val isPrimitiveProduct: Boolean = false
-
- def canBeKey: Boolean
-
- def mkRoot: UDTDescriptor = this
-
- def flatten: Seq[UDTDescriptor]
- def getters: Seq[FieldAccessor] = Seq()
-
- def select(member: String): Option[UDTDescriptor] = getters find { _.getter.name.toString == member } map { _.desc }
-
- def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
- case Nil => Seq(Some(this))
- case head :: tail => getters find { _.getter.name.toString == head } match {
- case None => Seq(None)
- case Some(d : FieldAccessor) => d.desc.select(tail)
- }
- }
-
- def findById(id: Int): Option[UDTDescriptor] = flatten.find { _.id == id }
-
- def findByType[T <: UDTDescriptor: ClassTag]: Seq[T] = {
- val clazz = classTag[T].runtimeClass
- flatten filter { item => clazz.isAssignableFrom(item.getClass) } map { _.asInstanceOf[T] }
- }
-
- def getRecursiveRefs: Seq[UDTDescriptor] = findByType[RecursiveDescriptor] flatMap { rd => findById(rd.refId) } map { _.mkRoot } distinct
- }
-
- case class UnsupportedDescriptor(id: Int, tpe: Type, errors: Seq[String]) extends UDTDescriptor {
- override def flatten = Seq(this)
-
- def canBeKey = false
- }
-
- case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type) extends UDTDescriptor {
- override val isPrimitiveProduct = true
- override def flatten = Seq(this)
- override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
- }
-
- case class BoxedPrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree) extends UDTDescriptor {
-
- override val isPrimitiveProduct = true
- override def flatten = Seq(this)
- override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
-
- override def hashCode() = (id, tpe, default, wrapper, "BoxedPrimitiveDescriptor").hashCode()
- override def equals(that: Any) = that match {
- case BoxedPrimitiveDescriptor(thatId, thatTpe, thatDefault, thatWrapper, _, _) => (id, tpe, default, wrapper).equals(thatId, thatTpe, thatDefault, thatWrapper)
- case _ => false
- }
- }
-
- case class ListDescriptor(id: Int, tpe: Type, iter: Tree => Tree, elem: UDTDescriptor) extends UDTDescriptor {
- override def canBeKey = false
- override def flatten = this +: elem.flatten
-
- def getInnermostElem: UDTDescriptor = elem match {
- case list: ListDescriptor => list.getInnermostElem
- case _ => elem
- }
-
- override def hashCode() = (id, tpe, elem).hashCode()
- override def equals(that: Any) = that match {
- case that @ ListDescriptor(thatId, thatTpe, _, thatElem) => (id, tpe, elem).equals((thatId, thatTpe, thatElem))
- case _ => false
- }
- }
-
- case class BaseClassDescriptor(id: Int, tpe: Type, override val getters: Seq[FieldAccessor], subTypes: Seq[UDTDescriptor]) extends UDTDescriptor {
- override def flatten = this +: ((getters flatMap { _.desc.flatten }) ++ (subTypes flatMap { _.flatten }))
- override def canBeKey = flatten forall { f => f.canBeKey }
-
- override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
- case Nil => getters flatMap { g => g.desc.select(Nil) }
- case head :: tail => getters find { _.getter.name.toString == head } match {
- case None => Seq(None)
- case Some(d : FieldAccessor) => d.desc.select(tail)
- }
- }
- }
-
- case class CaseClassDescriptor(id: Int, tpe: Type, mutable: Boolean, ctor: Symbol, override val getters: Seq[FieldAccessor]) extends UDTDescriptor {
-
- override val isPrimitiveProduct = !getters.isEmpty && getters.forall(_.desc.isPrimitiveProduct)
-
- override def mkRoot = this.copy(getters = getters map { _.copy(isBaseField = false) })
- override def flatten = this +: (getters flatMap { _.desc.flatten })
-
- override def canBeKey = flatten forall { f => f.canBeKey }
-
- // Hack: ignore the ctorTpe, since two Type instances representing
- // the same ctor function type don't appear to be considered equal.
- // Equality of the tpe and ctor fields implies equality of ctorTpe anyway.
- override def hashCode = (id, tpe, ctor, getters).hashCode
- override def equals(that: Any) = that match {
- case CaseClassDescriptor(thatId, thatTpe, thatMutable, thatCtor, thatGetters) => (id, tpe, mutable, ctor, getters).equals(thatId, thatTpe, thatMutable, thatCtor, thatGetters)
- case _ => false
- }
-
- override def select(path: List[String]): Seq[Option[UDTDescriptor]] = path match {
- case Nil => getters flatMap { g => g.desc.select(Nil) }
- case head :: tail => getters find { _.getter.name.toString == head } match {
- case None => Seq(None)
- case Some(d : FieldAccessor) => d.desc.select(tail)
- }
- }
- }
-
- case class FieldAccessor(getter: Symbol, setter: Symbol, tpe: Type, isBaseField: Boolean, desc: UDTDescriptor)
-
- case class RecursiveDescriptor(id: Int, tpe: Type, refId: Int) extends UDTDescriptor {
- override def flatten = Seq(this)
- override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
- }
-
- case class PactValueDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
- override val isPrimitiveProduct = true
- override def flatten = Seq(this)
- override def canBeKey = tpe <:< typeOf[org.apache.flink.types.Key[_]]
- }
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTGen.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTGen.scala
deleted file mode 100644
index c2293ff..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/UDTGen.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import scala.reflect.macros.Context
-
-import org.apache.flink.api.scala.analysis.UDT
-
-import org.apache.flink.types.ListValue
-import org.apache.flink.types.Record
-
-trait UDTGen[C <: Context] { this: MacroContextHolder[C] with UDTDescriptors[C] with UDTAnalyzer[C] with TreeGen[C] with SerializerGen[C] with SerializeMethodGen[C] with DeserializeMethodGen[C] with Loggers[C] =>
- import c.universe._
-
- def mkUdtClass[T: c.WeakTypeTag](): (ClassDef, Tree) = {
- val desc = getUDTDescriptor(weakTypeOf[T])
-
- val udtName = c.fresh[TypeName]("GeneratedUDTDescriptor")
- val udt = mkClass(udtName, Flag.FINAL, List(weakTypeOf[UDT[T]]), {
- val (ser, createSer) = mkUdtSerializerClass[T](creatorName = "createSerializer")
- val ctor = mkMethod(nme.CONSTRUCTOR.toString(), NoFlags, List(), NoType, {
- Block(List(mkSuperCall(Nil)), mkUnit)
- })
-
- List(ser, createSer, ctor, mkFieldTypes(desc), mkUDTIdToIndexMap(desc))
- })
-
- val (_, udtTpe) = typeCheck(udt)
-
- (udt, mkCtorCall(udtTpe, Nil))
- }
-
- private def mkFieldTypes(desc: UDTDescriptor): Tree = {
-
- mkVal("fieldTypes", Flag.OVERRIDE | Flag.FINAL, false, typeOf[Array[Class[_ <: org.apache.flink.types.Value]]], {
-
- val fieldTypes = getIndexFields(desc).toList map {
- case PrimitiveDescriptor(_, _, _, wrapper) => Literal(Constant(wrapper))
- case BoxedPrimitiveDescriptor(_, _, _, wrapper, _, _) => Literal(Constant(wrapper))
- case PactValueDescriptor(_, tpe) => Literal(Constant(tpe))
- case ListDescriptor(_, _, _, _) => Literal(Constant(typeOf[ListValue[org.apache.flink.types.Value]]))
- // Box inner instances of recursive types
- case RecursiveDescriptor(_, _, _) => Literal(Constant(typeOf[Record]))
- case BaseClassDescriptor(_, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- case CaseClassDescriptor(_, _, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- case UnsupportedDescriptor(_, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- }
- Apply(Select(Select(Ident("scala": TermName), "Array": TermName), "apply": TermName), fieldTypes)
- })
- }
-
- private def mkUDTIdToIndexMap(desc: UDTDescriptor): Tree = {
-
- mkVal("udtIdMap", Flag.OVERRIDE | Flag.FINAL, false, typeOf[Map[Int, Int]], {
-
- val fieldIds = getIndexFields(desc).toList map {
- case PrimitiveDescriptor(id, _, _, _) => Literal(Constant(id))
- case BoxedPrimitiveDescriptor(id, _, _, _, _, _) => Literal(Constant(id))
- case ListDescriptor(id, _, _, _) => Literal(Constant(id))
- case RecursiveDescriptor(id, _, _) => Literal(Constant(id))
- case PactValueDescriptor(id, _) => Literal(Constant(id))
- case BaseClassDescriptor(_, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- case CaseClassDescriptor(_, _, _, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- case UnsupportedDescriptor(_, _, _) => throw new RuntimeException("Illegal descriptor for basic record field.")
- }
- val fields = fieldIds.zipWithIndex map { case (id, idx) =>
- val idExpr = c.Expr[Int](id)
- val idxExpr = c.Expr[Int](Literal(Constant(idx)))
- reify { (idExpr.splice, idxExpr.splice) }.tree
- }
- Apply(Select(Select(Select(Ident("scala": TermName), "Predef": TermName), "Map": TermName), "apply": TermName), fields)
- })
- }
-
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Util.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Util.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Util.scala
deleted file mode 100644
index 278a5e2..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/Util.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.codegen
-
-import language.experimental.macros
-import scala.reflect.macros.Context
-
-import org.apache.flink.api.scala.analysis.UDT
-
-object Util {
-
- implicit def createUDT[T]: UDT[T] = macro createUDTImpl[T]
-
- def createUDTImpl[T: c.WeakTypeTag](c: Context): c.Expr[UDT[T]] = {
- import c.universe._
-
- val slave = MacroContextHolder.newMacroHelper(c)
-
- val (udt, createUdt) = slave.mkUdtClass[T]
-
- val udtResult = reify {
- c.Expr[UDT[T]](createUdt).splice
- }
-
- c.Expr[UDT[T]](Block(List(udt), udtResult.tree))
- }
-
- // filter out forwards that dont forward from one field position to the same field position
- def filterNonForwards(from: Array[Int], to: Array[Int]): Array[Int] = {
- from.zip(to).filter( z => z._1 == z._2).map { _._1}
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
new file mode 100644
index 0000000..5218745
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/crossDataSet.scala
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala
+
+import org.apache.commons.lang3.Validate
+import org.apache.flink.api.common.functions.{RichCrossFunction, CrossFunction}
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.java.operators._
+import org.apache.flink.api.java.{DataSet => JavaDataSet}
+import org.apache.flink.api.scala.typeutils.{ScalaTupleSerializer, ScalaTupleTypeInfo}
+import org.apache.flink.types.TypeInformation
+import org.apache.flink.util.Collector
+
+import scala.reflect.ClassTag
+
+/**
+ * A specific [[DataSet]] that results from a `cross` operation. The result of a default cross is a
+ * tuple containing the two values from the two sides of the cartesian product. The result of the
+ * cross can be changed by specifying a custom cross function using the `apply` method or by
+ * providing a [[RichCrossFunction]].
+ *
+ * Example:
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val crossResult = left.cross(right) {
+ * (left, right) => new MyCrossResult(left, right)
+ * }
+ * }}}
+ *
+ * @tparam T Type of the left input of the cross.
+ * @tparam O Type of the right input of the cross.
+ */
+trait CrossDataSet[T, O] extends DataSet[(T, O)] {
+
+ /**
+ * Creates a new [[DataSet]] where the result for each pair of elements is the result
+ * of the given function.
+ */
+ def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] by passing each pair of values to the given function.
+ * The function can output zero or more elements using the [[Collector]] which will form the
+ * result.
+ *
+ * A [[RichCrossFunction]] can be used to access the
+ * broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
+ */
+ def apply[R: TypeInformation: ClassTag](joiner: CrossFunction[T, O, R]): DataSet[R]
+}
+
+/**
+ * Private implementation for [[CrossDataSet]] to keep the implementation details, i.e. the
+ * parameters of the constructor, hidden.
+ */
+private[flink] class CrossDataSetImpl[T, O](
+ crossOperator: CrossOperator[T, O, (T, O)],
+ thisSet: JavaDataSet[T],
+ otherSet: JavaDataSet[O])
+ extends DataSet(crossOperator)
+ with CrossDataSet[T, O] {
+
+ def apply[R: TypeInformation: ClassTag](fun: (T, O) => R): DataSet[R] = {
+ Validate.notNull(fun, "Cross function must not be null.")
+ val crosser = new CrossFunction[T, O, R] {
+ def cross(left: T, right: O): R = {
+ fun(left, right)
+ }
+ }
+ val crossOperator = new CrossOperator[T, O, R](
+ thisSet,
+ otherSet,
+ crosser,
+ implicitly[TypeInformation[R]])
+ wrap(crossOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](crosser: CrossFunction[T, O, R]): DataSet[R] = {
+ Validate.notNull(crosser, "Cross function must not be null.")
+ val crossOperator = new CrossOperator[T, O, R](
+ thisSet,
+ otherSet,
+ crosser,
+ implicitly[TypeInformation[R]])
+ wrap(crossOperator)
+ }
+}
+
+private[flink] object CrossDataSetImpl {
+ def createCrossOperator[T, O](leftSet: JavaDataSet[T], rightSet: JavaDataSet[O]) = {
+ val crosser = new CrossFunction[T, O, (T, O)] {
+ def cross(left: T, right: O) = {
+ (left, right)
+ }
+ }
+ val returnType = new ScalaTupleTypeInfo[(T, O)](
+ classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
+
+ override def createSerializer: TypeSerializer[(T, O)] = {
+ val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
+ for (i <- 0 until getArity) {
+ fieldSerializers(i) = types(i).createSerializer
+ }
+
+ new ScalaTupleSerializer[(T, O)](classOf[(T, O)], fieldSerializers) {
+ override def createInstance(fields: Array[AnyRef]) = {
+ (fields(0).asInstanceOf[T], fields(1).asInstanceOf[O])
+ }
+ }
+ }
+ }
+ val crossOperator = new CrossOperator[T, O, (T, O)](leftSet, rightSet, crosser, returnType)
+
+ new CrossDataSetImpl(crossOperator, leftSet, rightSet)
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CoGroupFunction.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CoGroupFunction.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CoGroupFunction.scala
deleted file mode 100644
index 6c7e93b..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CoGroupFunction.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.scala.functions
-
-import java.util.{Iterator => JIterator}
-
-import org.apache.flink.api.scala.analysis.{UDTSerializer, UDT}
-import org.apache.flink.api.scala.analysis.UDF2
-
-import org.apache.flink.api.java.record.functions.{CoGroupFunction => JCoGroupFunction}
-import org.apache.flink.types.Record
-import org.apache.flink.util.Collector
-import org.apache.flink.configuration.Configuration
-
-
-abstract class CoGroupFunctionBase[LeftIn: UDT, RightIn: UDT, Out: UDT] extends JCoGroupFunction with Serializable {
- val leftInputUDT = implicitly[UDT[LeftIn]]
- val rightInputUDT = implicitly[UDT[RightIn]]
- val outputUDT = implicitly[UDT[Out]]
- val udf: UDF2[LeftIn, RightIn, Out] = new UDF2(leftInputUDT, rightInputUDT, outputUDT)
-
- protected val outputRecord = new Record()
-
- protected lazy val leftIterator: DeserializingIterator[LeftIn] = new DeserializingIterator(udf.getLeftInputDeserializer)
- protected lazy val leftForwardFrom: Array[Int] = udf.getLeftForwardIndexArrayFrom
- protected lazy val leftForwardTo: Array[Int] = udf.getLeftForwardIndexArrayTo
- protected lazy val rightIterator: DeserializingIterator[RightIn] = new DeserializingIterator(udf.getRightInputDeserializer)
- protected lazy val rightForwardFrom: Array[Int] = udf.getRightForwardIndexArrayFrom
- protected lazy val rightForwardTo: Array[Int] = udf.getRightForwardIndexArrayTo
- protected lazy val serializer: UDTSerializer[Out] = udf.getOutputSerializer
-
- override def open(config: Configuration) = {
- super.open(config)
-
- this.outputRecord.setNumFields(udf.getOutputLength)
- }
-}
-
-abstract class CoGroupFunction[LeftIn: UDT, RightIn: UDT, Out: UDT] extends CoGroupFunctionBase[LeftIn, RightIn, Out] with Function2[Iterator[LeftIn], Iterator[RightIn], Out] {
- override def coGroup(leftRecords: JIterator[Record], rightRecords: JIterator[Record], out: Collector[Record]) = {
- val firstLeftRecord = leftIterator.initialize(leftRecords)
- val firstRightRecord = rightIterator.initialize(rightRecords)
-
- if (firstRightRecord != null) {
- outputRecord.copyFrom(firstRightRecord, rightForwardFrom, rightForwardTo)
- }
- if (firstLeftRecord != null) {
- outputRecord.copyFrom(firstLeftRecord, leftForwardFrom, leftForwardTo)
- }
-
- val output = apply(leftIterator, rightIterator)
-
- serializer.serialize(output, outputRecord)
- out.collect(outputRecord)
- }
-}
-
-abstract class FlatCoGroupFunction[LeftIn: UDT, RightIn: UDT, Out: UDT] extends CoGroupFunctionBase[LeftIn, RightIn, Out] with Function2[Iterator[LeftIn], Iterator[RightIn], Iterator[Out]] {
- override def coGroup(leftRecords: JIterator[Record], rightRecords: JIterator[Record], out: Collector[Record]) = {
- val firstLeftRecord = leftIterator.initialize(leftRecords)
- outputRecord.copyFrom(firstLeftRecord, leftForwardFrom, leftForwardTo)
-
- val firstRightRecord = rightIterator.initialize(rightRecords)
- outputRecord.copyFrom(firstRightRecord, rightForwardFrom, rightForwardTo)
-
- val output = apply(leftIterator, rightIterator)
-
- if (output.nonEmpty) {
-
- for (item <- output) {
- serializer.serialize(item, outputRecord)
- out.collect(outputRecord)
- }
- }
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CrossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CrossFunction.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CrossFunction.scala
deleted file mode 100644
index c292100..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/CrossFunction.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.functions
-
-import org.apache.flink.api.scala.analysis.{UDTSerializer, UDT}
-import org.apache.flink.api.scala.analysis.UDF2
-
-import org.apache.flink.api.java.record.functions.{CrossFunction => JCrossFunction}
-import org.apache.flink.types.Record
-import org.apache.flink.util.Collector
-
-abstract class CrossFunctionBase[LeftIn: UDT, RightIn: UDT, Out: UDT] extends JCrossFunction with Serializable {
- val leftInputUDT = implicitly[UDT[LeftIn]]
- val rightInputUDT = implicitly[UDT[RightIn]]
- val outputUDT = implicitly[UDT[Out]]
- val udf: UDF2[LeftIn, RightIn, Out] = new UDF2(leftInputUDT, rightInputUDT, outputUDT)
-
- protected lazy val leftDeserializer: UDTSerializer[LeftIn] = udf.getLeftInputDeserializer
- protected lazy val leftForwardFrom: Array[Int] = udf.getLeftForwardIndexArrayFrom
- protected lazy val leftForwardTo: Array[Int] = udf.getLeftForwardIndexArrayTo
- protected lazy val leftDiscard: Array[Int] = udf.getLeftDiscardIndexArray.filter(_ < udf.getOutputLength)
- protected lazy val rightDeserializer: UDTSerializer[RightIn] = udf.getRightInputDeserializer
- protected lazy val rightForwardFrom: Array[Int] = udf.getRightForwardIndexArrayFrom
- protected lazy val rightForwardTo: Array[Int] = udf.getRightForwardIndexArrayTo
- protected lazy val serializer: UDTSerializer[Out] = udf.getOutputSerializer
- protected lazy val outputLength: Int = udf.getOutputLength
-
-}
-
-abstract class CrossFunction[LeftIn: UDT, RightIn: UDT, Out: UDT] extends CrossFunctionBase[LeftIn, RightIn, Out] with Function2[LeftIn, RightIn, Out] {
- override def cross(leftRecord: Record, rightRecord: Record) : Record = {
- val left = leftDeserializer.deserializeRecyclingOn(leftRecord)
- val right = rightDeserializer.deserializeRecyclingOn(rightRecord)
- val output = apply(left, right)
-
- leftRecord.setNumFields(outputLength)
-
- for (field <- leftDiscard)
- leftRecord.setNull(field)
-
- leftRecord.copyFrom(rightRecord, rightForwardFrom, rightForwardTo)
- leftRecord.copyFrom(leftRecord, leftForwardFrom, leftForwardTo)
-
- serializer.serialize(output, leftRecord)
- leftRecord
- }
-}
-
-
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/DeserializingIterator.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/DeserializingIterator.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/DeserializingIterator.scala
deleted file mode 100644
index 0d5c128..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/DeserializingIterator.scala
+++ /dev/null
@@ -1,61 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.functions
-
-import java.util.{ Iterator => JIterator }
-
-import org.apache.flink.api.scala.analysis.UDTSerializer
-
-import org.apache.flink.types.Record
-
-protected final class DeserializingIterator[T](deserializer: UDTSerializer[T]) extends Iterator[T] {
-
- private var source: JIterator[Record] = null
- private var first: Record = null
- private var fresh = true
-
- final def initialize(records: JIterator[Record]): Record = {
- source = records
-
- if (source.hasNext) {
- fresh = true
- first = source.next()
- } else {
- fresh = false
- first = null
- }
-
- first
- }
-
- final def hasNext = fresh || source.hasNext
-
- final def next(): T = {
-
- if (fresh) {
- fresh = false
- val record = deserializer.deserializeRecyclingOff(first)
- first = null
- record
- } else {
- deserializer.deserializeRecyclingOff(source.next())
- }
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/JoinFunction.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/JoinFunction.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/JoinFunction.scala
deleted file mode 100644
index a1d2e2b..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/JoinFunction.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.functions
-
-import org.apache.flink.api.scala.analysis.{UDTSerializer, UDT}
-import org.apache.flink.api.scala.analysis.UDF2
-
-import org.apache.flink.api.java.record.functions.{JoinFunction => JJoinFunction}
-import org.apache.flink.types.Record
-import org.apache.flink.util.Collector
-
-abstract class JoinFunctionBase[LeftIn: UDT, RightIn: UDT, Out: UDT] extends JJoinFunction with Serializable {
- val leftInputUDT = implicitly[UDT[LeftIn]]
- val rightInputUDT = implicitly[UDT[RightIn]]
- val outputUDT = implicitly[UDT[Out]]
- val udf: UDF2[LeftIn, RightIn, Out] = new UDF2(leftInputUDT, rightInputUDT, outputUDT)
-
- protected lazy val leftDeserializer: UDTSerializer[LeftIn] = udf.getLeftInputDeserializer
- protected lazy val leftDiscard: Array[Int] = udf.getLeftDiscardIndexArray.filter(_ < udf.getOutputLength)
- protected lazy val leftForwardFrom: Array[Int] = udf.getLeftForwardIndexArrayFrom
- protected lazy val leftForwardTo: Array[Int] = udf.getLeftForwardIndexArrayTo
- protected lazy val rightDeserializer: UDTSerializer[RightIn] = udf.getRightInputDeserializer
- protected lazy val rightForwardFrom: Array[Int] = udf.getRightForwardIndexArrayFrom
- protected lazy val rightForwardTo: Array[Int] = udf.getRightForwardIndexArrayTo
- protected lazy val serializer: UDTSerializer[Out] = udf.getOutputSerializer
- protected lazy val outputLength: Int = udf.getOutputLength
-}
-
-abstract class JoinFunction[LeftIn: UDT, RightIn: UDT, Out: UDT] extends JoinFunctionBase[LeftIn, RightIn, Out] with Function2[LeftIn, RightIn, Out] {
- override def join(leftRecord: Record, rightRecord: Record, out: Collector[Record]) = {
- val left = leftDeserializer.deserializeRecyclingOn(leftRecord)
- val right = rightDeserializer.deserializeRecyclingOn(rightRecord)
- val output = apply(left, right)
-
- leftRecord.setNumFields(outputLength)
- for (field <- leftDiscard)
- leftRecord.setNull(field)
-
- leftRecord.copyFrom(rightRecord, rightForwardFrom, rightForwardTo)
- leftRecord.copyFrom(leftRecord, leftForwardFrom, leftForwardTo)
-
- serializer.serialize(output, leftRecord)
- out.collect(leftRecord)
- }
-}
-
-abstract class FlatJoinFunction[LeftIn: UDT, RightIn: UDT, Out: UDT] extends JoinFunctionBase[LeftIn, RightIn, Out] with Function2[LeftIn, RightIn, Iterator[Out]] {
- override def join(leftRecord: Record, rightRecord: Record, out: Collector[Record]) = {
- val left = leftDeserializer.deserializeRecyclingOn(leftRecord)
- val right = rightDeserializer.deserializeRecyclingOn(rightRecord)
- val output = apply(left, right)
-
- if (output.nonEmpty) {
-
- leftRecord.setNumFields(outputLength)
-
- for (field <- leftDiscard)
- leftRecord.setNull(field)
-
- leftRecord.copyFrom(rightRecord, rightForwardFrom, rightForwardTo)
- leftRecord.copyFrom(leftRecord, leftForwardFrom, leftForwardTo)
-
- for (item <- output) {
- serializer.serialize(item, leftRecord)
- out.collect(leftRecord)
- }
- }
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/MapFunction.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/MapFunction.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/MapFunction.scala
deleted file mode 100644
index 445d443..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/MapFunction.scala
+++ /dev/null
@@ -1,83 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.functions
-
-import org.apache.flink.api.scala.analysis.{UDTSerializer, UDT}
-import org.apache.flink.api.scala.analysis.UDF1
-
-import org.apache.flink.api.java.record.functions.{MapFunction => JMapFunction}
-import org.apache.flink.types.Record
-import org.apache.flink.util.Collector
-
-abstract class MapFunctionBase[In: UDT, Out: UDT] extends JMapFunction with Serializable{
- val inputUDT: UDT[In] = implicitly[UDT[In]]
- val outputUDT: UDT[Out] = implicitly[UDT[Out]]
- val udf: UDF1[In, Out] = new UDF1(inputUDT, outputUDT)
-
- protected lazy val deserializer: UDTSerializer[In] = udf.getInputDeserializer
- protected lazy val serializer: UDTSerializer[Out] = udf.getOutputSerializer
- protected lazy val discard: Array[Int] = udf.getDiscardIndexArray
- protected lazy val outputLength: Int = udf.getOutputLength
-}
-
-abstract class MapFunction[In: UDT, Out: UDT] extends MapFunctionBase[In, Out] with Function1[In, Out] {
- override def map(record: Record, out: Collector[Record]) = {
- val input = deserializer.deserializeRecyclingOn(record)
- val output = apply(input)
-
- record.setNumFields(outputLength)
-
- for (field <- discard)
- record.setNull(field)
-
- serializer.serialize(output, record)
- out.collect(record)
- }
-}
-
-abstract class FlatMapFunction[In: UDT, Out: UDT] extends MapFunctionBase[In, Out] with Function1[In, Iterator[Out]] {
- override def map(record: Record, out: Collector[Record]) = {
- val input = deserializer.deserializeRecyclingOn(record)
- val output = apply(input)
-
- if (output.nonEmpty) {
-
- record.setNumFields(outputLength)
-
- for (field <- discard)
- record.setNull(field)
-
- for (item <- output) {
-
- serializer.serialize(item, record)
- out.collect(record)
- }
- }
- }
-}
-
-abstract class FilterFunction[In: UDT, Out: UDT] extends MapFunctionBase[In, Out] with Function1[In, Boolean] {
- override def map(record: Record, out: Collector[Record]) = {
- val input = deserializer.deserializeRecyclingOn(record)
- if (apply(input)) {
- out.collect(record)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/ReduceFunction.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/ReduceFunction.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/ReduceFunction.scala
deleted file mode 100644
index de7b7d1..0000000
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/functions/ReduceFunction.scala
+++ /dev/null
@@ -1,102 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-package org.apache.flink.api.scala.functions
-
-import scala.Iterator
-
-import java.util.{Iterator => JIterator}
-
-import org.apache.flink.api.scala.analysis.{UDTSerializer, FieldSelector, UDT}
-import org.apache.flink.api.scala.analysis.UDF1
-
-import org.apache.flink.api.java.record.functions.{ReduceFunction => JReduceFunction}
-import org.apache.flink.types.Record
-import org.apache.flink.util.Collector
-
-
-
-abstract class ReduceFunctionBase[In: UDT, Out: UDT] extends JReduceFunction with Serializable {
- val inputUDT: UDT[In] = implicitly[UDT[In]]
- val outputUDT: UDT[Out] = implicitly[UDT[Out]]
- val udf: UDF1[In, Out] = new UDF1(inputUDT, outputUDT)
-
- protected val reduceRecord = new Record()
-
- protected lazy val reduceIterator: DeserializingIterator[In] = new DeserializingIterator(udf.getInputDeserializer)
- protected lazy val reduceSerializer: UDTSerializer[Out] = udf.getOutputSerializer
- protected lazy val reduceForwardFrom: Array[Int] = udf.getForwardIndexArrayFrom
- protected lazy val reduceForwardTo: Array[Int] = udf.getForwardIndexArrayTo
-}
-
-abstract class ReduceFunction[In: UDT] extends ReduceFunctionBase[In, In] with Function2[In, In, In] {
-
- override def combine(records: JIterator[Record], out: Collector[Record]) = {
- reduce(records, out)
- }
-
- override def reduce(records: JIterator[Record], out: Collector[Record]) = {
- val firstRecord = reduceIterator.initialize(records)
- reduceRecord.copyFrom(firstRecord, reduceForwardFrom, reduceForwardTo)
-
- val output = reduceIterator.reduce(apply)
-
- reduceSerializer.serialize(output, reduceRecord)
- out.collect(reduceRecord)
- }
-}
-
-abstract class GroupReduceFunction[In: UDT, Out: UDT] extends ReduceFunctionBase[In, Out] with Function1[Iterator[In], Out] {
- override def reduce(records: JIterator[Record], out: Collector[Record]) = {
- val firstRecord = reduceIterator.initialize(records)
- reduceRecord.copyFrom(firstRecord, reduceForwardFrom, reduceForwardTo)
-
- val output = apply(reduceIterator)
-
- reduceSerializer.serialize(output, reduceRecord)
- out.collect(reduceRecord)
- }
-}
-
-abstract class CombinableGroupReduceFunction[In: UDT, Out: UDT] extends ReduceFunctionBase[In, Out] with Function1[Iterator[In], Out] {
- override def combine(records: JIterator[Record], out: Collector[Record]) = {
- val firstRecord = reduceIterator.initialize(records)
- reduceRecord.copyFrom(firstRecord, reduceForwardFrom, reduceForwardTo)
-
- val output = combine(reduceIterator)
-
- reduceSerializer.serialize(output, reduceRecord)
- out.collect(reduceRecord)
- }
-
- override def reduce(records: JIterator[Record], out: Collector[Record]) = {
- val firstRecord = reduceIterator.initialize(records)
- reduceRecord.copyFrom(firstRecord, reduceForwardFrom, reduceForwardTo)
-
- val output = reduce(reduceIterator)
-
- reduceSerializer.serialize(output, reduceRecord)
- out.collect(reduceRecord)
- }
-
- def reduce(records: Iterator[In]): Out
- def combine(records: Iterator[In]): Out
-
- def apply(record: Iterator[In]): Out = throw new RuntimeException("This should never be called.")
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/b8131fa7/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
new file mode 100644
index 0000000..8d24ee1
--- /dev/null
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/joinDataSet.scala
@@ -0,0 +1,232 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala
+
+import org.apache.commons.lang3.Validate
+import org.apache.flink.api.common.InvalidProgramException
+import org.apache.flink.api.common.functions.{JoinFunction, RichFlatJoinFunction, FlatJoinFunction}
+import org.apache.flink.api.common.typeutils.TypeSerializer
+import org.apache.flink.api.java.operators.JoinOperator.DefaultJoin.WrappingFlatJoinFunction
+import org.apache.flink.api.java.operators.JoinOperator.{EquiJoin, JoinHint}
+import org.apache.flink.api.java.operators._
+import org.apache.flink.api.java.{DataSet => JavaDataSet}
+import org.apache.flink.api.scala.typeutils.{ScalaTupleSerializer, ScalaTupleTypeInfo}
+import org.apache.flink.types.TypeInformation
+import org.apache.flink.util.Collector
+
+import scala.reflect.ClassTag
+
+/**
+ * A specific [[DataSet]] that results from a `join` operation. The result of a default join is a
+ * tuple containing the two values from the two sides of the join. The result of the join can be
+ * changed by specifying a custom join function using the `apply` method or by providing a
+ * [[RichFlatJoinFunction]].
+ *
+ * Example:
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val joinResult = left.join(right).where(0, 2).isEqualTo(0, 1) {
+ * (left, right) => new MyJoinResult(left, right)
+ * }
+ * }}}
+ *
+ * Or, using key selector functions with tuple data types:
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val joinResult = left.join(right).where({_._1}).isEqualTo({_._1) {
+ * (left, right) => new MyJoinResult(left, right)
+ * }
+ * }}}
+ *
+ * @tparam T Type of the left input of the join.
+ * @tparam O Type of the right input of the join.
+ */
+trait JoinDataSet[T, O] extends DataSet[(T, O)] {
+
+ /**
+ * Creates a new [[DataSet]] where the result for each pair of joined elements is the result
+ * of the given function. You can either return an element or choose to return [[None]],
+ * which allows implementing a filter directly in the join function.
+ */
+ def apply[R: TypeInformation: ClassTag](fun: (T, O) => Option[R]): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] by passing each pair of joined values to the given function.
+ * The function can output zero or more elements using the [[Collector]] which will form the
+ * result.
+ */
+ def apply[R: TypeInformation: ClassTag](fun: (T, O, Collector[R]) => Unit): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] by passing each pair of joined values to the given function.
+ * The function can output zero or more elements using the [[Collector]] which will form the
+ * result.
+ *
+ * A [[RichFlatJoinFunction]] can be used to access the
+ * broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
+ */
+ def apply[R: TypeInformation: ClassTag](joiner: FlatJoinFunction[T, O, R]): DataSet[R]
+
+ /**
+ * Creates a new [[DataSet]] by passing each pair of joined values to the given function.
+ * The function must output one value. The concatenation of those will be new the DataSet.
+ *
+ * A [[org.apache.flink.api.common.functions.RichJoinFunction]] can be used to access the
+ * broadcast variables and the [[org.apache.flink.api.common.functions.RuntimeContext]].
+ */
+ def apply[R: TypeInformation: ClassTag](joiner: JoinFunction[T, O, R]): DataSet[R]
+}
+
+/**
+ * Private implementation for [[JoinDataSet]] to keep the implementation details, i.e. the
+ * parameters of the constructor, hidden.
+ */
+private[flink] class JoinDataSetImpl[T, O](
+ joinOperator: EquiJoin[T, O, (T, O)],
+ thisSet: JavaDataSet[T],
+ otherSet: JavaDataSet[O],
+ thisKeys: Keys[T],
+ otherKeys: Keys[O])
+ extends DataSet(joinOperator)
+ with JoinDataSet[T, O] {
+
+ def apply[R: TypeInformation: ClassTag](fun: (T, O) => Option[R]): DataSet[R] = {
+ Validate.notNull(fun, "Join function must not be null.")
+ val joiner = new FlatJoinFunction[T, O, R] {
+ def join(left: T, right: O, out: Collector[R]) = {
+ fun(left, right) map { out.collect(_) }
+ }
+ }
+ val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
+ wrap(joinOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](fun: (T, O, Collector[R]) => Unit): DataSet[R] = {
+ Validate.notNull(fun, "Join function must not be null.")
+ val joiner = new FlatJoinFunction[T, O, R] {
+ def join(left: T, right: O, out: Collector[R]) = {
+ fun(left, right, out)
+ }
+ }
+ val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
+ wrap(joinOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](joiner: FlatJoinFunction[T, O, R]): DataSet[R] = {
+ Validate.notNull(joiner, "Join function must not be null.")
+ val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, joiner, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
+ wrap(joinOperator)
+ }
+
+ def apply[R: TypeInformation: ClassTag](fun: JoinFunction[T, O, R]): DataSet[R] = {
+ Validate.notNull(fun, "Join function must not be null.")
+
+ val generatedFunction: FlatJoinFunction[T, O, R] = new WrappingFlatJoinFunction[T, O, R](fun)
+
+ val joinOperator = new EquiJoin[T, O, R](thisSet, otherSet, thisKeys,
+ otherKeys, generatedFunction, implicitly[TypeInformation[R]], JoinHint.OPTIMIZER_CHOOSES)
+ wrap(joinOperator)
+ }
+}
+
+/**
+ * An unfinished join operation that results from [[DataSet.join()]] The keys for the left and right
+ * side must be specified using first `where` and then `isEqualTo`. For example:
+ *
+ * {{{
+ * val left = ...
+ * val right = ...
+ * val joinResult = left.join(right).where(...).isEqualTo(...)
+ * }}}
+ * @tparam T The type of the left input of the join.
+ * @tparam O The type of the right input of the join.
+ */
+trait UnfinishedJoinOperation[T, O] extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]]
+
+/**
+ * Private implementation for [[UnfinishedJoinOperation]] to keep the implementation details,
+ * i.e. the parameters of the constructor, hidden.
+ */
+private[flink] class UnfinishedJoinOperationImpl[T, O](
+ leftSet: JavaDataSet[T],
+ rightSet: JavaDataSet[O],
+ joinHint: JoinHint)
+ extends UnfinishedKeyPairOperation[T, O, JoinDataSet[T, O]](leftSet, rightSet)
+ with UnfinishedJoinOperation[T, O] {
+
+ private[flink] def finish(leftKey: Keys[T], rightKey: Keys[O]) = {
+ val joiner = new FlatJoinFunction[T, O, (T, O)] {
+ def join(left: T, right: O, out: Collector[(T, O)]) = {
+ out.collect((left, right))
+ }
+ }
+ val returnType = new ScalaTupleTypeInfo[(T, O)](
+ classOf[(T, O)], Seq(leftSet.getType, rightSet.getType)) {
+
+ override def createSerializer: TypeSerializer[(T, O)] = {
+ val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](getArity)
+ for (i <- 0 until getArity()) {
+ fieldSerializers(i) = types(i).createSerializer
+ }
+
+ new ScalaTupleSerializer[(T, O)](classOf[(T, O)], fieldSerializers) {
+ override def createInstance(fields: Array[AnyRef]) = {
+ (fields(0).asInstanceOf[T], fields(1).asInstanceOf[O])
+ }
+ }
+ }
+ }
+ val joinOperator = new EquiJoin[T, O, (T, O)](
+ leftSet, rightSet, leftKey, rightKey, joiner, returnType, joinHint)
+
+ // sanity check solution set key mismatches
+ leftSet match {
+ case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
+ leftKey match {
+ case keyFields: Keys.FieldPositionKeys[_] =>
+ val positions: Array[Int] = keyFields.computeLogicalKeyPositions
+ solutionSet.checkJoinKeyFields(positions)
+ case _ =>
+ throw new InvalidProgramException("Currently, the solution set may only be joined " +
+ "with " +
+ "using tuple field positions.")
+ }
+ case _ =>
+ }
+ rightSet match {
+ case solutionSet: DeltaIteration.SolutionSetPlaceHolder[_] =>
+ rightKey match {
+ case keyFields: Keys.FieldPositionKeys[_] =>
+ val positions: Array[Int] = keyFields.computeLogicalKeyPositions
+ solutionSet.checkJoinKeyFields(positions)
+ case _ =>
+ throw new InvalidProgramException("Currently, the solution set may only be joined " +
+ "with " +
+ "using tuple field positions.")
+ }
+ case _ =>
+ }
+
+ new JoinDataSetImpl(joinOperator, leftSet, rightSet, leftKey, rightKey)
+ }
+}
\ No newline at end of file