You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2021/10/08 03:56:47 UTC

[GitHub] [flink-ml] zhipeng93 opened a new pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

zhipeng93 opened a new pull request #18:
URL: https://github.com/apache/flink-ml/pull/18


   This PR supports withBroadcast() function in DataStream by caching the broadcastInputs in static variables.
   
   Note that this PR is rebased on [[FLINK-10][iteration]](https://github.com/apache/flink-ml/pull/17).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340310



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, Serializable {
+    /** names of the broadcast DataStreams. */
+    private final String[] broadcastNames;
+    /** input list of the multi-input operator. */
+    private final List<Input> inputList;
+    /** output types of input DataStreams. */
+    private final TypeInformation<?>[] inTypes;
+    /** caches of the broadcast inputs. */
+    private final List<?>[] caches;
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    public CacheStreamOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastNames.length);
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                Tuple2.of(broadcastNames[i - 1], getRuntimeContext().getIndexOfThisSubtask()));
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll((List) caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            Tuple2.of(
+                                    broadcastNames[i],
+                                    getRuntimeContext().getIndexOfThisSubtask()));
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =

Review comment:
       Thanks Yun. I have made the changes.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744116862



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
##########
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An subclass of {@link StreamingRuntimeContext} that provides accessibility of broadcast
+ * variables.
+ */
+public class BroadcastStreamingRuntimeContext extends StreamingRuntimeContext {
+
+    Map<String, List<?>> broadcastVariables = new HashMap<>();
+
+    public BroadcastStreamingRuntimeContext(
+            Environment env,
+            Map<String, Accumulator<?, ?>> accumulators,
+            OperatorMetricGroup operatorMetricGroup,
+            OperatorID operatorID,
+            ProcessingTimeService processingTimeService,
+            @Nullable KeyedStateStore keyedStateStore,
+            ExternalResourceInfoProvider externalResourceInfoProvider) {
+        super(
+                env,
+                accumulators,
+                operatorMetricGroup,
+                operatorID,
+                processingTimeService,
+                keyedStateStore,
+                externalResourceInfoProvider);
+    }
+
+    @Override
+    public boolean hasBroadcastVariable(String name) {
+        return broadcastVariables.containsKey(name);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public <RT> List<RT> getBroadcastVariable(String name) {
+        return (List<RT>) broadcastVariables.get(name);

Review comment:
       Hi Yun, good question.
   
   I would like to throw a exception if users call `getBroadcastVariable` before received all the elements of broadcastVariables, i.e., before calling `processElement` or `processWatermark`. How do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744116336



##########
File path: flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
##########
@@ -43,6 +43,10 @@
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
+    /**
+     * The segments that are newly added that has not been retrieved by getNewlyFinishedSegments().
+     */
+    private final List<Segment> newlyFinishedSegments;

Review comment:
       Hi Yun, you are right. 
   
   I have removed `newlyFinishedSegments` by including the pending elements into the `finishSegments` of new writers. Please checkout `DataCacheWriter#Line 57`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] gaoyunhaii closed pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
gaoyunhaii closed pull request #18:
URL: https://github.com/apache/flink-ml/pull/18


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340301



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {

Review comment:
       Thanks Yun, I have made the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732759218



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -232,6 +247,78 @@ private OperatorMetricGroup createOperatorMetricGroup(
         }
     }
 
+    /**
+     * extracts common processing logic in subclasses' processing elements.
+     *
+     * @param streamRecord the input record.
+     * @param inputIndex input id, starts from zero.
+     * @param consumer the consumer function.
+     * @throws Exception
+     */
+    protected void processElementX(
+            StreamRecord streamRecord,
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> consumer)
+            throws Exception {
+        if (!isBlocked[inputIndex]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[inputIndex].finishCurrentSegmentAndStartNewSegment();

Review comment:
       Thanks @gaoyunhaii ! I have extracted `AbstractBroadcastWrapperOperator#processPendingElements`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744116336



##########
File path: flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
##########
@@ -43,6 +43,10 @@
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
+    /**
+     * The segments that are newly added that has not been retrieved by getNewlyFinishedSegments().
+     */
+    private final List<Segment> newlyFinishedSegments;

Review comment:
       Hi Yun, this is a good point. But I think maintaining `newlyFinishedSegments` is better. The reason is as follows:
   
   We need to maintain a not-yet processed segment list for each input (i.e., `segmentLists[inputIndex]`), i.e., in this list we need to (1) add the newly cached segments and (2) remove the processed segments.
   
   If we are not maintaining the not-yet processed segment list and include the pending segments in  `finishSegmets` of the new writers, then we are not able to remove the processed segments.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340233



##########
File path: flink-ml-lib/pom.xml
##########
@@ -65,6 +71,44 @@ under the License.
       <artifactId>core</artifactId>
       <version>1.1.2</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>

Review comment:
       Thanks @gaoyunhaii . I have removed the dependency of `flink-statebackend-rocksdb`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       Thanks. I have done the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all

Review comment:
       Thanks. I have done the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340305



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);

Review comment:
       Thanks Yun. I have made the changes.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r728932487



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                for (IN ele : cache) {
+                    wrappedOperator.processElement(new StreamRecord<>(ele));
+                }
+                cache.clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                cache.add(streamRecord.getValue());

Review comment:
       Hi Yunfeng, it is a good point.
   
   I have re-implemented cache on filesystems (local file system and remote file system) based on this [Patch](https://github.com/apache/flink-ml/pull/17). Please checkout AbstractBroadcastWrapperOperator#initializeState and AbstractBroadcastWrapperOperator#snapshot for details.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] gaoyunhaii commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732431833



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -96,31 +97,36 @@
 
     protected final StreamOperatorFactory<T> operatorFactory;
 
-    /** Metric group for the operator. */
     protected final OperatorMetricGroup metrics;
 
     protected final S wrappedOperator;
 
-    /** variables for withBroadcast operators. */
-    protected final MailboxExecutor mailboxExecutor;
-
-    protected final String[] broadcastStreamNames;
+    protected transient StreamOperatorStateHandler stateHandler;
 
-    protected final boolean[] isBlocking;
+    protected transient InternalTimeServiceManager<?> timeServiceManager;
 
+    protected final MailboxExecutor mailboxExecutor;
+    /** variables specific for withBroadcast functionality. */

Review comment:
       In general one empty line before each instance variable

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -196,17 +203,25 @@ public AbstractBroadcastWrapperOperator(
     }
 
     /**
-     * check whether all of broadcast variables are ready.
+     * checks whether all of broadcast variables are ready. Besides it maintains a state
+     * {broadcastVariablesReady} to avoiding invoking {@code BroadcastContext.isCacheFinished(...)}
+     * repeatedly. Finally, it sets broadcast variables for ${@link HasBroadcastVariable} if the
+     * broadcast variables are ready.
      *
-     * @return
+     * @return true if all broadcast variables are ready, false otherwise.
      */
     protected boolean areBroadcastVariablesReady() {
         if (broadcastVariablesReady) {
             return true;
         }
         for (String name : broadcastStreamNames) {
-            if (!BroadcastContext.isCacheFinished(Tuple2.of(name, indexOfSubtask))) {
+            if (!BroadcastContext.isCacheFinished(name + "-" + indexOfSubtask)) {
                 return false;
+            } else if (wrappedOperator instanceof HasBroadcastVariable) {
+                String key = name + "-" + indexOfSubtask;
+                String userKey = name.substring(name.indexOf('-') + 1);
+                ((HasBroadcastVariable) wrappedOperator)

Review comment:
       Use `OperatorUtils#processOperatorOrUdfIfSatisfy` instead since we may need to handle both of operators and UDF if either of them implements the interface. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof PhysicalTransformation
+                        && resultStream.getTransformation() instanceof PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing and the only
+     * functionality of this operator is to cache all the input records in ${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];

Review comment:
       Do local variable always need explicit initialization in java?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/MultipleInputBroadcastWrapperOperator.java
##########
@@ -32,143 +31,73 @@
 import java.util.ArrayList;
 import java.util.List;
 
-/** Wrapper for WithBroadcastMultipleInputStreamOperator. */
+/** Wrapper for {@link MultipleInputStreamOperator} that implements {@link HasBroadcastVariable}. */
 public class MultipleInputBroadcastWrapperOperator<OUT>
         extends AbstractBroadcastWrapperOperator<OUT, MultipleInputStreamOperator<OUT>>
         implements MultipleInputStreamOperator<OUT> {
 
-    public MultipleInputBroadcastWrapperOperator(
+    private final List<Input> inputList;
+
+    MultipleInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
             String[] broadcastStreamNames,
             TypeInformation[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
-    }
-
-    @Override
-    public List<Input> getInputs() {
-        List<Input> proxyInputs = new ArrayList<>();
+            boolean[] isBlocked) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocked);
+        inputList = new ArrayList<>();
         for (int i = 0; i < wrappedOperator.getInputs().size(); i++) {
-            proxyInputs.add(new ProxyInput(i));
+            inputList.add(new ProxyInput(i));
         }
-        return proxyInputs;
-    }
-
-    private <IN> void processElement(StreamRecord streamRecord, Input<IN> input) throws Exception {
-        input.processElement(streamRecord);
-    }
-
-    private <IN> void processWatermark(Watermark watermark, Input<IN> input) throws Exception {
-        input.processWatermark(watermark);
     }
 
-    private <IN> void processLatencyMarker(LatencyMarker latencyMarker, Input<IN> input)
-            throws Exception {
-        input.processLatencyMarker(latencyMarker);
-    }
-
-    private <IN> void setKeyContextElement(StreamRecord streamRecord, Input<IN> input)
-            throws Exception {
-        input.setKeyContextElement(streamRecord);
-    }
-
-    private <IN> void processWatermarkStatus(WatermarkStatus watermarkStatus, Input<IN> input)
-            throws Exception {
-        input.processWatermarkStatus(watermarkStatus);
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
     }
 
     @Override
     public void endInput(int inputId) throws Exception {
-        ((ProxyInput) (getInputs().get(inputId - 1))).endInput();
+        endInputX(inputId - 1, x -> wrappedOperator.getInputs().get(inputId - 1).processElement(x));

Review comment:
       nit: `wrappedOperator.getInputs().get(inputId - 1)::processElement` ?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -232,6 +247,78 @@ private OperatorMetricGroup createOperatorMetricGroup(
         }
     }
 
+    /**
+     * extracts common processing logic in subclasses' processing elements.
+     *
+     * @param streamRecord the input record.
+     * @param inputIndex input id, starts from zero.
+     * @param consumer the consumer function.
+     * @throws Exception
+     */
+    protected void processElementX(
+            StreamRecord streamRecord,
+            int inputIndex,
+            ThrowingConsumer<StreamRecord, Exception> consumer)
+            throws Exception {
+        if (!isBlocked[inputIndex]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[inputIndex].finishCurrentSegmentAndStartNewSegment();

Review comment:
       Perhaps we also extract the method to process pending records to eliminate the repeat with `endInputX`? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/MultipleInputBroadcastWrapperOperator.java
##########
@@ -32,143 +31,73 @@
 import java.util.ArrayList;
 import java.util.List;
 
-/** Wrapper for WithBroadcastMultipleInputStreamOperator. */
+/** Wrapper for {@link MultipleInputStreamOperator} that implements {@link HasBroadcastVariable}. */
 public class MultipleInputBroadcastWrapperOperator<OUT>
         extends AbstractBroadcastWrapperOperator<OUT, MultipleInputStreamOperator<OUT>>
         implements MultipleInputStreamOperator<OUT> {
 
-    public MultipleInputBroadcastWrapperOperator(
+    private final List<Input> inputList;
+
+    MultipleInputBroadcastWrapperOperator(
             StreamOperatorParameters<OUT> parameters,
             StreamOperatorFactory<OUT> operatorFactory,
             String[] broadcastStreamNames,
             TypeInformation[] inTypes,
-            boolean[] isBlocking) {
-        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
-    }
-
-    @Override
-    public List<Input> getInputs() {
-        List<Input> proxyInputs = new ArrayList<>();
+            boolean[] isBlocked) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocked);
+        inputList = new ArrayList<>();
         for (int i = 0; i < wrappedOperator.getInputs().size(); i++) {
-            proxyInputs.add(new ProxyInput(i));
+            inputList.add(new ProxyInput(i));
         }
-        return proxyInputs;
-    }
-
-    private <IN> void processElement(StreamRecord streamRecord, Input<IN> input) throws Exception {
-        input.processElement(streamRecord);
-    }
-
-    private <IN> void processWatermark(Watermark watermark, Input<IN> input) throws Exception {
-        input.processWatermark(watermark);
     }
 
-    private <IN> void processLatencyMarker(LatencyMarker latencyMarker, Input<IN> input)
-            throws Exception {
-        input.processLatencyMarker(latencyMarker);
-    }
-
-    private <IN> void setKeyContextElement(StreamRecord streamRecord, Input<IN> input)
-            throws Exception {
-        input.setKeyContextElement(streamRecord);
-    }
-
-    private <IN> void processWatermarkStatus(WatermarkStatus watermarkStatus, Input<IN> input)
-            throws Exception {
-        input.processWatermarkStatus(watermarkStatus);
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
     }
 
     @Override
     public void endInput(int inputId) throws Exception {
-        ((ProxyInput) (getInputs().get(inputId - 1))).endInput();
+        endInputX(inputId - 1, x -> wrappedOperator.getInputs().get(inputId - 1).processElement(x));
+        super.endInput(inputId);
     }
 
     private class ProxyInput<IN> implements Input<IN> {
 
-        private final int inputIdMinusOne;
+        /** input index of this input. */
+        private final int inputIndex;
 
         private final Input<IN> input;
 
-        public ProxyInput(int inputIdMinusOne) {
-            this.inputIdMinusOne = inputIdMinusOne;
-            this.input = wrappedOperator.getInputs().get(inputIdMinusOne);
+        public ProxyInput(int inputIndex) {
+            this.inputIndex = inputIndex;
+            this.input = wrappedOperator.getInputs().get(inputIndex);
         }
 
         @Override
         public void processElement(StreamRecord<IN> streamRecord) throws Exception {
-            if (isBlocking[inputIdMinusOne]) {
-                if (areBroadcastVariablesReady()) {
-                    dataCacheWriters[inputIdMinusOne].finishCurrentSegmentAndStartNewSegment();
-                    segmentLists[inputIdMinusOne].addAll(
-                            dataCacheWriters[inputIdMinusOne].getNewlyFinishedSegments());
-                    if (segmentLists[inputIdMinusOne].size() != 0) {
-                        DataCacheReader dataCacheReader =
-                                new DataCacheReader<>(
-                                        inTypes[inputIdMinusOne].createSerializer(
-                                                containingTask.getExecutionConfig()),
-                                        fileSystem,
-                                        segmentLists[inputIdMinusOne]);
-                        while (dataCacheReader.hasNext()) {
-                            MultipleInputBroadcastWrapperOperator.this.processElement(
-                                    new StreamRecord(dataCacheReader.next()), input);
-                        }
-                    }
-                    segmentLists[inputIdMinusOne].clear();
-                    MultipleInputBroadcastWrapperOperator.this.processElement(streamRecord, input);
-
-                } else {
-                    dataCacheWriters[inputIdMinusOne].addRecord(streamRecord.getValue());
-                }
-
-            } else {
-                while (!areBroadcastVariablesReady()) {
-                    mailboxExecutor.yield();
-                }
-                MultipleInputBroadcastWrapperOperator.this.processElement(streamRecord, input);
-            }
+            MultipleInputBroadcastWrapperOperator.this.processElementX(
+                    streamRecord, inputIndex, x -> input.processElement(x));

Review comment:
       Similarly this might be simplified to `input::processElement`

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -18,106 +18,54 @@
 
 package org.apache.flink.ml.common.broadcast;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.java.tuple.Tuple2;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.concurrent.ConcurrentHashMap;
 
 public class BroadcastContext {
     /**
-     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
-     * is (isBroaddcastVariableReady, cacheList).
+     * stores broadcast data streams in a map. The key is broadcastName-partitionId and the value is
+     * (isBroadcastVariableReady, cacheList).
      */
-    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
-            new HashMap<>();
-    /**
-     * We use lock because we want to enable `getBroadcastVariable(String)` in a TM with multiple
-     * slots here. Note that using ConcurrentHashMap is not enough since we need "contains and get
-     * in an atomic operation".
-     */
-    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+    private static final Map<String, Tuple2<Boolean, List<?>>> BROADCAST_VARIABLES =
+            new ConcurrentHashMap<>();
 
-    public static void putBroadcastVariable(
-            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.put(key, variable);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void putBroadcastVariable(String key, Tuple2<Boolean, List<?>> variable) {
+        BROADCAST_VARIABLES.put(key, variable);
     }
 
     /**
-     * get the cached list with the given key.
-     *
-     * @param key
-     * @param <T>
-     * @return the cache list.
-     */
-    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> key) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            result = broadcastVariables.get(key).f1;
-        } finally {
-            lock.readLock().unlock();
-        }
-        return (List<T>) result;
-    }
-
-    /**
-     * get broadcast variables by name
+     * gets broadcast variables by name if this broadcast variable is fully cached.
      *
      * @param name
      * @param <T>
-     * @return
+     * @return the cache broadcast variable. Return null if it is not fully cached.
      */
+    @VisibleForTesting
     public static <T> List<T> getBroadcastVariable(String name) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            for (Tuple2<String, Integer> nameAndPartitionId : broadcastVariables.keySet()) {
-                if (name.equals(nameAndPartitionId.f0) && isCacheFinished(nameAndPartitionId)) {
-                    result = broadcastVariables.get(nameAndPartitionId).f1;
-                    break;
-                }
-            }
-        } finally {
-            lock.readLock().unlock();
+        Tuple2<Boolean, List<?>> cacheReadyAndList = BROADCAST_VARIABLES.get(name);
+        if (cacheReadyAndList.f0) {
+            return (List<T>) cacheReadyAndList.f1;
         }
-        return (List<T>) result;
+        return null;
     }
 
-    public static void remove(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.remove(key);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void remove(String key) {
+        BROADCAST_VARIABLES.remove(key);
     }
 
-    public static void markCacheFinished(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.get(key).f0 = true;
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void markCacheFinished(String key) {

Review comment:
       We should need to explicitly notify the wrapper operator by emitting one mail? otherwise the wrapper operator may stalled in the `endInputX`. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       @SuppressWarnings({"rawtypes", "unchecked"})

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       For the following assignments, not the first line

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof PhysicalTransformation
+                        && resultStream.getTransformation() instanceof PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing and the only
+     * functionality of this operator is to cache all the input records in ${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(

Review comment:
       `<OUT>` -> `<>`

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof PhysicalTransformation
+                        && resultStream.getTransformation() instanceof PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing and the only
+     * functionality of this operator is to cache all the input records in ${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];
         DraftExecutionEnvironment draftEnv =
                 new DraftExecutionEnvironment(
-                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocked));
 
         List<DataStream<?>> draftSources = new ArrayList<>();
         for (int i = 0; i < inputList.size(); i++) {
             draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
         }
         DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
-
+        Preconditions.checkState(
+                draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 + inputList.size(),

Review comment:
       Are you sure this would work? When calling `getStreamGraph` the list of `transformations` would be cleared from the env.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] gaoyunhaii commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744108746



##########
File path: flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
##########
@@ -43,6 +43,10 @@
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
+    /**
+     * The segments that are newly added that has not been retrieved by getNewlyFinishedSegments().
+     */
+    private final List<Segment> newlyFinishedSegments;

Review comment:
        Do we still need `newlyFinishedSegments` if on initialization we could add also include the pending segments into the `finishSegments` of the new writers?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java
##########
@@ -0,0 +1,630 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.functions.RichFunction;
+import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.ml.common.broadcast.BroadcastStreamingRuntimeContext;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElement;
+import org.apache.flink.ml.common.broadcast.typeinfo.CacheElementTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.metrics.groups.InternalOperatorIOMetricGroup;
+import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutpusStreamDecorator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler.CheckpointedStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.ThrowingConsumer;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.UUID;
+
+/** Base class for the broadcast wrapper operators. */
+public abstract class AbstractBroadcastWrapperOperator<T, S extends StreamOperator<T>>
+        implements StreamOperator<T>, StreamOperatorStateHandler.CheckpointedStreamOperator {
+
+    private static final Logger LOG =
+            LoggerFactory.getLogger(AbstractBroadcastWrapperOperator.class);
+
+    protected final StreamOperatorParameters<T> parameters;
+
+    protected final StreamConfig streamConfig;
+
+    protected final StreamTask<?, ?> containingTask;
+
+    protected final Output<StreamRecord<T>> output;
+
+    protected final StreamOperatorFactory<T> operatorFactory;
+
+    protected final OperatorMetricGroup metrics;
+
+    protected final S wrappedOperator;
+
+    protected transient StreamOperatorStateHandler stateHandler;
+
+    protected transient InternalTimeServiceManager<?> timeServiceManager;
+
+    protected final MailboxExecutor mailboxExecutor;
+
+    /** variables specific for withBroadcast functionality. */
+    protected final String[] broadcastStreamNames;
+
+    /**
+     * whether each input is blocked. Inputs with broadcast variables can only process their input
+     * records after broadcast variables are ready. One input is non-blocked if it can consume its
+     * inputs (by caching) when broadcast variables are not ready. Otherwise it has to block the
+     * processing and wait until the broadcast variables are ready to be accessed.
+     */
+    protected final boolean[] isBlocked;
+
+    /** type information of each input. */
+    protected final TypeInformation<?>[] inTypes;
+
+    /** whether all broadcast variables of this operator are ready. */
+    protected boolean broadcastVariablesReady;
+
+    /** index of this subtask. */
+    protected final transient int indexOfSubtask;
+
+    /** number of the inputs of this operator. */
+    protected final int numInputs;
+
+    /** runtimeContext of the rich function in wrapped operator. */
+    BroadcastStreamingRuntimeContext wrappedOperatorRuntimeContext;
+
+    /**
+     * path of the file used to stored the cached records. It could be local file system or remote
+     * file system.
+     */
+    private Path basePath;
+
+    /** file system. */
+    protected FileSystem fileSystem;
+
+    /** DataCacheWriter for each input. */
+    @SuppressWarnings("rawtypes")
+    protected DataCacheWriter[] dataCacheWriters;
+
+    /** segment list for each input. */
+    protected List<Segment>[] segmentLists;
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    AbstractBroadcastWrapperOperator(
+            StreamOperatorParameters<T> parameters,
+            StreamOperatorFactory<T> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation<?>[] inTypes,
+            boolean[] isBlocked) {
+        this.parameters = Objects.requireNonNull(parameters);
+        this.streamConfig = Objects.requireNonNull(parameters.getStreamConfig());
+        this.containingTask = Objects.requireNonNull(parameters.getContainingTask());
+        this.output = Objects.requireNonNull(parameters.getOutput());
+        this.operatorFactory = Objects.requireNonNull(operatorFactory);
+        this.metrics = createOperatorMetricGroup(containingTask.getEnvironment(), streamConfig);
+        this.wrappedOperator =
+                (S)
+                        StreamOperatorFactoryUtil.<T, S>createOperator(
+                                        operatorFactory,
+                                        (StreamTask) containingTask,
+                                        streamConfig,
+                                        output,
+                                        parameters.getOperatorEventDispatcher())
+                                .f0;
+
+        boolean hasRichFunction =
+                wrappedOperator instanceof AbstractUdfStreamOperator
+                        && ((AbstractUdfStreamOperator) wrappedOperator).getUserFunction()
+                                instanceof RichFunction;
+
+        if (hasRichFunction) {
+            wrappedOperatorRuntimeContext =
+                    new BroadcastStreamingRuntimeContext(
+                            containingTask.getEnvironment(),
+                            containingTask.getEnvironment().getAccumulatorRegistry().getUserMap(),
+                            wrappedOperator.getMetricGroup(),
+                            wrappedOperator.getOperatorID(),
+                            ((AbstractUdfStreamOperator) wrappedOperator)
+                                    .getProcessingTimeService(),
+                            null,
+                            containingTask.getEnvironment().getExternalResourceInfoProvider());
+
+            ((RichFunction) ((AbstractUdfStreamOperator) wrappedOperator).getUserFunction())
+                    .setRuntimeContext(wrappedOperatorRuntimeContext);
+        } else {
+            throw new RuntimeException(
+                    "The operator is not a instance of "
+                            + AbstractUdfStreamOperator.class.getSimpleName()
+                            + " that contains a "
+                            + RichFunction.class.getSimpleName());
+        }
+
+        this.mailboxExecutor =
+                containingTask.getMailboxExecutorFactory().createExecutor(TaskMailbox.MIN_PRIORITY);
+        // variables specific for withBroadcast functionality.
+        this.broadcastStreamNames = broadcastStreamNames;
+        this.isBlocked = isBlocked;
+        this.inTypes = inTypes;
+        this.broadcastVariablesReady = false;
+        this.indexOfSubtask = containingTask.getIndexInSubtaskGroup();
+        this.numInputs = inTypes.length;
+
+        // puts in mailboxExecutor
+        for (String name : broadcastStreamNames) {
+            BroadcastContext.putMailBoxExecutor(name + "-" + indexOfSubtask, mailboxExecutor);
+        }
+
+        basePath =
+                OperatorUtils.getDataCachePath(
+                        containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(),
+                        containingTask
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+        try {
+            fileSystem = basePath.getFileSystem();
+            dataCacheWriters = new DataCacheWriter[numInputs];
+            for (int i = 0; i < numInputs; i++) {
+                dataCacheWriters[i] =
+                        new DataCacheWriter(
+                                new CacheElementTypeInfo<>(inTypes[i])
+                                        .createSerializer(containingTask.getExecutionConfig()),
+                                fileSystem,
+                                () ->

Review comment:
       Also see `OperatorUtils#createDataCacheFileGenerator`

##########
File path: flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
##########
@@ -43,6 +43,10 @@
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
+    /**

Review comment:
       Add empty line.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
##########
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An subclass of {@link StreamingRuntimeContext} that provides accessibility of broadcast
+ * variables.
+ */
+public class BroadcastStreamingRuntimeContext extends StreamingRuntimeContext {
+
+    Map<String, List<?>> broadcastVariables = new HashMap<>();
+
+    public BroadcastStreamingRuntimeContext(
+            Environment env,
+            Map<String, Accumulator<?, ?>> accumulators,
+            OperatorMetricGroup operatorMetricGroup,
+            OperatorID operatorID,
+            ProcessingTimeService processingTimeService,
+            @Nullable KeyedStateStore keyedStateStore,
+            ExternalResourceInfoProvider externalResourceInfoProvider) {
+        super(
+                env,
+                accumulators,
+                operatorMetricGroup,
+                operatorID,
+                processingTimeService,
+                keyedStateStore,
+                externalResourceInfoProvider);
+    }
+
+    @Override
+    public boolean hasBroadcastVariable(String name) {
+        return broadcastVariables.containsKey(name);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public <RT> List<RT> getBroadcastVariable(String name) {
+        return (List<RT>) broadcastVariables.get(name);

Review comment:
       Do we need to consider the case that users call `getBroadcastVariable` before received all the elements of broadcastVariables ? 

##########
File path: flink-ml-lib/src/test/resources/log4j2-test.properties
##########
@@ -0,0 +1,28 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Set root logger level to OFF to not flood build logs
+# set manually to INFO for debugging purposes
+rootLogger.level = INFO

Review comment:
       Should be `OFF`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744116862



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastStreamingRuntimeContext.java
##########
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
+import org.apache.flink.api.common.state.KeyedStateStore;
+import org.apache.flink.metrics.groups.OperatorMetricGroup;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
+
+import javax.annotation.Nullable;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An subclass of {@link StreamingRuntimeContext} that provides accessibility of broadcast
+ * variables.
+ */
+public class BroadcastStreamingRuntimeContext extends StreamingRuntimeContext {
+
+    Map<String, List<?>> broadcastVariables = new HashMap<>();
+
+    public BroadcastStreamingRuntimeContext(
+            Environment env,
+            Map<String, Accumulator<?, ?>> accumulators,
+            OperatorMetricGroup operatorMetricGroup,
+            OperatorID operatorID,
+            ProcessingTimeService processingTimeService,
+            @Nullable KeyedStateStore keyedStateStore,
+            ExternalResourceInfoProvider externalResourceInfoProvider) {
+        super(
+                env,
+                accumulators,
+                operatorMetricGroup,
+                operatorID,
+                processingTimeService,
+                keyedStateStore,
+                externalResourceInfoProvider);
+    }
+
+    @Override
+    public boolean hasBroadcastVariable(String name) {
+        return broadcastVariables.containsKey(name);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public <RT> List<RT> getBroadcastVariable(String name) {
+        return (List<RT>) broadcastVariables.get(name);

Review comment:
       Hi Yun, good question.
   
   I would like to throw a exception if users call `getBroadcastVariable` before received all the elements of broadcastVariables, i.e., before calling `processElement` or `processWatermark`. I believe it is enough for developing ML algs for now. 
   
   By the way, we can optimize this later if we get a real use case for this. How do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r744116336



##########
File path: flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
##########
@@ -43,6 +43,10 @@
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
+    /**
+     * The segments that are newly added that has not been retrieved by getNewlyFinishedSegments().
+     */
+    private final List<Segment> newlyFinishedSegments;

Review comment:
       Hi Yun, thanks for the review. 
   
   I have removed `newlyFinishedSegments` by including the pending elements into the `finishSegments` of new writers. Please checkout `DataCacheWriter#Line 57`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r727635514



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>

Review comment:
       Could it be better to rename `CacheStreamOperator` to something like `BroadcastStreamOperator`? I find it a little bit hard to associate the current name with its functionality.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r726999611



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {

Review comment:
       Should this be `if(!isBlocking[0])`? From my understanding, putting a record in cache and continuing should be a non-blocking action.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340287



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);

Review comment:
       Hi @gaoyunhaii, I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkState(inputList.size() > 0);
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        final String[] broadcastStreamNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<OUT> resultStream =
+                buildGraph(env, inputList, broadcastStreamNames, userDefinedFunction);
+
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = getCoLocationKey(broadcastStreamNames);
+        DataStream<OUT> cachedBroadcastInputs = cacheBroadcastVariables(env, bcStreams, outType);
+
+        for (int i = 0; i < inputList.size(); i++) {
+            inputList.get(i).getTransformation().setCoLocationGroupKey(coLocationKey);

Review comment:
       Hi Yun, Thanks for the feedback.
   
   I have updated the code in my my understanding. Please check `BroadcatUtils#line93` for details.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {

Review comment:
       Thanks Yun, I have made the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       Thanks Yun. But can we use CheckNotNull here instead?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, Serializable {
+    /** names of the broadcast DataStreams. */
+    private final String[] broadcastNames;
+    /** input list of the multi-input operator. */
+    private final List<Input> inputList;
+    /** output types of input DataStreams. */
+    private final TypeInformation<?>[] inTypes;
+    /** caches of the broadcast inputs. */
+    private final List<?>[] caches;
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    public CacheStreamOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastNames.length);
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                Tuple2.of(broadcastNames[i - 1], getRuntimeContext().getIndexOfThisSubtask()));
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll((List) caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            Tuple2.of(
+                                    broadcastNames[i],
+                                    getRuntimeContext().getIndexOfThisSubtask()));
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);

Review comment:
       Thanks Yun. I have made the changes.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[0].finishCurrentSegmentAndStartNewSegment();
+                segmentLists[0].addAll(dataCacheWriters[0].getNewlyFinishedSegments());
+                if (segmentLists[0].size() != 0) {
+                    DataCacheReader dataCacheReader =
+                            new DataCacheReader<>(
+                                    inTypes[0].createSerializer(
+                                            containingTask.getExecutionConfig()),
+                                    fileSystem,
+                                    segmentLists[0]);
+                    while (dataCacheReader.hasNext()) {
+                        wrappedOperator.processElement(new StreamRecord(dataCacheReader.next()));
+                    }
+                }
+                segmentLists[0].clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                dataCacheWriters[0].addRecord(streamRecord.getValue());
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            wrappedOperator.processElement(streamRecord);
+        }
+    }

Review comment:
       Thanks Yun. I have finished the refactoring, please refer to `AbstractBroadcastWrapperOperator#processElementX()` and `AbstractBroadcastWrapperOperator#endInputX()` 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340268



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+

Review comment:
       Hi @gaoyunhaii , I am a bit confused why we need to check it *only wraps one operator*. 
   
   The signature of graph builder is `Function<List<DataStream<?>>, DataStream<OUT>>` and already assumes that only one operator is produced.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {

Review comment:
       Thanks @gaoyunhaii . I have remove this method and simply used `"broadcast-co-location-" + UUID.randomUUID()` as the co-location key.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r728936588



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();
+    /**
+     * We use lock because we want to enable `getBroadcastVariable(String)` in a TM with multiple
+     * slots here. Note that using ConcurrentHashMap is not enough since we need "contains and get
+     * in an atomic operation".
+     */
+    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+
+    public static void putBroadcastVariable(
+            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
+        lock.writeLock().lock();
+        try {
+            broadcastVariables.put(key, variable);
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * get the cached list with the given key.
+     *
+     * @param key
+     * @param <T>
+     * @return the cache list.
+     */
+    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> key) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            result = broadcastVariables.get(key).f1;
+        } finally {
+            lock.readLock().unlock();
+        }
+        return (List<T>) result;
+    }
+
+    /**
+     * get broadcast variables by name
+     *
+     * @param name
+     * @param <T>
+     * @return
+     */
+    public static <T> List<T> getBroadcastVariable(String name) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            for (Tuple2<String, Integer> nameAndPartitionId : broadcastVariables.keySet()) {
+                if (name.equals(nameAndPartitionId.f0) && isCacheFinished(nameAndPartitionId)) {
+                    result = broadcastVariables.get(nameAndPartitionId).f1;
+                    break;

Review comment:
       Hi Yunfeng, it is a great observation. I aggree that we can optimize here later and let slots in one TM share one copy of the broadcast variables.
   
   For now, partitionId is employed here because we need to clean up the cached variables. If we do not know partitionId, we can not decide when to decide the clean up for one variable. (Please checkout for CacheStreamOperator#endInput().




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732690085



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof PhysicalTransformation
+                        && resultStream.getTransformation() instanceof PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing and the only
+     * functionality of this operator is to cache all the input records in ${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];
         DraftExecutionEnvironment draftEnv =
                 new DraftExecutionEnvironment(
-                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocked));
 
         List<DataStream<?>> draftSources = new ArrayList<>();
         for (int i = 0; i < inputList.size(); i++) {
             draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
         }
         DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
-
+        Preconditions.checkState(
+                draftEnv.getStreamGraph(false).getStreamNodes().size() == 1 + inputList.size(),

Review comment:
       There is a parameter to decide wheather clear the transformations. Please checkout `StreamExecutionEnvironment#getStreamGraph(boolean clearTransformations)`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732760068



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -22,121 +22,168 @@
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
 import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable;
 import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.streaming.api.transformations.PhysicalTransformation;
+import org.apache.flink.util.AbstractID;
 import org.apache.flink.util.Preconditions;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import java.util.function.Function;
 
+/** Utility class to support withBroadcast in DataStream. */
 public class BroadcastUtils {
+    /**
+     * supports withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of an operator that implements ${@link HasBroadcastVariable}. An operator
+     * that wants to access broadcast variables must implement ${@link HasBroadcastVariable}.
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and further set by
+     * {@code HasBroadcastVariable.setBroadcastVariable(...)}. For now the non-broadcast input are
+     * cached by default to avoid the possible deadlocks.
+     *
+     * @param inputList non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream. Note that users can add only one
+     *     operator in this function, otherwise it raises an exception.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkArgument(inputList.size() > 0);
+
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        String[] broadcastNames = new String[bcStreams.size()];
+        DataStream<?>[] broadcastInputs = new DataStream[bcStreams.size()];
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[bcStreams.size()];
+        int idx = 0;
+        final String broadcastId = new AbstractID().toHexString();
+        for (String name : bcStreams.keySet()) {
+            broadcastNames[idx] = broadcastId + "-" + name;
+            broadcastInputs[idx] = bcStreams.get(name);
+            broadcastInTypes[idx] = broadcastInputs[idx].getType();
+            idx++;
+        }
 
+        DataStream<OUT> resultStream =
+                getResultStream(env, inputList, broadcastNames, userDefinedFunction);
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = "broadcast-co-location-" + UUID.randomUUID();
+        DataStream<OUT> cachedBroadcastInputs =
+                cacheBroadcastVariables(
+                        env,
+                        broadcastNames,
+                        broadcastInputs,
+                        broadcastInTypes,
+                        resultStream.getParallelism(),
+                        outType);
+
+        boolean canCoLocate =
+                cachedBroadcastInputs.getTransformation() instanceof PhysicalTransformation
+                        && resultStream.getTransformation() instanceof PhysicalTransformation;
+        if (canCoLocate) {
+            ((PhysicalTransformation) cachedBroadcastInputs.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+            ((PhysicalTransformation) resultStream.getTransformation())
+                    .setChainingStrategy(ChainingStrategy.HEAD);
+        } else {
+            throw new UnsupportedOperationException(
+                    "cannot set chaining strategy on "
+                            + cachedBroadcastInputs.getTransformation()
+                            + " and "
+                            + resultStream.getTransformation()
+                            + ".");
+        }
+        cachedBroadcastInputs.getTransformation().setCoLocationGroupKey(coLocationKey);
+        resultStream.getTransformation().setCoLocationGroupKey(coLocationKey);
+
+        return cachedBroadcastInputs.union(resultStream);
+    }
+
+    /**
+     * caches all broadcast iput data streams in static variables and returns the result multi-input
+     * stream operator. The result multi-input stream operator emits nothing and the only
+     * functionality of this operator is to cache all the input records in ${@link
+     * BroadcastContext}.
+     *
+     * @param env execution environment.
+     * @param broadcastInputNames names of the broadcast input data streams.
+     * @param broadcastInputs list of the broadcast data streams.
+     * @param broadcastInTypes output types of the broadcast input data streams.
+     * @param parallelism parallelism.
+     * @param outType output type.
+     * @param <OUT>
+     * @return the result multi-input stream operator.
+     */
     private static <OUT> DataStream<OUT> cacheBroadcastVariables(
             StreamExecutionEnvironment env,
-            Map<String, DataStream<?>> bcStreams,
+            String[] broadcastInputNames,
+            DataStream<?>[] broadcastInputs,
+            TypeInformation<?>[] broadcastInTypes,
+            int parallelism,
             TypeInformation<OUT> outType) {
-        int numBroadcastInput = bcStreams.size();
-        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
-        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
-        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
-        for (int i = 0; i < numBroadcastInput; i++) {
-            broadcastInTypes[i] = broadcastInputs[i].getType();
-        }
-
         MultipleInputTransformation<OUT> transformation =
                 new MultipleInputTransformation<OUT>(
                         "broadcastInputs",
                         new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
                         outType,
-                        env.getParallelism());
-        for (DataStream<?> dataStream : bcStreams.values()) {
+                        parallelism);
+        for (DataStream<?> dataStream : broadcastInputs) {
             transformation.addInput(dataStream.broadcast().getTransformation());
         }
         env.addOperator(transformation);
         return new MultipleConnectedStreams(env).transform(transformation);
     }
 
-    private static String getCoLocationKey(String[] broadcastNames) {
-        StringBuilder sb = new StringBuilder();
-        sb.append("Flink-ML-broadcast-co-location");
-        for (String name : broadcastNames) {
-            sb.append(name);
-        }
-        return sb.toString();
-    }
-
-    private static <OUT> DataStream<OUT> buildGraph(
+    /**
+     * uses {@link DraftExecutionEnvironment} to execute the userDefinedFunction and returns the
+     * resultStream.
+     *
+     * @param env execution environment.
+     * @param inputList non-broadcast input list.
+     * @param broadcastStreamNames names of the broadcast data streams.
+     * @param graphBuilder user-defined logic.
+     * @param <OUT> output type of the result stream.
+     * @return the result stream by applying user-defined logic on the input list.
+     */
+    private static <OUT> DataStream<OUT> getResultStream(
             StreamExecutionEnvironment env,
             List<DataStream<?>> inputList,
             String[] broadcastStreamNames,
             Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
         TypeInformation[] inTypes = new TypeInformation[inputList.size()];
         for (int i = 0; i < inputList.size(); i++) {
-            TypeInformation type = inputList.get(i).getType();
-            inTypes[i] = type;
+            inTypes[i] = inputList.get(i).getType();
         }
-        // blocking all non-broadcast input edges by default.
-        boolean[] isBlocking = new boolean[inTypes.length];
-        Arrays.fill(isBlocking, true);
+        // do not block all non-broadcast input edges by default.
+        boolean[] isBlocked = new boolean[inputList.size()];

Review comment:
       For better understanding, I have added an explicit initialization.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340242



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and

Review comment:
       Thanks. I have done the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();

Review comment:
       Thanks. I have done the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730339576



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       Hi Yun, thanks for the feedback. Yes, users have to make sure that the keys are unique.
   
   I have updated the implementation using id + name + subtaskIndex as the key.
   Please refer to `BroadastUtils#line71` and `HasBroadcastVariable` for details.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r727632118



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                for (IN ele : cache) {
+                    wrappedOperator.processElement(new StreamRecord<>(ele));
+                }
+                cache.clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                cache.add(streamRecord.getValue());

Review comment:
       I can see that this PR is trying to use this caching list to avoid fulfilling Flink's buffer, and the list is only stored in memory. I am worried that in case when the size of the cached records grows and exceeds the size of memory, this solution might cause java to throw exceptions and Flink job to fail.
   
   Shall we add some mechanism like follows to avoid this problem?
   
   - store part of the cached records on disk to avoid excess usage of memory.
   - check the size of cached records stored in memory and handle possible exceptions.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r728930466



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {

Review comment:
       Hi Yunfeng, Thanks for the review.
   
   Here isBlocking means that whether we can call processElement() immediately. When it is blocking, we have to cache the records (to avoid the possible deadlock) and process the elements later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340287



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);

Review comment:
       Hi @gaoyunhaii, I have made the changes.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#issuecomment-945667613


   > Very thanks @zhipeng93 for opening the PR! I have left some inline comments.
   > 
   > Besides, it seems we still do not provide an user interface for accessing the broadcast variable inside operator or UDF ?
   
   Hi Yun, I have updated the code and added `HasBroadcastVariable` interface. Please checkout for `org.apache.flink.ml.common.broadcast.operator.HasBroadcastVariable` for details.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730339576



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       Hi Yun, thanks for the feedback. Yes, users have to make sure that the keys are unique before.
   
   To solve this problem, I have updated the implementation using id + name + subtaskIndex as the key.
   Please refer to `BroadastUtils#line71` and `HasBroadcastVariable` for details.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] gaoyunhaii commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732437805



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+

Review comment:
       Users may have a.map(xx).map(xx), then we would have two operators. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340357



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);

Review comment:
       Thanks Yun. I changed `isBlocking` to `isBlocked` and reversed the logic.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r727632118



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                for (IN ele : cache) {
+                    wrappedOperator.processElement(new StreamRecord<>(ele));
+                }
+                cache.clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                cache.add(streamRecord.getValue());

Review comment:
       I can see that this PR is trying to use this caching list to avoid fulfilling Flink's buffer, and the list is only stored in memory. I am worried that in case when the size of the cached records grows and exceeds the size of memory, this solution might cause java to throw exceptions and Flink job to fail.
   
   Shall we add some mechanism like follows to avoid this problem?
   
   - store part of the cached records on disk to avoid excess usage of memory.
   - check the size of cached records stored in memory and handle possible exceptions.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>

Review comment:
       Could it be better to rename `CacheStreamOperator` to something like `BroadcastStreamOperator`? I find it a little bit hard to associate the current name with its functionality.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();
+    /**
+     * We use lock because we want to enable `getBroadcastVariable(String)` in a TM with multiple
+     * slots here. Note that using ConcurrentHashMap is not enough since we need "contains and get
+     * in an atomic operation".
+     */
+    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+
+    public static void putBroadcastVariable(
+            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
+        lock.writeLock().lock();
+        try {
+            broadcastVariables.put(key, variable);
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * get the cached list with the given key.
+     *
+     * @param key
+     * @param <T>
+     * @return the cache list.
+     */
+    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> key) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            result = broadcastVariables.get(key).f1;
+        } finally {
+            lock.readLock().unlock();
+        }
+        return (List<T>) result;
+    }
+
+    /**
+     * get broadcast variables by name
+     *
+     * @param name
+     * @param <T>
+     * @return
+     */
+    public static <T> List<T> getBroadcastVariable(String name) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            for (Tuple2<String, Integer> nameAndPartitionId : broadcastVariables.keySet()) {
+                if (name.equals(nameAndPartitionId.f0) && isCacheFinished(nameAndPartitionId)) {
+                    result = broadcastVariables.get(nameAndPartitionId).f1;
+                    break;

Review comment:
       If the cached broadcast variables are the same regardless of partitionId, I personally think that it might be unnecessary to store partitionId as part of the key of `broadcastVariables`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r728933410



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>

Review comment:
       I am ok with the renaming. Let's see how others think.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340195



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {

Review comment:
       Hmm, I have reverse the semantics here. Thanks! @yunfengzhou-hub @gaoyunhaii 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r732757633



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -18,106 +18,54 @@
 
 package org.apache.flink.ml.common.broadcast;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.java.tuple.Tuple2;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.concurrent.ConcurrentHashMap;
 
 public class BroadcastContext {
     /**
-     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
-     * is (isBroaddcastVariableReady, cacheList).
+     * stores broadcast data streams in a map. The key is broadcastName-partitionId and the value is
+     * (isBroadcastVariableReady, cacheList).
      */
-    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
-            new HashMap<>();
-    /**
-     * We use lock because we want to enable `getBroadcastVariable(String)` in a TM with multiple
-     * slots here. Note that using ConcurrentHashMap is not enough since we need "contains and get
-     * in an atomic operation".
-     */
-    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+    private static final Map<String, Tuple2<Boolean, List<?>>> BROADCAST_VARIABLES =
+            new ConcurrentHashMap<>();
 
-    public static void putBroadcastVariable(
-            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.put(key, variable);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void putBroadcastVariable(String key, Tuple2<Boolean, List<?>> variable) {
+        BROADCAST_VARIABLES.put(key, variable);
     }
 
     /**
-     * get the cached list with the given key.
-     *
-     * @param key
-     * @param <T>
-     * @return the cache list.
-     */
-    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> key) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            result = broadcastVariables.get(key).f1;
-        } finally {
-            lock.readLock().unlock();
-        }
-        return (List<T>) result;
-    }
-
-    /**
-     * get broadcast variables by name
+     * gets broadcast variables by name if this broadcast variable is fully cached.
      *
      * @param name
      * @param <T>
-     * @return
+     * @return the cache broadcast variable. Return null if it is not fully cached.
      */
+    @VisibleForTesting
     public static <T> List<T> getBroadcastVariable(String name) {
-        lock.readLock().lock();
-        List<?> result = null;
-        try {
-            for (Tuple2<String, Integer> nameAndPartitionId : broadcastVariables.keySet()) {
-                if (name.equals(nameAndPartitionId.f0) && isCacheFinished(nameAndPartitionId)) {
-                    result = broadcastVariables.get(nameAndPartitionId).f1;
-                    break;
-                }
-            }
-        } finally {
-            lock.readLock().unlock();
+        Tuple2<Boolean, List<?>> cacheReadyAndList = BROADCAST_VARIABLES.get(name);
+        if (cacheReadyAndList.f0) {
+            return (List<T>) cacheReadyAndList.f1;
         }
-        return (List<T>) result;
+        return null;
     }
 
-    public static void remove(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.remove(key);
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void remove(String key) {
+        BROADCAST_VARIABLES.remove(key);
     }
 
-    public static void markCacheFinished(Tuple2<String, Integer> key) {
-        lock.writeLock().lock();
-        try {
-            broadcastVariables.get(key).f0 = true;
-        } finally {
-            lock.writeLock().unlock();
-        }
+    @VisibleForTesting
+    public static void markCacheFinished(String key) {

Review comment:
       Thanks @gaoyunhaii ! I have updated the code also put operator's mailBoxExecutor in BroadcastContext. Please refer to `BroadcastContext#line50 & line71`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730340233



##########
File path: flink-ml-lib/pom.xml
##########
@@ -65,6 +71,44 @@ under the License.
       <artifactId>core</artifactId>
       <version>1.1.2</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>

Review comment:
       Thanks @gaoyunhaii . I have removed the dependency of `flink-statebackend-rocksdb`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       Thanks. I have done the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all

Review comment:
       Thanks. I have done the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and

Review comment:
       Thanks. I have done the change.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();

Review comment:
       Thanks. I have done the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730339576



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       Hi Yun, thanks for the feedback. Yes, users have to make sure that the keys are unique.
   
   For now, I am not clear how to avoid the possible conflict if users are using the same key for two `withBroadcast`. 
   
   If we are using id + name + subtaskIndex, how can we pass the `AbstractID` to the operator that uses broadcast variables?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] gaoyunhaii commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730201843



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all

Review comment:
       `support` -> `supports` , similarly the following comments should also use third person singular.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and

Review comment:
       blocking -> blocked.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       If possible, always use checkNotNull for object inputs. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());

Review comment:
       We should not use `env.getParallelism()`, we should use `resultStream.getParallelism()` directly. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       This should have problems: if there are multiple withBroadcast instance and they have use the same name for different input streams, they would override each other. 
   
   
   As a whole, we may have a unique `AbstractID` for each withBroadcast instance, and each variable should be keyed by id + name + subtaskIndex (We might have the optimization of shared variable in the following versions). 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[0].finishCurrentSegmentAndStartNewSegment();
+                segmentLists[0].addAll(dataCacheWriters[0].getNewlyFinishedSegments());
+                if (segmentLists[0].size() != 0) {
+                    DataCacheReader dataCacheReader =
+                            new DataCacheReader<>(
+                                    inTypes[0].createSerializer(
+                                            containingTask.getExecutionConfig()),
+                                    fileSystem,
+                                    segmentLists[0]);
+                    while (dataCacheReader.hasNext()) {
+                        wrappedOperator.processElement(new StreamRecord(dataCacheReader.next()));
+                    }
+                }
+                segmentLists[0].clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                dataCacheWriters[0].addRecord(streamRecord.getValue());
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            wrappedOperator.processElement(streamRecord);
+        }
+    }

Review comment:
       Perhaps we could extract a common method in the abstract base class ? 
   
   ```
   protected void xxx(int inputIndex, ThrowingConsumer consumer) {
         ...  
   }
   ```
   
   Then we could avoid repeat logic in all the processElementX by calling `xxxx(0, this::processElement)`. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);

Review comment:
       Do the separate calls of keySet() and values() always given the same order?  Perhaps we could create the three arrays in one traversal ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();

Review comment:
       Could be simplified to `inTypes[i] = inputList.get(i).getType()`

##########
File path: flink-ml-lib/pom.xml
##########
@@ -65,6 +71,44 @@ under the License.
       <artifactId>core</artifactId>
       <version>1.1.2</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>

Review comment:
       It seems we do not need to depends on `flink-statebackend-rocksdb` directly ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       Perhaps we could put the public method to the head of the file.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+

Review comment:
       Here we should need some method to check only one operator is added. we should not be able to support the case with multiple operators. 
   
   We may do the check inside the `BroadcastWrapper` by checking it only wraps one operator.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);

Review comment:
       Perhaps checkState -> checkArgument.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {

Review comment:
       I'm also have a reverse intuition here... Perhaps we could change it to `isBlocked` and reverse its semantics ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, Serializable {
+    /** names of the broadcast DataStreams. */
+    private final String[] broadcastNames;
+    /** input list of the multi-input operator. */
+    private final List<Input> inputList;
+    /** output types of input DataStreams. */
+    private final TypeInformation<?>[] inTypes;
+    /** caches of the broadcast inputs. */
+    private final List<?>[] caches;
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    public CacheStreamOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastNames.length);
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                Tuple2.of(broadcastNames[i - 1], getRuntimeContext().getIndexOfThisSubtask()));
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll((List) caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            Tuple2.of(
+                                    broadcastNames[i],
+                                    getRuntimeContext().getIndexOfThisSubtask()));
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =

Review comment:
       There is an utility  `OperatorStateUtils.getUniqueElement` to get the unique list element. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkState(inputList.size() > 0);
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        final String[] broadcastStreamNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<OUT> resultStream =
+                buildGraph(env, inputList, broadcastStreamNames, userDefinedFunction);
+
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = getCoLocationKey(broadcastStreamNames);
+        DataStream<OUT> cachedBroadcastInputs = cacheBroadcastVariables(env, bcStreams, outType);
+
+        for (int i = 0; i < inputList.size(); i++) {
+            inputList.get(i).getTransformation().setCoLocationGroupKey(coLocationKey);

Review comment:
       Here should be a bit problematic: for a node, we must declare the `co-location` key for the first operator in that node. We could not ensure the inputLists are always the head of one node.
   
   To overcome this problem, we may have to set the colocation key to the resultStream. Besides, we would need to ensures the wrapper operator has the chaining strategy of `HEAD`. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {

Review comment:
       I have some concerns here in that we may not ensure the name is unique if we have multiple withBroadcast with the same names, although it won't cause error, it might cause unnecessary co-location. Perhaps we could directly use a random id. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {

Review comment:
       @VisiableForTesting

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);

Review comment:
       I'm a bit confused on the name of `isBlocking`: why when it is true, we are in fact caching the records in the wrapper operator ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */

Review comment:
       name -> names ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       Also mark `broadcastVariables` as final and change the name to upper case.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r727636497



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();
+    /**
+     * We use lock because we want to enable `getBroadcastVariable(String)` in a TM with multiple
+     * slots here. Note that using ConcurrentHashMap is not enough since we need "contains and get
+     * in an atomic operation".
+     */
+    private static ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
+
+    public static void putBroadcastVariable(
+            Tuple2<String, Integer> key, Tuple2<Boolean, List<?>> variable) {
+        lock.writeLock().lock();
+        try {
+            broadcastVariables.put(key, variable);
+        } finally {
+            lock.writeLock().unlock();
+        }
+    }
+
+    /**
+     * get the cached list with the given key.
+     *
+     * @param key
+     * @param <T>
+     * @return the cache list.
+     */
+    public static <T> List<T> getBroadcastVariable(Tuple2<String, Integer> key) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            result = broadcastVariables.get(key).f1;
+        } finally {
+            lock.readLock().unlock();
+        }
+        return (List<T>) result;
+    }
+
+    /**
+     * get broadcast variables by name
+     *
+     * @param name
+     * @param <T>
+     * @return
+     */
+    public static <T> List<T> getBroadcastVariable(String name) {
+        lock.readLock().lock();
+        List<?> result = null;
+        try {
+            for (Tuple2<String, Integer> nameAndPartitionId : broadcastVariables.keySet()) {
+                if (name.equals(nameAndPartitionId.f0) && isCacheFinished(nameAndPartitionId)) {
+                    result = broadcastVariables.get(nameAndPartitionId).f1;
+                    break;

Review comment:
       If the cached broadcast variables are the same regardless of partitionId, I personally think that it might be unnecessary to store partitionId as part of the key of `broadcastVariables`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org