You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/19 03:09:19 UTC

[GitHub] [flink-ml] weibozhao opened a new pull request, #86: [FLINK-27294] Add Transformer for EvalBinaryClass

weibozhao opened a new pull request, #86:
URL: https://github.com/apache/flink-ml/pull/86

   Add Transformer for EvalBinaryClass


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r874281393


##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java:
##########
@@ -100,4 +100,22 @@ public boolean validate(T value) {
     public static <T> ParamValidator<T[]> nonEmptyArray() {
         return value -> value != null && value.length > 0;
     }
+
+    // Check if every element in the array-typed parameter value is in the array of allowed values.

Review Comment:
   nit: Check -> Checks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r874284646


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,724 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See {@link BinaryClassificationEvaluatorParams}.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final Logger LOG = LoggerFactory.getLogger(AbstractWrapperOperator.class);
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        DataStream<Tuple3<Double, Boolean, Double>> sortEvalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                sortEvalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */

Review Comment:
   nit: The comment should ends with `.`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();

Review Comment:
   nit: public -> private



##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java:
##########
@@ -100,4 +100,22 @@ public boolean validate(T value) {
     public static <T> ParamValidator<T[]> nonEmptyArray() {
         return value -> value != null && value.length > 0;
     }
+
+    // Check if every element in the array-typed parameter value is in the array of allowed values.

Review Comment:
   nit: Check -> Checks



##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java:
##########
@@ -100,4 +100,22 @@ public boolean validate(T value) {
     public static <T> ParamValidator<T[]> nonEmptyArray() {
         return value -> value != null && value.length > 0;
     }
+
+    // Check if every element in the array-typed parameter value is in the array of allowed values.
+    public static <T> ParamValidator<T[]> isSubArray(T... allowed) {

Review Comment:
   nit: `subArray` seems a bit confusing here since `subArray` indicates that the order should also be the same with the `allowed`, which is not true. How about `isSubSet`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867679164


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";

Review Comment:
   Got it. Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r857051925


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)

Review Comment:
   I noticed that in several places we have firstly used an operation for each subtask or partition, followed by a global reduce/aggregation. I believe improving map-reduce's operations' efficiency in this way is globally applicable to all Flink operators, such that Flink should have implemented relevant mechanism for it, and we can directly use `reduce()`/`aggregate()` without implementing it on our own. Could you please help check it?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {

Review Comment:
   Shall we move this logic to the `orElse(null)` used in `initializeState`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {
+                summary =
+                        new BinarySummary(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                -Double.MAX_VALUE,
+                                0,
+                                0);
+            }
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    private static class CalcAucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> aucVariableState;
+        private double[] aucVariable;
+
+        @Override
+        public void endInput() {
+            if (aucVariable != null) {
+                output.collect(new StreamRecord<>(aucVariable));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            if (aucVariable == null) {
+                aucVariable = streamRecord.getValue();
+            } else {
+                double[] tmpAucVar = streamRecord.getValue();
+                aucVariable[0] += tmpAucVar[0];
+                aucVariable[1] += tmpAucVar[1];
+                aucVariable[2] += tmpAucVar[2];
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            aucVariableState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "aucState", TypeInformation.of(double[].class)));
+            aucVariable =
+                    OperatorStateUtils.getUniqueElement(aucVariableState, "aucState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            aucVariableState.clear();
+            if (aucVariable != null) {
+                aucVariableState.add(aucVariable);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(PARTITION_SUMMARY);
+            Tuple2<Boolean, long[]> t =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+            long[] countValues = t.f1;
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                System.out.println("There is no positive sample in data!");
+            }
+            if (totalFalse == 0) {
+                System.out.println("There is no negative sample in data!");
+            }
+
+            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            double[] tprFprPrecision = new double[4];
+            for (Tuple3<Double, Boolean, Double> t3 : iterable) {
+                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+            }
+            collector.collect(metrics);
+        }
+    }
+
+    private static void updateBinaryMetrics(
+            Tuple3<Double, Boolean, Double> cur,
+            BinaryMetrics binaryMetrics,
+            long[] countValues,
+            double[] recordValues) {
+        if (binaryMetrics.count == 0) {
+            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+            recordValues[2] =
+                    countValues[0] + countValues[1] == 0
+                            ? 1.0
+                            : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+            recordValues[3] =
+                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+        }
+
+        binaryMetrics.count++;
+        if (cur.f1) {
+            countValues[0]++;
+        } else {
+            countValues[1]++;
+        }
+
+        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+        double precision =
+                countValues[0] + countValues[1] == 0
+                        ? 1.0
+                        : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+        double positiveRate =
+                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+
+        binaryMetrics.areaUnderLorenz +=
+                ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2);
+        binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2);
+        binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks);
+
+        recordValues[0] = tpr;
+        recordValues[1] = fpr;
+        recordValues[2] = precision;
+        recordValues[3] = positiveRate;

Review Comment:
   Shall we define a class to represent the metric results? It would bring better readability than remembering the meaning of each metric by their index in a double array.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(

Review Comment:
   It seems that we do not need `new ArrayList<>()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<

Review Comment:
   Implementations like this brings the risk of OOM, as data allocated to a single subtask can still be larger than memory space. It might be unnecessary to use `apply()` here, and we can consider using `reduce()` or `aggregate()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })

Review Comment:
   Shall we move anonymous functions like this into private static classes, give them meaningful class names and JavaDocs? That might help improve readability.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                            Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                            Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {

Review Comment:
   The order of test cases in this class and the naming convention seems to be different from other existing operators. Shall we follow the test style in other classes?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {
+                summary =
+                        new BinarySummary(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                -Double.MAX_VALUE,
+                                0,
+                                0);
+            }
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    private static class CalcAucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> aucVariableState;
+        private double[] aucVariable;
+
+        @Override
+        public void endInput() {
+            if (aucVariable != null) {
+                output.collect(new StreamRecord<>(aucVariable));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            if (aucVariable == null) {
+                aucVariable = streamRecord.getValue();
+            } else {
+                double[] tmpAucVar = streamRecord.getValue();
+                aucVariable[0] += tmpAucVar[0];
+                aucVariable[1] += tmpAucVar[1];
+                aucVariable[2] += tmpAucVar[2];
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            aucVariableState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "aucState", TypeInformation.of(double[].class)));
+            aucVariable =
+                    OperatorStateUtils.getUniqueElement(aucVariableState, "aucState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            aucVariableState.clear();
+            if (aucVariable != null) {
+                aucVariableState.add(aucVariable);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(PARTITION_SUMMARY);
+            Tuple2<Boolean, long[]> t =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+            long[] countValues = t.f1;
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                System.out.println("There is no positive sample in data!");

Review Comment:
   It might be better to avoid using `println`, instead, we can use `LOG` to record such information.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                            Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                            Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testTransform() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        Table evalResult = eval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);

Review Comment:
   Shall we also check the output column names, instead of just the value at each index?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);

Review Comment:
   I noticed that `areaUnderROC` is calculated separately from other metrics. Given that all metrics experience similarly to be computed, like a `mapPartition` or `window().apply()`, I have the sense that we may have just one window operator or map partition operator that gets all metrics out. That would greatly simplify the structure of this code. How do you like the idea to try in this way?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r858231618


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                            Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                            Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {

Review Comment:
   I have refine the naming format. This algo is different from other algo, then the type also has some different properties.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r859473199


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        Table evalResult = eval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithMultiScore() throws Exception {

Review Comment:
   `AccumulateMultiScore` will process the samples with same score, if the samples has same score more than one, then the action is different.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867595652


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));

Review Comment:
   If this code removed, algorithm will get an err. I do this as CongZhou told me.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867592350


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;

Review Comment:
   New test for this parameters is not need. This parameter is just a inner parameter for partitionSort.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867784038


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassificationevaluator;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    String AREA_UNDER_ROC = "areaUnderROC";
+    String AREA_UNDER_PR = "areaUnderPR";
+    String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    String KS = "ks";
+
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     *
+     * <ul>
+     *   <li>areaUnderROC: the area under the receiver operating characteristic (ROC) curve.
+     *   <li>areaUnderPR: the area under the precision-recall curve.
+     *   <li>ks: Kolmogorov-Smirnov, measures the ability of the model to separate positive and
+     *       negative samples.
+     *   <li>areaUnderLorenz: the area under the lorenz curve.
+     * </ul>
+     */
+    Param<String[]> METRICS_NAMES =
+            new StringArrayParam(
+                    "metricsNames",
+                    "Names of output metrics. The array element must be 'areaUnderROC', 'areaUnderPR', 'ks' and 'areaUnderLorenz'",

Review Comment:
   nit: it might be better to remove "The array element ...", since other array-typed parameters, like `HasHandleInvalid` or `HasBatchStrategy` does not contain this sentence.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassificationevaluator;

Review Comment:
   nit: Maybe `binaryclassification` makes a better package name, unless we have something like `binaryclassificationregressor` or `binaryclassificationscaler`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassificationevaluator;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    String AREA_UNDER_ROC = "areaUnderROC";
+    String AREA_UNDER_PR = "areaUnderPR";
+    String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    String KS = "ks";
+
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and

Review Comment:
   nit: Javadocs should start with Uppercase letter.
   
   Besides, "metric names in evaluation" might have ambiguity. "Param for supported metric names" or `Param for supported metric names in binary classification evaluation" looks good for me.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;

Review Comment:
   If `BinaryClassificationEvaluator` is in `org.apache.flink.ml.evaluation.binaryclassificationevaluator`, `BinaryClassificationEvaluatorTest` should be in `org.apache.flink.ml.evaluation.binaryclassificationevaluator` or `org.apache.flink.ml.evaluation`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {

Review Comment:
   Spark BinaryClassification algorithm have test cases for the following situations. Could you please add corresponding test cases?
   - `rawPredictionCol`'s data type is double
   - `label`'s data type is any possible numeric type, like int, shorts or decimal.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,730 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassificationevaluator;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAreaUnderROCVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                DataStreamUtils.reduce(
+                        localAreaUnderROCVariable,
+                        (ReduceFunction<double[]>)
+                                (t1, t2) -> {
+                                    t2[0] += t1[0];
+                                    t2[1] += t1[1];
+                                    t2[2] += t1[2];
+                                    return t2;
+                                });
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        Arrays.fill(metricTypes, Types.DOUBLE());
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating AreaUnderROC. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;

Review Comment:
   score is not checkpointed or restored. When there is failover, stream records of different scores might be accumulated in the same `accValue`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")

Review Comment:
   nit: It might be better to use the constants we have already defined. For example, `setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)`



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);

Review Comment:
   I tried changing this to `env.setParallelism(20);`, and these test cases failed. Could you please fix this problem and add a test similar to `LinearRegressionTest.testMoreSubtaskThanData`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassificationevaluator/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,730 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassificationevaluator;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAreaUnderROCVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                DataStreamUtils.reduce(
+                        localAreaUnderROCVariable,
+                        (ReduceFunction<double[]>)
+                                (t1, t2) -> {
+                                    t2[0] += t1[0];
+                                    t2[1] += t1[1];
+                                    t2[2] += t1[2];
+                                    return t2;
+                                });
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        Arrays.fill(metricTypes, Types.DOUBLE());
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating AreaUnderROC. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;
+            } else if (score != t.f0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+                Arrays.fill(accValue, 0.0);
+            }
+            accValue[0] += t.f1;
+            accValue[1] += 1.0;
+            if (t.f2) {
+                accValue[2] += t.f3;
+            } else {
+                accValue[3] += t.f3;
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accValueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "accValueState", TypeInformation.of(double[].class)));
+            accValue =
+                    OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accValueState.clear();
+            if (accValue != null) {
+                accValueState.add(accValue);
+            }
+        }
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState")
+                            .orElse(
+                                    new BinarySummary(
+                                            getRuntimeContext().getIndexOfThisSubtask(),
+                                            -Double.MAX_VALUE,
+                                            0,
+                                            0));
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private final String partitionSummariesKey;
+
+        public CalcBinaryMetrics(String partitionSummariesKey) {
+            this.partitionSummariesKey = partitionSummariesKey;
+        }
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
+            long[] countValues =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                LOG.warn("There is no positive sample in data!");
+            }
+            if (totalFalse == 0) {
+                LOG.warn("There is no negative sample in data!");
+            }
+
+            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            double[] tprFprPrecision = new double[4];
+            for (Tuple3<Double, Boolean, Double> t3 : iterable) {
+                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+            }
+            collector.collect(metrics);
+        }
+    }
+
+    private static void updateBinaryMetrics(
+            Tuple3<Double, Boolean, Double> cur,
+            BinaryMetrics binaryMetrics,
+            long[] countValues,
+            double[] recordValues) {
+        if (binaryMetrics.count == 0) {
+            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+            recordValues[2] =
+                    countValues[0] + countValues[1] == 0
+                            ? 1.0
+                            : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+            recordValues[3] =
+                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+        }
+
+        binaryMetrics.count++;
+        if (cur.f1) {
+            countValues[0]++;
+        } else {
+            countValues[1]++;
+        }
+
+        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+        double precision =
+                countValues[0] + countValues[1] == 0
+                        ? 1.0
+                        : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+        double positiveRate =
+                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+
+        binaryMetrics.areaUnderLorenz +=
+                ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2);
+        binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2);
+        binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks);
+
+        recordValues[0] = tpr;
+        recordValues[1] = fpr;
+        recordValues[2] = precision;
+        recordValues[3] = positiveRate;
+    }
+
+    /**
+     * For each sample, calculates its score order among all samples. The sample with minimum score
+     * has order 1, while the sample with maximum score has order samples.
+     *
+     * <p>Input is a dataset of tuple (score, is real positive, wight), output is a dataset of tuple

Review Comment:
   "weight".



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");

Review Comment:
   `areaUnderLorenz` is not tested in these test cases yet.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {

Review Comment:
   Maybe we can dive deeper into performance optimization chances later after this PR. Please feel free to close this conversation for now.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;

Review Comment:
   Got it.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassificationevaluator.BinaryClassificationEvaluator;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.8571428571428571, 0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(inputDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());

Review Comment:
   nit: `List<Row> results = IteratorUtils.toList(evalResult.execute().collect());` is enough. `tEnv.toDataStream` is unnecessary.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r857052087


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })

Review Comment:
   Shall we move anonymous functions like this into private static classes and add JavaDocs? That might help improve readability.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r858242365


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);

Review Comment:
   AreaUnderROC's calculation is different from others. It need keyed by score and do some calculation, other metrics just do some statistics of the input data.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r873579828


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,721 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;

Review Comment:
   We should use `org.slf4j.Logger` directly. `AbstractBroadcastWrapperOperator` or `Benchmark` could serve as an example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r869795956


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);

Review Comment:
   Thanks for fixing the problem here. Could we also add a test to verify that the algorithm can still work even if parallelism is larger than number of data? `LogisticRegressionTest.testMoreSubtaskThanData` is a good example. Or reviewers would have to manually change this to `env.setParallelism(20)` and re run the tests.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    String AREA_UNDER_ROC = "areaUnderROC";
+    String AREA_UNDER_PR = "areaUnderPR";
+    String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    String KS = "ks";
+
+    /**
+     * Param for supported metric names in binary classification evaluation (supports
+     * 'areaUnderROC', 'areaUnderPR', 'ks' and 'areaUnderLorenz').
+     *
+     * <ul>
+     *   <li>areaUnderROC: the area under the receiver operating characteristic (ROC) curve.
+     *   <li>areaUnderPR: the area under the precision-recall curve.
+     *   <li>ks: Kolmogorov-Smirnov, measures the ability of the model to separate positive and
+     *       negative samples.
+     *   <li>areaUnderLorenz: the area under the lorenz curve.
+     * </ul>
+     */
+    Param<String[]> METRICS_NAMES =
+            new StringArrayParam(
+                    "metricsNames",
+                    "Names of output metrics, which may contains 'areaUnderROC', 'areaUnderPR', 'ks' or 'areaUnderLorenz'",

Review Comment:
   nit: We don't need to present possible values in the description string. "Names of output metrics." is enough. Please check `HasHandleInvalid` as an example.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW))
+                        .as("label", "rawPrediction");
+
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)

Review Comment:
   Thanks for making this change. There are some other places that have directly used the string values like ”"areaUnderRoc" in BinaryClassificationEvaluatorTest. Shall we replace all of them into constants like `AREA_UNDER_ROC`?
   
   nit: Besides, statically importing the constants might make code looks better.
   ```java
   
   import static org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams.AREA_UNDER_ROC;
   ...
   .setMetricsNames(AREA_UNDER_ROC)
   ```



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW))
+                        .as("label", "rawPrediction");
+
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(inputDataTable)[0];
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        Table evalResult = eval.transform(inputDataTable)[0];
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithDoubleRaw() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        Table evalResult = eval.transform(inputDataTableScore)[0];
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithMultiScore() throws Exception {

Review Comment:
   nit: Shall we check and clean up warnings in `BinaryClassificationEvaluator` and `BinaryClassificationEvaluatorTest`? For example, Exception is never thrown in this method.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW))
+                        .as("label", "rawPrediction");
+
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(inputDataTable)[0];
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        Table evalResult = eval.transform(inputDataTable)[0];
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithDoubleRaw() throws Exception {

Review Comment:
   Thanks for adding this test. In Spark's BinaryClassification algorithm, it also tests when label's data type is any possible numeric type, like int, shorts or decimal. Could we also add these tests?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,742 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.

Review Comment:
   nit: `{@link BinaryClassificationEvaluatorParams}`. `parameter MetricsNames` might be outdated.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW))
+                        .as("label", "rawPrediction");
+
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {

Review Comment:
   nit: Shall we adjust the order of these test cases to follow other classes? For example, First is `testParam`, followed by `testOutputSchema`, `testEvaluate` and then save/load and other special cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,742 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 50;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAreaUnderROCVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                DataStreamUtils.reduce(
+                        localAreaUnderROCVariable,
+                        (ReduceFunction<double[]>)
+                                (t1, t2) -> {
+                                    t2[0] += t1[0];
+                                    t2[1] += t1[1];
+                                    t2[2] += t1[2];
+                                    return t2;
+                                });
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        Arrays.fill(metricTypes, Types.DOUBLE());
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating AreaUnderROC. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        private ListState<Double> scoreState;
+
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;
+            } else if (score != t.f0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+                Arrays.fill(accValue, 0.0);
+            }
+            accValue[0] += t.f1;
+            accValue[1] += 1.0;
+            if (t.f2) {
+                accValue[2] += t.f3;
+            } else {
+                accValue[3] += t.f3;
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accValueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "accValueState", TypeInformation.of(double[].class)));
+            accValue =
+                    OperatorStateUtils.getUniqueElement(accValueState, "accValueState")
+                            .orElse(null);
+
+            scoreState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "scoreState", TypeInformation.of(Double.class)));
+            score = OperatorStateUtils.getUniqueElement(scoreState, "scoreState").orElse(0.0);

Review Comment:
   score has not been saved in snapshot yet before restoring from operator state.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r869841348


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluator;
+import org.apache.flink.ml.evaluation.binaryclassfication.BinaryClassificationEvaluatorParams;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+    private Table inputDataTableScore;
+    private Table inputDataTableWithMultiScore;
+    private Table inputDataTableWithWeight;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_DOUBLE_RAW =
+            Arrays.asList(
+                    Row.of(1, 0.9),
+                    Row.of(1, 0.8),
+                    Row.of(1, 0.7),
+                    Row.of(0, 0.75),
+                    Row.of(0, 0.6),
+                    Row.of(1, 0.65),
+                    Row.of(1, 0.55),
+                    Row.of(0, 0.4),
+                    Row.of(0, 0.3),
+                    Row.of(1, 0.35),
+                    Row.of(0, 0.2),
+                    Row.of(1, 0.1));
+
+    private static final List<Row> INPUT_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> INPUT_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {
+                0.8571428571428571, 0.9377705627705628, 0.8571428571428571, 0.6488095238095237
+            };
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+    private static final double EPS = 1.0e-5;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(3);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputDataTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("label", "rawPrediction");
+        inputDataTableScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_DOUBLE_RAW))
+                        .as("label", "rawPrediction");
+
+        inputDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_MULTI_SCORE))
+                        .as("label", "rawPrediction");
+        inputDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA_WITH_WEIGHT))
+                        .as("label", "rawPrediction", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames(BinaryClassificationEvaluatorParams.AREA_UNDER_ROC)
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(inputDataTable)[0];
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "ks", "areaUnderROC");
+        Table evalResult = eval.transform(inputDataTable)[0];
+        List<Row> results = IteratorUtils.toList(evalResult.execute().collect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "ks", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), EPS);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithDoubleRaw() throws Exception {

Review Comment:
   I have add test for label‘s data type is int or double. I think it's not needed test all the types. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867591071


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));

Review Comment:
   I think there is no way to avoid sorting action.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867593318


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {

Review Comment:
   The input data will be applied two actions: sort action and statistics action. Sort action need buffer data but statistics not need. I  don't think merging these to one may improving performance.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r868690036


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {

Review Comment:
   OK, I will add these test case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 merged pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
lindong28 merged PR #86:
URL: https://github.com/apache/flink-ml/pull/86


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867680909


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));

Review Comment:
   I tried removing this code, but all `BinaryClassificationEvaluatorTest` cases still passed. If this code is irreplaceable, could you please add tests to verify its function?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
lindong28 commented on PR #86:
URL: https://github.com/apache/flink-ml/pull/86#issuecomment-1127537152

   Thanks for the update. LGTM. Will merge this PR after all comments are addressed.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r859468461


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The output may contain different metrics
+ * which will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator

Review Comment:
   I have support these two format rawPrediction



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r858175007


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     */
+    Param<String[]> METRICS_NAMES =
+            new StringArrayParam(
+                    "metricsNames",
+                    "Names of output metrics. The array element must be 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz'",
+                    new String[] {"areaUnderROC", "areaUnderPR"},
+                    ParamValidators.nonEmptyArray());

Review Comment:
   I agree with it. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867595098


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";

Review Comment:
   I use a static final variable last version. But LinDong suggests use local variable just as LogisticRegerssion done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867674748


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {

Review Comment:
   The evaluation of data which all labels are 1.0 or 0.0 has no meaning. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r869958835


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));

Review Comment:
   I have remove this code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867650416


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)

Review Comment:
   some value in doubleArray has no meaning, just a middle value in calculating AUC. If you want to know the meaning of variables, you can read the code which calculate these variables.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r858483138


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)

Review Comment:
   I have replace two map-reduce places with aggregate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r858335622


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {
+                summary =
+                        new BinarySummary(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                -Double.MAX_VALUE,
+                                0,
+                                0);
+            }
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    private static class CalcAucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> aucVariableState;
+        private double[] aucVariable;
+
+        @Override
+        public void endInput() {
+            if (aucVariable != null) {
+                output.collect(new StreamRecord<>(aucVariable));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            if (aucVariable == null) {
+                aucVariable = streamRecord.getValue();
+            } else {
+                double[] tmpAucVar = streamRecord.getValue();
+                aucVariable[0] += tmpAucVar[0];
+                aucVariable[1] += tmpAucVar[1];
+                aucVariable[2] += tmpAucVar[2];
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            aucVariableState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "aucState", TypeInformation.of(double[].class)));
+            aucVariable =
+                    OperatorStateUtils.getUniqueElement(aucVariableState, "aucState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            aucVariableState.clear();
+            if (aucVariable != null) {
+                aucVariableState.add(aucVariable);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(PARTITION_SUMMARY);
+            Tuple2<Boolean, long[]> t =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+            long[] countValues = t.f1;
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                System.out.println("There is no positive sample in data!");
+            }
+            if (totalFalse == 0) {
+                System.out.println("There is no negative sample in data!");
+            }
+
+            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            double[] tprFprPrecision = new double[4];
+            for (Tuple3<Double, Boolean, Double> t3 : iterable) {
+                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+            }
+            collector.collect(metrics);
+        }
+    }
+
+    private static void updateBinaryMetrics(
+            Tuple3<Double, Boolean, Double> cur,
+            BinaryMetrics binaryMetrics,
+            long[] countValues,
+            double[] recordValues) {
+        if (binaryMetrics.count == 0) {
+            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+            recordValues[2] =
+                    countValues[0] + countValues[1] == 0
+                            ? 1.0
+                            : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+            recordValues[3] =
+                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+        }
+
+        binaryMetrics.count++;
+        if (cur.f1) {
+            countValues[0]++;
+        } else {
+            countValues[1]++;
+        }
+
+        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+        double precision =
+                countValues[0] + countValues[1] == 0
+                        ? 1.0
+                        : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+        double positiveRate =
+                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+
+        binaryMetrics.areaUnderLorenz +=
+                ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2);
+        binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2);
+        binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks);
+
+        recordValues[0] = tpr;
+        recordValues[1] = fpr;
+        recordValues[2] = precision;
+        recordValues[3] = positiveRate;

Review Comment:
   I have defined a class (BinaryMetrics) to represent the metrics results. This double array stores the middle values which is used to  calculate the metrics results.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r874281167


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassfication/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,724 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryclassfication;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * An Estimator which calculates the evaluation metrics for binary classification. The input data
+ * has columns rawPrediction, label and an optional weight column. The rawPrediction can be of type
+ * double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of
+ * raw predictions, scores, or label probabilities). The output may contain different metrics which
+ * will be defined by parameter MetricsNames. See {@link BinaryClassificationEvaluatorParams}.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final Logger LOG = LoggerFactory.getLogger(AbstractWrapperOperator.class);

Review Comment:
   nit: should be BinaryClassificationEvaluator.class



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867684157


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));

Review Comment:
   Got it. Maybe we could do possible optimizations after this PR in the next phase.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867592164


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;

Review Comment:
   This variable used by partitionSort, just used to repartition the data to workers. But every worker does sorting action locally. 100 just a empirical value.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r864421566


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    String AREA_UNDER_ROC = "areaUnderROC";
+    String AREA_UNDER_PR = "areaUnderPR";
+    String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    String KS = "KS";
+
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and

Review Comment:
   nit: Javadocs should start with Uppercase letters.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));

Review Comment:
   Trying to store every record in a stream in a `List` could easily cause OOM problems, which means code blocks like this would soon be removed in Flink ML 2.1's release plan to optimize performance in the next few weeks. Could you please help to find out way to avoid storing or sorting all records totally in memory? I would try to optimize `mapPartition` code so the infrastructure would not be a concern.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {

Review Comment:
   I noticed that the implementation of this method composes a relatively complicated JobGraph. Is it possible to simplify the JobGraph's structure? For example, do we have to sort all records before proceeding to the next step? Is it a must for us to get areaUnderRoc before computing all other metrics? Can we try to avoid using `withBroadcast` and reduce operations multiple times?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> valueState;
+        private double[] value;
+
+        @Override
+        public void endInput() {
+            if (value != null) {
+                output.collect(new StreamRecord<>(value));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            double[] tmpValues = streamRecord.getValue();
+            value[0] += tmpValues[0];
+            value[1] += tmpValues[1];
+            value[2] += tmpValues[2];
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            valueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valueState", TypeInformation.of(double[].class)));
+            value =
+                    OperatorStateUtils.getUniqueElement(valueState, "valueState")
+                            .orElse(new double[3]);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            valueState.clear();
+            if (value != null) {
+                valueState.add(value);
+            }
+        }
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;
+            } else if (score != t.f0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+                Arrays.fill(accValue, 0.0);
+            }
+            accValue[0] += t.f1;
+            accValue[1] += 1.0;
+            if (t.f2) {
+                accValue[2] += t.f3;
+            } else {
+                accValue[3] += t.f3;
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accValueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "accValueState", TypeInformation.of(double[].class)));
+            accValue =
+                    OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accValueState.clear();
+            if (accValue != null) {
+                accValueState.add(accValue);
+            }
+        }
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState")
+                            .orElse(
+                                    new BinarySummary(
+                                            getRuntimeContext().getIndexOfThisSubtask(),
+                                            -Double.MAX_VALUE,
+                                            0,
+                                            0));
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();

Review Comment:
   Would it be better to use `Map<String, Double>` from the beginning, instead of introducing a new `BinaryMetrics` class?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {

Review Comment:
   Can we add more test cases about the corner cases? For example, all labels are 1.0 or 0.0, or label values are independent from the rawPrediction results. For these corner cases we can generate test data like follows
   ```java
   final List<Row> INPUT_DATA = new ArrayList<>();
   for (double i = 0.1; i < 1.0; i += 0.001) {
       INPUT_DATA.add(Row.of(i < 0.9? 1.0: 0.0, Vectors.dense(i, 1.0 - i)));
   }
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;

Review Comment:
   According to our offline discussion, this algorithm would not support `numBins` given that it would cause large deviations as pointed out by [tensorflow's issue](https://github.com/tensorflow/tensorflow/issues/14834). But so far as I can see, this `NUM_SAMPLE_FOR_RANGE_PARTITION` achieves similar function that could also cause the error. Could you please illustrate the difference between `numBins` and this variable?
   
   Besides, why should we choose a fixed `100` as the value of this variable, given that the scale of input data varies? Should we also add tests where the number of train data is larger than `the number of samples * the parallelism of the subtasks`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns

Review Comment:
   The first sentence of the Javadoc for an Estimator/AlgoOperator class usually starts with a noun. It would be better to follow the practice of existing Javadocs.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)

Review Comment:
   It might be hard for reviewers to remember the meaning of `value[0]` and `value[1]`. In order to improve readability of this PR, do you think it would be better to add more JavaDocs to explain the meaning of each element in double arrays and Tuple objects, or provide a meaningful variable name to each of such elements?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";

Review Comment:
   nit: It might be better to define these broadcast keys as static finals of the class, instead of defining them as local variables and passing them to operator's constructors.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();

Review Comment:
   nit: `ArrayUtils.addAll` might be better.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));

Review Comment:
   Watermark seems not used in the rest of this method. We can remove code like this.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);

Review Comment:
   In LinearRegression's PR, `DataStreamUtils.reduce()` is introduced to support reducing operations like this. We can refer to that PR to see how that infrastructure code can be shared.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> valueState;
+        private double[] value;
+
+        @Override
+        public void endInput() {
+            if (value != null) {

Review Comment:
   nit: `value` would always be created in `initializeState`, so it cannot be `null`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */

Review Comment:
   It might be better to improve the JavaDoc for the private static classes in this class. For example, I'm not sure how it would "update" variables, what are the "variables" to be updated, and what "Auc" means. Same for other JavaDocs. Besides, we can reorder the static classes and methods according to other classes, or according to the order they appeared in `transform()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> valueState;
+        private double[] value;
+
+        @Override
+        public void endInput() {
+            if (value != null) {
+                output.collect(new StreamRecord<>(value));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            double[] tmpValues = streamRecord.getValue();
+            value[0] += tmpValues[0];
+            value[1] += tmpValues[1];
+            value[2] += tmpValues[2];
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            valueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valueState", TypeInformation.of(double[].class)));
+            value =
+                    OperatorStateUtils.getUniqueElement(valueState, "valueState")
+                            .orElse(new double[3]);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            valueState.clear();
+            if (value != null) {
+                valueState.add(value);
+            }
+        }
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;
+            } else if (score != t.f0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+                Arrays.fill(accValue, 0.0);
+            }
+            accValue[0] += t.f1;
+            accValue[1] += 1.0;
+            if (t.f2) {
+                accValue[2] += t.f3;
+            } else {
+                accValue[3] += t.f3;
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accValueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "accValueState", TypeInformation.of(double[].class)));
+            accValue =
+                    OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accValueState.clear();
+            if (accValue != null) {
+                accValueState.add(accValue);
+            }
+        }
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState")
+                            .orElse(
+                                    new BinarySummary(
+                                            getRuntimeContext().getIndexOfThisSubtask(),
+                                            -Double.MAX_VALUE,
+                                            0,
+                                            0));
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+        private final String partitionSummariesKey;
+
+        public CalcBinaryMetrics(String partitionSummariesKey) {
+            this.partitionSummariesKey = partitionSummariesKey;
+        }
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(partitionSummariesKey);
+            long[] countValues =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                LOG.warn("There is no positive sample in data!");
+            }
+            if (totalFalse == 0) {
+                LOG.warn("There is no negative sample in data!");
+            }
+
+            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            double[] tprFprPrecision = new double[4];
+            for (Tuple3<Double, Boolean, Double> t3 : iterable) {
+                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+            }
+            collector.collect(metrics);
+        }
+    }
+
+    private static void updateBinaryMetrics(
+            Tuple3<Double, Boolean, Double> cur,
+            BinaryMetrics binaryMetrics,
+            long[] countValues,
+            double[] recordValues) {
+        if (binaryMetrics.count == 0) {
+            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+            recordValues[2] =
+                    countValues[0] + countValues[1] == 0
+                            ? 1.0
+                            : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+            recordValues[3] =
+                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+        }
+
+        binaryMetrics.count++;
+        if (cur.f1) {
+            countValues[0]++;
+        } else {
+            countValues[1]++;
+        }
+
+        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+        double precision =
+                countValues[0] + countValues[1] == 0
+                        ? 1.0
+                        : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+        double positiveRate =
+                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+
+        binaryMetrics.areaUnderLorenz +=
+                ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2);
+        binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2);
+        binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks);
+
+        recordValues[0] = tpr;
+        recordValues[1] = fpr;
+        recordValues[2] = precision;
+        recordValues[3] = positiveRate;
+    }
+
+    /**
+     * For each sample, calculates its score order among all samples. The sample with minimum score
+     * has order 1, while the sample with maximum score has order samples.
+     *
+     * <p>Input is a dataset of tuple (score, is real positive, wight), output is a dataset of tuple
+     * (score, order, is real positive, weight).
+     */
+    private static class CalcSampleOrders
+            extends RichFlatMapFunction<
+                    Tuple3<Double, Boolean, Double>, Tuple4<Double, Long, Boolean, Double>> {
+        private static final long serialVersionUID = 3047511137846831576L;

Review Comment:
   Is `serialVersionUID` necessary for this operator? Existing operator classes in other algorithms does not need `serialVersionUID`, so I think it is OK to remove this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867661821


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> valueState;
+        private double[] value;
+
+        @Override
+        public void endInput() {
+            if (value != null) {
+                output.collect(new StreamRecord<>(value));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            double[] tmpValues = streamRecord.getValue();
+            value[0] += tmpValues[0];
+            value[1] += tmpValues[1];
+            value[2] += tmpValues[2];
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            valueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valueState", TypeInformation.of(double[].class)));
+            value =
+                    OperatorStateUtils.getUniqueElement(valueState, "valueState")
+                            .orElse(new double[3]);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            valueState.clear();
+            if (value != null) {
+                valueState.add(value);
+            }
+        }
+    }
+
+    /** Updates variables for calculating Auc. */
+    private static class AccumulateMultiScoreOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<Tuple4<Double, Long, Boolean, Double>, double[]>,
+                    BoundedOneInput {
+        private ListState<double[]> accValueState;
+        double[] accValue;
+        double score;
+
+        @Override
+        public void endInput() {
+            if (accValue != null) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Double, Long, Boolean, Double>> streamRecord) {
+            Tuple4<Double, Long, Boolean, Double> t = streamRecord.getValue();
+            if (accValue == null) {
+                accValue = new double[4];
+                score = t.f0;
+            } else if (score != t.f0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new double[] {
+                                    accValue[0] / accValue[1] * accValue[2],
+                                    accValue[2],
+                                    accValue[3]
+                                }));
+                Arrays.fill(accValue, 0.0);
+            }
+            accValue[0] += t.f1;
+            accValue[1] += 1.0;
+            if (t.f2) {
+                accValue[2] += t.f3;
+            } else {
+                accValue[3] += t.f3;
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accValueState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "accValueState", TypeInformation.of(double[].class)));
+            accValue =
+                    OperatorStateUtils.getUniqueElement(accValueState, "valueState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accValueState.clear();
+            if (accValue != null) {
+                accValueState.add(accValue);
+            }
+        }
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState")
+                            .orElse(
+                                    new BinarySummary(
+                                            getRuntimeContext().getIndexOfThisSubtask(),
+                                            -Double.MAX_VALUE,
+                                            0,
+                                            0));
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();

Review Comment:
   I don‘t agree with you.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r867656982


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The rawPrediction can be of type double
+ * (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw
+ * predictions, scores, or label probabilities). The output may contain different metrics which will
+ * be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+        final String boundaryRangeKey = "boundaryRange";
+        final String partitionSummariesKey = "partitionSummaries";
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(boundaryRangeKey, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId(boundaryRangeKey));
+                        });
+
+        /* Repartition the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data by score.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(partitionSummariesKey, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders(partitionSummariesKey));
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders.transform(
+                        "AccumulateMultiScore",
+                        TypeInformation.of(double[].class),
+                        new AccumulateMultiScoreOperator());
+
+        DataStream<double[]> middleAreaUnderROC =
+                localAucVariable
+                        .transform(
+                                "calcLocalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .transform(
+                                "calcGlobalAucValues",
+                                TypeInformation.of(double[].class),
+                                new AucOperator())
+                        .setParallelism(1);
+
+        DataStream<Double> areaUnderROC =
+                middleAreaUnderROC.map(
+                        (MapFunction<double[], Double>)
+                                value -> {
+                                    if (value[1] > 0 && value[2] > 0) {
+                                        return (value[0] - 1. * value[1] * (value[1] + 1) / 2)
+                                                / (value[1] * value[2]);
+                                    } else {
+                                        return Double.NaN;
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(partitionSummariesKey, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(
+                                    input, new CalcBinaryMetrics(partitionSummariesKey));
+                        });
+
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+
+        DataStream<Row> evalResult =
+                metrics.map(
+                        (MapFunction<Map<String, Double>, Row>)
+                                value -> {
+                                    Row ret = new Row(metricsNames.length);
+                                    for (int i = 0; i < metricsNames.length; ++i) {
+                                        ret.setField(i, value.get(metricsNames[i]));
+                                    }
+                                    return ret;
+                                },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    /** Updates variables for calculating Auc. */

Review Comment:
   code has removed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r874281393


##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java:
##########
@@ -100,4 +100,22 @@ public boolean validate(T value) {
     public static <T> ParamValidator<T[]> nonEmptyArray() {
         return value -> value != null && value.length > 0;
     }
+
+    // Check if every element in the array-typed parameter value is in the array of allowed values.

Review Comment:
   nit: Check -> Checks



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r856870066


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and

Review Comment:
   Would it be useful to explain the definition of each metric here?
   
   Note that Spark ML provides the following Java doc in `BinaryClassificationMetrics.scala` [1]:
   
   areaUnderROC: the area under the receiver operating characteristic (ROC) curve.
   
   areaUnderPR: the area under the precision-recall curve.
   
   [1] https://spark.apache.org/docs/2.0.2/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
   
   
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns

Review Comment:
   Since the input data is not evaluated yet, maybe change `input evaluated data` to `input data`
   
   `may contains` -> `may contain`
   
   Instead of explicitly listing the metrics here, would it be simpler to just say `Please refer to the Java doc of the metricsNames parameter for the list of supported metrics`?
   
   Do we need to explain the supported types of the rawPrediction column here, similar to Spark ML's Java doc?
   
   Typically Java doc focuses on the behavior of the algorithm that users need to know.  Would it be simpler to remove the explanation of the implementation detail such as `we use a parallel method...`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     */
+    Param<String[]> METRICS_NAMES =
+            new StringArrayParam(
+                    "metricsNames",
+                    "Names of output metrics. The array element must be 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz'",
+                    new String[] {"areaUnderROC", "areaUnderPR"},
+                    ParamValidators.nonEmptyArray());

Review Comment:
   Instead of explaining the constraint in text, how about defining and using the following validator in `ParamValidators.java`?
   
   ```
   // Check if every element in the array-typed parameter value is in the array of allowed values.
   public static <T> ParamValidator<T[]> isSubArray(T... allowed) {
       return new ParamValidator<T[]>() {
           @Override
           public boolean validate(T[] value) {
               if (value == null) {
                   return false;
               }
               for (int i = 0; i < value.length; i++) {
                   if (!ArrayUtils.contains(allowed, value[i])) {
                       return false;
                   }
               }
               return true;
           }
       };
   }
   ```
   
   We probably need unit test for the newly defined validator if we decide to add it.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),

Review Comment:
   nits: Would it be simpler to define `boundaryRange` as a local variable in this method, similar to how we define `broadcastModelKey` in `LogisticRegressionModel::transform(...)`?
   
   Same for `partitionSummaries`.
   
   The program is typically simpler and easier to read by minimizing the class private (or global) variables.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r859293793


##########
flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java:
##########
@@ -463,5 +463,10 @@ public void testValidators() {
         Assert.assertTrue(nonEmptyArray.validate(new String[] {"1"}));
         Assert.assertFalse(nonEmptyArray.validate(null));
         Assert.assertFalse(nonEmptyArray.validate(new String[0]));
+
+        ParamValidator<String[]> isSubArray = ParamValidators.isSubArray("a", "b", "c");
+        Assert.assertFalse(isSubArray.validate(new String[] {"c", "v"}));

Review Comment:
   nits: could we also check the case where the input value is null, to be consistent with the above tests?
   
   `Assert.assertFalse(isSubArray.validate(null))`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     *
+     * <p>areaUnderROC: the area under the receiver operating characteristic (ROC) curve.

Review Comment:
   nits: could we update the Java doc here to follow the format used in `HasHandleInvalid.java`? It seems that the format in `HasHandleInvalid.java` is more readable.
   
   Currently the doc for `KS` and `areaUnderLorenz` are on the same lines.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =

Review Comment:
   Hmm... since `BinaryClassification` does not do any training (i.e. fit), would it be more intuitive to rename this as `INPUT_DATA*`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The output may contain different metrics
+ * which will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator

Review Comment:
   According to Spark's Java doc, its `BinaryClassificationEvaluator` supports rawPrediction column of type double (binary 0/1 prediction, or probability of label 1), in addition to the rawPrediction column of type vector (length-2 vector of raw predictions, scores, or label probabilities).
   
   Should we also support rawPrediction column of type double? If not, how do we handle those use-case that currently requires rawPrediction column of type double in Spark?
   
   



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =

Review Comment:
   Since we check double value equality using `delta=1.0e-5`, would it be simpler to reduce the precision of the expected values here accordingly?
   
   Same for other expected values initialized in this file.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.apache.flink.runtime.blob.BlobWriter.LOG;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input data has columns
+ * rawPrediction, label and an optional weight column. The output may contain different metrics
+ * which will be defined by parameter MetricsNames. See @BinaryClassificationEvaluatorParams.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String AREA_UNDER_ROC = "areaUnderROC";

Review Comment:
   Would it be better to move these variables to `BinaryClassificationEvaluatorParams` and make them `public`, so that users can set the parameter by referencing those variables instead of manually typing a string?
   
   Manually typing a string is a bit more error-prone and IDE won't be able to give hint in this case.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        Table evalResult = eval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        assertArrayEquals(
+                new String[] {"areaUnderPR", "KS", "areaUnderROC"},
+                evalResult.getResolvedSchema().getColumnNames().toArray());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testEvaluateWithMultiScore() throws Exception {

Review Comment:
   Could you help explain why we call this test `testEvaluateWithMultiScore` and call the test above `testEvaluate`? It seems that both tests use multiple evaluation metrics and their inputs have the same format.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                    Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                    Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1)));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            Arrays.asList(
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                    Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                    Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                    Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                    Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                    Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                    Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                    Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndEvaluate() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")

Review Comment:
   Since the goal of this test is to test save/load/transform, would it be simpler to keep using the default label parameter value (i.e. `label`) without explicitly setting it? Similarly, can we keep using the default value for rawPredictionCol?
   
   Same for other tests whose purpose is not to test the parameter set/get.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org