You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/21 10:39:00 UTC
[flink-ml] branch master updated: [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 4935d03 [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception
4935d03 is described below
commit 4935d03898c8c89bdda61b9cfbe10936b687785d
Author: Dong Lin <li...@gmail.com>
AuthorDate: Tue Dec 21 14:47:26 2021 +0800
[hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception
This closes #41.
---
.../apache/flink/iteration/IterationListener.java | 5 ++-
.../operator/AbstractWrapperOperator.java | 4 +-
.../logisticregression/LogisticRegression.java | 12 ++---
.../apache/flink/ml/clustering/kmeans/KMeans.java | 52 ++++++++++------------
4 files changed, 33 insertions(+), 40 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
index 9c451f2..73b323f 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java
@@ -46,7 +46,8 @@ public interface IterationListener<T> {
* the invocation of this method.
* @param collector The collector for returning result values.
*/
- void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);
+ void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector)
+ throws Exception;
/**
* This callback is invoked after the execution of the iteration body has terminated.
@@ -56,7 +57,7 @@ public interface IterationListener<T> {
* the invocation of this method.
* @param collector The collector for returning result values.
*/
- void onIterationTerminated(Context context, Collector<T> collector);
+ void onIterationTerminated(Context context, Collector<T> collector) throws Exception;
/**
* Information available in an invocation of the callbacks defined in the
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index 86dbc2f..d790729 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -123,8 +123,8 @@ public abstract class AbstractWrapperOperator<T>
}
@SuppressWarnings({"unchecked", "rawtypes"})
- protected void notifyEpochWatermarkIncrement(
- IterationListener<?> listener, int epochWatermark) {
+ protected void notifyEpochWatermarkIncrement(IterationListener<?> listener, int epochWatermark)
+ throws Exception {
if (epochWatermark != Integer.MAX_VALUE) {
listener.onEpochWatermarkIncremented(
epochWatermark,
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 9df0cdf..602b9c6 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -361,8 +361,8 @@ public class LogisticRegression
@Override
public void onEpochWatermarkIncremented(
- int epochWatermark, Context context, Collector<double[]> collector) {
- // TODO: let this method throws exception.
+ int epochWatermark, Context context, Collector<double[]> collector)
+ throws Exception {
if (epochWatermark == 0) {
coefficient = new DenseVector(feedbackBuffer);
coefficientDim = coefficient.size();
@@ -372,12 +372,8 @@ public class LogisticRegression
updateModel();
}
Arrays.fill(gradient.values, 0);
- try {
- if (trainData == null) {
- trainData = IteratorUtils.toList(trainDataState.get().iterator());
- }
- } catch (Exception e) {
- throw new RuntimeException(e);
+ if (trainData == null) {
+ trainData = IteratorUtils.toList(trainDataState.get().iterator());
}
miniBatchData = getMiniBatchData(trainData, localBatchSize);
Tuple2<Double, Double> weightSumAndLossSum =
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 1c1a47e..f9b704b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -276,37 +276,33 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea
@Override
public void onEpochWatermarkIncremented(
- int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out) {
- // TODO: update onEpochWatermarkIncremented to throw Exception.
- try {
- List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator());
- if (list.size() != 1) {
- throw new RuntimeException(
- "The operator received "
- + list.size()
- + " list of centroids in this round");
- }
- DenseVector[] centroidValues = list.get(0);
-
- for (DenseVector point : points.get()) {
- double minDistance = Double.MAX_VALUE;
- int closestCentroidId = -1;
-
- for (int i = 0; i < centroidValues.length; i++) {
- DenseVector centroid = centroidValues[i];
- double distance = distanceMeasure.distance(centroid, point);
- if (distance < minDistance) {
- minDistance = distance;
- closestCentroidId = i;
- }
+ int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out)
+ throws Exception {
+ List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator());
+ if (list.size() != 1) {
+ throw new RuntimeException(
+ "The operator received "
+ + list.size()
+ + " list of centroids in this round");
+ }
+ DenseVector[] centroidValues = list.get(0);
+
+ for (DenseVector point : points.get()) {
+ double minDistance = Double.MAX_VALUE;
+ int closestCentroidId = -1;
+
+ for (int i = 0; i < centroidValues.length; i++) {
+ DenseVector centroid = centroidValues[i];
+ double distance = distanceMeasure.distance(centroid, point);
+ if (distance < minDistance) {
+ minDistance = distance;
+ closestCentroidId = i;
}
-
- output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
}
- centroids.clear();
- } catch (Exception e) {
- throw new RuntimeException(e);
+
+ output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point)));
}
+ centroids.clear();
}
@Override