You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2017/04/06 16:47:23 UTC
[1/5] flink git commit: [FLINK-6257] [table] Consistent naming of
ProcessFunction and methods for OVER windows.
Repository: flink
Updated Branches:
refs/heads/master 5ff9c99ff -> c5173fa26
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
new file mode 100644
index 0000000..525d4d7
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeUnboundedOver.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.types.Row
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.java.typeutils.ListTypeInfo
+import org.apache.flink.streaming.api.operators.TimestampedCollector
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+
+/**
+ * A ProcessFunction to support unbounded event-time over-window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param intermediateType the intermediate row tye which the state saved
+ * @param inputType the input row tye which the state saved
+ */
+abstract class RowTimeUnboundedOver(
+ genAggregations: GeneratedAggregationsFunction,
+ intermediateType: TypeInformation[Row],
+ inputType: TypeInformation[Row])
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+
+ protected var output: Row = _
+ // state to hold the accumulators of the aggregations
+ private var accumulatorState: ValueState[Row] = _
+ // state to hold rows until the next watermark arrives
+ private var rowMapState: MapState[Long, JList[Row]] = _
+ // list to sort timestamps to access rows in timestamp order
+ private var sortedTimestamps: util.LinkedList[Long] = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ protected var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ sortedTimestamps = new util.LinkedList[Long]()
+
+ // initialize accumulator state
+ val accDescriptor: ValueStateDescriptor[Row] =
+ new ValueStateDescriptor[Row]("accumulatorstate", intermediateType)
+ accumulatorState = getRuntimeContext.getState[Row](accDescriptor)
+
+ // initialize row state
+ val rowListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputType)
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]]("rowmapstate",
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
+ rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
+ }
+
+ /**
+ * Puts an element from the input stream into state if it is not late.
+ * Registers a timer for the next watermark.
+ *
+ * @param input The input value.
+ * @param ctx The ctx to register timer or get current time
+ * @param out The collector for returning result values.
+ *
+ */
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ val timestamp = ctx.timestamp()
+ val curWatermark = ctx.timerService().currentWatermark()
+
+ // discard late record
+ if (timestamp >= curWatermark) {
+ // ensure every key just registers one timer
+ ctx.timerService.registerEventTimeTimer(curWatermark + 1)
+
+ // put row into state
+ var rowList = rowMapState.get(timestamp)
+ if (rowList == null) {
+ rowList = new util.ArrayList[Row]()
+ }
+ rowList.add(input)
+ rowMapState.put(timestamp, rowList)
+ }
+ }
+
+ /**
+ * Called when a watermark arrived.
+ * Sorts records according the timestamp, computes aggregates, and emits all records with
+ * timestamp smaller than the watermark in timestamp order.
+ *
+ * @param timestamp The timestamp of the firing timer.
+ * @param ctx The ctx to register timer or get current time
+ * @param out The collector for returning result values.
+ */
+ override def onTimer(
+ timestamp: Long,
+ ctx: ProcessFunction[Row, Row]#OnTimerContext,
+ out: Collector[Row]): Unit = {
+
+ Preconditions.checkArgument(out.isInstanceOf[TimestampedCollector[Row]])
+ val collector = out.asInstanceOf[TimestampedCollector[Row]]
+
+ val keyIterator = rowMapState.keys.iterator
+ if (keyIterator.hasNext) {
+ val curWatermark = ctx.timerService.currentWatermark
+ var existEarlyRecord: Boolean = false
+
+ // sort the record timestamps
+ do {
+ val recordTime = keyIterator.next
+ // only take timestamps smaller/equal to the watermark
+ if (recordTime <= curWatermark) {
+ insertToSortedList(recordTime)
+ } else {
+ existEarlyRecord = true
+ }
+ } while (keyIterator.hasNext)
+
+ // get last accumulator
+ var lastAccumulator = accumulatorState.value
+ if (lastAccumulator == null) {
+ // initialize accumulator
+ lastAccumulator = function.createAccumulators()
+ }
+
+ // emit the rows in order
+ while (!sortedTimestamps.isEmpty) {
+ val curTimestamp = sortedTimestamps.removeFirst()
+ val curRowList = rowMapState.get(curTimestamp)
+ collector.setAbsoluteTimestamp(curTimestamp)
+
+ // process the same timestamp datas, the mechanism is different according ROWS or RANGE
+ processElementsWithSameTimestamp(curRowList, lastAccumulator, collector)
+
+ rowMapState.remove(curTimestamp)
+ }
+
+ accumulatorState.update(lastAccumulator)
+
+ // if are are rows with timestamp > watermark, register a timer for the next watermark
+ if (existEarlyRecord) {
+ ctx.timerService.registerEventTimeTimer(curWatermark + 1)
+ }
+ }
+ }
+
+ /**
+ * Inserts timestamps in order into a linked list.
+ *
+ * If timestamps arrive in order (as in case of using the RocksDB state backend) this is just
+ * an append with O(1).
+ */
+ private def insertToSortedList(recordTimestamp: Long) = {
+ val listIterator = sortedTimestamps.listIterator(sortedTimestamps.size)
+ var continue = true
+ while (listIterator.hasPrevious && continue) {
+ val timestamp = listIterator.previous
+ if (recordTimestamp >= timestamp) {
+ listIterator.next
+ listIterator.add(recordTimestamp)
+ continue = false
+ }
+ }
+
+ if (continue) {
+ sortedTimestamps.addFirst(recordTimestamp)
+ }
+ }
+
+ /**
+ * Process the same timestamp datas, the mechanism is different between
+ * rows and range window.
+ */
+ def processElementsWithSameTimestamp(
+ curRowList: JList[Row],
+ lastAccumulator: Row,
+ out: Collector[Row]): Unit
+
+}
+
+/**
+ * A ProcessFunction to support unbounded ROWS window.
+ * The ROWS clause defines on a physical level how many rows are included in a window frame.
+ */
+class RowTimeUnboundedRowsOver(
+ genAggregations: GeneratedAggregationsFunction,
+ intermediateType: TypeInformation[Row],
+ inputType: TypeInformation[Row])
+ extends RowTimeUnboundedOver(
+ genAggregations: GeneratedAggregationsFunction,
+ intermediateType,
+ inputType) {
+
+ override def processElementsWithSameTimestamp(
+ curRowList: JList[Row],
+ lastAccumulator: Row,
+ out: Collector[Row]): Unit = {
+
+ var i = 0
+ while (i < curRowList.size) {
+ val curRow = curRowList.get(i)
+
+ var j = 0
+ // copy forwarded fields to output row
+ function.setForwardedFields(curRow, output)
+
+ // update accumulators and copy aggregates to output row
+ function.accumulate(lastAccumulator, curRow)
+ function.setAggregationResults(lastAccumulator, output)
+ // emit output row
+ out.collect(output)
+ i += 1
+ }
+ }
+}
+
+
+/**
+ * A ProcessFunction to support unbounded RANGE window.
+ * The RANGE option includes all the rows within the window frame
+ * that have the same ORDER BY values as the current row.
+ */
+class RowTimeUnboundedRangeOver(
+ genAggregations: GeneratedAggregationsFunction,
+ intermediateType: TypeInformation[Row],
+ inputType: TypeInformation[Row])
+ extends RowTimeUnboundedOver(
+ genAggregations: GeneratedAggregationsFunction,
+ intermediateType,
+ inputType) {
+
+ override def processElementsWithSameTimestamp(
+ curRowList: JList[Row],
+ lastAccumulator: Row,
+ out: Collector[Row]): Unit = {
+
+ var i = 0
+ // all same timestamp data should have same aggregation value.
+ while (i < curRowList.size) {
+ val curRow = curRowList.get(i)
+
+ function.accumulate(lastAccumulator, curRow)
+ i += 1
+ }
+
+ // emit output row
+ i = 0
+ while (i < curRowList.size) {
+ val curRow = curRowList.get(i)
+
+ // copy forwarded fields to output row
+ function.setForwardedFields(curRow, output)
+
+ //copy aggregates to output row
+ function.setAggregationResults(lastAccumulator, output)
+ out.collect(output)
+ i += 1
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
deleted file mode 100644
index 4539164..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowsClauseBoundedOverProcessFunction.scala
+++ /dev/null
@@ -1,222 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.util
-import java.util.{List => JList}
-
-import org.apache.flink.api.common.state._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function for ROWS clause event-time bounded OVER window
- *
- * @param genAggregations Generated aggregate helper function
- * @param aggregationStateType row type info of aggregation
- * @param inputRowType row type info of input row
- * @param precedingOffset preceding offset
- */
-class RowsClauseBoundedOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- aggregationStateType: RowTypeInfo,
- inputRowType: RowTypeInfo,
- precedingOffset: Long)
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- Preconditions.checkNotNull(aggregationStateType)
- Preconditions.checkNotNull(precedingOffset)
-
- private var output: Row = _
-
- // the state which keeps the last triggering timestamp
- private var lastTriggeringTsState: ValueState[Long] = _
-
- // the state which keeps the count of data
- private var dataCountState: ValueState[Long] = _
-
- // the state which used to materialize the accumulator for incremental calculation
- private var accumulatorState: ValueState[Row] = _
-
- // the state which keeps all the data that are not expired.
- // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
- // the second element of tuple is a list that contains the entire data of all the rows belonging
- // to this time stamp.
- private var dataState: MapState[Long, JList[Row]] = _
-
- val LOG = LoggerFactory.getLogger(this.getClass)
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
-
- val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
- lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
-
- val dataCountStateDescriptor =
- new ValueStateDescriptor[Long]("dataCountState", classOf[Long])
- dataCountState = getRuntimeContext.getState(dataCountStateDescriptor)
-
- val accumulatorStateDescriptor =
- new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
- accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
-
- val keyTypeInformation: TypeInformation[Long] =
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
- val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
-
- val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]](
- "dataState",
- keyTypeInformation,
- valueTypeInformation)
-
- dataState = getRuntimeContext.getMapState(mapStateDescriptor)
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- // triggering timestamp for trigger calculation
- val triggeringTs = ctx.timestamp
-
- val lastTriggeringTs = lastTriggeringTsState.value
- // check if the data is expired, if not, save the data and register event time timer
-
- if (triggeringTs > lastTriggeringTs) {
- val data = dataState.get(triggeringTs)
- if (null != data) {
- data.add(input)
- dataState.put(triggeringTs, data)
- } else {
- val data = new util.ArrayList[Row]
- data.add(input)
- dataState.put(triggeringTs, data)
- // register event time timer
- ctx.timerService.registerEventTimeTimer(triggeringTs)
- }
- }
- }
-
- override def onTimer(
- timestamp: Long,
- ctx: ProcessFunction[Row, Row]#OnTimerContext,
- out: Collector[Row]): Unit = {
-
- // gets all window data from state for the calculation
- val inputs: JList[Row] = dataState.get(timestamp)
-
- if (null != inputs) {
-
- var accumulators = accumulatorState.value
- var dataCount = dataCountState.value
-
- var retractList: JList[Row] = null
- var retractTs: Long = Long.MaxValue
- var retractCnt: Int = 0
- var i = 0
-
- while (i < inputs.size) {
- val input = inputs.get(i)
-
- // initialize when first run or failover recovery per key
- if (null == accumulators) {
- accumulators = function.createAccumulators()
- }
-
- var retractRow: Row = null
-
- if (dataCount >= precedingOffset) {
- if (null == retractList) {
- // find the smallest timestamp
- retractTs = Long.MaxValue
- val dataTimestampIt = dataState.keys.iterator
- while (dataTimestampIt.hasNext) {
- val dataTs = dataTimestampIt.next
- if (dataTs < retractTs) {
- retractTs = dataTs
- }
- }
- // get the oldest rows to retract them
- retractList = dataState.get(retractTs)
- }
-
- retractRow = retractList.get(retractCnt)
- retractCnt += 1
-
- // remove retracted values from state
- if (retractList.size == retractCnt) {
- dataState.remove(retractTs)
- retractList = null
- retractCnt = 0
- }
- } else {
- dataCount += 1
- }
-
- // copy forwarded fields to output row
- function.setForwardedFields(input, output)
-
- // retract old row from accumulators
- if (null != retractRow) {
- function.retract(accumulators, retractRow)
- }
-
- // accumulate current row and set aggregate in output row
- function.accumulate(accumulators, input)
- function.setAggregationResults(accumulators, output)
- i += 1
-
- out.collect(output)
- }
-
- // update all states
- if (dataState.contains(retractTs)) {
- if (retractCnt > 0) {
- retractList.subList(0, retractCnt).clear()
- dataState.put(retractTs, retractList)
- }
- }
- dataCountState.update(dataCount)
- accumulatorState.update(accumulators)
- }
-
- lastTriggeringTsState.update(timestamp)
- }
-}
-
-
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala
deleted file mode 100644
index cca3e3f..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedEventTimeOverProcessFunction.scala
+++ /dev/null
@@ -1,292 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.util
-import java.util.{List => JList}
-
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.types.Row
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.util.{Collector, Preconditions}
-import org.apache.flink.api.common.state._
-import org.apache.flink.api.java.typeutils.ListTypeInfo
-import org.apache.flink.streaming.api.operators.TimestampedCollector
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-
-/**
- * A ProcessFunction to support unbounded event-time over-window
- *
- * @param genAggregations Generated aggregate helper function
- * @param intermediateType the intermediate row tye which the state saved
- * @param inputType the input row tye which the state saved
- */
-abstract class UnboundedEventTimeOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- intermediateType: TypeInformation[Row],
- inputType: TypeInformation[Row])
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- protected var output: Row = _
- // state to hold the accumulators of the aggregations
- private var accumulatorState: ValueState[Row] = _
- // state to hold rows until the next watermark arrives
- private var rowMapState: MapState[Long, JList[Row]] = _
- // list to sort timestamps to access rows in timestamp order
- private var sortedTimestamps: util.LinkedList[Long] = _
-
- val LOG = LoggerFactory.getLogger(this.getClass)
- protected var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
- sortedTimestamps = new util.LinkedList[Long]()
-
- // initialize accumulator state
- val accDescriptor: ValueStateDescriptor[Row] =
- new ValueStateDescriptor[Row]("accumulatorstate", intermediateType)
- accumulatorState = getRuntimeContext.getState[Row](accDescriptor)
-
- // initialize row state
- val rowListTypeInfo: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputType)
- val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]]("rowmapstate",
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
- rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
- }
-
- /**
- * Puts an element from the input stream into state if it is not late.
- * Registers a timer for the next watermark.
- *
- * @param input The input value.
- * @param ctx The ctx to register timer or get current time
- * @param out The collector for returning result values.
- *
- */
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- val timestamp = ctx.timestamp()
- val curWatermark = ctx.timerService().currentWatermark()
-
- // discard late record
- if (timestamp >= curWatermark) {
- // ensure every key just registers one timer
- ctx.timerService.registerEventTimeTimer(curWatermark + 1)
-
- // put row into state
- var rowList = rowMapState.get(timestamp)
- if (rowList == null) {
- rowList = new util.ArrayList[Row]()
- }
- rowList.add(input)
- rowMapState.put(timestamp, rowList)
- }
- }
-
- /**
- * Called when a watermark arrived.
- * Sorts records according the timestamp, computes aggregates, and emits all records with
- * timestamp smaller than the watermark in timestamp order.
- *
- * @param timestamp The timestamp of the firing timer.
- * @param ctx The ctx to register timer or get current time
- * @param out The collector for returning result values.
- */
- override def onTimer(
- timestamp: Long,
- ctx: ProcessFunction[Row, Row]#OnTimerContext,
- out: Collector[Row]): Unit = {
-
- Preconditions.checkArgument(out.isInstanceOf[TimestampedCollector[Row]])
- val collector = out.asInstanceOf[TimestampedCollector[Row]]
-
- val keyIterator = rowMapState.keys.iterator
- if (keyIterator.hasNext) {
- val curWatermark = ctx.timerService.currentWatermark
- var existEarlyRecord: Boolean = false
-
- // sort the record timestamps
- do {
- val recordTime = keyIterator.next
- // only take timestamps smaller/equal to the watermark
- if (recordTime <= curWatermark) {
- insertToSortedList(recordTime)
- } else {
- existEarlyRecord = true
- }
- } while (keyIterator.hasNext)
-
- // get last accumulator
- var lastAccumulator = accumulatorState.value
- if (lastAccumulator == null) {
- // initialize accumulator
- lastAccumulator = function.createAccumulators()
- }
-
- // emit the rows in order
- while (!sortedTimestamps.isEmpty) {
- val curTimestamp = sortedTimestamps.removeFirst()
- val curRowList = rowMapState.get(curTimestamp)
- collector.setAbsoluteTimestamp(curTimestamp)
-
- // process the same timestamp datas, the mechanism is different according ROWS or RANGE
- processElementsWithSameTimestamp(curRowList, lastAccumulator, collector)
-
- rowMapState.remove(curTimestamp)
- }
-
- accumulatorState.update(lastAccumulator)
-
- // if are are rows with timestamp > watermark, register a timer for the next watermark
- if (existEarlyRecord) {
- ctx.timerService.registerEventTimeTimer(curWatermark + 1)
- }
- }
- }
-
- /**
- * Inserts timestamps in order into a linked list.
- *
- * If timestamps arrive in order (as in case of using the RocksDB state backend) this is just
- * an append with O(1).
- */
- private def insertToSortedList(recordTimestamp: Long) = {
- val listIterator = sortedTimestamps.listIterator(sortedTimestamps.size)
- var continue = true
- while (listIterator.hasPrevious && continue) {
- val timestamp = listIterator.previous
- if (recordTimestamp >= timestamp) {
- listIterator.next
- listIterator.add(recordTimestamp)
- continue = false
- }
- }
-
- if (continue) {
- sortedTimestamps.addFirst(recordTimestamp)
- }
- }
-
- /**
- * Process the same timestamp datas, the mechanism is different between
- * rows and range window.
- */
- def processElementsWithSameTimestamp(
- curRowList: JList[Row],
- lastAccumulator: Row,
- out: Collector[Row]): Unit
-
-}
-
-/**
- * A ProcessFunction to support unbounded ROWS window.
- * The ROWS clause defines on a physical level how many rows are included in a window frame.
- */
-class UnboundedEventTimeRowsOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- intermediateType: TypeInformation[Row],
- inputType: TypeInformation[Row])
- extends UnboundedEventTimeOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- intermediateType,
- inputType) {
-
- override def processElementsWithSameTimestamp(
- curRowList: JList[Row],
- lastAccumulator: Row,
- out: Collector[Row]): Unit = {
-
- var i = 0
- while (i < curRowList.size) {
- val curRow = curRowList.get(i)
-
- var j = 0
- // copy forwarded fields to output row
- function.setForwardedFields(curRow, output)
-
- // update accumulators and copy aggregates to output row
- function.accumulate(lastAccumulator, curRow)
- function.setAggregationResults(lastAccumulator, output)
- // emit output row
- out.collect(output)
- i += 1
- }
- }
-}
-
-
-/**
- * A ProcessFunction to support unbounded RANGE window.
- * The RANGE option includes all the rows within the window frame
- * that have the same ORDER BY values as the current row.
- */
-class UnboundedEventTimeRangeOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- intermediateType: TypeInformation[Row],
- inputType: TypeInformation[Row])
- extends UnboundedEventTimeOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- intermediateType,
- inputType) {
-
- override def processElementsWithSameTimestamp(
- curRowList: JList[Row],
- lastAccumulator: Row,
- out: Collector[Row]): Unit = {
-
- var i = 0
- // all same timestamp data should have same aggregation value.
- while (i < curRowList.size) {
- val curRow = curRowList.get(i)
-
- function.accumulate(lastAccumulator, curRow)
- i += 1
- }
-
- // emit output row
- i = 0
- while (i < curRowList.size) {
- val curRow = curRowList.get(i)
-
- // copy forwarded fields to output row
- function.setForwardedFields(curRow, output)
-
- //copy aggregates to output row
- function.setAggregationResults(lastAccumulator, output)
- out.collect(output)
- i += 1
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
deleted file mode 100644
index 1a8399b..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedNonPartitionedProcessingOverProcessFunction.scala
+++ /dev/null
@@ -1,96 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
-import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.Collector
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function for non-partitioned processing-time unbounded OVER window
- *
- * @param genAggregations Generated aggregate helper function
- * @param aggregationStateType row type info of aggregation
- */
-class UnboundedNonPartitionedProcessingOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- aggregationStateType: RowTypeInfo)
- extends ProcessFunction[Row, Row]
- with CheckpointedFunction
- with Compiler[GeneratedAggregations] {
-
- private var accumulators: Row = _
- private var output: Row = _
- private var state: ListState[Row] = _
- val LOG = LoggerFactory.getLogger(this.getClass)
-
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
- if (null == accumulators) {
- val it = state.get().iterator()
- if (it.hasNext) {
- accumulators = it.next()
- } else {
- accumulators = function.createAccumulators()
- }
- }
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- function.setForwardedFields(input, output)
-
- function.accumulate(accumulators, input)
- function.setAggregationResults(accumulators, output)
-
- out.collect(output)
- }
-
- override def snapshotState(context: FunctionSnapshotContext): Unit = {
- state.clear()
- if (null != accumulators) {
- state.add(accumulators)
- }
- }
-
- override def initializeState(context: FunctionInitializationContext): Unit = {
- val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", aggregationStateType)
- state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedProcessingOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedProcessingOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedProcessingOverProcessFunction.scala
deleted file mode 100644
index 9a6d4f0..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/UnboundedProcessingOverProcessFunction.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.Collector
-import org.apache.flink.api.common.state.ValueStateDescriptor
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.api.common.state.ValueState
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function for processing-time unbounded OVER window
- *
- * @param genAggregations Generated aggregate helper function
- * @param aggregationStateType row type info of aggregation
- */
-class UnboundedProcessingOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- aggregationStateType: RowTypeInfo)
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- private var output: Row = _
- private var state: ValueState[Row] = _
- val LOG = LoggerFactory.getLogger(this.getClass)
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
- val stateDescriptor: ValueStateDescriptor[Row] =
- new ValueStateDescriptor[Row]("overState", aggregationStateType)
- state = getRuntimeContext.getState(stateDescriptor)
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- var accumulators = state.value()
-
- if (null == accumulators) {
- accumulators = function.createAccumulators()
- }
-
- function.setForwardedFields(input, output)
-
- function.accumulate(accumulators, input)
- function.setAggregationResults(accumulators, output)
-
- state.update(accumulators)
-
- out.collect(output)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
index 25ec36e..3610898 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunctionTest.scala
@@ -165,7 +165,7 @@ class BoundedProcessingOverRangeProcessFunctionTest {
val genAggFunction = GeneratedAggregationsFunction(funcName, funcCode)
val processFunction = new KeyedProcessOperator[String, Row, Row](
- new BoundedProcessingOverRangeProcessFunction(
+ new ProcTimeBoundedRangeOver(
genAggFunction,
1000,
aggregationStateType,
[3/5] flink git commit: [FLINK-5435] [table] Remove
FlinkAggregateJoinTransposeRule and FlinkRelDecorrelator after bumping
Calcite to v1.12.
Posted by fh...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
index 3e44526..7496ff8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/QueryDecorrelationTest.scala
@@ -40,56 +40,6 @@ class QueryDecorrelationTest extends TableTestBase {
"and e1.deptno < 10 and d1.deptno < 15\n" +
"and e1.salary > (select avg(salary) from emp e2 where e1.empno = e2.empno)"
- // the inner query "select avg(salary) from emp e2 where e1.empno = e2.empno" will be
- // decorrelated into a join and then groupby. And the filters
- // "e1.deptno < 10 and d1.deptno < 15" will also be pushed down before join.
- val decorrelatedSubQuery = unaryNode(
- "DataSetAggregate",
- unaryNode(
- "DataSetCalc",
- binaryNode(
- "DataSetJoin",
- unaryNode(
- "DataSetCalc",
- batchTableNode(0),
- term("select", "empno", "salary")
- ),
- unaryNode(
- "DataSetDistinct",
- unaryNode(
- "DataSetCalc",
- binaryNode(
- "DataSetJoin",
- unaryNode(
- "DataSetCalc",
- batchTableNode(0),
- term("select", "empno", "deptno"),
- term("where", "<(deptno, 10)")
- ),
- unaryNode(
- "DataSetCalc",
- batchTableNode(1),
- term("select", "deptno"),
- term("where", "<(deptno, 15)")
- ),
- term("where", "=(deptno, deptno0)"),
- term("join", "empno", "deptno", "deptno0"),
- term("joinType", "InnerJoin")
- ),
- term("select", "empno")
- ),
- term("distinct", "empno")
- ),
- term("where", "=(empno0, empno)"),
- term("join", "empno", "salary", "empno0"),
- term("joinType", "InnerJoin")
- ),
- term("select", "empno0", "salary")
- ),
- term("groupBy", "empno0"),
- term("select", "empno0", "AVG(salary) AS EXPR$0")
- )
-
val expectedQuery = unaryNode(
"DataSetCalc",
binaryNode(
@@ -112,7 +62,17 @@ class QueryDecorrelationTest extends TableTestBase {
term("join", "empno", "ename", "job", "salary", "deptno", "deptno0", "name"),
term("joinType", "InnerJoin")
),
- decorrelatedSubQuery,
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "salary", "empno"),
+ term("where", "IS NOT NULL(empno)")
+ ),
+ term("groupBy", "empno"),
+ term("select", "empno", "AVG(salary) AS EXPR$0")
+ ),
term("where", "AND(=(empno, empno0), >(salary, EXPR$0))"),
term("join", "empno", "ename", "job", "salary", "deptno",
"deptno0", "name", "empno0", "EXPR$0"),
@@ -132,51 +92,6 @@ class QueryDecorrelationTest extends TableTestBase {
" select avg(e2.salary) from emp e2 where e2.deptno = d1.deptno" +
")"
- val decorrelatedSubQuery = unaryNode(
- "DataSetAggregate",
- unaryNode(
- "DataSetCalc",
- binaryNode(
- "DataSetJoin",
- unaryNode(
- "DataSetCalc",
- batchTableNode(0),
- term("select", "salary", "deptno")
- ),
- unaryNode(
- "DataSetDistinct",
- unaryNode(
- "DataSetCalc",
- binaryNode(
- "DataSetJoin",
- unaryNode(
- "DataSetCalc",
- batchTableNode(0),
- term("select", "deptno")
- ),
- unaryNode(
- "DataSetCalc",
- batchTableNode(1),
- term("select", "deptno")
- ),
- term("where", "=(deptno, deptno0)"),
- term("join", "deptno", "deptno0"),
- term("joinType", "InnerJoin")
- ),
- term("select", "deptno0")
- ),
- term("distinct", "deptno0")
- ),
- term("where", "=(deptno, deptno0)"),
- term("join", "salary", "deptno", "deptno0"),
- term("joinType", "InnerJoin")
- ),
- term("select", "deptno0", "salary")
- ),
- term("groupBy", "deptno0"),
- term("select", "deptno0", "AVG(salary) AS EXPR$0")
- )
-
val expectedQuery = unaryNode(
"DataSetAggregate",
binaryNode(
@@ -198,10 +113,20 @@ class QueryDecorrelationTest extends TableTestBase {
term("join", "empno", "ename", "job", "salary", "deptno", "deptno0", "name"),
term("joinType", "InnerJoin")
),
- decorrelatedSubQuery,
- term("where", "AND(=(deptno0, deptno00), >(salary, EXPR$0))"),
+ unaryNode(
+ "DataSetAggregate",
+ unaryNode(
+ "DataSetCalc",
+ batchTableNode(0),
+ term("select", "deptno", "salary"),
+ term("where", "IS NOT NULL(deptno)")
+ ),
+ term("groupBy", "deptno"),
+ term("select", "deptno", "AVG(salary) AS EXPR$0")
+ ),
+ term("where", "AND(=(deptno0, deptno1), >(salary, EXPR$0))"),
term("join", "empno", "ename", "job", "salary", "deptno", "deptno0",
- "name", "deptno00", "EXPR$0"),
+ "name", "deptno1", "EXPR$0"),
term("joinType", "InnerJoin")
),
term("select", "empno")
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
index f902338..2f9057d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SetOperatorsTest.scala
@@ -58,33 +58,15 @@ class SetOperatorsTest extends TableTestBase {
"DataSetAggregate",
unaryNode(
"DataSetCalc",
- binaryNode(
- "DataSetJoin",
- unaryNode(
- "DataSetCalc",
- batchTableNode(1),
- term("select", "b_long")
- ),
- unaryNode(
- "DataSetDistinct",
- unaryNode(
- "DataSetCalc",
- batchTableNode(0),
- term("select", "a_long")
- ),
- term("distinct", "a_long")
- ),
- term("where", "=(a_long, b_long)"),
- term("join", "b_long", "a_long"),
- term("joinType", "InnerJoin")
- ),
- term("select", "a_long", "true AS $f0")
+ batchTableNode(1),
+ term("select", "b_long AS b_long3", "true AS $f0"),
+ term("where", "IS NOT NULL(b_long)")
),
- term("groupBy", "a_long"),
- term("select", "a_long", "MIN($f0) AS $f1")
+ term("groupBy", "b_long3"),
+ term("select", "b_long3", "MIN($f0) AS $f1")
),
- term("where", "=(a_long, a_long0)"),
- term("join", "a_long", "a_int", "a_string", "a_long0", "$f1"),
+ term("where", "=(a_long, b_long3)"),
+ term("join", "a_long", "a_int", "a_string", "b_long3", "$f1"),
term("joinType", "InnerJoin")
),
term("select", "a_int", "a_string")
[2/5] flink git commit: [FLINK-6257] [table] Consistent naming of
ProcessFunction and methods for OVER windows.
Posted by fh...@apache.org.
[FLINK-6257] [table] Consistent naming of ProcessFunction and methods for OVER windows.
- Add check for sort order of OVER windows.
This closes #3681.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/07f1b035
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/07f1b035
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/07f1b035
Branch: refs/heads/master
Commit: 07f1b035ffbf07d160503c48e2c58a464ec5d014
Parents: 5ff9c99
Author: sunjincheng121 <su...@gmail.com>
Authored: Thu Apr 6 10:33:30 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Apr 6 16:33:45 2017 +0200
----------------------------------------------------------------------
.../datastream/DataStreamOverAggregate.scala | 196 +++++--------
.../table/runtime/aggregate/AggregateUtil.scala | 122 +++-----
...ndedProcessingOverRangeProcessFunction.scala | 183 ------------
...oundedProcessingOverRowProcessFunction.scala | 179 ------------
.../aggregate/ProcTimeBoundedRangeOver.scala | 182 ++++++++++++
.../aggregate/ProcTimeBoundedRowsOver.scala | 179 ++++++++++++
.../ProcTimeUnboundedNonPartitionedOver.scala | 96 ++++++
.../ProcTimeUnboundedPartitionedOver.scala | 84 ++++++
.../RangeClauseBoundedOverProcessFunction.scala | 201 -------------
.../aggregate/RowTimeBoundedRangeOver.scala | 200 +++++++++++++
.../aggregate/RowTimeBoundedRowsOver.scala | 222 ++++++++++++++
.../aggregate/RowTimeUnboundedOver.scala | 292 +++++++++++++++++++
.../RowsClauseBoundedOverProcessFunction.scala | 222 --------------
.../UnboundedEventTimeOverProcessFunction.scala | 292 -------------------
...rtitionedProcessingOverProcessFunction.scala | 96 ------
...UnboundedProcessingOverProcessFunction.scala | 84 ------
...ProcessingOverRangeProcessFunctionTest.scala | 2 +-
17 files changed, 1380 insertions(+), 1452 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index 947775b..2224752 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -17,20 +17,21 @@
*/
package org.apache.flink.table.plan.nodes.datastream
+import java.util.{List => JList}
+
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.core.{AggregateCall, Window}
+import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
+import org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING
import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.streaming.api.datastream.DataStream
import org.apache.flink.table.api.{StreamTableEnvironment, TableException}
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.table.plan.nodes.OverAggregate
+import org.apache.flink.table.runtime.aggregate._
import org.apache.flink.types.Row
-import org.apache.calcite.rel.core.Window
-import org.apache.calcite.rel.core.Window.Group
-import java.util.{List => JList}
import org.apache.flink.api.java.functions.NullByteKeySelector
import org.apache.flink.table.codegen.CodeGenerator
@@ -90,12 +91,20 @@ class DataStreamOverAggregate(
val overWindow: org.apache.calcite.rel.core.Window.Group = logicWindow.groups.get(0)
- val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
+ val orderKeys = overWindow.orderKeys.getFieldCollations
- if (overWindow.orderKeys.getFieldCollations.size() != 1) {
+ if (orderKeys.size() != 1) {
throw new TableException(
- "Unsupported use of OVER windows. The window may only be ordered by a single time column.")
+ "Unsupported use of OVER windows. The window can only be ordered by a single time column.")
}
+ val orderKey = orderKeys.get(0)
+
+ if (!orderKey.direction.equals(ASCENDING)) {
+ throw new TableException(
+ "Unsupported use of OVER windows. The window can only be ordered in ASCENDING mode.")
+ }
+
+ val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
val generator = new CodeGenerator(
tableEnv.getConfig,
@@ -104,78 +113,69 @@ class DataStreamOverAggregate(
val timeType = inputType
.getFieldList
- .get(overWindow.orderKeys.getFieldCollations.get(0).getFieldIndex)
+ .get(orderKey.getFieldIndex)
.getValue
+
timeType match {
case _: ProcTimeType =>
// proc-time OVER window
if (overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
- // unbounded preceding OVER window
- createUnboundedAndCurrentRowProcessingTimeOverWindow(
+ // unbounded OVER window
+ createUnboundedAndCurrentRowOverWindow(
generator,
- inputDS)
+ inputDS,
+ isRowTimeType = false,
+ isRowsClause = overWindow.isRows)
} else if (
overWindow.lowerBound.isPreceding && !overWindow.lowerBound.isUnbounded &&
- overWindow.upperBound.isCurrentRow) {
+ overWindow.upperBound.isCurrentRow) {
// bounded OVER window
- if (overWindow.isRows) {
- // ROWS clause bounded OVER window
- createBoundedAndCurrentRowOverWindow(
- generator,
- inputDS,
- isRangeClause = false,
- isRowTimeType = false)
- } else {
- // RANGE clause bounded OVER window
- createBoundedAndCurrentRowOverWindow(
- generator,
- inputDS,
- isRangeClause = true,
- isRowTimeType = false)
- }
+ createBoundedAndCurrentRowOverWindow(
+ generator,
+ inputDS,
+ isRowTimeType = false,
+ isRowsClause = overWindow.isRows
+ )
} else {
throw new TableException(
- "processing-time OVER RANGE FOLLOWING window is not supported yet.")
+ "OVER RANGE FOLLOWING windows are not supported yet.")
}
case _: RowTimeType =>
// row-time OVER window
if (overWindow.lowerBound.isPreceding &&
- overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
- // ROWS/RANGE clause unbounded OVER window
- createUnboundedAndCurrentRowEventTimeOverWindow(
+ overWindow.lowerBound.isUnbounded && overWindow.upperBound.isCurrentRow) {
+ // unbounded OVER window
+ createUnboundedAndCurrentRowOverWindow(
generator,
inputDS,
- overWindow.isRows)
+ isRowTimeType = true,
+ isRowsClause = overWindow.isRows
+ )
} else if (overWindow.lowerBound.isPreceding && overWindow.upperBound.isCurrentRow) {
// bounded OVER window
- if (overWindow.isRows) {
- // ROWS clause bounded OVER window
- createBoundedAndCurrentRowOverWindow(
- generator,
- inputDS,
- isRangeClause = false,
- isRowTimeType = true)
- } else {
- // RANGE clause bounded OVER window
- createBoundedAndCurrentRowOverWindow(
- generator,
- inputDS,
- isRangeClause = true,
- isRowTimeType = true)
- }
+ createBoundedAndCurrentRowOverWindow(
+ generator,
+ inputDS,
+ isRowTimeType = true,
+ isRowsClause = overWindow.isRows
+ )
} else {
throw new TableException(
- "row-time OVER RANGE FOLLOWING window is not supported yet.")
+ "OVER RANGE FOLLOWING windows are not supported yet.")
}
case _ =>
- throw new TableException(s"Unsupported time type {$timeType}")
+ throw new TableException(
+ "Unsupported time type {$timeType}. " +
+ "OVER windows do only support RowTimeType and ProcTimeType.")
}
}
- def createUnboundedAndCurrentRowProcessingTimeOverWindow(
+ def createUnboundedAndCurrentRowOverWindow(
generator: CodeGenerator,
- inputDS: DataStream[Row]): DataStream[Row] = {
+ inputDS: DataStream[Row],
+ isRowTimeType: Boolean,
+ isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
@@ -184,14 +184,17 @@ class DataStreamOverAggregate(
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
+ val processFunction = AggregateUtil.createUnboundedOverProcessFunction(
+ generator,
+ namedAggregates,
+ inputType,
+ isRowTimeType,
+ partitionKeys.nonEmpty,
+ isRowsClause)
+
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
- val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
- generator,
- namedAggregates,
- inputType)
-
inputDS
.keyBy(partitionKeys: _*)
.process(processFunction)
@@ -201,17 +204,19 @@ class DataStreamOverAggregate(
}
// non-partitioned aggregation
else {
- val processFunction = AggregateUtil.createUnboundedProcessingOverProcessFunction(
- generator,
- namedAggregates,
- inputType,
- isPartitioned = false)
-
- inputDS
- .process(processFunction).setParallelism(1).setMaxParallelism(1)
- .returns(rowTypeInfo)
- .name(aggOpName)
- .asInstanceOf[DataStream[Row]]
+ if (isRowTimeType) {
+ inputDS.keyBy(new NullByteKeySelector[Row])
+ .process(processFunction).setParallelism(1).setMaxParallelism(1)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
+ } else {
+ inputDS
+ .process(processFunction).setParallelism(1).setMaxParallelism(1)
+ .returns(rowTypeInfo)
+ .name(aggOpName)
+ .asInstanceOf[DataStream[Row]]
+ }
}
result
}
@@ -219,15 +224,15 @@ class DataStreamOverAggregate(
def createBoundedAndCurrentRowOverWindow(
generator: CodeGenerator,
inputDS: DataStream[Row],
- isRangeClause: Boolean,
- isRowTimeType: Boolean): DataStream[Row] = {
+ isRowTimeType: Boolean,
+ isRowsClause: Boolean): DataStream[Row] = {
val overWindow: Group = logicWindow.groups.get(0)
val partitionKeys: Array[Int] = overWindow.keys.toArray
val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
val precedingOffset =
- getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRangeClause) 0 else 1)
+ getLowerBoundary(logicWindow, overWindow, getInput()) + (if (isRowsClause) 1 else 0)
// get the output types
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
@@ -237,8 +242,9 @@ class DataStreamOverAggregate(
namedAggregates,
inputType,
precedingOffset,
- isRangeClause,
- isRowTimeType)
+ isRowsClause,
+ isRowTimeType
+ )
val result: DataStream[Row] =
// partitioned aggregation
if (partitionKeys.nonEmpty) {
@@ -253,49 +259,7 @@ class DataStreamOverAggregate(
else {
inputDS
.keyBy(new NullByteKeySelector[Row])
- .process(processFunction)
- .setParallelism(1)
- .setMaxParallelism(1)
- .returns(rowTypeInfo)
- .name(aggOpName)
- .asInstanceOf[DataStream[Row]]
- }
- result
- }
-
- def createUnboundedAndCurrentRowEventTimeOverWindow(
- generator: CodeGenerator,
- inputDS: DataStream[Row],
- isRows: Boolean): DataStream[Row] = {
-
- val overWindow: Group = logicWindow.groups.get(0)
- val partitionKeys: Array[Int] = overWindow.keys.toArray
- val namedAggregates: Seq[CalcitePair[AggregateCall, String]] = generateNamedAggregates
-
- // get the output types
- val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]
-
- val processFunction = AggregateUtil.createUnboundedEventTimeOverProcessFunction(
- generator,
- namedAggregates,
- inputType,
- isRows)
-
- val result: DataStream[Row] =
- // partitioned aggregation
- if (partitionKeys.nonEmpty) {
- inputDS.keyBy(partitionKeys: _*)
- .process(processFunction)
- .returns(rowTypeInfo)
- .name(aggOpName)
- .asInstanceOf[DataStream[Row]]
- }
- // global non-partitioned aggregation
- else {
- inputDS.keyBy(new NullByteKeySelector[Row])
- .process(processFunction)
- .setParallelism(1)
- .setMaxParallelism(1)
+ .process(processFunction).setParallelism(1).setMaxParallelism(1)
.returns(rowTypeInfo)
.name(aggOpName)
.asInstanceOf[DataStream[Row]]
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index fc03ac1..09d1a13 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -55,21 +55,23 @@ object AggregateUtil {
type JavaList[T] = java.util.List[T]
/**
- * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] to evaluate final
- * aggregate value.
+ * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for unbounded OVER
+ * window to evaluate final aggregate value.
*
* @param generator code generator instance
* @param namedAggregates List of calls to aggregate functions and their output field names
- * @param inputType Input row type
- * @param isPartitioned Flag to indicate whether the input is partitioned or not
- *
- * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
+ * @param inputType Input row type
+ * @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
+ * @param isPartitioned It is a tag that indicate whether the input is partitioned
+ * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
*/
- private[flink] def createUnboundedProcessingOverProcessFunction(
- generator: CodeGenerator,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- isPartitioned: Boolean = true): ProcessFunction[Row, Row] = {
+ private[flink] def createUnboundedOverProcessFunction(
+ generator: CodeGenerator,
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ inputType: RelDataType,
+ isRowTimeType: Boolean,
+ isPartitioned: Boolean,
+ isRowsClause: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
transformToAggregateFunctions(
@@ -95,14 +97,30 @@ object AggregateUtil {
outputArity
)
- if (isPartitioned) {
- new UnboundedProcessingOverProcessFunction(
- genFunction,
- aggregationStateType)
+ if (isRowTimeType) {
+ if (isRowsClause) {
+ // ROWS unbounded over process function
+ new RowTimeUnboundedRowsOver(
+ genFunction,
+ aggregationStateType,
+ FlinkTypeFactory.toInternalRowTypeInfo(inputType))
+ } else {
+ // RANGE unbounded over process function
+ new RowTimeUnboundedRangeOver(
+ genFunction,
+ aggregationStateType,
+ FlinkTypeFactory.toInternalRowTypeInfo(inputType))
+ }
} else {
- new UnboundedNonPartitionedProcessingOverProcessFunction(
- genFunction,
- aggregationStateType)
+ if (isPartitioned) {
+ new ProcTimeUnboundedPartitionedOver(
+ genFunction,
+ aggregationStateType)
+ } else {
+ new ProcTimeUnboundedNonPartitionedOver(
+ genFunction,
+ aggregationStateType)
+ }
}
}
@@ -114,7 +132,7 @@ object AggregateUtil {
* @param namedAggregates List of calls to aggregate functions and their output field names
* @param inputType Input row type
* @param precedingOffset the preceding offset
- * @param isRangeClause It is a tag that indicates whether the OVER clause is rangeClause
+ * @param isRowsClause It is a tag that indicates whether the OVER clause is ROWS clause
* @param isRowTimeType It is a tag that indicates whether the time type is rowTimeType
* @return [[org.apache.flink.streaming.api.functions.ProcessFunction]]
*/
@@ -123,7 +141,7 @@ object AggregateUtil {
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
precedingOffset: Long,
- isRangeClause: Boolean,
+ isRowsClause: Boolean,
isRowTimeType: Boolean): ProcessFunction[Row, Row] = {
val (aggFields, aggregates) =
@@ -151,15 +169,15 @@ object AggregateUtil {
)
if (isRowTimeType) {
- if (isRangeClause) {
- new RangeClauseBoundedOverProcessFunction(
+ if (isRowsClause) {
+ new RowTimeBoundedRowsOver(
genFunction,
aggregationStateType,
inputRowType,
precedingOffset
)
} else {
- new RowsClauseBoundedOverProcessFunction(
+ new RowTimeBoundedRangeOver(
genFunction,
aggregationStateType,
inputRowType,
@@ -167,14 +185,14 @@ object AggregateUtil {
)
}
} else {
- if (isRangeClause) {
- new BoundedProcessingOverRangeProcessFunction(
+ if (isRowsClause) {
+ new ProcTimeBoundedRowsOver(
genFunction,
precedingOffset,
aggregationStateType,
inputRowType)
} else {
- new BoundedProcessingOverRowProcessFunction(
+ new ProcTimeBoundedRangeOver(
genFunction,
precedingOffset,
aggregationStateType,
@@ -183,58 +201,6 @@ object AggregateUtil {
}
}
- /**
- * Create an [[ProcessFunction]] to evaluate final aggregate value.
- *
- * @param generator code generator instance
- * @param namedAggregates List of calls to aggregate functions and their output field names
- * @param inputType Input row type
- * @param isRows Flag to indicate if whether this is a Row (true) or a Range (false)
- * over window process
- * @return [[UnboundedEventTimeOverProcessFunction]]
- */
- private[flink] def createUnboundedEventTimeOverProcessFunction(
- generator: CodeGenerator,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]],
- inputType: RelDataType,
- isRows: Boolean): UnboundedEventTimeOverProcessFunction = {
-
- val (aggFields, aggregates) =
- transformToAggregateFunctions(
- namedAggregates.map(_.getKey),
- inputType,
- needRetraction = false)
-
- val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates)
-
- val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray
- val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray
- val outputArity = inputType.getFieldCount + aggregates.length
-
- val genFunction = generator.generateAggregations(
- "UnboundedEventTimeOverAggregateHelper",
- generator,
- inputType,
- aggregates,
- aggFields,
- aggMapping,
- forwardMapping,
- outputArity)
-
- if (isRows) {
- // ROWS unbounded over process function
- new UnboundedEventTimeRowsOverProcessFunction(
- genFunction,
- aggregationStateType,
- FlinkTypeFactory.toInternalRowTypeInfo(inputType))
- } else {
- // RANGE unbounded over process function
- new UnboundedEventTimeRangeOverProcessFunction(
- genFunction,
- aggregationStateType,
- FlinkTypeFactory.toInternalRowTypeInfo(inputType))
- }
- }
/**
* Create a [[org.apache.flink.api.common.functions.MapFunction]] that prepares for aggregates.
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala
deleted file mode 100644
index 8f3aa3e..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRangeProcessFunction.scala
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.Collector
-import org.apache.flink.api.common.state.ValueState
-import org.apache.flink.api.common.state.ValueStateDescriptor
-import org.apache.flink.api.common.state.MapState
-import org.apache.flink.api.common.state.MapStateDescriptor
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.typeutils.ListTypeInfo
-import java.util.{ArrayList, List => JList}
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function used for the aggregate in bounded proc-time OVER window
- * [[org.apache.flink.streaming.api.datastream.DataStream]]
- *
- * @param genAggregations Generated aggregate helper function
- * @param precedingTimeBoundary Is used to indicate the processing time boundaries
- * @param aggregatesTypeInfo row type info of aggregation
- * @param inputType row type info of input row
- */
-class BoundedProcessingOverRangeProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- precedingTimeBoundary: Long,
- aggregatesTypeInfo: RowTypeInfo,
- inputType: TypeInformation[Row])
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- private var output: Row = _
- private var accumulatorState: ValueState[Row] = _
- private var rowMapState: MapState[Long, JList[Row]] = _
-
- val LOG = LoggerFactory.getLogger(this.getClass)
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
- output = function.createOutputRow()
-
- // We keep the elements received in a MapState indexed based on their ingestion time
- val rowListTypeInfo: TypeInformation[JList[Row]] =
- new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]]
- val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]]("rowmapstate",
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
- rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
-
- val stateDescriptor: ValueStateDescriptor[Row] =
- new ValueStateDescriptor[Row]("overState", aggregatesTypeInfo)
- accumulatorState = getRuntimeContext.getState(stateDescriptor)
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- val currentTime = ctx.timerService.currentProcessingTime
- // buffer the event incoming event
-
- // add current element to the window list of elements with corresponding timestamp
- var rowList = rowMapState.get(currentTime)
- // null value means that this si the first event received for this timestamp
- if (rowList == null) {
- rowList = new ArrayList[Row]()
- // register timer to process event once the current millisecond passed
- ctx.timerService.registerProcessingTimeTimer(currentTime + 1)
- }
- rowList.add(input)
- rowMapState.put(currentTime, rowList)
-
- }
-
- override def onTimer(
- timestamp: Long,
- ctx: ProcessFunction[Row, Row]#OnTimerContext,
- out: Collector[Row]): Unit = {
-
- // we consider the original timestamp of events that have registered this time trigger 1 ms ago
- val currentTime = timestamp - 1
- var i = 0
-
- // initialize the accumulators
- var accumulators = accumulatorState.value()
-
- if (null == accumulators) {
- accumulators = function.createAccumulators()
- }
-
- // update the elements to be removed and retract them from aggregators
- val limit = currentTime - precedingTimeBoundary
-
- // we iterate through all elements in the window buffer based on timestamp keys
- // when we find timestamps that are out of interest, we retrieve corresponding elements
- // and eliminate them. Multiple elements could have been received at the same timestamp
- // the removal of old elements happens only once per proctime as onTimer is called only once
- val iter = rowMapState.keys.iterator
- val markToRemove = new ArrayList[Long]()
- while (iter.hasNext) {
- val elementKey = iter.next
- if (elementKey < limit) {
- // element key outside of window. Retract values
- val elementsRemove = rowMapState.get(elementKey)
- var iRemove = 0
- while (iRemove < elementsRemove.size()) {
- val retractRow = elementsRemove.get(iRemove)
- function.retract(accumulators, retractRow)
- iRemove += 1
- }
- // mark element for later removal not to modify the iterator over MapState
- markToRemove.add(elementKey)
- }
- }
- // need to remove in 2 steps not to have concurrent access errors via iterator to the MapState
- i = 0
- while (i < markToRemove.size()) {
- rowMapState.remove(markToRemove.get(i))
- i += 1
- }
-
- // get the list of elements of current proctime
- val currentElements = rowMapState.get(currentTime)
- // add current elements to aggregator. Multiple elements might have arrived in the same proctime
- // the same accumulator value will be computed for all elements
- var iElemenets = 0
- while (iElemenets < currentElements.size()) {
- val input = currentElements.get(iElemenets)
- function.accumulate(accumulators, input)
- iElemenets += 1
- }
-
- // we need to build the output and emit for every event received at this proctime
- iElemenets = 0
- while (iElemenets < currentElements.size()) {
- val input = currentElements.get(iElemenets)
-
- // set the fields of the last event to carry on with the aggregates
- function.setForwardedFields(input, output)
-
- // add the accumulators values to result
- function.setAggregationResults(accumulators, output)
- out.collect(output)
- iElemenets += 1
- }
-
- // update the value of accumulators for future incremental computation
- accumulatorState.update(accumulators)
-
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala
deleted file mode 100644
index d5ee4ae..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/BoundedProcessingOverRowProcessFunction.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.util
-
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
-import org.apache.flink.api.common.state.ValueStateDescriptor
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.api.common.state.ValueState
-import org.apache.flink.api.common.state.MapState
-import org.apache.flink.api.common.state.MapStateDescriptor
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.typeutils.ListTypeInfo
-import java.util.{List => JList}
-
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function for ROW clause processing-time bounded OVER window
- *
- * @param genAggregations Generated aggregate helper function
- * @param precedingOffset preceding offset
- * @param aggregatesTypeInfo row type info of aggregation
- * @param inputType row type info of input row
- */
-class BoundedProcessingOverRowProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- precedingOffset: Long,
- aggregatesTypeInfo: RowTypeInfo,
- inputType: TypeInformation[Row])
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- Preconditions.checkArgument(precedingOffset > 0)
-
- private var accumulatorState: ValueState[Row] = _
- private var rowMapState: MapState[Long, JList[Row]] = _
- private var output: Row = _
- private var counterState: ValueState[Long] = _
- private var smallestTsState: ValueState[Long] = _
-
- val LOG = LoggerFactory.getLogger(this.getClass)
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
- // We keep the elements received in a Map state keyed
- // by the ingestion time in the operator.
- // we also keep counter of processed elements
- // and timestamp of oldest element
- val rowListTypeInfo: TypeInformation[JList[Row]] =
- new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]]
-
- val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]]("windowBufferMapState",
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
- rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
-
- val aggregationStateDescriptor: ValueStateDescriptor[Row] =
- new ValueStateDescriptor[Row]("aggregationState", aggregatesTypeInfo)
- accumulatorState = getRuntimeContext.getState(aggregationStateDescriptor)
-
- val processedCountDescriptor : ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("processedCountState", classOf[Long])
- counterState = getRuntimeContext.getState(processedCountDescriptor)
-
- val smallestTimestampDescriptor : ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("smallestTSState", classOf[Long])
- smallestTsState = getRuntimeContext.getState(smallestTimestampDescriptor)
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- val currentTime = ctx.timerService.currentProcessingTime
-
- // initialize state for the processed element
- var accumulators = accumulatorState.value
- if (accumulators == null) {
- accumulators = function.createAccumulators()
- }
-
- // get smallest timestamp
- var smallestTs = smallestTsState.value
- if (smallestTs == 0L) {
- smallestTs = currentTime
- smallestTsState.update(smallestTs)
- }
- // get previous counter value
- var counter = counterState.value
-
- if (counter == precedingOffset) {
- val retractList = rowMapState.get(smallestTs)
-
- // get oldest element beyond buffer size
- // and if oldest element exist, retract value
- val retractRow = retractList.get(0)
- function.retract(accumulators, retractRow)
- retractList.remove(0)
-
- // if reference timestamp list not empty, keep the list
- if (!retractList.isEmpty) {
- rowMapState.put(smallestTs, retractList)
- } // if smallest timestamp list is empty, remove and find new smallest
- else {
- rowMapState.remove(smallestTs)
- val iter = rowMapState.keys.iterator
- var currentTs: Long = 0L
- var newSmallestTs: Long = Long.MaxValue
- while (iter.hasNext) {
- currentTs = iter.next
- if (currentTs < newSmallestTs) {
- newSmallestTs = currentTs
- }
- }
- smallestTsState.update(newSmallestTs)
- }
- } // we update the counter only while buffer is getting filled
- else {
- counter += 1
- counterState.update(counter)
- }
-
- // copy forwarded fields in output row
- function.setForwardedFields(input, output)
-
- // accumulate current row and set aggregate in output row
- function.accumulate(accumulators, input)
- function.setAggregationResults(accumulators, output)
-
- // update map state, accumulator state, counter and timestamp
- val currentTimeState = rowMapState.get(currentTime)
- if (currentTimeState != null) {
- currentTimeState.add(input)
- rowMapState.put(currentTime, currentTimeState)
- } else { // add new input
- val newList = new util.ArrayList[Row]
- newList.add(input)
- rowMapState.put(currentTime, newList)
- }
-
- accumulatorState.update(accumulators)
-
- out.collect(output)
- }
-
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
new file mode 100644
index 0000000..7f87e50
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRangeOver.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+import org.apache.flink.api.common.state.ValueState
+import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.common.state.MapState
+import org.apache.flink.api.common.state.MapStateDescriptor
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.ListTypeInfo
+import java.util.{ArrayList, List => JList}
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function used for the aggregate in bounded proc-time OVER window
+ * [[org.apache.flink.streaming.api.datastream.DataStream]]
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param precedingTimeBoundary Is used to indicate the processing time boundaries
+ * @param aggregatesTypeInfo row type info of aggregation
+ * @param inputType row type info of input row
+ */
+class ProcTimeBoundedRangeOver(
+ genAggregations: GeneratedAggregationsFunction,
+ precedingTimeBoundary: Long,
+ aggregatesTypeInfo: RowTypeInfo,
+ inputType: TypeInformation[Row])
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+ private var output: Row = _
+ private var accumulatorState: ValueState[Row] = _
+ private var rowMapState: MapState[Long, JList[Row]] = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+ output = function.createOutputRow()
+
+ // We keep the elements received in a MapState indexed based on their ingestion time
+ val rowListTypeInfo: TypeInformation[JList[Row]] =
+ new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]]
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]]("rowmapstate",
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
+ rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
+
+ val stateDescriptor: ValueStateDescriptor[Row] =
+ new ValueStateDescriptor[Row]("overState", aggregatesTypeInfo)
+ accumulatorState = getRuntimeContext.getState(stateDescriptor)
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ val currentTime = ctx.timerService.currentProcessingTime
+ // buffer the event incoming event
+
+ // add current element to the window list of elements with corresponding timestamp
+ var rowList = rowMapState.get(currentTime)
+ // null value means that this si the first event received for this timestamp
+ if (rowList == null) {
+ rowList = new ArrayList[Row]()
+ // register timer to process event once the current millisecond passed
+ ctx.timerService.registerProcessingTimeTimer(currentTime + 1)
+ }
+ rowList.add(input)
+ rowMapState.put(currentTime, rowList)
+
+ }
+
+ override def onTimer(
+ timestamp: Long,
+ ctx: ProcessFunction[Row, Row]#OnTimerContext,
+ out: Collector[Row]): Unit = {
+
+ // we consider the original timestamp of events that have registered this time trigger 1 ms ago
+ val currentTime = timestamp - 1
+ var i = 0
+
+ // initialize the accumulators
+ var accumulators = accumulatorState.value()
+
+ if (null == accumulators) {
+ accumulators = function.createAccumulators()
+ }
+
+ // update the elements to be removed and retract them from aggregators
+ val limit = currentTime - precedingTimeBoundary
+
+ // we iterate through all elements in the window buffer based on timestamp keys
+ // when we find timestamps that are out of interest, we retrieve corresponding elements
+ // and eliminate them. Multiple elements could have been received at the same timestamp
+ // the removal of old elements happens only once per proctime as onTimer is called only once
+ val iter = rowMapState.keys.iterator
+ val markToRemove = new ArrayList[Long]()
+ while (iter.hasNext) {
+ val elementKey = iter.next
+ if (elementKey < limit) {
+ // element key outside of window. Retract values
+ val elementsRemove = rowMapState.get(elementKey)
+ var iRemove = 0
+ while (iRemove < elementsRemove.size()) {
+ val retractRow = elementsRemove.get(iRemove)
+ function.retract(accumulators, retractRow)
+ iRemove += 1
+ }
+ // mark element for later removal not to modify the iterator over MapState
+ markToRemove.add(elementKey)
+ }
+ }
+ // need to remove in 2 steps not to have concurrent access errors via iterator to the MapState
+ i = 0
+ while (i < markToRemove.size()) {
+ rowMapState.remove(markToRemove.get(i))
+ i += 1
+ }
+
+ // get the list of elements of current proctime
+ val currentElements = rowMapState.get(currentTime)
+ // add current elements to aggregator. Multiple elements might have arrived in the same proctime
+ // the same accumulator value will be computed for all elements
+ var iElemenets = 0
+ while (iElemenets < currentElements.size()) {
+ val input = currentElements.get(iElemenets)
+ function.accumulate(accumulators, input)
+ iElemenets += 1
+ }
+
+ // we need to build the output and emit for every event received at this proctime
+ iElemenets = 0
+ while (iElemenets < currentElements.size()) {
+ val input = currentElements.get(iElemenets)
+
+ // set the fields of the last event to carry on with the aggregates
+ function.setForwardedFields(input, output)
+
+ // add the accumulators values to result
+ function.setAggregationResults(accumulators, output)
+ out.collect(output)
+ iElemenets += 1
+ }
+
+ // update the value of accumulators for future incremental computation
+ accumulatorState.update(accumulators)
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
new file mode 100644
index 0000000..31cfd73
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedRowsOver.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util
+
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.common.state.ValueState
+import org.apache.flink.api.common.state.MapState
+import org.apache.flink.api.common.state.MapStateDescriptor
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.ListTypeInfo
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function for ROW clause processing-time bounded OVER window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param precedingOffset preceding offset
+ * @param aggregatesTypeInfo row type info of aggregation
+ * @param inputType row type info of input row
+ */
+class ProcTimeBoundedRowsOver(
+ genAggregations: GeneratedAggregationsFunction,
+ precedingOffset: Long,
+ aggregatesTypeInfo: RowTypeInfo,
+ inputType: TypeInformation[Row])
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+
+ Preconditions.checkArgument(precedingOffset > 0)
+
+ private var accumulatorState: ValueState[Row] = _
+ private var rowMapState: MapState[Long, JList[Row]] = _
+ private var output: Row = _
+ private var counterState: ValueState[Long] = _
+ private var smallestTsState: ValueState[Long] = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ // We keep the elements received in a Map state keyed
+ // by the ingestion time in the operator.
+ // we also keep counter of processed elements
+ // and timestamp of oldest element
+ val rowListTypeInfo: TypeInformation[JList[Row]] =
+ new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]]
+
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]]("windowBufferMapState",
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo)
+ rowMapState = getRuntimeContext.getMapState(mapStateDescriptor)
+
+ val aggregationStateDescriptor: ValueStateDescriptor[Row] =
+ new ValueStateDescriptor[Row]("aggregationState", aggregatesTypeInfo)
+ accumulatorState = getRuntimeContext.getState(aggregationStateDescriptor)
+
+ val processedCountDescriptor : ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("processedCountState", classOf[Long])
+ counterState = getRuntimeContext.getState(processedCountDescriptor)
+
+ val smallestTimestampDescriptor : ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("smallestTSState", classOf[Long])
+ smallestTsState = getRuntimeContext.getState(smallestTimestampDescriptor)
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ val currentTime = ctx.timerService.currentProcessingTime
+
+ // initialize state for the processed element
+ var accumulators = accumulatorState.value
+ if (accumulators == null) {
+ accumulators = function.createAccumulators()
+ }
+
+ // get smallest timestamp
+ var smallestTs = smallestTsState.value
+ if (smallestTs == 0L) {
+ smallestTs = currentTime
+ smallestTsState.update(smallestTs)
+ }
+ // get previous counter value
+ var counter = counterState.value
+
+ if (counter == precedingOffset) {
+ val retractList = rowMapState.get(smallestTs)
+
+ // get oldest element beyond buffer size
+ // and if oldest element exist, retract value
+ val retractRow = retractList.get(0)
+ function.retract(accumulators, retractRow)
+ retractList.remove(0)
+
+ // if reference timestamp list not empty, keep the list
+ if (!retractList.isEmpty) {
+ rowMapState.put(smallestTs, retractList)
+ } // if smallest timestamp list is empty, remove and find new smallest
+ else {
+ rowMapState.remove(smallestTs)
+ val iter = rowMapState.keys.iterator
+ var currentTs: Long = 0L
+ var newSmallestTs: Long = Long.MaxValue
+ while (iter.hasNext) {
+ currentTs = iter.next
+ if (currentTs < newSmallestTs) {
+ newSmallestTs = currentTs
+ }
+ }
+ smallestTsState.update(newSmallestTs)
+ }
+ } // we update the counter only while buffer is getting filled
+ else {
+ counter += 1
+ counterState.update(counter)
+ }
+
+ // copy forwarded fields in output row
+ function.setForwardedFields(input, output)
+
+ // accumulate current row and set aggregate in output row
+ function.accumulate(accumulators, input)
+ function.setAggregationResults(accumulators, output)
+
+ // update map state, accumulator state, counter and timestamp
+ val currentTimeState = rowMapState.get(currentTime)
+ if (currentTimeState != null) {
+ currentTimeState.add(input)
+ rowMapState.put(currentTime, currentTimeState)
+ } else { // add new input
+ val newList = new util.ArrayList[Row]
+ newList.add(input)
+ rowMapState.put(currentTime, newList)
+ }
+
+ accumulatorState.update(accumulators)
+
+ out.collect(output)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala
new file mode 100644
index 0000000..6b9800b
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedNonPartitionedOver.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.api.common.state.{ListState, ListStateDescriptor}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.runtime.state.{FunctionInitializationContext, FunctionSnapshotContext}
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function for non-partitioned processing-time unbounded OVER window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param aggregationStateType row type info of aggregation
+ */
+class ProcTimeUnboundedNonPartitionedOver(
+ genAggregations: GeneratedAggregationsFunction,
+ aggregationStateType: RowTypeInfo)
+ extends ProcessFunction[Row, Row]
+ with CheckpointedFunction
+ with Compiler[GeneratedAggregations] {
+
+ private var accumulators: Row = _
+ private var output: Row = _
+ private var state: ListState[Row] = _
+ val LOG = LoggerFactory.getLogger(this.getClass)
+
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ if (null == accumulators) {
+ val it = state.get().iterator()
+ if (it.hasNext) {
+ accumulators = it.next()
+ } else {
+ accumulators = function.createAccumulators()
+ }
+ }
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ function.setForwardedFields(input, output)
+
+ function.accumulate(accumulators, input)
+ function.setAggregationResults(accumulators, output)
+
+ out.collect(output)
+ }
+
+ override def snapshotState(context: FunctionSnapshotContext): Unit = {
+ state.clear()
+ if (null != accumulators) {
+ state.add(accumulators)
+ }
+ }
+
+ override def initializeState(context: FunctionInitializationContext): Unit = {
+ val accumulatorsDescriptor = new ListStateDescriptor[Row]("overState", aggregationStateType)
+ state = context.getOperatorStateStore.getOperatorState(accumulatorsDescriptor)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala
new file mode 100644
index 0000000..9baa6a3
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeUnboundedPartitionedOver.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.types.Row
+import org.apache.flink.util.Collector
+import org.apache.flink.api.common.state.ValueStateDescriptor
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.common.state.ValueState
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function for processing-time unbounded OVER window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param aggregationStateType row type info of aggregation
+ */
+class ProcTimeUnboundedPartitionedOver(
+ genAggregations: GeneratedAggregationsFunction,
+ aggregationStateType: RowTypeInfo)
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+
+ private var output: Row = _
+ private var state: ValueState[Row] = _
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+ val stateDescriptor: ValueStateDescriptor[Row] =
+ new ValueStateDescriptor[Row]("overState", aggregationStateType)
+ state = getRuntimeContext.getState(stateDescriptor)
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ var accumulators = state.value()
+
+ if (null == accumulators) {
+ accumulators = function.createAccumulators()
+ }
+
+ function.setForwardedFields(input, output)
+
+ function.accumulate(accumulators, input)
+ function.setAggregationResults(accumulators, output)
+
+ state.update(accumulators)
+
+ out.collect(output)
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala
deleted file mode 100644
index 0f1ef49..0000000
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RangeClauseBoundedOverProcessFunction.scala
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.runtime.aggregate
-
-import java.util.{List => JList, ArrayList => JArrayList}
-
-import org.apache.flink.api.common.state._
-import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.streaming.api.functions.ProcessFunction
-import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
-import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
-import org.slf4j.LoggerFactory
-
-/**
- * Process Function for RANGE clause event-time bounded OVER window
- *
- * @param genAggregations Generated aggregate helper function
- * @param aggregationStateType row type info of aggregation
- * @param inputRowType row type info of input row
- * @param precedingOffset preceding offset
- */
-class RangeClauseBoundedOverProcessFunction(
- genAggregations: GeneratedAggregationsFunction,
- aggregationStateType: RowTypeInfo,
- inputRowType: RowTypeInfo,
- precedingOffset: Long)
- extends ProcessFunction[Row, Row]
- with Compiler[GeneratedAggregations] {
-
- Preconditions.checkNotNull(aggregationStateType)
- Preconditions.checkNotNull(precedingOffset)
-
- private var output: Row = _
-
- // the state which keeps the last triggering timestamp
- private var lastTriggeringTsState: ValueState[Long] = _
-
- // the state which used to materialize the accumulator for incremental calculation
- private var accumulatorState: ValueState[Row] = _
-
- // the state which keeps all the data that are not expired.
- // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
- // the second element of tuple is a list that contains the entire data of all the rows belonging
- // to this time stamp.
- private var dataState: MapState[Long, JList[Row]] = _
-
- val LOG = LoggerFactory.getLogger(this.getClass)
- private var function: GeneratedAggregations = _
-
- override def open(config: Configuration) {
- LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
- s"Code:\n$genAggregations.code")
- val clazz = compile(
- getRuntimeContext.getUserCodeClassLoader,
- genAggregations.name,
- genAggregations.code)
- LOG.debug("Instantiating AggregateHelper.")
- function = clazz.newInstance()
-
- output = function.createOutputRow()
-
- val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
- new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
- lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
-
- val accumulatorStateDescriptor =
- new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
- accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
-
- val keyTypeInformation: TypeInformation[Long] =
- BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
- val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
-
- val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
- new MapStateDescriptor[Long, JList[Row]](
- "dataState",
- keyTypeInformation,
- valueTypeInformation)
-
- dataState = getRuntimeContext.getMapState(mapStateDescriptor)
- }
-
- override def processElement(
- input: Row,
- ctx: ProcessFunction[Row, Row]#Context,
- out: Collector[Row]): Unit = {
-
- // triggering timestamp for trigger calculation
- val triggeringTs = ctx.timestamp
-
- val lastTriggeringTs = lastTriggeringTsState.value
-
- // check if the data is expired, if not, save the data and register event time timer
- if (triggeringTs > lastTriggeringTs) {
- val data = dataState.get(triggeringTs)
- if (null != data) {
- data.add(input)
- dataState.put(triggeringTs, data)
- } else {
- val data = new JArrayList[Row]
- data.add(input)
- dataState.put(triggeringTs, data)
- // register event time timer
- ctx.timerService.registerEventTimeTimer(triggeringTs)
- }
- }
- }
-
- override def onTimer(
- timestamp: Long,
- ctx: ProcessFunction[Row, Row]#OnTimerContext,
- out: Collector[Row]): Unit = {
- // gets all window data from state for the calculation
- val inputs: JList[Row] = dataState.get(timestamp)
-
- if (null != inputs) {
-
- var accumulators = accumulatorState.value
- var dataListIndex = 0
- var aggregatesIndex = 0
-
- // initialize when first run or failover recovery per key
- if (null == accumulators) {
- accumulators = function.createAccumulators()
- aggregatesIndex = 0
- }
-
- // keep up timestamps of retract data
- val retractTsList: JList[Long] = new JArrayList[Long]
-
- // do retraction
- val dataTimestampIt = dataState.keys.iterator
- while (dataTimestampIt.hasNext) {
- val dataTs: Long = dataTimestampIt.next()
- val offset = timestamp - dataTs
- if (offset > precedingOffset) {
- val retractDataList = dataState.get(dataTs)
- dataListIndex = 0
- while (dataListIndex < retractDataList.size()) {
- val retractRow = retractDataList.get(dataListIndex)
- function.retract(accumulators, retractRow)
- dataListIndex += 1
- }
- retractTsList.add(dataTs)
- }
- }
-
- // do accumulation
- dataListIndex = 0
- while (dataListIndex < inputs.size()) {
- val curRow = inputs.get(dataListIndex)
- // accumulate current row
- function.accumulate(accumulators, curRow)
- dataListIndex += 1
- }
-
- // set aggregate in output row
- function.setAggregationResults(accumulators, output)
-
- // copy forwarded fields to output row and emit output row
- dataListIndex = 0
- while (dataListIndex < inputs.size()) {
- aggregatesIndex = 0
- function.setForwardedFields(inputs.get(dataListIndex), output)
- out.collect(output)
- dataListIndex += 1
- }
-
- // remove the data that has been retracted
- dataListIndex = 0
- while (dataListIndex < retractTsList.size) {
- dataState.remove(retractTsList.get(dataListIndex))
- dataListIndex += 1
- }
-
- // update state
- accumulatorState.update(accumulators)
- lastTriggeringTsState.update(timestamp)
- }
- }
-}
-
-
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
new file mode 100644
index 0000000..03ca02c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRangeOver.scala
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util.{List => JList, ArrayList => JArrayList}
+
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function for RANGE clause event-time bounded OVER window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param aggregationStateType row type info of aggregation
+ * @param inputRowType row type info of input row
+ * @param precedingOffset preceding offset
+ */
+class RowTimeBoundedRangeOver(
+ genAggregations: GeneratedAggregationsFunction,
+ aggregationStateType: RowTypeInfo,
+ inputRowType: RowTypeInfo,
+ precedingOffset: Long)
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+ Preconditions.checkNotNull(aggregationStateType)
+ Preconditions.checkNotNull(precedingOffset)
+
+ private var output: Row = _
+
+ // the state which keeps the last triggering timestamp
+ private var lastTriggeringTsState: ValueState[Long] = _
+
+ // the state which used to materialize the accumulator for incremental calculation
+ private var accumulatorState: ValueState[Row] = _
+
+ // the state which keeps all the data that are not expired.
+ // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
+ // the second element of tuple is a list that contains the entire data of all the rows belonging
+ // to this time stamp.
+ private var dataState: MapState[Long, JList[Row]] = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+
+ val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
+ lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
+
+ val accumulatorStateDescriptor =
+ new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
+ accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
+
+ val keyTypeInformation: TypeInformation[Long] =
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
+ val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
+
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]](
+ "dataState",
+ keyTypeInformation,
+ valueTypeInformation)
+
+ dataState = getRuntimeContext.getMapState(mapStateDescriptor)
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ // triggering timestamp for trigger calculation
+ val triggeringTs = ctx.timestamp
+
+ val lastTriggeringTs = lastTriggeringTsState.value
+
+ // check if the data is expired, if not, save the data and register event time timer
+ if (triggeringTs > lastTriggeringTs) {
+ val data = dataState.get(triggeringTs)
+ if (null != data) {
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ } else {
+ val data = new JArrayList[Row]
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ // register event time timer
+ ctx.timerService.registerEventTimeTimer(triggeringTs)
+ }
+ }
+ }
+
+ override def onTimer(
+ timestamp: Long,
+ ctx: ProcessFunction[Row, Row]#OnTimerContext,
+ out: Collector[Row]): Unit = {
+ // gets all window data from state for the calculation
+ val inputs: JList[Row] = dataState.get(timestamp)
+
+ if (null != inputs) {
+
+ var accumulators = accumulatorState.value
+ var dataListIndex = 0
+ var aggregatesIndex = 0
+
+ // initialize when first run or failover recovery per key
+ if (null == accumulators) {
+ accumulators = function.createAccumulators()
+ aggregatesIndex = 0
+ }
+
+ // keep up timestamps of retract data
+ val retractTsList: JList[Long] = new JArrayList[Long]
+
+ // do retraction
+ val dataTimestampIt = dataState.keys.iterator
+ while (dataTimestampIt.hasNext) {
+ val dataTs: Long = dataTimestampIt.next()
+ val offset = timestamp - dataTs
+ if (offset > precedingOffset) {
+ val retractDataList = dataState.get(dataTs)
+ dataListIndex = 0
+ while (dataListIndex < retractDataList.size()) {
+ val retractRow = retractDataList.get(dataListIndex)
+ function.retract(accumulators, retractRow)
+ dataListIndex += 1
+ }
+ retractTsList.add(dataTs)
+ }
+ }
+
+ // do accumulation
+ dataListIndex = 0
+ while (dataListIndex < inputs.size()) {
+ val curRow = inputs.get(dataListIndex)
+ // accumulate current row
+ function.accumulate(accumulators, curRow)
+ dataListIndex += 1
+ }
+
+ // set aggregate in output row
+ function.setAggregationResults(accumulators, output)
+
+ // copy forwarded fields to output row and emit output row
+ dataListIndex = 0
+ while (dataListIndex < inputs.size()) {
+ aggregatesIndex = 0
+ function.setForwardedFields(inputs.get(dataListIndex), output)
+ out.collect(output)
+ dataListIndex += 1
+ }
+
+ // remove the data that has been retracted
+ dataListIndex = 0
+ while (dataListIndex < retractTsList.size) {
+ dataState.remove(retractTsList.get(dataListIndex))
+ dataListIndex += 1
+ }
+
+ // update state
+ accumulatorState.update(accumulators)
+ lastTriggeringTsState.update(timestamp)
+ }
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/flink/blob/07f1b035/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
new file mode 100644
index 0000000..4a9a14c
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/RowTimeBoundedRowsOver.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime.aggregate
+
+import java.util
+import java.util.{List => JList}
+
+import org.apache.flink.api.common.state._
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.{ListTypeInfo, RowTypeInfo}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.functions.ProcessFunction
+import org.apache.flink.types.Row
+import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.table.codegen.{GeneratedAggregationsFunction, Compiler}
+import org.slf4j.LoggerFactory
+
+/**
+ * Process Function for ROWS clause event-time bounded OVER window
+ *
+ * @param genAggregations Generated aggregate helper function
+ * @param aggregationStateType row type info of aggregation
+ * @param inputRowType row type info of input row
+ * @param precedingOffset preceding offset
+ */
+class RowTimeBoundedRowsOver(
+ genAggregations: GeneratedAggregationsFunction,
+ aggregationStateType: RowTypeInfo,
+ inputRowType: RowTypeInfo,
+ precedingOffset: Long)
+ extends ProcessFunction[Row, Row]
+ with Compiler[GeneratedAggregations] {
+
+ Preconditions.checkNotNull(aggregationStateType)
+ Preconditions.checkNotNull(precedingOffset)
+
+ private var output: Row = _
+
+ // the state which keeps the last triggering timestamp
+ private var lastTriggeringTsState: ValueState[Long] = _
+
+ // the state which keeps the count of data
+ private var dataCountState: ValueState[Long] = _
+
+ // the state which used to materialize the accumulator for incremental calculation
+ private var accumulatorState: ValueState[Row] = _
+
+ // the state which keeps all the data that are not expired.
+ // The first element (as the mapState key) of the tuple is the time stamp. Per each time stamp,
+ // the second element of tuple is a list that contains the entire data of all the rows belonging
+ // to this time stamp.
+ private var dataState: MapState[Long, JList[Row]] = _
+
+ val LOG = LoggerFactory.getLogger(this.getClass)
+ private var function: GeneratedAggregations = _
+
+ override def open(config: Configuration) {
+ LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+ s"Code:\n$genAggregations.code")
+ val clazz = compile(
+ getRuntimeContext.getUserCodeClassLoader,
+ genAggregations.name,
+ genAggregations.code)
+ LOG.debug("Instantiating AggregateHelper.")
+ function = clazz.newInstance()
+
+ output = function.createOutputRow()
+
+ val lastTriggeringTsDescriptor: ValueStateDescriptor[Long] =
+ new ValueStateDescriptor[Long]("lastTriggeringTsState", classOf[Long])
+ lastTriggeringTsState = getRuntimeContext.getState(lastTriggeringTsDescriptor)
+
+ val dataCountStateDescriptor =
+ new ValueStateDescriptor[Long]("dataCountState", classOf[Long])
+ dataCountState = getRuntimeContext.getState(dataCountStateDescriptor)
+
+ val accumulatorStateDescriptor =
+ new ValueStateDescriptor[Row]("accumulatorState", aggregationStateType)
+ accumulatorState = getRuntimeContext.getState(accumulatorStateDescriptor)
+
+ val keyTypeInformation: TypeInformation[Long] =
+ BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]]
+ val valueTypeInformation: TypeInformation[JList[Row]] = new ListTypeInfo[Row](inputRowType)
+
+ val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] =
+ new MapStateDescriptor[Long, JList[Row]](
+ "dataState",
+ keyTypeInformation,
+ valueTypeInformation)
+
+ dataState = getRuntimeContext.getMapState(mapStateDescriptor)
+ }
+
+ override def processElement(
+ input: Row,
+ ctx: ProcessFunction[Row, Row]#Context,
+ out: Collector[Row]): Unit = {
+
+ // triggering timestamp for trigger calculation
+ val triggeringTs = ctx.timestamp
+
+ val lastTriggeringTs = lastTriggeringTsState.value
+ // check if the data is expired, if not, save the data and register event time timer
+
+ if (triggeringTs > lastTriggeringTs) {
+ val data = dataState.get(triggeringTs)
+ if (null != data) {
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ } else {
+ val data = new util.ArrayList[Row]
+ data.add(input)
+ dataState.put(triggeringTs, data)
+ // register event time timer
+ ctx.timerService.registerEventTimeTimer(triggeringTs)
+ }
+ }
+ }
+
+ override def onTimer(
+ timestamp: Long,
+ ctx: ProcessFunction[Row, Row]#OnTimerContext,
+ out: Collector[Row]): Unit = {
+
+ // gets all window data from state for the calculation
+ val inputs: JList[Row] = dataState.get(timestamp)
+
+ if (null != inputs) {
+
+ var accumulators = accumulatorState.value
+ var dataCount = dataCountState.value
+
+ var retractList: JList[Row] = null
+ var retractTs: Long = Long.MaxValue
+ var retractCnt: Int = 0
+ var i = 0
+
+ while (i < inputs.size) {
+ val input = inputs.get(i)
+
+ // initialize when first run or failover recovery per key
+ if (null == accumulators) {
+ accumulators = function.createAccumulators()
+ }
+
+ var retractRow: Row = null
+
+ if (dataCount >= precedingOffset) {
+ if (null == retractList) {
+ // find the smallest timestamp
+ retractTs = Long.MaxValue
+ val dataTimestampIt = dataState.keys.iterator
+ while (dataTimestampIt.hasNext) {
+ val dataTs = dataTimestampIt.next
+ if (dataTs < retractTs) {
+ retractTs = dataTs
+ }
+ }
+ // get the oldest rows to retract them
+ retractList = dataState.get(retractTs)
+ }
+
+ retractRow = retractList.get(retractCnt)
+ retractCnt += 1
+
+ // remove retracted values from state
+ if (retractList.size == retractCnt) {
+ dataState.remove(retractTs)
+ retractList = null
+ retractCnt = 0
+ }
+ } else {
+ dataCount += 1
+ }
+
+ // copy forwarded fields to output row
+ function.setForwardedFields(input, output)
+
+ // retract old row from accumulators
+ if (null != retractRow) {
+ function.retract(accumulators, retractRow)
+ }
+
+ // accumulate current row and set aggregate in output row
+ function.accumulate(accumulators, input)
+ function.setAggregationResults(accumulators, output)
+ i += 1
+
+ out.collect(output)
+ }
+
+ // update all states
+ if (dataState.contains(retractTs)) {
+ if (retractCnt > 0) {
+ retractList.subList(0, retractCnt).clear()
+ dataState.put(retractTs, retractList)
+ }
+ }
+ dataCountState.update(dataCount)
+ accumulatorState.update(accumulators)
+ }
+
+ lastTriggeringTsState.update(timestamp)
+ }
+}
+
+
[4/5] flink git commit: [FLINK-5435] [table] Remove
FlinkAggregateJoinTransposeRule and FlinkRelDecorrelator after bumping
Calcite to v1.12.
Posted by fh...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/sql2rel/FlinkRelDecorrelator.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/sql2rel/FlinkRelDecorrelator.java b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/sql2rel/FlinkRelDecorrelator.java
deleted file mode 100644
index 0179192..0000000
--- a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/sql2rel/FlinkRelDecorrelator.java
+++ /dev/null
@@ -1,2216 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.calcite.sql2rel;
-
-import com.google.common.base.Supplier;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.ImmutableSortedMap;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Multimaps;
-import com.google.common.collect.Sets;
-import com.google.common.collect.SortedSetMultimap;
-import org.apache.calcite.linq4j.Ord;
-import org.apache.calcite.linq4j.function.Function2;
-import org.apache.calcite.plan.Context;
-import org.apache.calcite.plan.RelOptCluster;
-import org.apache.calcite.plan.RelOptCostImpl;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.RelOptUtil;
-import org.apache.calcite.plan.hep.HepPlanner;
-import org.apache.calcite.plan.hep.HepProgram;
-import org.apache.calcite.plan.hep.HepRelVertex;
-import org.apache.calcite.rel.BiRel;
-import org.apache.calcite.rel.RelCollation;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.RelShuttleImpl;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.Correlate;
-import org.apache.calcite.rel.core.CorrelationId;
-import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.core.Project;
-import org.apache.calcite.rel.core.RelFactories;
-import org.apache.calcite.rel.core.Sort;
-import org.apache.calcite.rel.core.Values;
-import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rel.logical.LogicalCorrelate;
-import org.apache.calcite.rel.logical.LogicalFilter;
-import org.apache.calcite.rel.logical.LogicalJoin;
-import org.apache.calcite.rel.logical.LogicalProject;
-import org.apache.calcite.rel.logical.LogicalSort;
-import org.apache.calcite.rel.metadata.RelMdUtil;
-import org.apache.calcite.rel.metadata.RelMetadataQuery;
-import org.apache.calcite.rel.rules.FilterCorrelateRule;
-import org.apache.calcite.rel.rules.FilterJoinRule;
-import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexCorrelVariable;
-import org.apache.calcite.rex.RexFieldAccess;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexShuttle;
-import org.apache.calcite.rex.RexSubQuery;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.rex.RexVisitorImpl;
-import org.apache.calcite.sql.SqlExplainLevel;
-import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.fun.SqlCountAggFunction;
-import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.tools.RelBuilder;
-import org.apache.calcite.util.Bug;
-import org.apache.calcite.util.Holder;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.Litmus;
-import org.apache.calcite.util.Pair;
-import org.apache.calcite.util.ReflectUtil;
-import org.apache.calcite.util.ReflectiveVisitor;
-import org.apache.calcite.util.Util;
-import org.apache.calcite.util.mapping.Mappings;
-import org.apache.calcite.util.trace.CalciteTrace;
-import org.apache.flink.util.Preconditions;
-import org.slf4j.Logger;
-
-import java.math.BigDecimal;
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.NavigableMap;
-import java.util.Objects;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
-import java.util.TreeSet;
-
-/**
- * Copied from {@link org.apache.calcite.sql2rel.RelDecorrelator}, should be
- * removed once <a href="https://issues.apache.org/jira/browse/CALCITE-1543">[CALCITE-1543] fixes.
- */
-public class FlinkRelDecorrelator implements ReflectiveVisitor {
- //~ Static fields/initializers ---------------------------------------------
-
- private static final Logger SQL2REL_LOGGER = CalciteTrace.getSqlToRelTracer();
-
- //~ Instance fields --------------------------------------------------------
-
- private final RelBuilder relBuilder;
-
- // map built during translation
- private CorelMap cm;
-
- private final ReflectUtil.MethodDispatcher<Frame> dispatcher = ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel", RelNode.class);
-
- private final RexBuilder rexBuilder;
-
- // The rel which is being visited
- private RelNode currentRel;
-
- private final Context context;
-
- /**
- * Built during decorrelation, of rel to all the newly created correlated
- * variables in its output, and to map old input positions to new input
- * positions. This is from the view point of the parent rel of a new rel.
- */
- private final Map<RelNode, Frame> map = new HashMap<>();
-
- private final HashSet<LogicalCorrelate> generatedCorRels = Sets.newHashSet();
-
- //~ Constructors -----------------------------------------------------------
-
- private FlinkRelDecorrelator(RelOptCluster cluster, CorelMap cm, Context context) {
- this.cm = cm;
- this.rexBuilder = cluster.getRexBuilder();
- this.context = context;
- relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null);
-
- }
-
- //~ Methods ----------------------------------------------------------------
-
- /**
- * Decorrelates a query.
- * <p>
- * <p>This is the main entry point to {@code FlinkRelDecorrelator}.
- *
- * @param rootRel Root node of the query
- * @return Equivalent query with all
- * {@link LogicalCorrelate} instances removed
- */
- public static RelNode decorrelateQuery(RelNode rootRel) {
- final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
- if (!corelMap.hasCorrelation()) {
- return rootRel;
- }
-
- final RelOptCluster cluster = rootRel.getCluster();
- final FlinkRelDecorrelator decorrelator = new FlinkRelDecorrelator(cluster, corelMap, cluster.getPlanner().getContext());
-
- RelNode newRootRel = decorrelator.removeCorrelationViaRule(rootRel);
-
- if (SQL2REL_LOGGER.isDebugEnabled()) {
- SQL2REL_LOGGER.debug(RelOptUtil.dumpPlan("Plan after removing Correlator", newRootRel, false, SqlExplainLevel.EXPPLAN_ATTRIBUTES));
- }
-
- if (!decorrelator.cm.mapCorVarToCorRel.isEmpty()) {
- newRootRel = decorrelator.decorrelate(newRootRel);
- }
-
- return newRootRel;
- }
-
- private void setCurrent(RelNode root, LogicalCorrelate corRel) {
- currentRel = corRel;
- if (corRel != null) {
- cm = new CorelMapBuilder().build(Util.first(root, corRel));
- }
- }
-
- private RelNode decorrelate(RelNode root) {
- // first adjust count() expression if any
- HepProgram program = HepProgram.builder().addRuleInstance(new AdjustProjectForCountAggregateRule(false)).addRuleInstance(new AdjustProjectForCountAggregateRule(true)).addRuleInstance(FilterJoinRule.FILTER_ON_JOIN).addRuleInstance(FilterProjectTransposeRule.INSTANCE).addRuleInstance(FilterCorrelateRule.INSTANCE).build();
-
- HepPlanner planner = createPlanner(program);
-
- planner.setRoot(root);
- root = planner.findBestExp();
-
- // Perform decorrelation.
- map.clear();
-
- final Frame frame = getInvoke(root, null);
- if (frame != null) {
- // has been rewritten; apply rules post-decorrelation
- final HepProgram program2 = HepProgram.builder().addRuleInstance(FilterJoinRule.FILTER_ON_JOIN).addRuleInstance(FilterJoinRule.JOIN).build();
-
- final HepPlanner planner2 = createPlanner(program2);
- final RelNode newRoot = frame.r;
- planner2.setRoot(newRoot);
- return planner2.findBestExp();
- }
-
- return root;
- }
-
- private Function2<RelNode, RelNode, Void> createCopyHook() {
- return new Function2<RelNode, RelNode, Void>() {
- public Void apply(RelNode oldNode, RelNode newNode) {
- if (cm.mapRefRelToCorVar.containsKey(oldNode)) {
- cm.mapRefRelToCorVar.putAll(newNode, cm.mapRefRelToCorVar.get(oldNode));
- }
- if (oldNode instanceof LogicalCorrelate && newNode instanceof LogicalCorrelate) {
- LogicalCorrelate oldCor = (LogicalCorrelate) oldNode;
- CorrelationId c = oldCor.getCorrelationId();
- if (cm.mapCorVarToCorRel.get(c) == oldNode) {
- cm.mapCorVarToCorRel.put(c, newNode);
- }
-
- if (generatedCorRels.contains(oldNode)) {
- generatedCorRels.add((LogicalCorrelate) newNode);
- }
- }
- return null;
- }
- };
- }
-
- private HepPlanner createPlanner(HepProgram program) {
- // Create a planner with a hook to update the mapping tables when a
- // node is copied when it is registered.
- return new HepPlanner(program, context, true, createCopyHook(), RelOptCostImpl.FACTORY);
- }
-
- public RelNode removeCorrelationViaRule(RelNode root) {
- HepProgram program = HepProgram.builder().addRuleInstance(new RemoveSingleAggregateRule()).addRuleInstance(new RemoveCorrelationForScalarProjectRule()).addRuleInstance(new RemoveCorrelationForScalarAggregateRule()).build();
-
- HepPlanner planner = createPlanner(program);
-
- planner.setRoot(root);
- return planner.findBestExp();
- }
-
- protected RexNode decorrelateExpr(RexNode exp) {
- DecorrelateRexShuttle shuttle = new DecorrelateRexShuttle();
- return exp.accept(shuttle);
- }
-
- protected RexNode removeCorrelationExpr(RexNode exp, boolean projectPulledAboveLeftCorrelator) {
- RemoveCorrelationRexShuttle shuttle = new RemoveCorrelationRexShuttle(rexBuilder, projectPulledAboveLeftCorrelator, null, ImmutableSet.<Integer>of());
- return exp.accept(shuttle);
- }
-
- protected RexNode removeCorrelationExpr(RexNode exp, boolean projectPulledAboveLeftCorrelator, RexInputRef nullIndicator) {
- RemoveCorrelationRexShuttle shuttle = new RemoveCorrelationRexShuttle(rexBuilder, projectPulledAboveLeftCorrelator, nullIndicator, ImmutableSet.<Integer>of());
- return exp.accept(shuttle);
- }
-
- protected RexNode removeCorrelationExpr(RexNode exp, boolean projectPulledAboveLeftCorrelator, Set<Integer> isCount) {
- RemoveCorrelationRexShuttle shuttle = new RemoveCorrelationRexShuttle(rexBuilder, projectPulledAboveLeftCorrelator, null, isCount);
- return exp.accept(shuttle);
- }
-
- /**
- * Fallback if none of the other {@code decorrelateRel} methods match.
- */
- public Frame decorrelateRel(RelNode rel) {
- RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());
-
- if (rel.getInputs().size() > 0) {
- List<RelNode> oldInputs = rel.getInputs();
- List<RelNode> newInputs = Lists.newArrayList();
- for (int i = 0; i < oldInputs.size(); ++i) {
- final Frame frame = getInvoke(oldInputs.get(i), rel);
- if (frame == null || !frame.corVarOutputPos.isEmpty()) {
- // if input is not rewritten, or if it produces correlated
- // variables, terminate rewrite
- return null;
- }
- newInputs.add(frame.r);
- newRel.replaceInput(i, frame.r);
- }
-
- if (!Util.equalShallow(oldInputs, newInputs)) {
- newRel = rel.copy(rel.getTraitSet(), newInputs);
- }
- }
-
- // the output position should not change since there are no corVars
- // coming from below.
- return register(rel, newRel, identityMap(rel.getRowType().getFieldCount()), ImmutableSortedMap.<Correlation, Integer>of());
- }
-
- /**
- * Rewrite Sort.
- *
- * @param rel Sort to be rewritten
- */
- public Frame decorrelateRel(Sort rel) {
- //
- // Rewrite logic:
- //
- // 1. change the collations field to reference the new input.
- //
-
- // Sort itself should not reference cor vars.
- assert !cm.mapRefRelToCorVar.containsKey(rel);
-
- // Sort only references field positions in collations field.
- // The collations field in the newRel now need to refer to the
- // new output positions in its input.
- // Its output does not change the input ordering, so there's no
- // need to call propagateExpr.
-
- final RelNode oldInput = rel.getInput();
- final Frame frame = getInvoke(oldInput, rel);
- if (frame == null) {
- // If input has not been rewritten, do not rewrite this rel.
- return null;
- }
- final RelNode newInput = frame.r;
-
- Mappings.TargetMapping mapping = Mappings.target(frame.oldToNewOutputPos, oldInput.getRowType().getFieldCount(), newInput.getRowType().getFieldCount());
-
- RelCollation oldCollation = rel.getCollation();
- RelCollation newCollation = RexUtil.apply(mapping, oldCollation);
-
- final Sort newSort = LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
-
- // Sort does not change input ordering
- return register(rel, newSort, frame.oldToNewOutputPos, frame.corVarOutputPos);
- }
-
- /**
- * Rewrites a {@link Values}.
- *
- * @param rel Values to be rewritten
- */
- public Frame decorrelateRel(Values rel) {
- // There are no inputs, so rel does not need to be changed.
- return null;
- }
-
- /**
- * Rewrites a {@link LogicalAggregate}.
- *
- * @param rel Aggregate to rewrite
- */
- public Frame decorrelateRel(LogicalAggregate rel) {
- if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
- throw new AssertionError(Bug.CALCITE_461_FIXED);
- }
- //
- // Rewrite logic:
- //
- // 1. Permute the group by keys to the front.
- // 2. If the input of an aggregate produces correlated variables,
- // add them to the group list.
- // 3. Change aggCalls to reference the new project.
- //
-
- // Aggregate itself should not reference cor vars.
- assert !cm.mapRefRelToCorVar.containsKey(rel);
-
- final RelNode oldInput = rel.getInput();
- final Frame frame = getInvoke(oldInput, rel);
- if (frame == null) {
- // If input has not been rewritten, do not rewrite this rel.
- return null;
- }
- final RelNode newInput = frame.r;
-
- // map from newInput
- Map<Integer, Integer> mapNewInputToProjOutputPos = Maps.newHashMap();
- final int oldGroupKeyCount = rel.getGroupSet().cardinality();
-
- // Project projects the original expressions,
- // plus any correlated variables the input wants to pass along.
- final List<Pair<RexNode, String>> projects = Lists.newArrayList();
-
- List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
-
- int newPos = 0;
-
- // oldInput has the original group by keys in the front.
- final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
- for (int i = 0; i < oldGroupKeyCount; i++) {
- final RexLiteral constant = projectedLiteral(newInput, i);
- if (constant != null) {
- // Exclude constants. Aggregate({true}) occurs because Aggregate({})
- // would generate 1 row even when applied to an empty table.
- omittedConstants.put(i, constant);
- continue;
- }
- int newInputPos = frame.oldToNewOutputPos.get(i);
- projects.add(RexInputRef.of2(newInputPos, newInputOutput));
- mapNewInputToProjOutputPos.put(newInputPos, newPos);
- newPos++;
- }
-
- final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>();
- if (!frame.corVarOutputPos.isEmpty()) {
- // If input produces correlated variables, move them to the front,
- // right after any existing GROUP BY fields.
-
- // Now add the corVars from the input, starting from
- // position oldGroupKeyCount.
- for (Map.Entry<Correlation, Integer> entry : frame.corVarOutputPos.entrySet()) {
- projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
-
- mapCorVarToOutputPos.put(entry.getKey(), newPos);
- mapNewInputToProjOutputPos.put(entry.getValue(), newPos);
- newPos++;
- }
- }
-
- // add the remaining fields
- final int newGroupKeyCount = newPos;
- for (int i = 0; i < newInputOutput.size(); i++) {
- if (!mapNewInputToProjOutputPos.containsKey(i)) {
- projects.add(RexInputRef.of2(i, newInputOutput));
- mapNewInputToProjOutputPos.put(i, newPos);
- newPos++;
- }
- }
-
- assert newPos == newInputOutput.size();
-
- // This Project will be what the old input maps to,
- // replacing any previous mapping from old input).
- RelNode newProject = RelOptUtil.createProject(newInput, projects, false);
-
- // update mappings:
- // oldInput ----> newInput
- //
- // newProject
- // |
- // oldInput ----> newInput
- //
- // is transformed to
- //
- // oldInput ----> newProject
- // |
- // newInput
- Map<Integer, Integer> combinedMap = Maps.newHashMap();
-
- for (Integer oldInputPos : frame.oldToNewOutputPos.keySet()) {
- combinedMap.put(oldInputPos, mapNewInputToProjOutputPos.get(frame.oldToNewOutputPos.get(oldInputPos)));
- }
-
- register(oldInput, newProject, combinedMap, mapCorVarToOutputPos);
-
- // now it's time to rewrite the Aggregate
- final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
- List<AggregateCall> newAggCalls = Lists.newArrayList();
- List<AggregateCall> oldAggCalls = rel.getAggCallList();
-
- int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
- int newInputOutputFieldCount = newGroupSet.cardinality();
-
- int i = -1;
- for (AggregateCall oldAggCall : oldAggCalls) {
- ++i;
- List<Integer> oldAggArgs = oldAggCall.getArgList();
-
- List<Integer> aggArgs = Lists.newArrayList();
-
- // Adjust the aggregator argument positions.
- // Note aggregator does not change input ordering, so the input
- // output position mapping can be used to derive the new positions
- // for the argument.
- for (int oldPos : oldAggArgs) {
- aggArgs.add(combinedMap.get(oldPos));
- }
- final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg : combinedMap.get(oldAggCall.filterArg);
-
- newAggCalls.add(oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
-
- // The old to new output position mapping will be the same as that
- // of newProject, plus any aggregates that the oldAgg produces.
- combinedMap.put(oldInputOutputFieldCount + i, newInputOutputFieldCount + i);
- }
-
- relBuilder.push(LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
-
- if (!omittedConstants.isEmpty()) {
- final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
- for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.descendingMap().entrySet()) {
- postProjects.add(entry.getKey() + frame.corVarOutputPos.size(), entry.getValue());
- }
- relBuilder.project(postProjects);
- }
-
- // Aggregate does not change input ordering so corVars will be
- // located at the same position as the input newProject.
- return register(rel, relBuilder.build(), combinedMap, mapCorVarToOutputPos);
- }
-
- public Frame getInvoke(RelNode r, RelNode parent) {
- final Frame frame = dispatcher.invoke(r);
- if (frame != null) {
- map.put(r, frame);
- }
- currentRel = parent;
- return frame;
- }
-
- /**
- * Returns a literal output field, or null if it is not literal.
- */
- private static RexLiteral projectedLiteral(RelNode rel, int i) {
- if (rel instanceof Project) {
- final Project project = (Project) rel;
- final RexNode node = project.getProjects().get(i);
- if (node instanceof RexLiteral) {
- return (RexLiteral) node;
- }
- }
- return null;
- }
-
- /**
- * Rewrite LogicalProject.
- *
- * @param rel the project rel to rewrite
- */
- public Frame decorrelateRel(LogicalProject rel) {
- //
- // Rewrite logic:
- //
- // 1. Pass along any correlated variables coming from the input.
- //
-
- final RelNode oldInput = rel.getInput();
- Frame frame = getInvoke(oldInput, rel);
- if (frame == null) {
- // If input has not been rewritten, do not rewrite this rel.
- return null;
- }
- final List<RexNode> oldProjects = rel.getProjects();
- final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
-
- // LogicalProject projects the original expressions,
- // plus any correlated variables the input wants to pass along.
- final List<Pair<RexNode, String>> projects = Lists.newArrayList();
-
- // If this LogicalProject has correlated reference, create value generator
- // and produce the correlated variables in the new output.
- if (cm.mapRefRelToCorVar.containsKey(rel)) {
- decorrelateInputWithValueGenerator(rel);
-
- // The old input should be mapped to the LogicalJoin created by
- // rewriteInputWithValueGenerator().
- frame = map.get(oldInput);
- }
-
- // LogicalProject projects the original expressions
- final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
- int newPos;
- for (newPos = 0; newPos < oldProjects.size(); newPos++) {
- projects.add(newPos, Pair.of(decorrelateExpr(oldProjects.get(newPos)), relOutput.get(newPos).getName()));
- mapOldToNewOutputPos.put(newPos, newPos);
- }
-
- // Project any correlated variables the input wants to pass along.
- final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>();
- for (Map.Entry<Correlation, Integer> entry : frame.corVarOutputPos.entrySet()) {
- projects.add(RexInputRef.of2(entry.getValue(), frame.r.getRowType().getFieldList()));
- mapCorVarToOutputPos.put(entry.getKey(), newPos);
- newPos++;
- }
-
- RelNode newProject = RelOptUtil.createProject(frame.r, projects, false);
-
- return register(rel, newProject, mapOldToNewOutputPos, mapCorVarToOutputPos);
- }
-
- /**
- * Create RelNode tree that produces a list of correlated variables.
- *
- * @param correlations correlated variables to generate
- * @param valueGenFieldOffset offset in the output that generated columns
- * will start
- * @param mapCorVarToOutputPos output positions for the correlated variables
- * generated
- * @return RelNode the root of the resultant RelNode tree
- */
- private RelNode createValueGenerator(Iterable<Correlation> correlations, int valueGenFieldOffset, SortedMap<Correlation, Integer> mapCorVarToOutputPos) {
- final Map<RelNode, List<Integer>> mapNewInputToOutputPos = new HashMap<>();
-
- final Map<RelNode, Integer> mapNewInputToNewOffset = new HashMap<>();
-
- // Input provides the definition of a correlated variable.
- // Add to map all the referenced positions (relative to each input rel).
- for (Correlation corVar : correlations) {
- final int oldCorVarOffset = corVar.field;
-
- final RelNode oldInput = getCorRel(corVar);
- assert oldInput != null;
- final Frame frame = map.get(oldInput);
- assert frame != null;
- final RelNode newInput = frame.r;
-
- final List<Integer> newLocalOutputPosList;
- if (!mapNewInputToOutputPos.containsKey(newInput)) {
- newLocalOutputPosList = Lists.newArrayList();
- } else {
- newLocalOutputPosList = mapNewInputToOutputPos.get(newInput);
- }
-
- final int newCorVarOffset = frame.oldToNewOutputPos.get(oldCorVarOffset);
-
- // Add all unique positions referenced.
- if (!newLocalOutputPosList.contains(newCorVarOffset)) {
- newLocalOutputPosList.add(newCorVarOffset);
- }
- mapNewInputToOutputPos.put(newInput, newLocalOutputPosList);
- }
-
- int offset = 0;
-
- // Project only the correlated fields out of each inputRel
- // and join the projectRel together.
- // To make sure the plan does not change in terms of join order,
- // join these rels based on their occurrence in cor var list which
- // is sorted.
- final Set<RelNode> joinedInputRelSet = Sets.newHashSet();
-
- RelNode r = null;
- for (Correlation corVar : correlations) {
- final RelNode oldInput = getCorRel(corVar);
- assert oldInput != null;
- final RelNode newInput = map.get(oldInput).r;
- assert newInput != null;
-
- if (!joinedInputRelSet.contains(newInput)) {
- RelNode project = RelOptUtil.createProject(newInput, mapNewInputToOutputPos.get(newInput));
- RelNode distinct = RelOptUtil.createDistinctRel(project);
- RelOptCluster cluster = distinct.getCluster();
-
- joinedInputRelSet.add(newInput);
- mapNewInputToNewOffset.put(newInput, offset);
- offset += distinct.getRowType().getFieldCount();
-
- if (r == null) {
- r = distinct;
- } else {
- r = LogicalJoin.create(r, distinct, cluster.getRexBuilder().makeLiteral(true), ImmutableSet.<CorrelationId>of(), JoinRelType.INNER);
- }
- }
- }
-
- // Translate the positions of correlated variables to be relative to
- // the join output, leaving room for valueGenFieldOffset because
- // valueGenerators are joined with the original left input of the rel
- // referencing correlated variables.
- for (Correlation corVar : correlations) {
- // The first input of a Correlator is always the rel defining
- // the correlated variables.
- final RelNode oldInput = getCorRel(corVar);
- assert oldInput != null;
- final Frame frame = map.get(oldInput);
- final RelNode newInput = frame.r;
- assert newInput != null;
-
- final List<Integer> newLocalOutputPosList = mapNewInputToOutputPos.get(newInput);
-
- final int newLocalOutputPos = frame.oldToNewOutputPos.get(corVar.field);
-
- // newOutputPos is the index of the cor var in the referenced
- // position list plus the offset of referenced position list of
- // each newInput.
- final int newOutputPos = newLocalOutputPosList.indexOf(newLocalOutputPos) + mapNewInputToNewOffset.get(newInput) + valueGenFieldOffset;
-
- if (mapCorVarToOutputPos.containsKey(corVar)) {
- assert mapCorVarToOutputPos.get(corVar) == newOutputPos;
- }
- mapCorVarToOutputPos.put(corVar, newOutputPos);
- }
-
- return r;
- }
-
- private RelNode getCorRel(Correlation corVar) {
- final RelNode r = cm.mapCorVarToCorRel.get(corVar.corr);
- return r.getInput(0);
- }
-
- private void decorrelateInputWithValueGenerator(RelNode rel) {
- // currently only handles one input input
- assert rel.getInputs().size() == 1;
- RelNode oldInput = rel.getInput(0);
- final Frame frame = map.get(oldInput);
-
- final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>(frame.corVarOutputPos);
-
- final Collection<Correlation> corVarList = cm.mapRefRelToCorVar.get(rel);
-
- int leftInputOutputCount = frame.r.getRowType().getFieldCount();
-
- // can directly add positions into mapCorVarToOutputPos since join
- // does not change the output ordering from the inputs.
- RelNode valueGen = createValueGenerator(corVarList, leftInputOutputCount, mapCorVarToOutputPos);
-
- RelNode join = LogicalJoin.create(frame.r, valueGen, rexBuilder.makeLiteral(true), ImmutableSet.<CorrelationId>of(), JoinRelType.INNER);
-
- // LogicalJoin or LogicalFilter does not change the old input ordering. All
- // input fields from newLeftInput(i.e. the original input to the old
- // LogicalFilter) are in the output and in the same position.
- register(oldInput, join, frame.oldToNewOutputPos, mapCorVarToOutputPos);
- }
-
- /**
- * Rewrite LogicalFilter.
- *
- * @param rel the filter rel to rewrite
- */
- public Frame decorrelateRel(LogicalFilter rel) {
- //
- // Rewrite logic:
- //
- // 1. If a LogicalFilter references a correlated field in its filter
- // condition, rewrite the LogicalFilter to be
- // LogicalFilter
- // LogicalJoin(cross product)
- // OriginalFilterInput
- // ValueGenerator(produces distinct sets of correlated variables)
- // and rewrite the correlated fieldAccess in the filter condition to
- // reference the LogicalJoin output.
- //
- // 2. If LogicalFilter does not reference correlated variables, simply
- // rewrite the filter condition using new input.
- //
-
- final RelNode oldInput = rel.getInput();
- Frame frame = getInvoke(oldInput, rel);
- if (frame == null) {
- // If input has not been rewritten, do not rewrite this rel.
- return null;
- }
-
- // If this LogicalFilter has correlated reference, create value generator
- // and produce the correlated variables in the new output.
- if (cm.mapRefRelToCorVar.containsKey(rel)) {
- decorrelateInputWithValueGenerator(rel);
-
- // The old input should be mapped to the newly created LogicalJoin by
- // rewriteInputWithValueGenerator().
- frame = map.get(oldInput);
- }
-
- // Replace the filter expression to reference output of the join
- // Map filter to the new filter over join
- RelNode newFilter = RelOptUtil.createFilter(frame.r, decorrelateExpr(rel.getCondition()));
-
- // Filter does not change the input ordering.
- // Filter rel does not permute the input.
- // All corvars produced by filter will have the same output positions in the
- // input rel.
- return register(rel, newFilter, frame.oldToNewOutputPos, frame.corVarOutputPos);
- }
-
- /**
- * Rewrite Correlator into a left outer join.
- *
- * @param rel Correlator
- */
- public Frame decorrelateRel(LogicalCorrelate rel) {
- //
- // Rewrite logic:
- //
- // The original left input will be joined with the new right input that
- // has generated correlated variables propagated up. For any generated
- // cor vars that are not used in the join key, pass them along to be
- // joined later with the CorrelatorRels that produce them.
- //
-
- // the right input to Correlator should produce correlated variables
- final RelNode oldLeft = rel.getInput(0);
- final RelNode oldRight = rel.getInput(1);
-
- final Frame leftFrame = getInvoke(oldLeft, rel);
- final Frame rightFrame = getInvoke(oldRight, rel);
-
- if (leftFrame == null || rightFrame == null) {
- // If any input has not been rewritten, do not rewrite this rel.
- return null;
- }
-
- if (rightFrame.corVarOutputPos.isEmpty()) {
- return null;
- }
-
- assert rel.getRequiredColumns().cardinality() <= rightFrame.corVarOutputPos.keySet().size();
-
- // Change correlator rel into a join.
- // Join all the correlated variables produced by this correlator rel
- // with the values generated and propagated from the right input
- final SortedMap<Correlation, Integer> corVarOutputPos = new TreeMap<>(rightFrame.corVarOutputPos);
- final List<RexNode> conditions = new ArrayList<>();
- final List<RelDataTypeField> newLeftOutput = leftFrame.r.getRowType().getFieldList();
- int newLeftFieldCount = newLeftOutput.size();
-
- final List<RelDataTypeField> newRightOutput = rightFrame.r.getRowType().getFieldList();
-
- for (Map.Entry<Correlation, Integer> rightOutputPos : Lists.newArrayList(corVarOutputPos.entrySet())) {
- final Correlation corVar = rightOutputPos.getKey();
- if (!corVar.corr.equals(rel.getCorrelationId())) {
- continue;
- }
- final int newLeftPos = leftFrame.oldToNewOutputPos.get(corVar.field);
- final int newRightPos = rightOutputPos.getValue();
- conditions.add(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, RexInputRef.of(newLeftPos, newLeftOutput), new RexInputRef(newLeftFieldCount + newRightPos, newRightOutput.get(newRightPos).getType())));
-
- // remove this cor var from output position mapping
- corVarOutputPos.remove(corVar);
- }
-
- // Update the output position for the cor vars: only pass on the cor
- // vars that are not used in the join key.
- for (Correlation corVar : corVarOutputPos.keySet()) {
- int newPos = corVarOutputPos.get(corVar) + newLeftFieldCount;
- corVarOutputPos.put(corVar, newPos);
- }
-
- // then add any cor var from the left input. Do not need to change
- // output positions.
- corVarOutputPos.putAll(leftFrame.corVarOutputPos);
-
- // Create the mapping between the output of the old correlation rel
- // and the new join rel
- final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
-
- int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
-
- int oldRightFieldCount = oldRight.getRowType().getFieldCount();
- assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount;
-
- // Left input positions are not changed.
- mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos);
-
- // Right input positions are shifted by newLeftFieldCount.
- for (int i = 0; i < oldRightFieldCount; i++) {
- mapOldToNewOutputPos.put(i + oldLeftFieldCount, rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount);
- }
-
- final RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions, false);
- RelNode newJoin = LogicalJoin.create(leftFrame.r, rightFrame.r, condition, ImmutableSet.<CorrelationId>of(), rel.getJoinType().toJoinType());
-
- return register(rel, newJoin, mapOldToNewOutputPos, corVarOutputPos);
- }
-
- /**
- * Rewrite LogicalJoin.
- *
- * @param rel LogicalJoin
- */
- public Frame decorrelateRel(LogicalJoin rel) {
- //
- // Rewrite logic:
- //
- // 1. rewrite join condition.
- // 2. map output positions and produce cor vars if any.
- //
-
- final RelNode oldLeft = rel.getInput(0);
- final RelNode oldRight = rel.getInput(1);
-
- final Frame leftFrame = getInvoke(oldLeft, rel);
- final Frame rightFrame = getInvoke(oldRight, rel);
-
- if (leftFrame == null || rightFrame == null) {
- // If any input has not been rewritten, do not rewrite this rel.
- return null;
- }
-
- final RelNode newJoin = LogicalJoin.create(leftFrame.r, rightFrame.r, decorrelateExpr(rel.getCondition()), ImmutableSet.<CorrelationId>of(), rel.getJoinType());
-
- // Create the mapping between the output of the old correlation rel
- // and the new join rel
- Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
-
- int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
- int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
-
- int oldRightFieldCount = oldRight.getRowType().getFieldCount();
- assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount;
-
- // Left input positions are not changed.
- mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos);
-
- // Right input positions are shifted by newLeftFieldCount.
- for (int i = 0; i < oldRightFieldCount; i++) {
- mapOldToNewOutputPos.put(i + oldLeftFieldCount, rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount);
- }
-
- final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>(leftFrame.corVarOutputPos);
-
- // Right input positions are shifted by newLeftFieldCount.
- for (Map.Entry<Correlation, Integer> entry : rightFrame.corVarOutputPos.entrySet()) {
- mapCorVarToOutputPos.put(entry.getKey(), entry.getValue() + newLeftFieldCount);
- }
- return register(rel, newJoin, mapOldToNewOutputPos, mapCorVarToOutputPos);
- }
-
- private RexInputRef getNewForOldInputRef(RexInputRef oldInputRef) {
- assert currentRel != null;
-
- int oldOrdinal = oldInputRef.getIndex();
- int newOrdinal = 0;
-
- // determine which input rel oldOrdinal references, and adjust
- // oldOrdinal to be relative to that input rel
- RelNode oldInput = null;
-
- for (RelNode oldInput0 : currentRel.getInputs()) {
- RelDataType oldInputType = oldInput0.getRowType();
- int n = oldInputType.getFieldCount();
- if (oldOrdinal < n) {
- oldInput = oldInput0;
- break;
- }
- RelNode newInput = map.get(oldInput0).r;
- newOrdinal += newInput.getRowType().getFieldCount();
- oldOrdinal -= n;
- }
-
- assert oldInput != null;
-
- final Frame frame = map.get(oldInput);
- assert frame != null;
-
- // now oldOrdinal is relative to oldInput
- int oldLocalOrdinal = oldOrdinal;
-
- // figure out the newLocalOrdinal, relative to the newInput.
- int newLocalOrdinal = oldLocalOrdinal;
-
- if (!frame.oldToNewOutputPos.isEmpty()) {
- newLocalOrdinal = frame.oldToNewOutputPos.get(oldLocalOrdinal);
- }
-
- newOrdinal += newLocalOrdinal;
-
- return new RexInputRef(newOrdinal, frame.r.getRowType().getFieldList().get(newLocalOrdinal).getType());
- }
-
- /**
- * Pulls project above the join from its RHS input. Enforces nullability
- * for join output.
- *
- * @param join Join
- * @param project Original project as the right-hand input of the join
- * @param nullIndicatorPos Position of null indicator
- * @return the subtree with the new LogicalProject at the root
- */
- private RelNode projectJoinOutputWithNullability(LogicalJoin join, LogicalProject project, int nullIndicatorPos) {
- final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
- final RelNode left = join.getLeft();
- final JoinRelType joinType = join.getJoinType();
-
- RexInputRef nullIndicator = new RexInputRef(nullIndicatorPos, typeFactory.createTypeWithNullability(join.getRowType().getFieldList().get(nullIndicatorPos).getType(), true));
-
- // now create the new project
- List<Pair<RexNode, String>> newProjExprs = Lists.newArrayList();
-
- // project everything from the LHS and then those from the original
- // projRel
- List<RelDataTypeField> leftInputFields = left.getRowType().getFieldList();
-
- for (int i = 0; i < leftInputFields.size(); i++) {
- newProjExprs.add(RexInputRef.of2(i, leftInputFields));
- }
-
- // Marked where the projected expr is coming from so that the types will
- // become nullable for the original projections which are now coming out
- // of the nullable side of the OJ.
- boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight();
-
- for (Pair<RexNode, String> pair : project.getNamedProjects()) {
- RexNode newProjExpr = removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, nullIndicator);
-
- newProjExprs.add(Pair.of(newProjExpr, pair.right));
- }
-
- return RelOptUtil.createProject(join, newProjExprs, false);
- }
-
- /**
- * Pulls a {@link Project} above a {@link Correlate} from its RHS input.
- * Enforces nullability for join output.
- *
- * @param correlate Correlate
- * @param project the original project as the RHS input of the join
- * @param isCount Positions which are calls to the <code>COUNT</code>
- * aggregation function
- * @return the subtree with the new LogicalProject at the root
- */
- private RelNode aggregateCorrelatorOutput(Correlate correlate, LogicalProject project, Set<Integer> isCount) {
- final RelNode left = correlate.getLeft();
- final JoinRelType joinType = correlate.getJoinType().toJoinType();
-
- // now create the new project
- final List<Pair<RexNode, String>> newProjects = Lists.newArrayList();
-
- // Project everything from the LHS and then those from the original
- // project
- final List<RelDataTypeField> leftInputFields = left.getRowType().getFieldList();
-
- for (int i = 0; i < leftInputFields.size(); i++) {
- newProjects.add(RexInputRef.of2(i, leftInputFields));
- }
-
- // Marked where the projected expr is coming from so that the types will
- // become nullable for the original projections which are now coming out
- // of the nullable side of the OJ.
- boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight();
-
- for (Pair<RexNode, String> pair : project.getNamedProjects()) {
- RexNode newProjExpr = removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, isCount);
- newProjects.add(Pair.of(newProjExpr, pair.right));
- }
-
- return RelOptUtil.createProject(correlate, newProjects, false);
- }
-
- /**
- * Checks whether the correlations in projRel and filter are related to
- * the correlated variables provided by corRel.
- *
- * @param correlate Correlate
- * @param project The original Project as the RHS input of the join
- * @param filter Filter
- * @param correlatedJoinKeys Correlated join keys
- * @return true if filter and proj only references corVar provided by corRel
- */
- private boolean checkCorVars(LogicalCorrelate correlate, LogicalProject project, LogicalFilter filter, List<RexFieldAccess> correlatedJoinKeys) {
- if (filter != null) {
- assert correlatedJoinKeys != null;
-
- // check that all correlated refs in the filter condition are
- // used in the join(as field access).
- Set<Correlation> corVarInFilter = Sets.newHashSet(cm.mapRefRelToCorVar.get(filter));
-
- for (RexFieldAccess correlatedJoinKey : correlatedJoinKeys) {
- corVarInFilter.remove(cm.mapFieldAccessToCorVar.get(correlatedJoinKey));
- }
-
- if (!corVarInFilter.isEmpty()) {
- return false;
- }
-
- // Check that the correlated variables referenced in these
- // comparisons do come from the correlatorRel.
- corVarInFilter.addAll(cm.mapRefRelToCorVar.get(filter));
-
- for (Correlation corVar : corVarInFilter) {
- if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) {
- return false;
- }
- }
- }
-
- // if project has any correlated reference, make sure they are also
- // provided by the current correlate. They will be projected out of the LHS
- // of the correlate.
- if ((project != null) && cm.mapRefRelToCorVar.containsKey(project)) {
- for (Correlation corVar : cm.mapRefRelToCorVar.get(project)) {
- if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) {
- return false;
- }
- }
- }
-
- return true;
- }
-
- /**
- * Remove correlated variables from the tree at root corRel
- *
- * @param correlate Correlator
- */
- private void removeCorVarFromTree(LogicalCorrelate correlate) {
- if (cm.mapCorVarToCorRel.get(correlate.getCorrelationId()) == correlate) {
- cm.mapCorVarToCorRel.remove(correlate.getCorrelationId());
- }
- }
-
- /**
- * Projects all {@code input} output fields plus the additional expressions.
- *
- * @param input Input relational expression
- * @param additionalExprs Additional expressions and names
- * @return the new LogicalProject
- */
- private RelNode createProjectWithAdditionalExprs(RelNode input, List<Pair<RexNode, String>> additionalExprs) {
- final List<RelDataTypeField> fieldList = input.getRowType().getFieldList();
- List<Pair<RexNode, String>> projects = Lists.newArrayList();
- for (Ord<RelDataTypeField> field : Ord.zip(fieldList)) {
- projects.add(Pair.of((RexNode) rexBuilder.makeInputRef(field.e.getType(), field.i), field.e.getName()));
- }
- projects.addAll(additionalExprs);
- return RelOptUtil.createProject(input, projects, false);
- }
-
- /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */
- static Map<Integer, Integer> identityMap(int count) {
- ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
- for (int i = 0; i < count; i++) {
- builder.put(i, i);
- }
- return builder.build();
- }
-
- /**
- * Registers a relational expression and the relational expression it became
- * after decorrelation.
- */
- Frame register(RelNode rel, RelNode newRel, Map<Integer, Integer> oldToNewOutputPos, SortedMap<Correlation, Integer> corVarToOutputPos) {
- assert allLessThan(oldToNewOutputPos.keySet(), newRel.getRowType().getFieldCount(), Litmus.THROW);
- final Frame frame = new Frame(newRel, corVarToOutputPos, oldToNewOutputPos);
- map.put(rel, frame);
- return frame;
- }
-
- static boolean allLessThan(Collection<Integer> integers, int limit, Litmus ret) {
- for (int value : integers) {
- if (value >= limit) {
- return ret.fail("out of range; value: {}, limit: {}", value, limit);
- }
- }
- return ret.succeed();
- }
-
- private static RelNode stripHep(RelNode rel) {
- if (rel instanceof HepRelVertex) {
- HepRelVertex hepRelVertex = (HepRelVertex) rel;
- rel = hepRelVertex.getCurrentRel();
- }
- return rel;
- }
-
- //~ Inner Classes ----------------------------------------------------------
-
- /**
- * Shuttle that decorrelates.
- */
- private class DecorrelateRexShuttle extends RexShuttle {
- @Override
- public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
- int newInputOutputOffset = 0;
- for (RelNode input : currentRel.getInputs()) {
- final Frame frame = map.get(input);
-
- if (frame != null) {
- // try to find in this input rel the position of cor var
- final Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess);
-
- if (corVar != null) {
- Integer newInputPos = frame.corVarOutputPos.get(corVar);
- if (newInputPos != null) {
- // This input rel does produce the cor var referenced.
- // Assume fieldAccess has the correct type info.
- return new RexInputRef(newInputPos + newInputOutputOffset, fieldAccess.getType());
- }
- }
-
- // this input rel does not produce the cor var needed
- newInputOutputOffset += frame.r.getRowType().getFieldCount();
- } else {
- // this input rel is not rewritten
- newInputOutputOffset += input.getRowType().getFieldCount();
- }
- }
- return fieldAccess;
- }
-
- @Override
- public RexNode visitInputRef(RexInputRef inputRef) {
- return getNewForOldInputRef(inputRef);
- }
- }
-
- /**
- * Shuttle that removes correlations.
- */
- private class RemoveCorrelationRexShuttle extends RexShuttle {
- final RexBuilder rexBuilder;
- final RelDataTypeFactory typeFactory;
- final boolean projectPulledAboveLeftCorrelator;
- final RexInputRef nullIndicator;
- final ImmutableSet<Integer> isCount;
-
- public RemoveCorrelationRexShuttle(RexBuilder rexBuilder, boolean projectPulledAboveLeftCorrelator, RexInputRef nullIndicator, Set<Integer> isCount) {
- this.projectPulledAboveLeftCorrelator = projectPulledAboveLeftCorrelator;
- this.nullIndicator = nullIndicator; // may be null
- this.isCount = ImmutableSet.copyOf(isCount);
- this.rexBuilder = rexBuilder;
- this.typeFactory = rexBuilder.getTypeFactory();
- }
-
- private RexNode createCaseExpression(RexInputRef nullInputRef, RexLiteral lit, RexNode rexNode) {
- RexNode[] caseOperands = new RexNode[3];
-
- // Construct a CASE expression to handle the null indicator.
- //
- // This also covers the case where a left correlated subquery
- // projects fields from outer relation. Since LOJ cannot produce
- // nulls on the LHS, the projection now need to make a nullable LHS
- // reference using a nullability indicator. If this this indicator
- // is null, it means the subquery does not produce any value. As a
- // result, any RHS ref by this usbquery needs to produce null value.
-
- // WHEN indicator IS NULL
- caseOperands[0] = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, new RexInputRef(nullInputRef.getIndex(), typeFactory.createTypeWithNullability(nullInputRef.getType(), true)));
-
- // THEN CAST(NULL AS newInputTypeNullable)
- caseOperands[1] = rexBuilder.makeCast(typeFactory.createTypeWithNullability(rexNode.getType(), true), lit);
-
- // ELSE cast (newInput AS newInputTypeNullable) END
- caseOperands[2] = rexBuilder.makeCast(typeFactory.createTypeWithNullability(rexNode.getType(), true), rexNode);
-
- return rexBuilder.makeCall(SqlStdOperatorTable.CASE, caseOperands);
- }
-
- @Override
- public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
- if (cm.mapFieldAccessToCorVar.containsKey(fieldAccess)) {
- // if it is a corVar, change it to be input ref.
- Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess);
-
- // corVar offset should point to the leftInput of currentRel,
- // which is the Correlator.
- RexNode newRexNode = new RexInputRef(corVar.field, fieldAccess.getType());
-
- if (projectPulledAboveLeftCorrelator && (nullIndicator != null)) {
- // need to enforce nullability by applying an additional
- // cast operator over the transformed expression.
- newRexNode = createCaseExpression(nullIndicator, rexBuilder.constantNull(), newRexNode);
- }
- return newRexNode;
- }
- return fieldAccess;
- }
-
- @Override
- public RexNode visitInputRef(RexInputRef inputRef) {
- if (currentRel instanceof LogicalCorrelate) {
- // if this rel references corVar
- // and now it needs to be rewritten
- // it must have been pulled above the Correlator
- // replace the input ref to account for the LHS of the
- // Correlator
- final int leftInputFieldCount = ((LogicalCorrelate) currentRel).getLeft().getRowType().getFieldCount();
- RelDataType newType = inputRef.getType();
-
- if (projectPulledAboveLeftCorrelator) {
- newType = typeFactory.createTypeWithNullability(newType, true);
- }
-
- int pos = inputRef.getIndex();
- RexInputRef newInputRef = new RexInputRef(leftInputFieldCount + pos, newType);
-
- if ((isCount != null) && isCount.contains(pos)) {
- return createCaseExpression(newInputRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO), newInputRef);
- } else {
- return newInputRef;
- }
- }
- return inputRef;
- }
-
- @Override
- public RexNode visitLiteral(RexLiteral literal) {
- // Use nullIndicator to decide whether to project null.
- // Do nothing if the literal is null.
- if (!RexUtil.isNull(literal) && projectPulledAboveLeftCorrelator && (nullIndicator != null)) {
- return createCaseExpression(nullIndicator, rexBuilder.constantNull(), literal);
- }
- return literal;
- }
-
- @Override
- public RexNode visitCall(final RexCall call) {
- RexNode newCall;
-
- boolean[] update = {false};
- List<RexNode> clonedOperands = visitList(call.operands, update);
- if (update[0]) {
- SqlOperator operator = call.getOperator();
-
- boolean isSpecialCast = false;
- if (operator instanceof SqlFunction) {
- SqlFunction function = (SqlFunction) operator;
- if (function.getKind() == SqlKind.CAST) {
- if (call.operands.size() < 2) {
- isSpecialCast = true;
- }
- }
- }
-
- final RelDataType newType;
- if (!isSpecialCast) {
- // TODO: ideally this only needs to be called if the result
- // type will also change. However, since that requires
- // support from type inference rules to tell whether a rule
- // decides return type based on input types, for now all
- // operators will be recreated with new type if any operand
- // changed, unless the operator has "built-in" type.
- newType = rexBuilder.deriveReturnType(operator, clonedOperands);
- } else {
- // Use the current return type when creating a new call, for
- // operators with return type built into the operator
- // definition, and with no type inference rules, such as
- // cast function with less than 2 operands.
-
- // TODO: Comments in RexShuttle.visitCall() mention other
- // types in this category. Need to resolve those together
- // and preferably in the base class RexShuttle.
- newType = call.getType();
- }
- newCall = rexBuilder.makeCall(newType, operator, clonedOperands);
- } else {
- newCall = call;
- }
-
- if (projectPulledAboveLeftCorrelator && (nullIndicator != null)) {
- return createCaseExpression(nullIndicator, rexBuilder.constantNull(), newCall);
- }
- return newCall;
- }
- }
-
- /**
- * Rule to remove single_value rel. For cases like
- * <p>
- * <blockquote>AggRel single_value proj/filter/agg/ join on unique LHS key
- * AggRel single group</blockquote>
- */
- private final class RemoveSingleAggregateRule extends RelOptRule {
- public RemoveSingleAggregateRule() {
- super(operand(LogicalAggregate.class, operand(LogicalProject.class, operand(LogicalAggregate.class, any()))));
- }
-
- public void onMatch(RelOptRuleCall call) {
- LogicalAggregate singleAggregate = call.rel(0);
- LogicalProject project = call.rel(1);
- LogicalAggregate aggregate = call.rel(2);
-
- // check singleAggRel is single_value agg
- if ((!singleAggregate.getGroupSet().isEmpty()) || (singleAggregate.getAggCallList().size() != 1) || !(singleAggregate.getAggCallList().get(0).getAggregation() instanceof SqlSingleValueAggFunction)) {
- return;
- }
-
- // check projRel only projects one expression
- // check this project only projects one expression, i.e. scalar
- // subqueries.
- List<RexNode> projExprs = project.getProjects();
- if (projExprs.size() != 1) {
- return;
- }
-
- // check the input to projRel is an aggregate on the entire input
- if (!aggregate.getGroupSet().isEmpty()) {
- return;
- }
-
- // singleAggRel produces a nullable type, so create the new
- // projection that casts proj expr to a nullable type.
- final RelOptCluster cluster = project.getCluster();
- RelNode newProject = RelOptUtil.createProject(aggregate, ImmutableList.of(rexBuilder.makeCast(cluster.getTypeFactory().createTypeWithNullability(projExprs.get(0).getType(), true), projExprs.get(0))), null);
- call.transformTo(newProject);
- }
- }
-
- /**
- * Planner rule that removes correlations for scalar projects.
- */
- private final class RemoveCorrelationForScalarProjectRule extends RelOptRule {
- public RemoveCorrelationForScalarProjectRule() {
- super(operand(LogicalCorrelate.class, operand(RelNode.class, any()), operand(LogicalAggregate.class, operand(LogicalProject.class, operand(RelNode.class, any())))));
- }
-
- public void onMatch(RelOptRuleCall call) {
- final LogicalCorrelate correlate = call.rel(0);
- final RelNode left = call.rel(1);
- final LogicalAggregate aggregate = call.rel(2);
- final LogicalProject project = call.rel(3);
- RelNode right = call.rel(4);
- final RelOptCluster cluster = correlate.getCluster();
-
- setCurrent(call.getPlanner().getRoot(), correlate);
-
- // Check for this pattern.
- // The pattern matching could be simplified if rules can be applied
- // during decorrelation.
- //
- // CorrelateRel(left correlation, condition = true)
- // LeftInputRel
- // LogicalAggregate (groupby (0) single_value())
- // LogicalProject-A (may reference coVar)
- // RightInputRel
- final JoinRelType joinType = correlate.getJoinType().toJoinType();
-
- // corRel.getCondition was here, however Correlate was updated so it
- // never includes a join condition. The code was not modified for brevity.
- RexNode joinCond = rexBuilder.makeLiteral(true);
- if ((joinType != JoinRelType.LEFT) || (joinCond != rexBuilder.makeLiteral(true))) {
- return;
- }
-
- // check that the agg is of the following type:
- // doing a single_value() on the entire input
- if ((!aggregate.getGroupSet().isEmpty()) || (aggregate.getAggCallList().size() != 1) || !(aggregate.getAggCallList().get(0).getAggregation() instanceof SqlSingleValueAggFunction)) {
- return;
- }
-
- // check this project only projects one expression, i.e. scalar
- // subqueries.
- if (project.getProjects().size() != 1) {
- return;
- }
-
- int nullIndicatorPos;
-
- if ((right instanceof LogicalFilter) && cm.mapRefRelToCorVar.containsKey(right)) {
- // rightInputRel has this shape:
- //
- // LogicalFilter (references corvar)
- // FilterInputRel
-
- // If rightInputRel is a filter and contains correlated
- // reference, make sure the correlated keys in the filter
- // condition forms a unique key of the RHS.
-
- LogicalFilter filter = (LogicalFilter) right;
- right = filter.getInput();
-
- assert right instanceof HepRelVertex;
- right = ((HepRelVertex) right).getCurrentRel();
-
- // check filter input contains no correlation
- if (RelOptUtil.getVariablesUsed(right).size() > 0) {
- return;
- }
-
- // extract the correlation out of the filter
-
- // First breaking up the filter conditions into equality
- // comparisons between rightJoinKeys(from the original
- // filterInputRel) and correlatedJoinKeys. correlatedJoinKeys
- // can be expressions, while rightJoinKeys need to be input
- // refs. These comparisons are AND'ed together.
- List<RexNode> tmpRightJoinKeys = Lists.newArrayList();
- List<RexNode> correlatedJoinKeys = Lists.newArrayList();
- RelOptUtil.splitCorrelatedFilterCondition(filter, tmpRightJoinKeys, correlatedJoinKeys, false);
-
- // check that the columns referenced in these comparisons form
- // an unique key of the filterInputRel
- final List<RexInputRef> rightJoinKeys = new ArrayList<>();
- for (RexNode key : tmpRightJoinKeys) {
- assert key instanceof RexInputRef;
- rightJoinKeys.add((RexInputRef) key);
- }
-
- // check that the columns referenced in rightJoinKeys form an
- // unique key of the filterInputRel
- if (rightJoinKeys.isEmpty()) {
- return;
- }
-
- // The join filters out the nulls. So, it's ok if there are
- // nulls in the join keys.
- final RelMetadataQuery mq = RelMetadataQuery.instance();
- if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, right, rightJoinKeys)) {
- SQL2REL_LOGGER.debug("{} are not unique keys for {}", rightJoinKeys.toString(), right.toString());
- return;
- }
-
- RexUtil.FieldAccessFinder visitor = new RexUtil.FieldAccessFinder();
- RexUtil.apply(visitor, correlatedJoinKeys, null);
- List<RexFieldAccess> correlatedKeyList = visitor.getFieldAccessList();
-
- if (!checkCorVars(correlate, project, filter, correlatedKeyList)) {
- return;
- }
-
- // Change the plan to this structure.
- // Note that the aggregateRel is removed.
- //
- // LogicalProject-A' (replace corvar to input ref from the LogicalJoin)
- // LogicalJoin (replace corvar to input ref from LeftInputRel)
- // LeftInputRel
- // RightInputRel(oreviously FilterInputRel)
-
- // Change the filter condition into a join condition
- joinCond = removeCorrelationExpr(filter.getCondition(), false);
-
- nullIndicatorPos = left.getRowType().getFieldCount() + rightJoinKeys.get(0).getIndex();
- } else if (cm.mapRefRelToCorVar.containsKey(project)) {
- // check filter input contains no correlation
- if (RelOptUtil.getVariablesUsed(right).size() > 0) {
- return;
- }
-
- if (!checkCorVars(correlate, project, null, null)) {
- return;
- }
-
- // Change the plan to this structure.
- //
- // LogicalProject-A' (replace corvar to input ref from LogicalJoin)
- // LogicalJoin (left, condition = true)
- // LeftInputRel
- // LogicalAggregate(groupby(0), single_value(0), s_v(1)....)
- // LogicalProject-B (everything from input plus literal true)
- // ProjInputRel
-
- // make the new projRel to provide a null indicator
- right = createProjectWithAdditionalExprs(right, ImmutableList.of(Pair.<RexNode, String>of(rexBuilder.makeLiteral(true), "nullIndicator")));
-
- // make the new aggRel
- right = RelOptUtil.createSingleValueAggRel(cluster, right);
-
- // The last field:
- // single_value(true)
- // is the nullIndicator
- nullIndicatorPos = left.getRowType().getFieldCount() + right.getRowType().getFieldCount() - 1;
- } else {
- return;
- }
-
- // make the new join rel
- LogicalJoin join = LogicalJoin.create(left, right, joinCond, ImmutableSet.<CorrelationId>of(), joinType);
-
- RelNode newProject = projectJoinOutputWithNullability(join, project, nullIndicatorPos);
-
- call.transformTo(newProject);
-
- removeCorVarFromTree(correlate);
- }
- }
-
- /**
- * Planner rule that removes correlations for scalar aggregates.
- */
- private final class RemoveCorrelationForScalarAggregateRule extends RelOptRule {
- public RemoveCorrelationForScalarAggregateRule() {
- super(operand(LogicalCorrelate.class, operand(RelNode.class, any()), operand(LogicalProject.class, operand(LogicalAggregate.class, null, Aggregate.IS_SIMPLE, operand(LogicalProject.class, operand(RelNode.class, any()))))));
- }
-
- public void onMatch(RelOptRuleCall call) {
- final LogicalCorrelate correlate = call.rel(0);
- final RelNode left = call.rel(1);
- final LogicalProject aggOutputProject = call.rel(2);
- final LogicalAggregate aggregate = call.rel(3);
- final LogicalProject aggInputProject = call.rel(4);
- RelNode right = call.rel(5);
- final RelOptCluster cluster = correlate.getCluster();
-
- setCurrent(call.getPlanner().getRoot(), correlate);
-
- // check for this pattern
- // The pattern matching could be simplified if rules can be applied
- // during decorrelation,
- //
- // CorrelateRel(left correlation, condition = true)
- // LeftInputRel
- // LogicalProject-A (a RexNode)
- // LogicalAggregate (groupby (0), agg0(), agg1()...)
- // LogicalProject-B (references coVar)
- // rightInputRel
-
- // check aggOutputProject projects only one expression
- final List<RexNode> aggOutputProjects = aggOutputProject.getProjects();
- if (aggOutputProjects.size() != 1) {
- return;
- }
-
- final JoinRelType joinType = correlate.getJoinType().toJoinType();
- // corRel.getCondition was here, however Correlate was updated so it
- // never includes a join condition. The code was not modified for brevity.
- RexNode joinCond = rexBuilder.makeLiteral(true);
- if ((joinType != JoinRelType.LEFT) || (joinCond != rexBuilder.makeLiteral(true))) {
- return;
- }
-
- // check that the agg is on the entire input
- if (!aggregate.getGroupSet().isEmpty()) {
- return;
- }
-
- final List<RexNode> aggInputProjects = aggInputProject.getProjects();
-
- final List<AggregateCall> aggCalls = aggregate.getAggCallList();
- final Set<Integer> isCountStar = Sets.newHashSet();
-
- // mark if agg produces count(*) which needs to reference the
- // nullIndicator after the transformation.
- int k = -1;
- for (AggregateCall aggCall : aggCalls) {
- ++k;
- if ((aggCall.getAggregation() instanceof SqlCountAggFunction) && (aggCall.getArgList().size() == 0)) {
- isCountStar.add(k);
- }
- }
-
- if ((right instanceof LogicalFilter) && cm.mapRefRelToCorVar.containsKey(right)) {
- // rightInputRel has this shape:
- //
- // LogicalFilter (references corvar)
- // FilterInputRel
- LogicalFilter filter = (LogicalFilter) right;
- right = filter.getInput();
-
- assert right instanceof HepRelVertex;
- right = ((HepRelVertex) right).getCurrentRel();
-
- // check filter input contains no correlation
- if (RelOptUtil.getVariablesUsed(right).size() > 0) {
- return;
- }
-
- // check filter condition type First extract the correlation out
- // of the filter
-
- // First breaking up the filter conditions into equality
- // comparisons between rightJoinKeys(from the original
- // filterInputRel) and correlatedJoinKeys. correlatedJoinKeys
- // can only be RexFieldAccess, while rightJoinKeys can be
- // expressions. These comparisons are AND'ed together.
- List<RexNode> rightJoinKeys = Lists.newArrayList();
- List<RexNode> tmpCorrelatedJoinKeys = Lists.newArrayList();
- RelOptUtil.splitCorrelatedFilterCondition(filter, rightJoinKeys, tmpCorrelatedJoinKeys, true);
-
- // make sure the correlated reference forms a unique key check
- // that the columns referenced in these comparisons form an
- // unique key of the leftInputRel
- List<RexFieldAccess> correlatedJoinKeys = Lists.newArrayList();
- List<RexInputRef> correlatedInputRefJoinKeys = Lists.newArrayList();
- for (RexNode joinKey : tmpCorrelatedJoinKeys) {
- assert joinKey instanceof RexFieldAccess;
- correlatedJoinKeys.add((RexFieldAccess) joinKey);
- RexNode correlatedInputRef = removeCorrelationExpr(joinKey, false);
- assert correlatedInputRef instanceof RexInputRef;
- correlatedInputRefJoinKeys.add((RexInputRef) correlatedInputRef);
- }
-
- // check that the columns referenced in rightJoinKeys form an
- // unique key of the filterInputRel
- if (correlatedInputRefJoinKeys.isEmpty()) {
- return;
- }
-
- // The join filters out the nulls. So, it's ok if there are
- // nulls in the join keys.
- final RelMetadataQuery mq = RelMetadataQuery.instance();
- if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, left, correlatedInputRefJoinKeys)) {
- SQL2REL_LOGGER.debug("{} are not unique keys for {}", correlatedJoinKeys.toString(), left.toString());
- return;
- }
-
- // check cor var references are valid
- if (!checkCorVars(correlate, aggInputProject, filter, correlatedJoinKeys)) {
- return;
- }
-
- // Rewrite the above plan:
- //
- // CorrelateRel(left correlation, condition = true)
- // LeftInputRel
- // LogicalProject-A (a RexNode)
- // LogicalAggregate (groupby(0), agg0(),agg1()...)
- // LogicalProject-B (may reference coVar)
- // LogicalFilter (references corVar)
- // RightInputRel (no correlated reference)
- //
-
- // to this plan:
- //
- // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr)
- // LogicalAggregate (groupby(all left input refs)
- // agg0(rewritten expression),
- // agg1()...)
- // LogicalProject-B' (rewriten original projected exprs)
- // LogicalJoin(replace corvar w/ input ref from LeftInputRel)
- // LeftInputRel
- // RightInputRel
- //
-
- // In the case where agg is count(*) or count($corVar), it is
- // changed to count(nullIndicator).
- // Note: any non-nullable field from the RHS can be used as
- // the indicator however a "true" field is added to the
- // projection list from the RHS for simplicity to avoid
- // searching for non-null fields.
- //
- // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr)
- // LogicalAggregate (groupby(all left input refs),
- // count(nullIndicator), other aggs...)
- // LogicalProject-B' (all left input refs plus
- // the rewritten original projected exprs)
- // LogicalJoin(replace corvar to input ref from LeftInputRel)
- // LeftInputRel
- // LogicalProject (everything from RightInputRel plus
- // the nullIndicator "true")
- // RightInputRel
- //
-
- // first change the filter condition into a join condition
- joinCond = removeCorrelationExpr(filter.getCondition(), false);
- } else if (cm.mapRefRelToCorVar.containsKey(aggInputProject)) {
- // check rightInputRel contains no correlation
- if (RelOptUtil.getVariablesUsed(right).size() > 0) {
- return;
- }
-
- // check cor var references are valid
- if (!checkCorVars(correlate, aggInputProject, null, null)) {
- return;
- }
-
- int nFields = left.getRowType().getFieldCount();
- ImmutableBitSet allCols = ImmutableBitSet.range(nFields);
-
- // leftInputRel contains unique keys
- // i.e. each row is distinct and can group by on all the left
- // fields
- final RelMetadataQuery mq = RelMetadataQuery.instance();
- if (!RelMdUtil.areColumnsDefinitelyUnique(mq, left, allCols)) {
- SQL2REL_LOGGER.debug("There are no unique keys for {}", left);
- return;
- }
- //
- // Rewrite the above plan:
- //
- // CorrelateRel(left correlation, condition = true)
- // LeftInputRel
- // LogicalProject-A (a RexNode)
- // LogicalAggregate (groupby(0), agg0(), agg1()...)
- // LogicalProject-B (references coVar)
- // RightInputRel (no correlated reference)
- //
-
- // to this plan:
- //
- // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr)
- // LogicalAggregate (groupby(all left input refs)
- // agg0(rewritten expression),
- // agg1()...)
- // LogicalProject-B' (rewriten original projected exprs)
- // LogicalJoin (LOJ cond = true)
- // LeftInputRel
- // RightInputRel
- //
-
- // In the case where agg is count($corVar), it is changed to
- // count(nullIndicator).
- // Note: any non-nullable field from the RHS can be used as
- // the indicator however a "true" field is added to the
- // projection list from the RHS for simplicity to avoid
- // searching for non-null fields.
- //
- // LogicalProject-A' (all gby keys + rewritten nullable ProjExpr)
- // LogicalAggregate (groupby(all left input refs),
- // count(nullIndicator), other aggs...)
- // LogicalProject-B' (all left input refs plus
- // the rewritten original projected exprs)
- // LogicalJoin(replace corvar to input ref from LeftInputRel)
- // LeftInputRel
- // LogicalProject (everything from RightInputRel plus
- // the nullIndicator "true")
- // RightInputRel
- } else {
- return;
- }
-
- RelDataType leftInputFieldType = left.getRowType();
- int leftInputFieldCount = leftInputFieldType.getFieldCount();
- int joinOutputProjExprCount = leftInputFieldCount + aggInputProjects.size() + 1;
-
- right = createProjectWithAdditionalExprs(right, ImmutableList.of(Pair.<RexNode, String>of(rexBuilder.makeLiteral(true), "nullIndicator")));
-
- LogicalJoin join = LogicalJoin.create(left, right, joinCond, ImmutableSet.<CorrelationId>of(), joinType);
-
- // To the consumer of joinOutputProjRel, nullIndicator is located
- // at the end
- int nullIndicatorPos = join.getRowType().getFieldCount() - 1;
-
- RexInputRef nullIndicator = new RexInputRef(nullIndicatorPos, cluster.getTypeFactory().createTypeWithNullability(join.getRowType().getFieldList().get(nullIndicatorPos).getType(), true));
-
- // first project all group-by keys plus the transformed agg input
- List<RexNode> joinOutputProjects = Lists.newArrayList();
-
- // LOJ Join preserves LHS types
- for (int i = 0; i < leftInputFieldCount; i++) {
- joinOutputProjects.add(rexBuilder.makeInputRef(leftInputFieldType.getFieldList().get(i).getType(), i));
- }
-
- for (RexNode aggInputProjExpr : aggInputProjects) {
- joinOutputProjects.add(removeCorrelationExpr(aggInputProjExpr, joinType.generatesNullsOnRight(), nullIndicator));
- }
-
- joinOutputProjects.add(rexBuilder.makeInputRef(join, nullIndicatorPos));
-
- RelNode joinOutputProject = RelOptUtil.createProject(join, joinOutputProjects, null);
-
- // nullIndicator is now at a different location in the output of
- // the join
- nullIndicatorPos = joinOutputProjExprCount - 1;
-
- final int groupCount = leftInputFieldCount;
-
- List<AggregateCall> newAggCalls = Lists.newArrayList();
- k = -1;
- for (AggregateCall aggCall : aggCalls) {
- ++k;
- final List<Integer> argList;
-
- if (isCountStar.contains(k)) {
- // this is a count(*), transform it to count(nullIndicator)
- // the null indicator is located at the end
- argList = Collections.singletonList(nullIndicatorPos);
- } else {
- argList = Lists.newArrayList();
-
- for (int aggArg : aggCall.getArgList()) {
- argList.add(aggArg + groupCount);
- }
- }
-
- int filterArg = aggCall.filterArg < 0 ? aggCall.filterArg : aggCall.filterArg + groupCount;
- newAggCalls.add(aggCall.adaptTo(joinOutputProject, argList, filterArg, aggregate.getGroupCount(), groupCount));
- }
-
- ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount);
- LogicalAggregate newAggregate = LogicalAggregate.create(joinOutputProject, false, groupSet, null, newAggCalls);
-
- List<RexNode> newAggOutputProjectList = Lists.newArrayList();
- for (int i : groupSet) {
- newAggOutputProjectList.add(rexBuilder.makeInputRef(newAggregate, i));
- }
-
- RexNode newAggOutputProjects = removeCorrelationExpr(aggOutputProjects.get(0), false);
- newAggOutputProjectList.add(rexBuilder.makeCast(cluster.getTypeFactory().createTypeWithNullability(newAggOutputProjects.getType(), true), newAggOutputProjects));
-
- RelNode newAggOutputProject = RelOptUtil.createProject(newAggregate, newAggOutputProjectList, null);
-
- call.transformTo(newAggOutputProject);
-
- removeCorVarFromTree(correlate);
- }
- }
-
- // REVIEW jhyde 29-Oct-2007: This rule is non-static, depends on the state
- // of members in FlinkRelDecorrelator, and has side-effects in the decorrelator.
- // This breaks the contract of a planner rule, and the rule will not be
- // reusable in other planners.
-
- // REVIEW jvs 29-Oct-2007: Shouldn't it also be incorporating
- // the flavor attribute into the description?
-
- /**
- * Planner rule that adjusts projects when counts are added.
- */
- private final class AdjustProjectForCountAggregateRule extends RelOptRule {
- final boolean flavor;
-
- public AdjustProjectForCountAggregateRule(boolean flavor) {
- super(flavor ? operand(LogicalCorrelate.class, operand(RelNode.class, any()), operand(LogicalProject.class, operand(LogicalAggregate.class, any()))) : operand(LogicalCorrelate.class, operand(RelNode.class, any()), operand(LogicalAggregate.class, any())));
- this.flavor = flavor;
- }
-
- public void onMatch(RelOptRuleCall call) {
- final LogicalCorrelate correlate = call.rel(0);
- final RelNode left = call.rel(1);
- final LogicalProject aggOutputProject;
- final LogicalAggregate aggregate;
- if (flavor) {
- aggOutputProject = call.rel(2);
- aggregate = call.rel(3);
- } else {
- aggregate = call.rel(2);
-
- // Create identity projection
- final List<Pair<RexNode, String>> projects = Lists.newArrayList();
- final List<RelDataTypeField> fields = aggregate.getRowType().getFieldList();
- for (int i = 0; i < fields.size(); i++) {
- projects.add(RexInputRef.of2(projects.size(), fields));
- }
- aggOutputProject = (LogicalProject) RelOptUtil.createProject(aggregate, projects, false);
- }
- onMatch2(call, correlate, left, aggOutputProject, aggregate);
- }
-
- private void onMatch2(RelOptRuleCall call, LogicalCorrelate correlate, RelNode leftInput, LogicalProject aggOutputProject, LogicalAggregate aggregate) {
- if (generatedCorRels.contains(correlate)) {
- // This correlator was generated by a previous invocation of
- // this rule. No further work to do.
- return;
- }
-
- setCurrent(call.getPlanner().getRoot(), correlate);
-
- // check for this pattern
- // The pattern matching could be simplified if rules can be applied
- // during decorrelation,
- //
- // CorrelateRel(left correlation, condition = true)
- // LeftInputRel
- // LogicalProject-A (a RexNode)
- // LogicalAggregate (groupby (0), agg0(), agg1()...)
-
- // check aggOutputProj projects only one expression
- List<RexNode> aggOutputProjExprs = aggOutputProject.getProjects();
- if (aggOutputProjExprs.size() != 1) {
- return;
- }
-
- JoinRelType joinType = correlate.getJoinType().toJoinType();
- // corRel.getCondition was here, however Correlate was updated so it
- // never includes a join condition. The code was not modified for brevity.
- RexNode joinCond = rexBuilder.makeLiteral(true);
- if ((joinType != JoinRelType.LEFT) || (joinCond != rexBuilder.makeLiteral(true))) {
- return;
- }
-
- // check that the agg is on the entire input
- if (!aggregate.getGroupSet().isEmpty()) {
- return;
- }
-
- List<AggregateCall> aggCalls = aggregate.getAggCallList();
- Set<Integer> isCount = Sets.newHashSet();
-
- // remember the count() positions
- int i = -1;
- for (AggregateCall aggCall : aggCalls) {
- ++i;
- if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
- isCount.add(i);
- }
- }
-
- // now rewrite the plan to
- //
- // Project-A' (all LHS plus transformed original projections,
- // replacing references to count() with case statement)
- // Correlator(left correlation, condition = true)
- // LeftInputRel
- // LogicalAggregate (groupby (0), agg0(), agg1()...)
- //
- LogicalCorrelate newCorrelate = LogicalCorrelate.create(leftInput, aggregate, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
-
- // remember this rel so we don't fire rule on it again
- // REVIEW jhyde 29-Oct-2007: rules should not save state; rule
- // should recognize patterns where it does or does not need to do
- // work
- generatedCorRels.add(newCorrelate);
-
- // need to update the mapCorVarToCorRel Update the output position
- // for the cor vars: only pass on the cor vars that are not used in
- // the join key.
- if (cm.mapCorVarToCorRel.get(correlate.getCorrelationId()) == correlate) {
- cm.mapCorVarToCorRel.put(correlate.getCorrelationId(), newCorrelate);
- }
-
- RelNode newOutput = aggregateCorrelatorOutput(newCorrelate, aggOutputProject, isCount);
-
- call.transformTo(newOutput);
- }
- }
-
- /**
- * {@code Correlation} here represents a unique reference to a correlation
- * field.
- * For instance, if a RelNode references emp.name multiple times, it would
- * result in multiple {@code Correlation} objects that differ just in
- * {@link Correlation#uniqueKey}.
- */
- static class Correlation implements Comparable<Correlation> {
- public final int uniqueKey;
- public final CorrelationId corr;
- public final int field;
-
- Correlation(CorrelationId corr, int field, int uniqueKey) {
- this.corr = corr;
- this.field = field;
- this.uniqueKey = uniqueKey;
- }
-
- public int compareTo(Correlation o) {
- int c = corr.compareTo(o.corr);
- if (c != 0) {
- return c;
- }
- c = Integer.compare(field, o.field);
- if (c != 0) {
- return c;
- }
- return Integer.compare(uniqueKey, o.uniqueKey);
- }
- }
-
- /**
- * A map of the locations of
- * {@link LogicalCorrelate}
- * in a tree of {@link RelNode}s.
- * <p>
- * <p>It is used to drive the decorrelation process.
- * Treat it as immutable; rebuild if you modify the tree.
- * <p>
- * <p>There are three maps:<ol>
- * <p>
- * <li>mapRefRelToCorVars map a rel node to the correlated variables it
- * references;
- * <p>
- * <li>mapCorVarToCorRel maps a correlated variable to the correlatorRel
- * providing it;
- * <p>
- * <li>mapFieldAccessToCorVar maps a rex field access to
- * the cor var it represents. Because typeFlattener does not clone or
- * modify a correlated field access this map does not need to be
- * updated.
- * <p>
- * </ol>
- */
- private static class CorelMap {
- private final Multimap<RelNode, Correlation> mapRefRelToCorVar;
- private final SortedMap<CorrelationId, RelNode> mapCorVarToCorRel;
- private final Map<RexFieldAccess, Correlation> mapFieldAccessToCorVar;
-
- // TODO: create immutable copies of all maps
- private CorelMap(Multimap<RelNode, Correlation> mapRefRelToCorVar, SortedMap<CorrelationId, RelNode> mapCorVarToCorRel, Map<RexFieldAccess, Correlation> mapFieldAccessToCorVar) {
- this.mapRefRelToCorVar = mapRefRelToCorVar;
- this.mapCorVarToCorRel = mapCorVarToCorRel;
- this.mapFieldAccessToCorVar = ImmutableMap.copyOf(mapFieldAccessToCorVar);
- }
-
- @Override
- public String toString() {
- return "mapRefRelToCorVar=" + mapRefRelToCorVar + "\nmapCorVarToCorRel=" + mapCorVarToCorRel + "\nmapFieldAccessToCorVar=" + mapFieldAccessToCorVar + "\n";
- }
-
- @Override
- public boolean equals(Object obj) {
- return obj == this || obj instanceof CorelMap && mapRefRelToCorVar.equals(((CorelMap) obj).mapRefRelToCorVar) && mapCorVarToCorRel.equals(((CorelMap) obj).mapCorVarToCorRel) && mapFieldAccessToCorVar.equals(((CorelMap) obj).mapFieldAccessToCorVar);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(mapRefRelToCorVar, mapCorVarToCorRel, mapFieldAccessToCorVar);
- }
-
- /**
- * Creates a CorelMap with given contents.
- */
- public static CorelMap of(SortedSetMultimap<RelNode, Correlation> mapRefRelToCorVar, SortedMap<CorrelationId, RelNode> mapCorVarToCorRel, Map<RexFieldAccess, Correlation> mapFieldAccessToCorVar) {
- return new CorelMap(mapRefRelToCorVar, mapCorVarToCorRel, mapFieldAccessToCorVar);
- }
-
- /**
- * Returns whether there are any correlating variables in this statement.
- *
- * @return whether there are any correlating variables
- */
- public boolean hasCorrelation() {
- return !mapCorVarToCorRel.isEmpty();
- }
- }
-
- /**
- * Builds a {@link FlinkRelDecorrelator.CorelMap}.
- */
- private static class CorelMapBuilder extends RelShuttleImpl {
- final SortedMap<CorrelationId, RelNode> mapCorVarToCorRel = new TreeMap<>();
-
- final SortedSetMultimap<RelNode, Correlation> mapRefRelToCorVar = Multimaps.newSortedSetMultimap(Maps.<RelNode, Collection<Correlation>>newHashMap(), new Supplier<TreeSet<Correlation>>() {
- public TreeSet<Correlation> get() {
- Bug.upgrade("use MultimapBuilder when we're on Guava-16");
- return Sets.newTreeSet();
- }
- });
-
- final Map<RexFieldAccess, Correlation> mapFieldAccessToCorVar = new HashMap<>();
-
- final Holder<Integer> offset = Holder.of(0);
- int corrIdGenerator = 0;
-
- final Deque<RelNode> stack = new ArrayDeque<>();
-
- /**
- * Creates a CorelMap by iterating over a {@link RelNode} tree.
- */
- CorelMap build(RelNode rel) {
- stripHep(rel).accept(this);
- return new CorelMap(mapRefRelToCorVar, mapCorVarToCorRel, mapFieldAccessToCorVar);
- }
-
- @Override
- public RelNode visit(LogicalJoin join) {
- try {
- stack.push(join);
- join.getCondition().accept(rexVisitor(join));
- } finally {
- stack.pop();
- }
- return visitJoin(join);
- }
-
- @Override
- protected RelNode visitChild(RelNode parent, int i, RelNode input) {
- return super.visitChild(parent, i, stripHep(input));
- }
-
- @Override
- public RelNode visit(LogicalCorrelate correlate) {
- mapCorVarToCorRel.put(correlate.getCorrelationId(), correlate);
- return visitJoin(correlate);
- }
-
- private RelNode visitJoin(BiRel join) {
- final int x = offset.get();
- visitChild(join, 0, join.getLeft());
- offset.set(x + join.getLeft().getRowType().getFieldCount());
- visitChild(join, 1, join.getRight());
- offset.set(x);
- return join;
- }
-
- @Override
- public RelNode visit(final LogicalFilter filter) {
- try {
- stack.push(filter);
- filter.getCondition().accept(rexVisitor(filter));
- } finally {
- stack.pop();
- }
- return super.visit(filter);
- }
-
- @Override
- public RelNode visit(LogicalProject project) {
- try {
- stack.push(project);
- for (RexNode node : project.getProjects()) {
- node.accept(rexVisitor(project));
- }
- } finally {
- stack.pop();
- }
- return super.visit(project);
- }
-
- private RexVisitorImpl<Void> rexVisitor(final RelNode rel) {
- return new RexVisitorImpl<Void>(true) {
- @Override
- public Void visitFieldAccess(RexFieldAccess fieldAccess) {
- final RexNode ref = fieldAccess.getReferenceExpr();
- if (ref instanceof RexCorrelVariable) {
- final RexCorrelVariable var = (RexCorrelVariable) ref;
- final Correlation correlation = new Correlation(var.id, fieldAccess.getField().getIndex(), corrIdGenerator++);
- mapFieldAccessToCorVar.put(fieldAccess, correlation);
- mapRefRelToCorVar.put(rel, correlation);
- }
- return super.visitFieldAccess(fieldAccess);
- }
-
- @Override
- public Void visitSubQuery(RexSubQuery subQuery) {
- subQuery.rel.accept(FlinkRelDecorrelator.CorelMapBuilder.this);
- return super.visitSubQuery(subQuery);
- }
- };
- }
- }
-
- /**
- * Frame describing the relational expression after decorrelation
- * and where to find the output fields and correlation variables
- * among its output fields.
- */
- static class Frame {
- final RelNode r;
- final ImmutableSortedMap<Correlation, Integer> corVarOutputPos;
- final ImmutableMap<Integer, Integer> oldToNewOutputPos;
-
- Frame(RelNode r, SortedMap<Correlation, Integer> corVarOutputPos, Map<Integer, Integer> oldToNewOutputPos) {
- this.r = Preconditions.checkNotNull(r);
- this.corVarOutputPos = ImmutableSortedMap.copyOf(corVarOutputPos);
- this.oldToNewOutputPos = ImmutableSortedMap.copyOf(oldToNewOutputPos);
- }
- }
-}
-
-// End FlinkRelDecorrelator.java
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkPlannerImpl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkPlannerImpl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkPlannerImpl.scala
index 4f3e317..09e3277 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkPlannerImpl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/calcite/FlinkPlannerImpl.scala
@@ -32,11 +32,9 @@ import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.sql.parser.{SqlParser, SqlParseException => CSqlParseException}
import org.apache.calcite.sql.validate.SqlValidator
import org.apache.calcite.sql.{SqlNode, SqlOperatorTable}
-import org.apache.calcite.sql2rel.{SqlRexConvertletTable, SqlToRelConverter}
+import org.apache.calcite.sql2rel.{RelDecorrelator, SqlRexConvertletTable, SqlToRelConverter}
import org.apache.calcite.tools.{FrameworkConfig, RelConversionException}
import org.apache.flink.table.api.{SqlParserException, TableException, ValidationException}
-import org.apache.flink.table.calcite.sql2rel.FlinkRelDecorrelator
-import org.apache.flink.table.plan.cost.FlinkDefaultRelMetadataProvider
import scala.collection.JavaConversions._
@@ -109,7 +107,7 @@ class FlinkPlannerImpl(
// we disable automatic flattening in order to let composite types pass without modification
// we might enable it again once Calcite has better support for structured types
// root = root.withRel(sqlToRelConverter.flattenTypes(root.rel, true))
- root = root.withRel(FlinkRelDecorrelator.decorrelateQuery(root.rel))
+ root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel))
root
} catch {
case e: RelConversionException => throw TableException(e.getMessage)
@@ -148,7 +146,7 @@ class FlinkPlannerImpl(
new ViewExpanderImpl, validator, catalogReader, cluster, convertletTable, config)
root = sqlToRelConverter.convertQuery(validatedSqlNode, true, false)
root = root.withRel(sqlToRelConverter.flattenTypes(root.rel, true))
- root = root.withRel(FlinkRelDecorrelator.decorrelateQuery(root.rel))
+ root = root.withRel(RelDecorrelator.decorrelateQuery(root.rel))
FlinkPlannerImpl.this.root
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
index 1301c8d..5caaf1f 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala
@@ -20,7 +20,7 @@ package org.apache.flink.table.plan.rules
import org.apache.calcite.rel.rules._
import org.apache.calcite.tools.{RuleSet, RuleSets}
-import org.apache.flink.table.calcite.rules.{FlinkAggregateExpandDistinctAggregatesRule, FlinkAggregateJoinTransposeRule}
+import org.apache.flink.table.calcite.rules.FlinkAggregateExpandDistinctAggregatesRule
import org.apache.flink.table.plan.rules.dataSet._
import org.apache.flink.table.plan.rules.datastream._
@@ -79,7 +79,7 @@ object FlinkRuleSets {
// remove aggregation if it does not aggregate and input is already distinct
AggregateRemoveRule.INSTANCE,
// push aggregate through join
- FlinkAggregateJoinTransposeRule.EXTENDED,
+ AggregateJoinTransposeRule.EXTENDED,
// aggregate union rule
AggregateUnionAggregateRule.INSTANCE,
[5/5] flink git commit: [FLINK-5435] [table] Remove
FlinkAggregateJoinTransposeRule and FlinkRelDecorrelator after bumping
Calcite to v1.12.
Posted by fh...@apache.org.
[FLINK-5435] [table] Remove FlinkAggregateJoinTransposeRule and FlinkRelDecorrelator after bumping Calcite to v1.12.
This closes #3689.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c5173fa2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c5173fa2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c5173fa2
Branch: refs/heads/master
Commit: c5173fa26d3d8a32b0b182a37d34a8eeff6e36d0
Parents: 07f1b03
Author: Kurt Young <yk...@gmail.com>
Authored: Thu Apr 6 22:06:51 2017 +0800
Committer: Fabian Hueske <fh...@apache.org>
Committed: Thu Apr 6 16:34:02 2017 +0200
----------------------------------------------------------------------
.../rules/FlinkAggregateJoinTransposeRule.java | 358 ---
.../calcite/sql2rel/FlinkRelDecorrelator.java | 2216 ------------------
.../flink/table/calcite/FlinkPlannerImpl.scala | 8 +-
.../flink/table/plan/rules/FlinkRuleSets.scala | 4 +-
.../batch/sql/QueryDecorrelationTest.scala | 123 +-
.../api/scala/batch/sql/SetOperatorsTest.scala | 32 +-
6 files changed, 36 insertions(+), 2705 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/c5173fa2/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java b/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
deleted file mode 100644
index a817c91..0000000
--- a/flink-libraries/flink-table/src/main/java/org/apache/flink/table/calcite/rules/FlinkAggregateJoinTransposeRule.java
+++ /dev/null
@@ -1,358 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to you under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.calcite.rules;
-
-import com.google.common.base.Function;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import org.apache.calcite.linq4j.Ord;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.RelOptUtil;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.core.Join;
-import org.apache.calcite.rel.core.JoinRelType;
-import org.apache.calcite.rel.core.RelFactories;
-import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rel.logical.LogicalJoin;
-import org.apache.calcite.rel.metadata.RelMetadataQuery;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlSplittableAggFunction;
-import org.apache.calcite.tools.RelBuilder;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.Bug;
-import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.Util;
-import org.apache.calcite.util.mapping.Mapping;
-import org.apache.calcite.util.mapping.Mappings;
-
-import java.util.ArrayList;
-import java.util.BitSet;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.SortedMap;
-import java.util.TreeMap;
-import org.apache.flink.util.Preconditions;
-
-/**
- * Copied from {@link org.apache.calcite.rel.rules.AggregateJoinTransposeRule}, should be
- * removed once <a href="https://issues.apache.org/jira/browse/CALCITE-1544">[CALCITE-1544] fixes.
- */
-public class FlinkAggregateJoinTransposeRule extends RelOptRule {
- public static final FlinkAggregateJoinTransposeRule INSTANCE = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, false);
-
- /**
- * Extended instance of the rule that can push down aggregate functions.
- */
- public static final FlinkAggregateJoinTransposeRule EXTENDED = new FlinkAggregateJoinTransposeRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER, true);
-
- private final boolean allowFunctions;
-
- /**
- * Creates an FlinkAggregateJoinTransposeRule.
- */
- public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory, boolean allowFunctions) {
- super(operand(aggregateClass, null, Aggregate.IS_SIMPLE, operand(joinClass, any())), relBuilderFactory, null);
- this.allowFunctions = allowFunctions;
- }
-
- /**
- * @deprecated to be removed before 2.0
- */
- @Deprecated
- public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory) {
- this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), false);
- }
-
- /**
- * @deprecated to be removed before 2.0
- */
- @Deprecated
- public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, boolean allowFunctions) {
- this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory), allowFunctions);
- }
-
- /**
- * @deprecated to be removed before 2.0
- */
- @Deprecated
- public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory) {
- this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), false);
- }
-
- /**
- * @deprecated to be removed before 2.0
- */
- @Deprecated
- public FlinkAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
- this(aggregateClass, joinClass, RelBuilder.proto(aggregateFactory, joinFactory, projectFactory), allowFunctions);
- }
-
- public void onMatch(RelOptRuleCall call) {
- final Aggregate aggregate = call.rel(0);
- final Join join = call.rel(1);
- final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
- final RelBuilder relBuilder = call.builder();
-
- // If any aggregate functions do not support splitting, bail out
- // If any aggregate call has a filter, bail out
- for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
- if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
- return;
- }
- if (aggregateCall.filterArg >= 0) {
- return;
- }
- }
-
- // If it is not an inner join, we do not push the
- // aggregate operator
- if (join.getJoinType() != JoinRelType.INNER) {
- return;
- }
-
- if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
- return;
- }
-
- // Do the columns used by the join appear in the output of the aggregate?
- final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
- final RelMetadataQuery mq = RelMetadataQuery.instance();
- final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
- final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
- final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
- final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
-
- // Split join condition
- final List<Integer> leftKeys = Lists.newArrayList();
- final List<Integer> rightKeys = Lists.newArrayList();
- final List<Boolean> filterNulls = Lists.newArrayList();
- RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
- // If it contains non-equi join conditions, we bail out
- if (!nonEquiConj.isAlwaysTrue()) {
- return;
- }
-
- // Push each aggregate function down to each side that contains all of its
- // arguments. Note that COUNT(*), because it has no arguments, can go to
- // both sides.
- final Map<Integer, Integer> map = new HashMap<>();
- final List<Side> sides = new ArrayList<>();
- int uniqueCount = 0;
- int offset = 0;
- int belowOffset = 0;
- for (int s = 0; s < 2; s++) {
- final Side side = new Side();
- final RelNode joinInput = join.getInput(s);
- int fieldCount = joinInput.getRowType().getFieldCount();
- final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
- final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
- for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
- map.put(c.e, belowOffset + c.i);
- }
- final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
- final boolean unique;
- if (!allowFunctions) {
- assert aggregate.getAggCallList().isEmpty();
- // If there are no functions, it doesn't matter as much whether we
- // aggregate the inputs before the join, because there will not be
- // any functions experiencing a cartesian product effect.
- //
- // But finding out whether the input is already unique requires a call
- // to areColumnsUnique that currently (until [CALCITE-1048] "Make
- // metadata more robust" is fixed) places a heavy load on
- // the metadata system.
- //
- // So we choose to imagine the the input is already unique, which is
- // untrue but harmless.
- //
- Util.discard(Bug.CALCITE_1048_FIXED);
- unique = true;
- } else {
- final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
- unique = unique0 != null && unique0;
- }
- if (unique) {
- ++uniqueCount;
- side.aggregate = false;
- side.newInput = joinInput;
- } else {
- side.aggregate = true;
- List<AggregateCall> belowAggCalls = new ArrayList<>();
- final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
- final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
- for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
- final SqlAggFunction aggregation = aggCall.e.getAggregation();
- final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
- final AggregateCall call1;
- if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
- call1 = splitter.split(aggCall.e, mapping);
- } else {
- call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
- }
- if (call1 != null) {
- side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
- }
- }
- side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
- }
- offset += fieldCount;
- belowOffset += side.newInput.getRowType().getFieldCount();
- sides.add(side);
- }
-
- if (uniqueCount == 2) {
- // Both inputs to the join are unique. There is nothing to be gained by
- // this rule. In fact, this aggregate+join may be the result of a previous
- // invocation of this rule; if we continue we might loop forever.
- return;
- }
-
- // Update condition
- final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {
- public Integer apply(Integer a0) {
- return map.get(a0);
- }
- }, join.getRowType().getFieldCount(), belowOffset);
- final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
-
- // Create new join
- relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
-
- // Aggregate above to sum up the sub-totals
- final List<AggregateCall> newAggCalls = new ArrayList<>();
- final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
- final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
- final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
- for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
- final SqlAggFunction aggregation = aggCall.e.getAggregation();
- final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
- final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
- final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
- newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
- }
-
- relBuilder.project(projects);
-
- boolean aggConvertedToProjects = false;
- if (allColumnsInAggregate) {
- // let's see if we can convert aggregate into projects
- List<RexNode> projects2 = new ArrayList<>();
- for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
- projects2.add(relBuilder.field(key));
- }
- for (AggregateCall newAggCall : newAggCalls) {
- final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
- if (splitter != null) {
- projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
- }
- }
- if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
- // We successfully converted agg calls into projects.
- relBuilder.project(projects2);
- aggConvertedToProjects = true;
- }
- }
-
- if (!aggConvertedToProjects) {
- relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
- }
-
- call.transformTo(relBuilder.build());
- }
-
- /**
- * Computes the closure of a set of columns according to a given list of
- * constraints. Each 'x = y' constraint causes bit y to be set if bit x is
- * set, and vice versa.
- */
- private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
- SortedMap<Integer, BitSet> equivalence = new TreeMap<>();
- for (RexNode pred : predicates) {
- populateEquivalences(equivalence, pred);
- }
- ImmutableBitSet keyColumns = aggregateColumns;
- for (Integer aggregateColumn : aggregateColumns) {
- final BitSet bitSet = equivalence.get(aggregateColumn);
- if (bitSet != null) {
- keyColumns = keyColumns.union(bitSet);
- }
- }
- return keyColumns;
- }
-
- private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
- switch (predicate.getKind()) {
- case EQUALS:
- RexCall call = (RexCall) predicate;
- final List<RexNode> operands = call.getOperands();
- if (operands.get(0) instanceof RexInputRef) {
- final RexInputRef ref0 = (RexInputRef) operands.get(0);
- if (operands.get(1) instanceof RexInputRef) {
- final RexInputRef ref1 = (RexInputRef) operands.get(1);
- populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
- populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
- }
- }
- }
- }
-
- private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
- BitSet bitSet = equivalence.get(i0);
- if (bitSet == null) {
- bitSet = new BitSet();
- equivalence.put(i0, bitSet);
- }
- bitSet.set(i1);
- }
-
- /**
- * Creates a {@link SqlSplittableAggFunction.Registry}
- * that is a view of a list.
- */
- private static <E> SqlSplittableAggFunction.Registry<E> registry(final List<E> list) {
- return new SqlSplittableAggFunction.Registry<E>() {
- public int register(E e) {
- int i = list.indexOf(e);
- if (i < 0) {
- i = list.size();
- list.add(e);
- }
- return i;
- }
- };
- }
-
- /**
- * Work space for an input to a join.
- */
- private static class Side {
- final Map<Integer, Integer> split = new HashMap<>();
- RelNode newInput;
- boolean aggregate;
- }
-}
-
-// End FlinkAggregateJoinTransposeRule.java