You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/20 03:00:50 UTC

[flink-ml] branch master updated (eeccb82 -> d36fe8f)

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git.


    from eeccb82  [FLINK-24845] Add allreduce utility function in FlinkML
     add 0f20eeb  [FLINK-24556] Add Estimator and Transformer for logistic regression
     add a7a0b44  [hotfix] Reformat naive bayes, kmeans and logistic regression
     new d36fe8f  [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 flink-ml-core/pom.xml                              |   4 +-
 .../ml/common/datastream/DataStreamUtils.java      |  66 +++
 .../ml/common/datastream/EndOfStreamWindows.java   |   5 +-
 .../datastream/MapPartitionFunctionWrapper.java    |  67 ---
 .../ml/common/feature/LabeledPointWithWeight.java  |  63 +++
 ...axIterationNum.java => TerminateOnMaxIter.java} |  13 +-
 ...rationNum.java => TerminateOnMaxIterOrTol.java} |  43 +-
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  21 +
 .../org/apache/flink/ml/linalg/DenseVector.java    |   4 +
 .../org/apache/flink/ml/util/ReadWriteUtils.java   |   6 +-
 .../ml/common/datastream/DataStreamUtilsTest.java  |  75 ++++
 .../java/org/apache/flink/ml/linalg/BLASTest.java  |  64 +++
 .../logisticregression/LogisticGradient.java       |  97 +++++
 .../logisticregression/LogisticRegression.java     | 458 +++++++++++++++++++++
 .../LogisticRegressionModel.java                   | 173 ++++++++
 .../LogisticRegressionModelData.java               | 111 +++++
 .../LogisticRegressionModelParams.java             |  20 +-
 .../LogisticRegressionParams.java}                 |  37 +-
 .../ml/classification/naivebayes/NaiveBayes.java   |  54 +--
 .../classification/naivebayes/NaiveBayesModel.java |  10 +-
 .../naivebayes/NaiveBayesModelData.java            |  85 ++--
 .../apache/flink/ml/clustering/kmeans/KMeans.java  |  19 +-
 .../flink/ml/clustering/kmeans/KMeansModel.java    | 151 +++----
 .../ml/clustering/kmeans/KMeansModelData.java      |  77 ++--
 .../ml/clustering/kmeans/KMeansModelParams.java    |   2 +-
 .../flink/ml/clustering/kmeans/KMeansParams.java   |   2 +-
 .../flink/ml/common/param/HasDistanceMeasure.java  |   2 +-
 .../{HasMaxIter.java => HasGlobalBatchSize.java}   |  22 +-
 .../flink/ml/common/param/HasHandleInvalid.java    |   2 +-
 .../apache/flink/ml/common/param/HasLabelCol.java  |   2 +-
 .../{HasInputCols.java => HasLearningRate.java}    |  24 +-
 .../apache/flink/ml/common/param/HasMaxIter.java   |   2 +-
 .../{HasHandleInvalid.java => HasMultiClass.java}  |  38 +-
 .../flink/ml/common/param/HasPredictionCol.java    |   2 +-
 .../{HasLabelCol.java => HasRawPredictionCol.java} |  18 +-
 .../common/param/{HasMaxIter.java => HasReg.java}  |  19 +-
 .../param/{HasFeaturesCol.java => HasTol.java}     |  25 +-
 .../param/{HasLabelCol.java => HasWeightCol.java}  |  19 +-
 .../ml/feature/onehotencoder/OneHotEncoder.java    |  16 +-
 .../logisticregression/LogisticRegressionTest.java | 281 +++++++++++++
 .../org/apache/flink/ml/clustering/KMeansTest.java |  49 ++-
 flink-ml-tests/pom.xml                             |   7 +
 .../BoundedAllRoundStreamIterationITCase.java      |   6 +-
 .../operators/RoundBasedTerminationCriteria.java   |  48 ---
 44 files changed, 1801 insertions(+), 508 deletions(-)
 delete mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/MapPartitionFunctionWrapper.java
 create mode 100644 flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
 copy flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/{TerminateOnMaxIterationNum.java => TerminateOnMaxIter.java} (85%)
 rename flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/{TerminateOnMaxIterationNum.java => TerminateOnMaxIterOrTol.java} (56%)
 create mode 100644 flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
 create mode 100644 flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
 create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
 create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
 create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
 create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
 copy flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/EpochAware.java => flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java (63%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/classification/{naivebayes/NaiveBayesParams.java => logisticregression/LogisticRegressionParams.java} (52%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasMaxIter.java => HasGlobalBatchSize.java} (66%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasInputCols.java => HasLearningRate.java} (64%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasHandleInvalid.java => HasMultiClass.java} (57%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasLabelCol.java => HasRawPredictionCol.java} (69%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasMaxIter.java => HasReg.java} (70%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasFeaturesCol.java => HasTol.java} (66%)
 copy flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/{HasLabelCol.java => HasWeightCol.java} (69%)
 create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
 delete mode 100644 flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/RoundBasedTerminationCriteria.java

[flink-ml] 01/01: [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit d36fe8feb043a010e45f75dff9d7d7f21aa37fc7
Author: zhangzp <zh...@gmail.com>
AuthorDate: Fri Dec 17 17:36:41 2021 +0800

    [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression
    
    This closes #28.
---
 .../ml/common/feature/LabeledPointWithWeight.java  | 32 +++++++++++++++--
 .../logisticregression/LogisticGradient.java       | 17 ++++-----
 .../logisticregression/LogisticRegression.java     |  8 ++---
 .../LogisticRegressionModelData.java               |  7 ++--
 .../ml/classification/naivebayes/NaiveBayes.java   |  4 ++-
 .../naivebayes/NaiveBayesModelData.java            | 19 +++++-----
 .../ml/clustering/kmeans/KMeansModelData.java      |  7 ++--
 .../logisticregression/LogisticRegressionTest.java |  9 +++--
 .../org/apache/flink/ml/clustering/KMeansTest.java | 41 ++++++++++------------
 9 files changed, 87 insertions(+), 57 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
index e02192f..8440bc9 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
@@ -23,15 +23,41 @@ import org.apache.flink.ml.linalg.DenseVector;
 /** Utility class to represent a data point that contains features, label and weight. */
 public class LabeledPointWithWeight {
 
-    public final DenseVector features;
+    private DenseVector features;
 
-    public final double label;
+    private double label;
 
-    public final double weight;
+    private double weight;
 
     public LabeledPointWithWeight(DenseVector features, double label, double weight) {
         this.features = features;
         this.label = label;
         this.weight = weight;
     }
+
+    public LabeledPointWithWeight() {}
+
+    public DenseVector getFeatures() {
+        return features;
+    }
+
+    public void setFeatures(DenseVector features) {
+        this.features = features;
+    }
+
+    public double getLabel() {
+        return label;
+    }
+
+    public void setLabel(double label) {
+        this.label = label;
+    }
+
+    public double getWeight() {
+        return weight;
+    }
+
+    public void setWeight(double weight) {
+        this.weight = weight;
+    }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
index c63b72e..13f753b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
@@ -52,8 +52,8 @@ public class LogisticGradient implements Serializable {
         double weightSum = 0.0;
         double lossSum = 0.0;
         for (LabeledPointWithWeight dataPoint : dataPoints) {
-            lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient);
-            weightSum += dataPoint.weight;
+            lossSum += dataPoint.getWeight() * computeLoss(dataPoint, coefficient);
+            weightSum += dataPoint.getWeight();
         }
         if (Double.compare(0, l2) != 0) {
             lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
@@ -81,16 +81,17 @@ public class LogisticGradient implements Serializable {
     }
 
     private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) {
-        double dot = BLAS.dot(dataPoint.features, coefficient);
-        double labelScaled = 2 * dataPoint.label - 1;
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
         return Math.log(1 + Math.exp(-dot * labelScaled));
     }
 
     private void computeGradient(
             LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) {
-        double dot = BLAS.dot(dataPoint.features, coefficient);
-        double labelScaled = 2 * dataPoint.label - 1;
-        double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
-        BLAS.axpy(multiplier, dataPoint.features, cumGradient);
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        double multiplier =
+                dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
+        BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient);
     }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 7266610..a17269b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -114,7 +114,7 @@ public class LogisticRegression
                                 dataPoint -> {
                                     Double weight =
                                             getWeightCol() == null
-                                                    ? new Double(1.0)
+                                                    ? 1.0
                                                     : (Double) dataPoint.getField(getWeightCol());
                                     Double label = (Double) dataPoint.getField(getLabelCol());
                                     boolean isBinomial =
@@ -160,9 +160,9 @@ public class LogisticRegression
         @Override
         public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) {
             if (dim == 0) {
-                dim = streamRecord.getValue().features.size();
+                dim = streamRecord.getValue().getFeatures().size();
             } else {
-                if (dim != streamRecord.getValue().features.size()) {
+                if (dim != streamRecord.getValue().getFeatures().size()) {
                     throw new RuntimeException(
                             "The training data should all have same dimensions.");
                 }
@@ -390,7 +390,7 @@ public class LogisticRegression
         }
 
         @Override
-        public void onIterationTerminated(Context context, Collector collector) {
+        public void onIterationTerminated(Context context, Collector<double[]> collector) {
             trainDataState.clear();
             coefficientState.clear();
             feedbackBufferState.clear();
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
index aae66fb..774c19e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
@@ -44,12 +44,14 @@ import java.io.OutputStream;
  */
 public class LogisticRegressionModelData {
 
-    public final DenseVector coefficient;
+    public DenseVector coefficient;
 
     public LogisticRegressionModelData(DenseVector coefficient) {
         this.coefficient = coefficient;
     }
 
+    public LogisticRegressionModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -59,7 +61,8 @@ public class LogisticRegressionModelData {
     public static DataStream<LogisticRegressionModelData> getModelDataStream(Table modelData) {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
-        return tEnv.toDataStream(modelData).map(x -> (LogisticRegressionModelData) x.getField(0));
+        return tEnv.toDataStream(modelData)
+                .map(x -> new LogisticRegressionModelData((DenseVector) x.getField(0)));
     }
 
     /** Data encoder for {@link LogisticRegressionModel}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index aefb44d..7a3cc3d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple4;
 import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -339,7 +340,8 @@ public class NaiveBayes
                 piArray[i] = Math.log(weightSum + smoothing) - piLog;
             }
 
-            NaiveBayesModelData modelData = new NaiveBayesModelData(theta, piArray, labels);
+            NaiveBayesModelData modelData =
+                    new NaiveBayesModelData(theta, Vectors.dense(piArray), Vectors.dense(labels));
             collector.collect(modelData);
         }
     }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
index fee3b35..a03141d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
@@ -29,7 +29,6 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
@@ -54,17 +53,13 @@ public class NaiveBayesModelData {
      * Log of class conditional probabilities, whose dimension is C (number of classes) by D (number
      * of features).
      */
-    public final Map<Double, Double>[][] theta;
+    public Map<Double, Double>[][] theta;
 
     /** Log of class priors, whose dimension is C (number of classes). */
-    public final DenseVector piArray;
+    public DenseVector piArray;
 
     /** Value of labels. */
-    public final DenseVector labels;
-
-    public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] piArray, double[] labels) {
-        this(theta, Vectors.dense(piArray), Vectors.dense(labels));
-    }
+    public DenseVector labels;
 
     public NaiveBayesModelData(
             Map<Double, Double>[][] theta, DenseVector piArray, DenseVector labels) {
@@ -73,6 +68,8 @@ public class NaiveBayesModelData {
         this.labels = labels;
     }
 
+    public NaiveBayesModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -85,7 +82,11 @@ public class NaiveBayesModelData {
         return tEnv.toDataStream(modelData)
                 .map(
                         (MapFunction<Row, NaiveBayesModelData>)
-                                row -> (NaiveBayesModelData) row.getField("f0"));
+                                row ->
+                                        new NaiveBayesModelData(
+                                                (Map<Double, Double>[][]) row.getField(0),
+                                                (DenseVector) row.getField(1),
+                                                (DenseVector) row.getField(2)));
     }
 
     /** Data encoder for the {@link NaiveBayesModelData}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 4bbf345..af0733d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -45,12 +45,14 @@ import java.io.OutputStream;
  */
 public class KMeansModelData {
 
-    public final DenseVector[] centroids;
+    public DenseVector[] centroids;
 
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
     }
 
+    public KMeansModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -60,7 +62,8 @@ public class KMeansModelData {
     public static DataStream<KMeansModelData> getModelDataStream(Table modelData) {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
-        return tEnv.toDataStream(modelData).map(x -> (KMeansModelData) x.getField(0));
+        return tEnv.toDataStream(modelData)
+                .map(x -> new KMeansModelData((DenseVector[]) x.getField(0)));
     }
 
     /** Data encoder for {@link KMeansModelData}. */
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
index db57bd9..e7dc036 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
@@ -228,7 +228,7 @@ public class LogisticRegressionTest {
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
         model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("f0"),
+                Collections.singletonList("coefficient"),
                 model.getModelData()[0].getResolvedSchema().getColumnNames());
         Table output = model.transform(binomialDataTable)[0];
         verifyPredictionResult(
@@ -242,11 +242,10 @@ public class LogisticRegressionTest {
     public void testGetModelData() throws Exception {
         LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight");
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
-        List<Row> collectedModelData =
-                IteratorUtils.toList(
-                        tEnv.toDataStream(model.getModelData()[0]).executeAndCollect());
         LogisticRegressionModelData modelData =
-                (LogisticRegressionModelData) collectedModelData.get(0).getField(0);
+                LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        .executeAndCollect()
+                        .next();
         assertNotNull(modelData);
         assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 0.1);
     }
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 9f613e1..fe42829 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -60,7 +60,7 @@ import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
-/** Tests KMeans and KMeansModel. */
+/** Tests {@link KMeans} and {@link KMeansModel}. */
 public class KMeansTest extends AbstractTestBase {
     @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
 
@@ -150,7 +150,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFeaturePredictionParam() throws Exception {
+    public void testFeaturePredictionParam() {
         Table input = dataTable.as("test_feature");
         KMeans kmeans =
                 new KMeans().setFeaturesCol("test_feature").setPredictionCol("test_prediction");
@@ -166,7 +166,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFewerDistinctPointsThanCluster() throws Exception {
+    public void testFewerDistinctPointsThanCluster() {
         List<DenseVector> data =
                 Arrays.asList(
                         Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1));
@@ -185,7 +185,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFitAndPredict() throws Exception {
+    public void testFitAndPredict() {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
         Table output = model.transform(dataTable)[0];
@@ -201,18 +201,14 @@ public class KMeansTest extends AbstractTestBase {
     @Test
     public void testSaveLoadAndPredict() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-
         KMeans loadedKmeans =
                 StageTestUtils.saveAndReload(env, kmeans, tempFolder.newFolder().getAbsolutePath());
-
         KMeansModel model = loadedKmeans.fit(dataTable);
-
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-
         assertEquals(
-                Arrays.asList("f0"),
+                Collections.singletonList("centroids"),
                 loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
@@ -226,16 +222,17 @@ public class KMeansTest extends AbstractTestBase {
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-        KMeansModel modelA = kmeans.fit(dataTable);
-        Table modelData = modelA.getModelData()[0];
-
-        DataStream<KMeansModelData> output =
-                tEnv.toDataStream(modelData).map(row -> (KMeansModelData) row.getField("f0"));
-
-        assertEquals(Arrays.asList("f0"), modelData.getResolvedSchema().getColumnNames());
-        List<KMeansModelData> kMeansModelData = IteratorUtils.toList(output.executeAndCollect());
-        DenseVector[] centroids = kMeansModelData.get(0).centroids;
-        assertEquals(1, kMeansModelData.size());
+        KMeansModel model = kmeans.fit(dataTable);
+        assertEquals(
+                Collections.singletonList("centroids"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        DataStream<KMeansModelData> modelData =
+                KMeansModelData.getModelDataStream(model.getModelData()[0]);
+        List<KMeansModelData> collectedModelData =
+                IteratorUtils.toList(modelData.executeAndCollect());
+        assertEquals(1, collectedModelData.size());
+        DenseVector[] centroids = collectedModelData.get(0).centroids;
         assertEquals(2, centroids.length);
         Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0)));
         assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5);
@@ -243,12 +240,10 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testSetModelData() throws Exception {
+    public void testSetModelData() {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel modelA = kmeans.fit(dataTable);
-        Table modelData = modelA.getModelData()[0];
-
-        KMeansModel modelB = new KMeansModel().setModelData(modelData);
+        KMeansModel modelB = new KMeansModel().setModelData(modelA.getModelData());
         ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
 
         Table output = modelB.transform(dataTable)[0];