You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/23 09:18:33 UTC

[GitHub] [flink-ml] lindong28 commented on a diff in pull request #86: [FLINK-27294] Add Transformer for BinaryClassificationEvaluator

lindong28 commented on code in PR #86:
URL: https://github.com/apache/flink-ml/pull/86#discussion_r856870066


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and

Review Comment:
   Would it be useful to explain the definition of each metric here?
   
   Note that Spark ML provides the following Java doc in `BinaryClassificationMetrics.scala` [1]:
   
   areaUnderROC: the area under the receiver operating characteristic (ROC) curve.
   
   areaUnderPR: the area under the precision-recall curve.
   
   [1] https://spark.apache.org/docs/2.0.2/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
   
   
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns

Review Comment:
   Since the input data is not evaluated yet, maybe change `input evaluated data` to `input data`
   
   `may contains` -> `may contain`
   
   Instead of explicitly listing the metrics here, would it be simpler to just say `Please refer to the Java doc of the metricsNames parameter for the list of supported metrics`?
   
   Do we need to explain the supported types of the rawPrediction column here, similar to Spark ML's Java doc?
   
   Typically Java doc focuses on the behavior of the algorithm that users need to know.  Would it be simpler to remove the explanation of the implementation detail such as `we use a parallel method...`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluatorParams.java:
##########
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+
+/**
+ * Params of BinaryClassificationEvaluator.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface BinaryClassificationEvaluatorParams<T>
+        extends HasLabelCol<T>, HasRawPredictionCol<T>, HasWeightCol<T> {
+    /**
+     * param for metric names in evaluation (supports 'areaUnderROC', 'areaUnderPR', 'KS' and
+     * 'areaUnderLorenz').
+     */
+    Param<String[]> METRICS_NAMES =
+            new StringArrayParam(
+                    "metricsNames",
+                    "Names of output metrics. The array element must be 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz'",
+                    new String[] {"areaUnderROC", "areaUnderPR"},
+                    ParamValidators.nonEmptyArray());

Review Comment:
   Instead of explaining the constraint in text, how about defining and using the following validator in `ParamValidators.java`?
   
   ```
   // Check if every element in the array-typed parameter value is in the array of allowed values.
   public static <T> ParamValidator<T[]> isSubArray(T... allowed) {
       return new ParamValidator<T[]>() {
           @Override
           public boolean validate(T[] value) {
               if (value == null) {
                   return false;
               }
               for (int i = 0; i < value.length; i++) {
                   if (!ArrayUtils.contains(allowed, value[i])) {
                       return false;
                   }
               }
               return true;
           }
       };
   }
   ```
   
   We probably need unit test for the newly defined validator if we decide to add it.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryeval/BinaryClassificationEvaluator.java:
##########
@@ -0,0 +1,736 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.evaluation.binaryeval;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.scala.typeutils.Types;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamMap;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * Calculates the evaluation metrics for binary classification. The input evaluated data has columns
+ * rawPrediction, label and an optional weight column. The output metrics may contains
+ * 'areaUnderROC', 'areaUnderPR', 'KS' and 'areaUnderLorenz' which will be defined by parameter
+ * MetricsNames. Here, we use a parallel method to sort the whole evaluated data and calculate the
+ * accurate metrics.
+ */
+public class BinaryClassificationEvaluator
+        implements AlgoOperator<BinaryClassificationEvaluator>,
+                BinaryClassificationEvaluatorParams<BinaryClassificationEvaluator> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final int NUM_SAMPLE_FOR_RANGE_PARTITION = 100;
+    private static final String BOUNDARY_RANGE = "boundaryRange";
+    private static final String PARTITION_SUMMARY = "partitionSummaries";
+    private static final String AREA_UNDER_ROC = "areaUnderROC";
+    private static final String AREA_UNDER_PR = "areaUnderPR";
+    private static final String AREA_UNDER_LORENZ = "areaUnderLorenz";
+    private static final String KS = "KS";
+
+    public BinaryClassificationEvaluator() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<Double, Boolean, Double>> evalData =
+                tEnv.toDataStream(inputs[0])
+                        .map(new ParseSample(getLabelCol(), getRawPredictionCol(), getWeightCol()));
+
+        DataStream<Tuple4<Double, Boolean, Double, Integer>> evalDataWithTaskId =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(evalData),
+                        Collections.singletonMap(BOUNDARY_RANGE, getBoundaryRange(evalData)),

Review Comment:
   nits: Would it be simpler to define `boundaryRange` as a local variable in this method, similar to how we define `broadcastModelKey` in `LogisticRegressionModel::transform(...)`?
   
   Same for `partitionSummaries`.
   
   The program is typically simpler and easier to read by minimizing the class private (or global) variables.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org