You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/24 01:52:51 UTC

[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

yunfengzhou-hub commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r857051925


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)

Review Comment:
   I noticed that in several places we have firstly used an operation for each subtask or partition, followed by a global reduce/aggregation. I believe improving map-reduce's operations' efficiency in this way is globally applicable to all Flink operators, such that Flink should have implemented relevant mechanism for it, and we can directly use `reduce()`/`aggregate()` without implementing it on our own. Could you please help check it?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {

Review Comment:
   Shall we move this logic to the `orElse(null)` used in `initializeState`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {
+                summary =
+                        new BinarySummary(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                -Double.MAX_VALUE,
+                                0,
+                                0);
+            }
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    private static class CalcAucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> aucVariableState;
+        private double[] aucVariable;
+
+        @Override
+        public void endInput() {
+            if (aucVariable != null) {
+                output.collect(new StreamRecord<>(aucVariable));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            if (aucVariable == null) {
+                aucVariable = streamRecord.getValue();
+            } else {
+                double[] tmpAucVar = streamRecord.getValue();
+                aucVariable[0] += tmpAucVar[0];
+                aucVariable[1] += tmpAucVar[1];
+                aucVariable[2] += tmpAucVar[2];
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            aucVariableState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "aucState", TypeInformation.of(double[].class)));
+            aucVariable =
+                    OperatorStateUtils.getUniqueElement(aucVariableState, "aucState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            aucVariableState.clear();
+            if (aucVariable != null) {
+                aucVariableState.add(aucVariable);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(PARTITION_SUMMARY);
+            Tuple2<Boolean, long[]> t =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+            long[] countValues = t.f1;
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                System.out.println("There is no positive sample in data!");
+            }
+            if (totalFalse == 0) {
+                System.out.println("There is no negative sample in data!");
+            }
+
+            BinaryMetrics metrics = new BinaryMetrics(0L, areaUnderROC);
+            double[] tprFprPrecision = new double[4];
+            for (Tuple3<Double, Boolean, Double> t3 : iterable) {
+                updateBinaryMetrics(t3, metrics, countValues, tprFprPrecision);
+            }
+            collector.collect(metrics);
+        }
+    }
+
+    private static void updateBinaryMetrics(
+            Tuple3<Double, Boolean, Double> cur,
+            BinaryMetrics binaryMetrics,
+            long[] countValues,
+            double[] recordValues) {
+        if (binaryMetrics.count == 0) {
+            recordValues[0] = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+            recordValues[1] = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+            recordValues[2] =
+                    countValues[0] + countValues[1] == 0
+                            ? 1.0
+                            : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+            recordValues[3] =
+                    1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+        }
+
+        binaryMetrics.count++;
+        if (cur.f1) {
+            countValues[0]++;
+        } else {
+            countValues[1]++;
+        }
+
+        double tpr = countValues[2] == 0 ? 1.0 : 1.0 * countValues[0] / countValues[2];
+        double fpr = countValues[3] == 0 ? 1.0 : 1.0 * countValues[1] / countValues[3];
+        double precision =
+                countValues[0] + countValues[1] == 0
+                        ? 1.0
+                        : 1.0 * countValues[0] / (countValues[0] + countValues[1]);
+        double positiveRate =
+                1.0 * (countValues[0] + countValues[1]) / (countValues[2] + countValues[3]);
+
+        binaryMetrics.areaUnderLorenz +=
+                ((positiveRate - recordValues[3]) * (tpr + recordValues[0]) / 2);
+        binaryMetrics.areaUnderPR += ((tpr - recordValues[0]) * (precision + recordValues[2]) / 2);
+        binaryMetrics.ks = Math.max(Math.abs(fpr - tpr), binaryMetrics.ks);
+
+        recordValues[0] = tpr;
+        recordValues[1] = fpr;
+        recordValues[2] = precision;
+        recordValues[3] = positiveRate;

Review Comment:
   Shall we define a class to represent the metric results? It would bring better readability than remembering the meaning of each metric by their index in a double array.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(

Review Comment:
   It seems that we do not need `new ArrayList<>()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<

Review Comment:
   Implementations like this brings the risk of OOM, as data allocated to a single subtask can still be larger than memory space. It might be unnecessary to use `apply()` here, and we can consider using `reduce()` or `aggregate()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })

Review Comment:
   Shall we move anonymous functions like this into private static classes, give them meaningful class names and JavaDocs? That might help improve readability.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                            Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                            Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {

Review Comment:
   The order of test cases in this class and the naming convention seems to be different from other existing operators. Shall we follow the test style in other classes?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);
+        DataStream<BinaryMetrics> localMetrics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        broadcastMap,
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return DataStreamUtils.mapPartition(input, new CalcBinaryMetrics());
+                        });
+
+        final String[] metricsNames = getMetricsNames();
+        TypeInformation<?>[] metricTypes = new TypeInformation[metricsNames.length];
+        for (int i = 0; i < metricsNames.length; ++i) {
+            metricTypes[i] = Types.DOUBLE();
+        }
+        RowTypeInfo outputTypeInfo = new RowTypeInfo(metricTypes, metricsNames);
+        DataStream<Map<String, Double>> metrics =
+                DataStreamUtils.mapPartition(localMetrics, new MergeMetrics());
+        metrics.getTransformation().setParallelism(1);
+        DataStream<Row> evalResult =
+                metrics.map(
+                        new MapFunction<Map<String, Double>, Row>() {
+                            @Override
+                            public Row map(Map<String, Double> value) {
+                                Row ret = new Row(metricsNames.length);
+                                for (int i = 0; i < metricsNames.length; ++i) {
+                                    ret.setField(i, value.get(metricsNames[i]));
+                                }
+                                return ret;
+                            }
+                        },
+                        outputTypeInfo);
+        return new Table[] {tEnv.fromDataStream(evalResult)};
+    }
+
+    private static class PartitionSummaryOperator extends AbstractStreamOperator<BinarySummary>
+            implements OneInputStreamOperator<Tuple3<Double, Boolean, Double>, BinarySummary>,
+                    BoundedOneInput {
+        private ListState<BinarySummary> summaryState;
+        private BinarySummary summary;
+
+        @Override
+        public void endInput() {
+            if (summary != null) {
+                output.collect(new StreamRecord<>(summary));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Double, Boolean, Double>> streamRecord) {
+            if (summary == null) {
+                summary =
+                        new BinarySummary(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                -Double.MAX_VALUE,
+                                0,
+                                0);
+            }
+            updateBinarySummary(summary, streamRecord.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            summaryState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "summaryState",
+                                            TypeInformation.of(BinarySummary.class)));
+            summary =
+                    OperatorStateUtils.getUniqueElement(summaryState, "summaryState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            summaryState.clear();
+            if (summary != null) {
+                summaryState.add(summary);
+            }
+        }
+    }
+
+    private static class CalcAucOperator extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<double[], double[]>, BoundedOneInput {
+        private ListState<double[]> aucVariableState;
+        private double[] aucVariable;
+
+        @Override
+        public void endInput() {
+            if (aucVariable != null) {
+                output.collect(new StreamRecord<>(aucVariable));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<double[]> streamRecord) {
+            if (aucVariable == null) {
+                aucVariable = streamRecord.getValue();
+            } else {
+                double[] tmpAucVar = streamRecord.getValue();
+                aucVariable[0] += tmpAucVar[0];
+                aucVariable[1] += tmpAucVar[1];
+                aucVariable[2] += tmpAucVar[2];
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            aucVariableState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "aucState", TypeInformation.of(double[].class)));
+            aucVariable =
+                    OperatorStateUtils.getUniqueElement(aucVariableState, "aucState").orElse(null);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            aucVariableState.clear();
+            if (aucVariable != null) {
+                aucVariableState.add(aucVariable);
+            }
+        }
+    }
+
+    /** Merges the metrics calculated locally and output metrics data. */
+    private static class MergeMetrics
+            implements MapPartitionFunction<BinaryMetrics, Map<String, Double>> {
+        private static final long serialVersionUID = 463407033215369847L;
+
+        @Override
+        public void mapPartition(
+                Iterable<BinaryMetrics> values, Collector<Map<String, Double>> out) {
+            Iterator<BinaryMetrics> iter = values.iterator();
+            BinaryMetrics reduceMetrics = iter.next();
+            while (iter.hasNext()) {
+                reduceMetrics = reduceMetrics.merge(iter.next());
+            }
+            Map<String, Double> map = new HashMap<>();
+            map.put(AREA_UNDER_ROC, reduceMetrics.areaUnderROC);
+            map.put(AREA_UNDER_PR, reduceMetrics.areaUnderPR);
+            map.put(AREA_UNDER_LORENZ, reduceMetrics.areaUnderLorenz);
+            map.put(KS, reduceMetrics.ks);
+            out.collect(map);
+        }
+    }
+
+    private static class CalcBinaryMetrics
+            extends RichMapPartitionFunction<Tuple3<Double, Boolean, Double>, BinaryMetrics> {
+        private static final long serialVersionUID = 5680342197308160013L;
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Boolean, Double>> iterable,
+                Collector<BinaryMetrics> collector) {
+
+            List<BinarySummary> statistics =
+                    getRuntimeContext().getBroadcastVariable(PARTITION_SUMMARY);
+            Tuple2<Boolean, long[]> t =
+                    reduceBinarySummary(statistics, getRuntimeContext().getIndexOfThisSubtask());
+            long[] countValues = t.f1;
+
+            double areaUnderROC =
+                    getRuntimeContext().<Double>getBroadcastVariable(AREA_UNDER_ROC).get(0);
+            long totalTrue = countValues[2];
+            long totalFalse = countValues[3];
+            if (totalTrue == 0) {
+                System.out.println("There is no positive sample in data!");

Review Comment:
   It might be better to avoid using `println`, instead, we can use `LOG` to record such information.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorTest.java:
##########
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests {@link BinaryClassificationEvaluator}. */
+public class BinaryClassificationEvaluatorTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table trainDataTableWithMultiScore;
+    private Table trainDataTableWithWeight;
+
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.2, 0.8)),
+                            Row.of(1.0, Vectors.dense(0.3, 0.7)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.35, 0.65)),
+                            Row.of(1.0, Vectors.dense(0.45, 0.55)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.65, 0.35)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_MULTI_SCORE =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75)),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4)),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3)),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9)),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2)),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1))));
+
+    private static final List<Row> TRAIN_DATA_WITH_WEIGHT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.8),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.7),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 0.5),
+                            Row.of(0.0, Vectors.dense(0.25, 0.75), 1.2),
+                            Row.of(0.0, Vectors.dense(0.4, 0.6), 1.3),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.4),
+                            Row.of(0.0, Vectors.dense(0.6, 0.4), 0.3),
+                            Row.of(0.0, Vectors.dense(0.7, 0.3), 0.5),
+                            Row.of(1.0, Vectors.dense(0.1, 0.9), 1.9),
+                            Row.of(0.0, Vectors.dense(0.8, 0.2), 1.2),
+                            Row.of(1.0, Vectors.dense(0.9, 0.1), 1.0)));
+
+    private static final double[] EXPECTED_DATA =
+            new double[] {0.7691481137909708, 0.3714285714285714, 0.6571428571428571};
+    private static final double[] EXPECTED_DATA_M =
+            new double[] {0.9377705627705628, 0.8571428571428571};
+    private static final double EXPECTED_DATA_W = 0.8911680911680911;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("label", "raw");
+        trainDataTableWithMultiScore =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_MULTI_SCORE))
+                        .as("label", "raw");
+        trainDataTableWithWeight =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA_WITH_WEIGHT))
+                        .as("label", "raw", "weight");
+    }
+
+    @Test
+    public void testParam() {
+        BinaryClassificationEvaluator binaryEval = new BinaryClassificationEvaluator();
+        assertEquals("label", binaryEval.getLabelCol());
+        assertNull(binaryEval.getWeightCol());
+        assertEquals("rawPrediction", binaryEval.getRawPredictionCol());
+        assertArrayEquals(
+                new String[] {"areaUnderROC", "areaUnderPR"}, binaryEval.getMetricsNames());
+        binaryEval
+                .setLabelCol("labelCol")
+                .setRawPredictionCol("raw")
+                .setMetricsNames("areaUnderROC")
+                .setWeightCol("weight");
+        assertEquals("labelCol", binaryEval.getLabelCol());
+        assertEquals("weight", binaryEval.getWeightCol());
+        assertEquals("raw", binaryEval.getRawPredictionCol());
+        assertArrayEquals(new String[] {"areaUnderROC"}, binaryEval.getMetricsNames());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        BinaryClassificationEvaluator loadedEval =
+                StageTestUtils.saveAndReload(tEnv, eval, tempFolder.newFolder().getAbsolutePath());
+        Table evalResult = loadedEval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);
+        }
+    }
+
+    @Test
+    public void testTransform() throws Exception {
+        BinaryClassificationEvaluator eval =
+                new BinaryClassificationEvaluator()
+                        .setMetricsNames("areaUnderPR", "KS", "areaUnderROC")
+                        .setLabelCol("label")
+                        .setRawPredictionCol("raw");
+        Table evalResult = eval.transform(trainDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(evalResult);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        Row result = results.get(0);
+        for (int i = 0; i < EXPECTED_DATA.length; ++i) {
+            assertEquals(EXPECTED_DATA[i], result.getFieldAs(i), 1.0e-5);

Review Comment:
   Shall we also check the output column names, instead of just the value at each index?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(new AppendTaskId());
+                        });
+
+        /* Repartitions the evaluated data by range. */
+        evalDataWithTaskId =
+                evalDataWithTaskId.partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f3);
+
+        /* Sorts local data.*/
+        evalData =
+                DataStreamUtils.mapPartition(
+                        evalDataWithTaskId,
+                        new MapPartitionFunction<
+                                Tuple4<Double, Boolean, Double, Integer>,
+                                Tuple3<Double, Boolean, Double>>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple4<Double, Boolean, Double, Integer>> values,
+                                    Collector<Tuple3<Double, Boolean, Double>> out) {
+                                List<Tuple3<Double, Boolean, Double>> bufferedData =
+                                        new LinkedList<>();
+                                for (Tuple4<Double, Boolean, Double, Integer> t4 : values) {
+                                    bufferedData.add(Tuple3.of(t4.f0, t4.f1, t4.f2));
+                                }
+                                bufferedData.sort(Comparator.comparingDouble(o -> -o.f0));
+                                for (Tuple3<Double, Boolean, Double> dataPoint : bufferedData) {
+                                    out.collect(dataPoint);
+                                }
+                            }
+                        });
+
+        /* Calculates the summary of local data. */
+        DataStream<BinarySummary> partitionSummaries =
+                evalData.transform(
+                        "reduceInEachPartition",
+                        TypeInformation.of(BinarySummary.class),
+                        new PartitionSummaryOperator());
+
+        /* Sorts global data. Output Tuple4 : <score, order, isPositive, weight> */
+        DataStream<Tuple4<Double, Long, Boolean, Double>> dataWithOrders =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(PARTITION_SUMMARY, partitionSummaries),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.flatMap(new CalcSampleOrders());
+                        });
+
+        dataWithOrders =
+                dataWithOrders.transform(
+                        "appendMaxWaterMark",
+                        dataWithOrders.getType(),
+                        new AppendMaxWatermark(x -> x));
+
+        DataStream<double[]> localAucVariable =
+                dataWithOrders
+                        .keyBy(
+                                (KeySelector<Tuple4<Double, Long, Boolean, Double>, Double>)
+                                        value -> value.f0)
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (WindowFunction<
+                                                Tuple4<Double, Long, Boolean, Double>,
+                                                double[],
+                                                Double,
+                                                TimeWindow>)
+                                        (key, window, values, out) -> {
+                                            long sum = 0;
+                                            long cnt = 0;
+                                            double positiveSum = 0;
+                                            double negativeSum = 0;
+
+                                            for (Tuple4<Double, Long, Boolean, Double> t : values) {
+                                                sum += t.f1;
+                                                cnt++;
+                                                if (t.f2) {
+                                                    positiveSum += t.f3;
+                                                } else {
+                                                    negativeSum += t.f3;
+                                                }
+                                            }
+                                            out.collect(
+                                                    new double[] {
+                                                        1. * sum / cnt * positiveSum,
+                                                        positiveSum,
+                                                        negativeSum
+                                                    });
+                                        })
+                        .returns(double[].class);
+
+        DataStream<Double> areaUnderROC =
+                localAucVariable
+                        .transform(
+                                "reduceInEachPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                TypeInformation.of(double[].class),
+                                new CalcAucOperator())
+                        .setParallelism(1)
+                        .map(
+                                new MapFunction<double[], Double>() {
+                                    @Override
+                                    public Double map(double[] aucVariable) {
+                                        if (aucVariable[1] > 0 && aucVariable[2] > 0) {
+                                            return (aucVariable[0]
+                                                            - 1.
+                                                                    * aucVariable[1]
+                                                                    * (aucVariable[1] + 1)
+                                                                    / 2)
+                                                    / (aucVariable[1] * aucVariable[2]);
+                                        } else {
+                                            return Double.NaN;
+                                        }
+                                    }
+                                });
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(PARTITION_SUMMARY, partitionSummaries);
+        broadcastMap.put(AREA_UNDER_ROC, areaUnderROC);

Review Comment:
   I noticed that `areaUnderROC` is calculated separately from other metrics. Given that all metrics experience similarly to be computed, like a `mapPartition` or `window().apply()`, I have the sense that we may have just one window operator or map partition operator that gets all metrics out. That would greatly simplify the structure of this code. How do you like the idea to try in this way?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org