You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/20 03:00:51 UTC

[flink-ml] 01/01: [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit d36fe8feb043a010e45f75dff9d7d7f21aa37fc7
Author: zhangzp <zh...@gmail.com>
AuthorDate: Fri Dec 17 17:36:41 2021 +0800

    [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression
    
    This closes #28.
---
 .../ml/common/feature/LabeledPointWithWeight.java  | 32 +++++++++++++++--
 .../logisticregression/LogisticGradient.java       | 17 ++++-----
 .../logisticregression/LogisticRegression.java     |  8 ++---
 .../LogisticRegressionModelData.java               |  7 ++--
 .../ml/classification/naivebayes/NaiveBayes.java   |  4 ++-
 .../naivebayes/NaiveBayesModelData.java            | 19 +++++-----
 .../ml/clustering/kmeans/KMeansModelData.java      |  7 ++--
 .../logisticregression/LogisticRegressionTest.java |  9 +++--
 .../org/apache/flink/ml/clustering/KMeansTest.java | 41 ++++++++++------------
 9 files changed, 87 insertions(+), 57 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
index e02192f..8440bc9 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
@@ -23,15 +23,41 @@ import org.apache.flink.ml.linalg.DenseVector;
 /** Utility class to represent a data point that contains features, label and weight. */
 public class LabeledPointWithWeight {
 
-    public final DenseVector features;
+    private DenseVector features;
 
-    public final double label;
+    private double label;
 
-    public final double weight;
+    private double weight;
 
     public LabeledPointWithWeight(DenseVector features, double label, double weight) {
         this.features = features;
         this.label = label;
         this.weight = weight;
     }
+
+    public LabeledPointWithWeight() {}
+
+    public DenseVector getFeatures() {
+        return features;
+    }
+
+    public void setFeatures(DenseVector features) {
+        this.features = features;
+    }
+
+    public double getLabel() {
+        return label;
+    }
+
+    public void setLabel(double label) {
+        this.label = label;
+    }
+
+    public double getWeight() {
+        return weight;
+    }
+
+    public void setWeight(double weight) {
+        this.weight = weight;
+    }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
index c63b72e..13f753b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
@@ -52,8 +52,8 @@ public class LogisticGradient implements Serializable {
         double weightSum = 0.0;
         double lossSum = 0.0;
         for (LabeledPointWithWeight dataPoint : dataPoints) {
-            lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient);
-            weightSum += dataPoint.weight;
+            lossSum += dataPoint.getWeight() * computeLoss(dataPoint, coefficient);
+            weightSum += dataPoint.getWeight();
         }
         if (Double.compare(0, l2) != 0) {
             lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
@@ -81,16 +81,17 @@ public class LogisticGradient implements Serializable {
     }
 
     private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) {
-        double dot = BLAS.dot(dataPoint.features, coefficient);
-        double labelScaled = 2 * dataPoint.label - 1;
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
         return Math.log(1 + Math.exp(-dot * labelScaled));
     }
 
     private void computeGradient(
             LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) {
-        double dot = BLAS.dot(dataPoint.features, coefficient);
-        double labelScaled = 2 * dataPoint.label - 1;
-        double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
-        BLAS.axpy(multiplier, dataPoint.features, cumGradient);
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        double multiplier =
+                dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
+        BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient);
     }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 7266610..a17269b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -114,7 +114,7 @@ public class LogisticRegression
                                 dataPoint -> {
                                     Double weight =
                                             getWeightCol() == null
-                                                    ? new Double(1.0)
+                                                    ? 1.0
                                                     : (Double) dataPoint.getField(getWeightCol());
                                     Double label = (Double) dataPoint.getField(getLabelCol());
                                     boolean isBinomial =
@@ -160,9 +160,9 @@ public class LogisticRegression
         @Override
         public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) {
             if (dim == 0) {
-                dim = streamRecord.getValue().features.size();
+                dim = streamRecord.getValue().getFeatures().size();
             } else {
-                if (dim != streamRecord.getValue().features.size()) {
+                if (dim != streamRecord.getValue().getFeatures().size()) {
                     throw new RuntimeException(
                             "The training data should all have same dimensions.");
                 }
@@ -390,7 +390,7 @@ public class LogisticRegression
         }
 
         @Override
-        public void onIterationTerminated(Context context, Collector collector) {
+        public void onIterationTerminated(Context context, Collector<double[]> collector) {
             trainDataState.clear();
             coefficientState.clear();
             feedbackBufferState.clear();
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
index aae66fb..774c19e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
@@ -44,12 +44,14 @@ import java.io.OutputStream;
  */
 public class LogisticRegressionModelData {
 
-    public final DenseVector coefficient;
+    public DenseVector coefficient;
 
     public LogisticRegressionModelData(DenseVector coefficient) {
         this.coefficient = coefficient;
     }
 
+    public LogisticRegressionModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -59,7 +61,8 @@ public class LogisticRegressionModelData {
     public static DataStream<LogisticRegressionModelData> getModelDataStream(Table modelData) {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
-        return tEnv.toDataStream(modelData).map(x -> (LogisticRegressionModelData) x.getField(0));
+        return tEnv.toDataStream(modelData)
+                .map(x -> new LogisticRegressionModelData((DenseVector) x.getField(0)));
     }
 
     /** Data encoder for {@link LogisticRegressionModel}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index aefb44d..7a3cc3d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple4;
 import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -339,7 +340,8 @@ public class NaiveBayes
                 piArray[i] = Math.log(weightSum + smoothing) - piLog;
             }
 
-            NaiveBayesModelData modelData = new NaiveBayesModelData(theta, piArray, labels);
+            NaiveBayesModelData modelData =
+                    new NaiveBayesModelData(theta, Vectors.dense(piArray), Vectors.dense(labels));
             collector.collect(modelData);
         }
     }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
index fee3b35..a03141d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
@@ -29,7 +29,6 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
@@ -54,17 +53,13 @@ public class NaiveBayesModelData {
      * Log of class conditional probabilities, whose dimension is C (number of classes) by D (number
      * of features).
      */
-    public final Map<Double, Double>[][] theta;
+    public Map<Double, Double>[][] theta;
 
     /** Log of class priors, whose dimension is C (number of classes). */
-    public final DenseVector piArray;
+    public DenseVector piArray;
 
     /** Value of labels. */
-    public final DenseVector labels;
-
-    public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] piArray, double[] labels) {
-        this(theta, Vectors.dense(piArray), Vectors.dense(labels));
-    }
+    public DenseVector labels;
 
     public NaiveBayesModelData(
             Map<Double, Double>[][] theta, DenseVector piArray, DenseVector labels) {
@@ -73,6 +68,8 @@ public class NaiveBayesModelData {
         this.labels = labels;
     }
 
+    public NaiveBayesModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -85,7 +82,11 @@ public class NaiveBayesModelData {
         return tEnv.toDataStream(modelData)
                 .map(
                         (MapFunction<Row, NaiveBayesModelData>)
-                                row -> (NaiveBayesModelData) row.getField("f0"));
+                                row ->
+                                        new NaiveBayesModelData(
+                                                (Map<Double, Double>[][]) row.getField(0),
+                                                (DenseVector) row.getField(1),
+                                                (DenseVector) row.getField(2)));
     }
 
     /** Data encoder for the {@link NaiveBayesModelData}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 4bbf345..af0733d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -45,12 +45,14 @@ import java.io.OutputStream;
  */
 public class KMeansModelData {
 
-    public final DenseVector[] centroids;
+    public DenseVector[] centroids;
 
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
     }
 
+    public KMeansModelData() {}
+
     /**
      * Converts the table model to a data stream.
      *
@@ -60,7 +62,8 @@ public class KMeansModelData {
     public static DataStream<KMeansModelData> getModelDataStream(Table modelData) {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
-        return tEnv.toDataStream(modelData).map(x -> (KMeansModelData) x.getField(0));
+        return tEnv.toDataStream(modelData)
+                .map(x -> new KMeansModelData((DenseVector[]) x.getField(0)));
     }
 
     /** Data encoder for {@link KMeansModelData}. */
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
index db57bd9..e7dc036 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
@@ -228,7 +228,7 @@ public class LogisticRegressionTest {
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
         model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("f0"),
+                Collections.singletonList("coefficient"),
                 model.getModelData()[0].getResolvedSchema().getColumnNames());
         Table output = model.transform(binomialDataTable)[0];
         verifyPredictionResult(
@@ -242,11 +242,10 @@ public class LogisticRegressionTest {
     public void testGetModelData() throws Exception {
         LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight");
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
-        List<Row> collectedModelData =
-                IteratorUtils.toList(
-                        tEnv.toDataStream(model.getModelData()[0]).executeAndCollect());
         LogisticRegressionModelData modelData =
-                (LogisticRegressionModelData) collectedModelData.get(0).getField(0);
+                LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        .executeAndCollect()
+                        .next();
         assertNotNull(modelData);
         assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 0.1);
     }
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 9f613e1..fe42829 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -60,7 +60,7 @@ import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
-/** Tests KMeans and KMeansModel. */
+/** Tests {@link KMeans} and {@link KMeansModel}. */
 public class KMeansTest extends AbstractTestBase {
     @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
 
@@ -150,7 +150,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFeaturePredictionParam() throws Exception {
+    public void testFeaturePredictionParam() {
         Table input = dataTable.as("test_feature");
         KMeans kmeans =
                 new KMeans().setFeaturesCol("test_feature").setPredictionCol("test_prediction");
@@ -166,7 +166,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFewerDistinctPointsThanCluster() throws Exception {
+    public void testFewerDistinctPointsThanCluster() {
         List<DenseVector> data =
                 Arrays.asList(
                         Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1));
@@ -185,7 +185,7 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testFitAndPredict() throws Exception {
+    public void testFitAndPredict() {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
         Table output = model.transform(dataTable)[0];
@@ -201,18 +201,14 @@ public class KMeansTest extends AbstractTestBase {
     @Test
     public void testSaveLoadAndPredict() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-
         KMeans loadedKmeans =
                 StageTestUtils.saveAndReload(env, kmeans, tempFolder.newFolder().getAbsolutePath());
-
         KMeansModel model = loadedKmeans.fit(dataTable);
-
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-
         assertEquals(
-                Arrays.asList("f0"),
+                Collections.singletonList("centroids"),
                 loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
@@ -226,16 +222,17 @@ public class KMeansTest extends AbstractTestBase {
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-        KMeansModel modelA = kmeans.fit(dataTable);
-        Table modelData = modelA.getModelData()[0];
-
-        DataStream<KMeansModelData> output =
-                tEnv.toDataStream(modelData).map(row -> (KMeansModelData) row.getField("f0"));
-
-        assertEquals(Arrays.asList("f0"), modelData.getResolvedSchema().getColumnNames());
-        List<KMeansModelData> kMeansModelData = IteratorUtils.toList(output.executeAndCollect());
-        DenseVector[] centroids = kMeansModelData.get(0).centroids;
-        assertEquals(1, kMeansModelData.size());
+        KMeansModel model = kmeans.fit(dataTable);
+        assertEquals(
+                Collections.singletonList("centroids"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        DataStream<KMeansModelData> modelData =
+                KMeansModelData.getModelDataStream(model.getModelData()[0]);
+        List<KMeansModelData> collectedModelData =
+                IteratorUtils.toList(modelData.executeAndCollect());
+        assertEquals(1, collectedModelData.size());
+        DenseVector[] centroids = collectedModelData.get(0).centroids;
         assertEquals(2, centroids.length);
         Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0)));
         assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5);
@@ -243,12 +240,10 @@ public class KMeansTest extends AbstractTestBase {
     }
 
     @Test
-    public void testSetModelData() throws Exception {
+    public void testSetModelData() {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel modelA = kmeans.fit(dataTable);
-        Table modelData = modelA.getModelData()[0];
-
-        KMeansModel modelB = new KMeansModel().setModelData(modelData);
+        KMeansModel modelB = new KMeansModel().setModelData(modelA.getModelData());
         ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
 
         Table output = modelB.transform(dataTable)[0];