You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/21 10:39:00 UTC

[flink-ml] branch master updated: [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 4935d03  [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception
4935d03 is described below

commit 4935d03898c8c89bdda61b9cfbe10936b687785d
Author: Dong Lin <li...@gmail.com>
AuthorDate: Tue Dec 21 14:47:26 2021 +0800

    [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception
    
    This closes #41.
---
 .../apache/flink/iteration/IterationListener.java  |  5 ++-
 .../operator/AbstractWrapperOperator.java          |  4 +-
 .../logisticregression/LogisticRegression.java     | 12 ++---
 .../apache/flink/ml/clustering/kmeans/KMeans.java  | 52 ++++++++++------------
 4 files changed, 33 insertions(+), 40 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
index 9c451f2..73b323f 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
@@ -46,7 +46,8 @@ public interface IterationListener<T> {
      *     the invocation of this method.
      * @param collector The collector for returning result values.
      */
-    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);
+    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector)
+            throws Exception;
 
     /**
      * This callback is invoked after the execution of the iteration body has terminated.
@@ -56,7 +57,7 @@ public interface IterationListener<T> {
      *     the invocation of this method.
      * @param collector The collector for returning result values.
      */
-    void onIterationTerminated(Context context, Collector<T> collector);
+    void onIterationTerminated(Context context, Collector<T> collector) throws Exception;
 
     /**
      * Information available in an invocation of the callbacks defined in the
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index 86dbc2f..d790729 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -123,8 +123,8 @@ public abstract class AbstractWrapperOperator<T>
     }
 
     @SuppressWarnings({"unchecked", "rawtypes"})
-    protected void notifyEpochWatermarkIncrement(
-            IterationListener<?> listener, int epochWatermark) {
+    protected void notifyEpochWatermarkIncrement(IterationListener<?> listener, int epochWatermark)
+            throws Exception {
         if (epochWatermark != Integer.MAX_VALUE) {
             listener.onEpochWatermarkIncremented(
                     epochWatermark,
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 9df0cdf..602b9c6 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -361,8 +361,8 @@ public class LogisticRegression
 
         @Override
         public void onEpochWatermarkIncremented(
-                int epochWatermark, Context context, Collector<double[]> collector) {
-            // TODO: let this method throws exception.
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
             if (epochWatermark == 0) {
                 coefficient = new DenseVector(feedbackBuffer);
                 coefficientDim = coefficient.size();
@@ -372,12 +372,8 @@ public class LogisticRegression
                 updateModel();
             }
             Arrays.fill(gradient.values, 0);
-            try {
-                if (trainData == null) {
-                    trainData = IteratorUtils.toList(trainDataState.get().iterator());
-                }
-            } catch (Exception e) {
-                throw new RuntimeException(e);
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
             }
             miniBatchData = getMiniBatchData(trainData, localBatchSize);
             Tuple2<Double, Double> weightSumAndLossSum =
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 1c1a47e..f9b704b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -276,37 +276,33 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
 
         @Override
         public void onEpochWatermarkIncremented(
-                int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out) {
-            // TODO: update onEpochWatermarkIncremented to throw Exception.
-            try {
-                List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator());
-                if (list.size() != 1) {
-                    throw new RuntimeException(
-                            "The operator received "
-                                    + list.size()
-                                    + " list of centroids in this round");
-                }
-                DenseVector[] centroidValues = list.get(0);
-
-                for (DenseVector point : points.get()) {
-                    double minDistance = Double.MAX_VALUE;
-                    int closestCentroidId = -1;
-
-                    for (int i = 0; i < centroidValues.length; i++) {
-                        DenseVector centroid = centroidValues[i];
-                        double distance = distanceMeasure.distance(centroid, point);
-                        if (distance < minDistance) {
-                            minDistance = distance;
-                            closestCentroidId = i;
-                        }
+                int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out)
+                throws Exception {
+            List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator());
+            if (list.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + list.size()
+                                + " list of centroids in this round");
+            }
+            DenseVector[] centroidValues = list.get(0);
+
+            for (DenseVector point : points.get()) {
+                double minDistance = Double.MAX_VALUE;
+                int closestCentroidId = -1;
+
+                for (int i = 0; i < centroidValues.length; i++) {
+                    DenseVector centroid = centroidValues[i];
+                    double distance = distanceMeasure.distance(centroid, point);
+                    if (distance < minDistance) {
+                        minDistance = distance;
+                        closestCentroidId = i;
                     }
-
-                    output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
                 }
-                centroids.clear();
-            } catch (Exception e) {
-                throw new RuntimeException(e);
+
+                output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
             }
+            centroids.clear();
         }
 
         @Override