You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/12/20 03:00:51 UTC
[flink-ml] 01/01: [FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit d36fe8feb043a010e45f75dff9d7d7f21aa37fc7
Author: zhangzp <zh...@gmail.com>
AuthorDate: Fri Dec 17 17:36:41 2021 +0800
[FLINK-24556] Make model data pojo for naive bayes, kmeans and logistic regression
This closes #28.
---
.../ml/common/feature/LabeledPointWithWeight.java | 32 +++++++++++++++--
.../logisticregression/LogisticGradient.java | 17 ++++-----
.../logisticregression/LogisticRegression.java | 8 ++---
.../LogisticRegressionModelData.java | 7 ++--
.../ml/classification/naivebayes/NaiveBayes.java | 4 ++-
.../naivebayes/NaiveBayesModelData.java | 19 +++++-----
.../ml/clustering/kmeans/KMeansModelData.java | 7 ++--
.../logisticregression/LogisticRegressionTest.java | 9 +++--
.../org/apache/flink/ml/clustering/KMeansTest.java | 41 ++++++++++------------
9 files changed, 87 insertions(+), 57 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
index e02192f..8440bc9 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
@@ -23,15 +23,41 @@ import org.apache.flink.ml.linalg.DenseVector;
/** Utility class to represent a data point that contains features, label and weight. */
public class LabeledPointWithWeight {
- public final DenseVector features;
+ private DenseVector features;
- public final double label;
+ private double label;
- public final double weight;
+ private double weight;
public LabeledPointWithWeight(DenseVector features, double label, double weight) {
this.features = features;
this.label = label;
this.weight = weight;
}
+
+ public LabeledPointWithWeight() {}
+
+ public DenseVector getFeatures() {
+ return features;
+ }
+
+ public void setFeatures(DenseVector features) {
+ this.features = features;
+ }
+
+ public double getLabel() {
+ return label;
+ }
+
+ public void setLabel(double label) {
+ this.label = label;
+ }
+
+ public double getWeight() {
+ return weight;
+ }
+
+ public void setWeight(double weight) {
+ this.weight = weight;
+ }
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
index c63b72e..13f753b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
@@ -52,8 +52,8 @@ public class LogisticGradient implements Serializable {
double weightSum = 0.0;
double lossSum = 0.0;
for (LabeledPointWithWeight dataPoint : dataPoints) {
- lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient);
- weightSum += dataPoint.weight;
+ lossSum += dataPoint.getWeight() * computeLoss(dataPoint, coefficient);
+ weightSum += dataPoint.getWeight();
}
if (Double.compare(0, l2) != 0) {
lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
@@ -81,16 +81,17 @@ public class LogisticGradient implements Serializable {
}
private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) {
- double dot = BLAS.dot(dataPoint.features, coefficient);
- double labelScaled = 2 * dataPoint.label - 1;
+ double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+ double labelScaled = 2 * dataPoint.getLabel() - 1;
return Math.log(1 + Math.exp(-dot * labelScaled));
}
private void computeGradient(
LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) {
- double dot = BLAS.dot(dataPoint.features, coefficient);
- double labelScaled = 2 * dataPoint.label - 1;
- double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
- BLAS.axpy(multiplier, dataPoint.features, cumGradient);
+ double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+ double labelScaled = 2 * dataPoint.getLabel() - 1;
+ double multiplier =
+ dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1));
+ BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 7266610..a17269b 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -114,7 +114,7 @@ public class LogisticRegression
dataPoint -> {
Double weight =
getWeightCol() == null
- ? new Double(1.0)
+ ? 1.0
: (Double) dataPoint.getField(getWeightCol());
Double label = (Double) dataPoint.getField(getLabelCol());
boolean isBinomial =
@@ -160,9 +160,9 @@ public class LogisticRegression
@Override
public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) {
if (dim == 0) {
- dim = streamRecord.getValue().features.size();
+ dim = streamRecord.getValue().getFeatures().size();
} else {
- if (dim != streamRecord.getValue().features.size()) {
+ if (dim != streamRecord.getValue().getFeatures().size()) {
throw new RuntimeException(
"The training data should all have same dimensions.");
}
@@ -390,7 +390,7 @@ public class LogisticRegression
}
@Override
- public void onIterationTerminated(Context context, Collector collector) {
+ public void onIterationTerminated(Context context, Collector<double[]> collector) {
trainDataState.clear();
coefficientState.clear();
feedbackBufferState.clear();
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
index aae66fb..774c19e 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
@@ -44,12 +44,14 @@ import java.io.OutputStream;
*/
public class LogisticRegressionModelData {
- public final DenseVector coefficient;
+ public DenseVector coefficient;
public LogisticRegressionModelData(DenseVector coefficient) {
this.coefficient = coefficient;
}
+ public LogisticRegressionModelData() {}
+
/**
* Converts the table model to a data stream.
*
@@ -59,7 +61,8 @@ public class LogisticRegressionModelData {
public static DataStream<LogisticRegressionModelData> getModelDataStream(Table modelData) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
- return tEnv.toDataStream(modelData).map(x -> (LogisticRegressionModelData) x.getField(0));
+ return tEnv.toDataStream(modelData)
+ .map(x -> new LogisticRegressionModelData((DenseVector) x.getField(0)));
}
/** Data encoder for {@link LogisticRegressionModel}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index aefb44d..7a3cc3d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -339,7 +340,8 @@ public class NaiveBayes
piArray[i] = Math.log(weightSum + smoothing) - piLog;
}
- NaiveBayesModelData modelData = new NaiveBayesModelData(theta, piArray, labels);
+ NaiveBayesModelData modelData =
+ new NaiveBayesModelData(theta, Vectors.dense(piArray), Vectors.dense(labels));
collector.collect(modelData);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
index fee3b35..a03141d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
@@ -29,7 +29,6 @@ import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
@@ -54,17 +53,13 @@ public class NaiveBayesModelData {
* Log of class conditional probabilities, whose dimension is C (number of classes) by D (number
* of features).
*/
- public final Map<Double, Double>[][] theta;
+ public Map<Double, Double>[][] theta;
/** Log of class priors, whose dimension is C (number of classes). */
- public final DenseVector piArray;
+ public DenseVector piArray;
/** Value of labels. */
- public final DenseVector labels;
-
- public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] piArray, double[] labels) {
- this(theta, Vectors.dense(piArray), Vectors.dense(labels));
- }
+ public DenseVector labels;
public NaiveBayesModelData(
Map<Double, Double>[][] theta, DenseVector piArray, DenseVector labels) {
@@ -73,6 +68,8 @@ public class NaiveBayesModelData {
this.labels = labels;
}
+ public NaiveBayesModelData() {}
+
/**
* Converts the table model to a data stream.
*
@@ -85,7 +82,11 @@ public class NaiveBayesModelData {
return tEnv.toDataStream(modelData)
.map(
(MapFunction<Row, NaiveBayesModelData>)
- row -> (NaiveBayesModelData) row.getField("f0"));
+ row ->
+ new NaiveBayesModelData(
+ (Map<Double, Double>[][]) row.getField(0),
+ (DenseVector) row.getField(1),
+ (DenseVector) row.getField(2)));
}
/** Data encoder for the {@link NaiveBayesModelData}. */
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 4bbf345..af0733d 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -45,12 +45,14 @@ import java.io.OutputStream;
*/
public class KMeansModelData {
- public final DenseVector[] centroids;
+ public DenseVector[] centroids;
public KMeansModelData(DenseVector[] centroids) {
this.centroids = centroids;
}
+ public KMeansModelData() {}
+
/**
* Converts the table model to a data stream.
*
@@ -60,7 +62,8 @@ public class KMeansModelData {
public static DataStream<KMeansModelData> getModelDataStream(Table modelData) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
- return tEnv.toDataStream(modelData).map(x -> (KMeansModelData) x.getField(0));
+ return tEnv.toDataStream(modelData)
+ .map(x -> new KMeansModelData((DenseVector[]) x.getField(0)));
}
/** Data encoder for {@link KMeansModelData}. */
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
index db57bd9..e7dc036 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java
@@ -228,7 +228,7 @@ public class LogisticRegressionTest {
LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
assertEquals(
- Collections.singletonList("f0"),
+ Collections.singletonList("coefficient"),
model.getModelData()[0].getResolvedSchema().getColumnNames());
Table output = model.transform(binomialDataTable)[0];
verifyPredictionResult(
@@ -242,11 +242,10 @@ public class LogisticRegressionTest {
public void testGetModelData() throws Exception {
LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight");
LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
- List<Row> collectedModelData =
- IteratorUtils.toList(
- tEnv.toDataStream(model.getModelData()[0]).executeAndCollect());
LogisticRegressionModelData modelData =
- (LogisticRegressionModelData) collectedModelData.get(0).getField(0);
+ LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect()
+ .next();
assertNotNull(modelData);
assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 0.1);
}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 9f613e1..fe42829 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -60,7 +60,7 @@ import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-/** Tests KMeans and KMeansModel. */
+/** Tests {@link KMeans} and {@link KMeansModel}. */
public class KMeansTest extends AbstractTestBase {
@Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
@@ -150,7 +150,7 @@ public class KMeansTest extends AbstractTestBase {
}
@Test
- public void testFeaturePredictionParam() throws Exception {
+ public void testFeaturePredictionParam() {
Table input = dataTable.as("test_feature");
KMeans kmeans =
new KMeans().setFeaturesCol("test_feature").setPredictionCol("test_prediction");
@@ -166,7 +166,7 @@ public class KMeansTest extends AbstractTestBase {
}
@Test
- public void testFewerDistinctPointsThanCluster() throws Exception {
+ public void testFewerDistinctPointsThanCluster() {
List<DenseVector> data =
Arrays.asList(
Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1));
@@ -185,7 +185,7 @@ public class KMeansTest extends AbstractTestBase {
}
@Test
- public void testFitAndPredict() throws Exception {
+ public void testFitAndPredict() {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
KMeansModel model = kmeans.fit(dataTable);
Table output = model.transform(dataTable)[0];
@@ -201,18 +201,14 @@ public class KMeansTest extends AbstractTestBase {
@Test
public void testSaveLoadAndPredict() throws Exception {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-
KMeans loadedKmeans =
StageTestUtils.saveAndReload(env, kmeans, tempFolder.newFolder().getAbsolutePath());
-
KMeansModel model = loadedKmeans.fit(dataTable);
-
KMeansModel loadedModel =
StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
Table output = loadedModel.transform(dataTable)[0];
-
assertEquals(
- Arrays.asList("f0"),
+ Collections.singletonList("centroids"),
loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
assertEquals(
Arrays.asList("features", "prediction"),
@@ -226,16 +222,17 @@ public class KMeansTest extends AbstractTestBase {
@Test
public void testGetModelData() throws Exception {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
- KMeansModel modelA = kmeans.fit(dataTable);
- Table modelData = modelA.getModelData()[0];
-
- DataStream<KMeansModelData> output =
- tEnv.toDataStream(modelData).map(row -> (KMeansModelData) row.getField("f0"));
-
- assertEquals(Arrays.asList("f0"), modelData.getResolvedSchema().getColumnNames());
- List<KMeansModelData> kMeansModelData = IteratorUtils.toList(output.executeAndCollect());
- DenseVector[] centroids = kMeansModelData.get(0).centroids;
- assertEquals(1, kMeansModelData.size());
+ KMeansModel model = kmeans.fit(dataTable);
+ assertEquals(
+ Collections.singletonList("centroids"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ DataStream<KMeansModelData> modelData =
+ KMeansModelData.getModelDataStream(model.getModelData()[0]);
+ List<KMeansModelData> collectedModelData =
+ IteratorUtils.toList(modelData.executeAndCollect());
+ assertEquals(1, collectedModelData.size());
+ DenseVector[] centroids = collectedModelData.get(0).centroids;
assertEquals(2, centroids.length);
Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0)));
assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5);
@@ -243,12 +240,10 @@ public class KMeansTest extends AbstractTestBase {
}
@Test
- public void testSetModelData() throws Exception {
+ public void testSetModelData() {
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
KMeansModel modelA = kmeans.fit(dataTable);
- Table modelData = modelA.getModelData()[0];
-
- KMeansModel modelB = new KMeansModel().setModelData(modelData);
+ KMeansModel modelB = new KMeansModel().setModelData(modelA.getModelData());
ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(dataTable)[0];