You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2021/10/17 00:41:16 UTC

[GitHub] [flink-ml] gaoyunhaii commented on a change in pull request #18: [FLINK-24279] Support withBroadcast in DataStream by caching in static variables

gaoyunhaii commented on a change in pull request #18:
URL: https://github.com/apache/flink-ml/pull/18#discussion_r730201843



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all

Review comment:
       `support` -> `supports` , similarly the following comments should also use third person singular.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and

Review comment:
       blocking -> blocked.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);
+        this.broadcastStreamNames = broadcastStreamNames;

Review comment:
       If possible, always use checkNotNull for object inputs. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());

Review comment:
       We should not use `env.getParallelism()`, we should use `resultStream.getParallelism()` directly. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       This should have problems: if there are multiple withBroadcast instance and they have use the same name for different input streams, they would override each other. 
   
   
   As a whole, we may have a unique `AbstractID` for each withBroadcast instance, and each variable should be keyed by id + name + subtaskIndex (We might have the optimization of shared variable in the following versions). 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {
+            if (areBroadcastVariablesReady()) {
+                dataCacheWriters[0].finishCurrentSegmentAndStartNewSegment();
+                segmentLists[0].addAll(dataCacheWriters[0].getNewlyFinishedSegments());
+                if (segmentLists[0].size() != 0) {
+                    DataCacheReader dataCacheReader =
+                            new DataCacheReader<>(
+                                    inTypes[0].createSerializer(
+                                            containingTask.getExecutionConfig()),
+                                    fileSystem,
+                                    segmentLists[0]);
+                    while (dataCacheReader.hasNext()) {
+                        wrappedOperator.processElement(new StreamRecord(dataCacheReader.next()));
+                    }
+                }
+                segmentLists[0].clear();
+                wrappedOperator.processElement(streamRecord);
+
+            } else {
+                dataCacheWriters[0].addRecord(streamRecord.getValue());
+            }
+
+        } else {
+            while (!areBroadcastVariablesReady()) {
+                mailboxExecutor.yield();
+            }
+            wrappedOperator.processElement(streamRecord);
+        }
+    }

Review comment:
       Perhaps we could extract a common method in the abstract base class ? 
   
   ```
   protected void xxx(int inputIndex, ThrowingConsumer consumer) {
         ...  
   }
   ```
   
   Then we could avoid repeat logic in all the processElementX by calling `xxxx(0, this::processElement)`. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);

Review comment:
       Do the separate calls of keySet() and values() always given the same order?  Perhaps we could create the three arrays in one traversal ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();

Review comment:
       Could be simplified to `inTypes[i] = inputList.get(i).getType()`

##########
File path: flink-ml-lib/pom.xml
##########
@@ -65,6 +71,44 @@ under the License.
       <artifactId>core</artifactId>
       <version>1.1.2</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients_${scala.binary.version}</artifactId>
+      <version>${flink.version}</version>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>

Review comment:
       It seems we do not need to depends on `flink-statebackend-rocksdb` directly ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(

Review comment:
       Perhaps we could put the public method to the head of the file.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+

Review comment:
       Here we should need some method to check only one operator is added. we should not be able to support the case with multiple operators. 
   
   We may do the check inside the `BroadcastWrapper` by checking it only wraps one operator.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {
+        this(broadcastStreamNames, inTypes, new boolean[inTypes.length]);
+    }
+
+    public BroadcastWrapper(
+            String[] broadcastStreamNames, TypeInformation[] inTypes, boolean[] isBlocking) {
+        Preconditions.checkState(inTypes.length == isBlocking.length);

Review comment:
       Perhaps checkState -> checkArgument.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/OneInputBroadcastWrapperOperator.java
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheReader;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.ml.iteration.datacache.nonkeyed.Segment;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+
+/** Wrapper for WithBroadcastOneInputStreamOperator. */
+public class OneInputBroadcastWrapperOperator<IN, OUT>
+        extends AbstractBroadcastWrapperOperator<OUT, OneInputStreamOperator<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private List<IN> cache;
+
+    public OneInputBroadcastWrapperOperator(
+            StreamOperatorParameters<OUT> parameters,
+            StreamOperatorFactory<OUT> operatorFactory,
+            String[] broadcastStreamNames,
+            TypeInformation[] inTypes,
+            boolean[] isBlocking) {
+        super(parameters, operatorFactory, broadcastStreamNames, inTypes, isBlocking);
+        this.cache = new ArrayList<>();
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+        if (isBlocking[0]) {

Review comment:
       I'm also have a reverse intuition here... Perhaps we could change it to `isBlocked` and reverse its semantics ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/CacheStreamOperator.java
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractInput;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Input;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/** The operator that process all broadcast inputs and stores them in {@link BroadcastContext}. */
+public class CacheStreamOperator<OUT> extends AbstractStreamOperatorV2<OUT>
+        implements MultipleInputStreamOperator<OUT>, BoundedMultiInput, Serializable {
+    /** names of the broadcast DataStreams. */
+    private final String[] broadcastNames;
+    /** input list of the multi-input operator. */
+    private final List<Input> inputList;
+    /** output types of input DataStreams. */
+    private final TypeInformation<?>[] inTypes;
+    /** caches of the broadcast inputs. */
+    private final List<?>[] caches;
+    /** state storage of the broadcast inputs. */
+    private ListState<?>[] cacheStates;
+    /** cacheReady state storage of the broadcast inputs. */
+    private ListState<Boolean>[] cacheReadyStates;
+
+    public CacheStreamOperator(
+            StreamOperatorParameters<OUT> parameters,
+            String[] broadcastNames,
+            TypeInformation<?>[] inTypes) {
+        super(parameters, broadcastNames.length);
+        this.broadcastNames = broadcastNames;
+        this.inTypes = inTypes;
+        this.caches = new List[inTypes.length];
+        for (int i = 0; i < inTypes.length; i++) {
+            caches[i] = new ArrayList<>();
+        }
+        this.cacheStates = new ListState[inTypes.length];
+        this.cacheReadyStates = new ListState[inTypes.length];
+
+        inputList = new ArrayList<>();
+        for (int i = 0; i < inTypes.length; i++) {
+            inputList.add(new ProxyInput(this, i + 1));
+        }
+    }
+
+    @Override
+    public List<Input> getInputs() {
+        return inputList;
+    }
+
+    @Override
+    public void endInput(int i) {
+        BroadcastContext.markCacheFinished(
+                Tuple2.of(broadcastNames[i - 1], getRuntimeContext().getIndexOfThisSubtask()));
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i].clear();
+            cacheStates[i].addAll((List) caches[i]);
+            cacheReadyStates[i].clear();
+            boolean isCacheFinished =
+                    BroadcastContext.isCacheFinished(
+                            Tuple2.of(
+                                    broadcastNames[i],
+                                    getRuntimeContext().getIndexOfThisSubtask()));
+            cacheReadyStates[i].add(isCacheFinished);
+        }
+    }
+
+    @Override
+    public void initializeState(StateInitializationContext context) throws Exception {
+        super.initializeState(context);
+        for (int i = 0; i < inTypes.length; i++) {
+            cacheStates[i] =

Review comment:
       There is an utility  `OperatorStateUtils.getUniqueElement` to get the unique list element. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);
+        DraftExecutionEnvironment draftEnv =
+                new DraftExecutionEnvironment(
+                        env, new BroadcastWrapper<>(broadcastStreamNames, inTypes, isBlocking));
+
+        List<DataStream<?>> draftSources = new ArrayList<>();
+        for (int i = 0; i < inputList.size(); i++) {
+            draftSources.add(draftEnv.addDraftSource(inputList.get(i), inputList.get(i).getType()));
+        }
+        DataStream<OUT> draftOutStream = graphBuilder.apply(draftSources);
+
+        draftEnv.copyToActualEnvironment();
+        DataStream<OUT> outStream = draftEnv.getActualStream(draftOutStream.getId());
+        return outStream;
+    }
+
+    /**
+     * Support withBroadcastStream in DataStream API. Broadcast data streams are available at all
+     * parallel instances of the input operators. A broadcast data stream is registered under a
+     * certain name and can be retrieved under that name via {@link
+     * BroadcastContext}.getBroadcastVariable(...).
+     *
+     * <p>In detail, the broadcast input data streams will be consumed first and cached as static
+     * variables in {@link BroadcastContext}. For now the non-broadcast input are blocking and
+     * cached to avoid the possible deadlocks.
+     *
+     * @param inputList the non-broadcast input list.
+     * @param bcStreams map of the broadcast data streams, where the key is the name and the value
+     *     is the corresponding data stream.
+     * @param userDefinedFunction the user defined logic in which users can access the broadcast
+     *     data streams and produce the output data stream.
+     * @param <OUT> type of the output data stream.
+     * @return the output data stream.
+     */
+    @PublicEvolving
+    public static <OUT> DataStream<OUT> withBroadcastStream(
+            List<DataStream<?>> inputList,
+            Map<String, DataStream<?>> bcStreams,
+            Function<List<DataStream<?>>, DataStream<OUT>> userDefinedFunction) {
+        Preconditions.checkState(inputList.size() > 0);
+        StreamExecutionEnvironment env = inputList.get(0).getExecutionEnvironment();
+        final String[] broadcastStreamNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<OUT> resultStream =
+                buildGraph(env, inputList, broadcastStreamNames, userDefinedFunction);
+
+        TypeInformation outType = resultStream.getType();
+        final String coLocationKey = getCoLocationKey(broadcastStreamNames);
+        DataStream<OUT> cachedBroadcastInputs = cacheBroadcastVariables(env, bcStreams, outType);
+
+        for (int i = 0; i < inputList.size(); i++) {
+            inputList.get(i).getTransformation().setCoLocationGroupKey(coLocationKey);

Review comment:
       Here should be a bit problematic: for a node, we must declare the `co-location` key for the first operator in that node. We could not ensure the inputLists are always the head of one node.
   
   To overcome this problem, we may have to set the colocation key to the resultStream. Besides, we would need to ensures the wrapper operator has the chaining strategy of `HEAD`. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {

Review comment:
       I have some concerns here in that we may not ensure the name is unique if we have multiple withBroadcast with the same names, although it won't cause error, it might cause unnecessary co-location. Perhaps we could directly use a random id. 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */
+    private final String[] broadcastStreamNames;
+    /** types of input DataStreams. */
+    private final TypeInformation[] inTypes;
+    /** whether each input is blocking or not. */
+    private final boolean[] isBlocking;
+
+    public BroadcastWrapper(String[] broadcastStreamNames, TypeInformation[] inTypes) {

Review comment:
       @VisiableForTesting

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastUtils.java
##########
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.ml.common.broadcast.operator.BroadcastWrapper;
+import org.apache.flink.ml.common.broadcast.operator.CacheStreamOperatorFactory;
+import org.apache.flink.ml.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.MultipleConnectedStreams;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.transformations.MultipleInputTransformation;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+public class BroadcastUtils {
+
+    private static <OUT> DataStream<OUT> cacheBroadcastVariables(
+            StreamExecutionEnvironment env,
+            Map<String, DataStream<?>> bcStreams,
+            TypeInformation<OUT> outType) {
+        int numBroadcastInput = bcStreams.size();
+        String[] broadcastInputNames = bcStreams.keySet().toArray(new String[0]);
+        DataStream<?>[] broadcastInputs = bcStreams.values().toArray(new DataStream<?>[0]);
+        TypeInformation<?>[] broadcastInTypes = new TypeInformation[numBroadcastInput];
+        for (int i = 0; i < numBroadcastInput; i++) {
+            broadcastInTypes[i] = broadcastInputs[i].getType();
+        }
+
+        MultipleInputTransformation<OUT> transformation =
+                new MultipleInputTransformation<OUT>(
+                        "broadcastInputs",
+                        new CacheStreamOperatorFactory<OUT>(broadcastInputNames, broadcastInTypes),
+                        outType,
+                        env.getParallelism());
+        for (DataStream<?> dataStream : bcStreams.values()) {
+            transformation.addInput(dataStream.broadcast().getTransformation());
+        }
+        env.addOperator(transformation);
+        return new MultipleConnectedStreams(env).transform(transformation);
+    }
+
+    private static String getCoLocationKey(String[] broadcastNames) {
+        StringBuilder sb = new StringBuilder();
+        sb.append("Flink-ML-broadcast-co-location");
+        for (String name : broadcastNames) {
+            sb.append(name);
+        }
+        return sb.toString();
+    }
+
+    private static <OUT> DataStream<OUT> buildGraph(
+            StreamExecutionEnvironment env,
+            List<DataStream<?>> inputList,
+            String[] broadcastStreamNames,
+            Function<List<DataStream<?>>, DataStream<OUT>> graphBuilder) {
+        TypeInformation[] inTypes = new TypeInformation[inputList.size()];
+        for (int i = 0; i < inputList.size(); i++) {
+            TypeInformation type = inputList.get(i).getType();
+            inTypes[i] = type;
+        }
+        // blocking all non-broadcast input edges by default.
+        boolean[] isBlocking = new boolean[inTypes.length];
+        Arrays.fill(isBlocking, true);

Review comment:
       I'm a bit confused on the name of `isBlocking`: why when it is true, we are in fact caching the records in the wrapper operator ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast.operator;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.ml.iteration.operator.OperatorWrapper;
+import org.apache.flink.streaming.api.operators.MultipleInputStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+/** The operator wrapper for broadcast wrappers. */
+public class BroadcastWrapper<T> implements OperatorWrapper<T, T> {
+    /** name of the broadcast DataStreams. */

Review comment:
       name -> names ? 

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/BroadcastContext.java
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.broadcast;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+public class BroadcastContext {
+    /**
+     * Store broadcast DataStreams in a Map. The key is (broadcastName, partitionId) and the value
+     * is (isBroaddcastVariableReady, cacheList).
+     */
+    private static Map<Tuple2<String, Integer>, Tuple2<Boolean, List<?>>> broadcastVariables =
+            new HashMap<>();

Review comment:
       Also mark `broadcastVariables` as final and change the name to upper case.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org