You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/09 09:33:16 UTC
[flink-ml] branch master updated: [hotfix] Remove those library
infra classes that need to be revisited
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 81cd74a [hotfix] Remove those library infra classes that need to be revisited
81cd74a is described below
commit 81cd74adb7679cfb5cc5f446d758c4ff3e50583f
Author: Dong Lin <li...@gmail.com>
AuthorDate: Tue Nov 9 17:21:20 2021 +0800
[hotfix] Remove those library infra classes that need to be revisited
This closes #19.
---
.../flink/examples/java/ml/LinearRegression.java | 257 ---------
.../java/ml/util/LinearRegressionData.java | 69 ---
.../ml/util/LinearRegressionDataGenerator.java | 115 ----
.../flink/examples/scala/ml/LinearRegression.scala | 159 ------
flink-ml-examples/examples-streaming/pom.xml | 190 -------
.../examples/ml/IncrementalLearningSkeleton.java | 212 --------
.../ml/util/IncrementalLearningSkeletonData.java | 33 --
.../examples/ml/IncrementalLearningSkeleton.scala | 184 -------
.../streaming/test/StreamingExamplesITCase.java | 36 --
.../scala/examples/StreamingExamplesITCase.scala | 37 --
flink-ml-examples/pom.xml | 111 ----
flink-ml-lib/pom.xml | 26 +-
.../org/apache/flink/ml/common/MLEnvironment.java | 89 ----
.../flink/ml/common/MLEnvironmentFactory.java | 116 -----
.../org/apache/flink/ml/common/linalg/BLAS.java | 234 ---------
.../apache/flink/ml/common/linalg/DenseMatrix.java | 577 ---------------------
.../apache/flink/ml/common/linalg/DenseVector.java | 379 --------------
.../apache/flink/ml/common/linalg/MatVecOp.java | 307 -----------
.../flink/ml/common/linalg/SparseVector.java | 574 --------------------
.../org/apache/flink/ml/common/linalg/Vector.java | 89 ----
.../flink/ml/common/linalg/VectorIterator.java | 73 ---
.../apache/flink/ml/common/linalg/VectorUtil.java | 240 ---------
.../org/apache/flink/ml/common/mapper/Mapper.java | 79 ---
.../flink/ml/common/mapper/MapperAdapter.java | 46 --
.../apache/flink/ml/common/mapper/ModelMapper.java | 66 ---
.../flink/ml/common/mapper/ModelMapperAdapter.java | 62 ---
.../common/model/BroadcastVariableModelSource.java | 47 --
.../apache/flink/ml/common/model/ModelSource.java | 40 --
.../flink/ml/common/model/RowsModelSource.java | 46 --
.../basicstatistic/MultivariateGaussian.java | 138 -----
.../ml/common/utils/DataStreamConversionUtil.java | 167 ------
.../flink/ml/common/utils/OutputColsHelper.java | 211 --------
.../apache/flink/ml/common/utils/TableUtil.java | 424 ---------------
.../apache/flink/ml/common/utils/VectorTypes.java | 43 --
.../org/apache/flink/ml/operator/AlgoOperator.java | 186 -------
.../flink/ml/operator/batch/BatchOperator.java | 113 ----
.../operator/batch/source/TableSourceBatchOp.java | 40 --
.../flink/ml/operator/stream/StreamOperator.java | 114 ----
.../stream/source/TableSourceStreamOp.java | 40 --
.../flink/ml/params/shared/HasMLEnvironmentId.java | 43 --
.../ml/params/shared/colname/HasOutputCol.java | 48 --
.../shared/colname/HasOutputColDefaultAsNull.java | 49 --
.../ml/params/shared/colname/HasOutputCols.java | 48 --
.../shared/colname/HasOutputColsDefaultAsNull.java | 49 --
.../ml/params/shared/colname/HasPredictionCol.java | 42 --
.../shared/colname/HasPredictionDetailCol.java | 47 --
.../ml/params/shared/colname/HasReservedCols.java | 45 --
.../ml/params/shared/colname/HasSelectedCol.java | 48 --
.../colname/HasSelectedColDefaultAsNull.java | 49 --
.../ml/params/shared/colname/HasSelectedCols.java | 48 --
.../colname/HasSelectedColsDefaultAsNull.java | 49 --
.../apache/flink/ml/common/MLEnvironmentTest.java | 65 ---
.../apache/flink/ml/common/linalg/BLASTest.java | 186 -------
.../flink/ml/common/linalg/DenseMatrixTest.java | 195 -------
.../flink/ml/common/linalg/DenseVectorTest.java | 158 ------
.../flink/ml/common/linalg/MatVecOpTest.java | 103 ----
.../flink/ml/common/linalg/SparseVectorTest.java | 232 ---------
.../flink/ml/common/linalg/VectorUtilTest.java | 76 ---
.../basicstatistic/MultivariateGaussianTest.java | 72 ---
.../common/utils/DataStreamConversionUtilTest.java | 208 --------
.../ml/common/utils/OutputColsHelperTest.java | 249 ---------
.../flink/ml/common/utils/TableUtilTest.java | 200 -------
.../flink/ml/common/utils/VectorTypesTest.java | 78 ---
pom.xml | 1 -
64 files changed, 4 insertions(+), 8353 deletions(-)
diff --git a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/LinearRegression.java b/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/LinearRegression.java
deleted file mode 100644
index 4f2f528..0000000
--- a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/LinearRegression.java
+++ /dev/null
@@ -1,257 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.examples.java.ml;
-
-import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.api.java.operators.IterativeDataSet;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.utils.ParameterTool;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.examples.java.ml.util.LinearRegressionData;
-
-import java.io.Serializable;
-import java.util.Collection;
-
-/**
- * This example implements a basic Linear Regression to solve the y = theta0 + theta1*x problem
- * using batch gradient descent algorithm.
- *
- * <p>Linear Regression with BGD(batch gradient descent) algorithm is an iterative clustering
- * algorithm and works as follows:<br>
- * Giving a data set and target set, the BGD try to find out the best parameters for the data set to
- * fit the target set. In each iteration, the algorithm computes the gradient of the cost function
- * and use it to update all the parameters. The algorithm terminates after a fixed number of
- * iterations (as in this implementation) With enough iteration, the algorithm can minimize the cost
- * function and find the best parameters This is the Wikipedia entry for the <a href =
- * "http://en.wikipedia.org/wiki/Linear_regression">Linear regression</a> and <a href =
- * "http://en.wikipedia.org/wiki/Gradient_descent">Gradient descent algorithm</a>.
- *
- * <p>This implementation works on one-dimensional data. And find the two-dimensional theta.<br>
- * It find the best Theta parameter to fit the target.
- *
- * <p>Input files are plain text files and must be formatted as follows:
- *
- * <ul>
- * <li>Data points are represented as two double values separated by a blank character. The first
- * one represent the X(the training data) and the second represent the Y(target). Data points
- * are separated by newline characters.<br>
- * For example <code>"-0.02 -0.04\n5.3 10.6\n"</code> gives two data points (x=-0.02, y=-0.04)
- * and (x=5.3, y=10.6).
- * </ul>
- *
- * <p>This example shows how to use:
- *
- * <ul>
- * <li>Bulk iterations
- * <li>Broadcast variables in bulk iterations
- * <li>Custom Java objects (PoJos)
- * </ul>
- */
-@SuppressWarnings("serial")
-public class LinearRegression {
-
- // *************************************************************************
- // PROGRAM
- // *************************************************************************
-
- public static void main(String[] args) throws Exception {
-
- final ParameterTool params = ParameterTool.fromArgs(args);
-
- // set up execution environment
- final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
-
- final int iterations = params.getInt("iterations", 10);
-
- // make parameters available in the web interface
- env.getConfig().setGlobalJobParameters(params);
-
- // get input x data from elements
- DataSet<Data> data;
- if (params.has("input")) {
- // read data from CSV file
- data =
- env.readCsvFile(params.get("input"))
- .fieldDelimiter(" ")
- .includeFields(true, true)
- .pojoType(Data.class);
- } else {
- System.out.println("Executing LinearRegression example with default input data set.");
- System.out.println("Use --input to specify file input.");
- data = LinearRegressionData.getDefaultDataDataSet(env);
- }
-
- // get the parameters from elements
- DataSet<Params> parameters = LinearRegressionData.getDefaultParamsDataSet(env);
-
- // set number of bulk iterations for SGD linear Regression
- IterativeDataSet<Params> loop = parameters.iterate(iterations);
-
- DataSet<Params> newParameters =
- data
- // compute a single step using every sample
- .map(new SubUpdate())
- .withBroadcastSet(loop, "parameters")
- // sum up all the steps
- .reduce(new UpdateAccumulator())
- // average the steps and update all parameters
- .map(new Update());
-
- // feed new parameters back into next iteration
- DataSet<Params> result = loop.closeWith(newParameters);
-
- // emit result
- if (params.has("output")) {
- result.writeAsText(params.get("output"));
- // execute program
- env.execute("Linear Regression example");
- } else {
- System.out.println("Printing result to stdout. Use --output to specify output path.");
- result.print();
- }
- }
-
- // *************************************************************************
- // DATA TYPES
- // *************************************************************************
-
- /** A simple data sample, x means the input, and y means the target. */
- public static class Data implements Serializable {
- public double x, y;
-
- public Data() {}
-
- public Data(double x, double y) {
- this.x = x;
- this.y = y;
- }
-
- @Override
- public String toString() {
- return "(" + x + "|" + y + ")";
- }
- }
-
- /** A set of parameters -- theta0, theta1. */
- public static class Params implements Serializable {
-
- private double theta0, theta1;
-
- public Params() {}
-
- public Params(double x0, double x1) {
- this.theta0 = x0;
- this.theta1 = x1;
- }
-
- @Override
- public String toString() {
- return theta0 + " " + theta1;
- }
-
- public double getTheta0() {
- return theta0;
- }
-
- public double getTheta1() {
- return theta1;
- }
-
- public void setTheta0(double theta0) {
- this.theta0 = theta0;
- }
-
- public void setTheta1(double theta1) {
- this.theta1 = theta1;
- }
-
- public Params div(Integer a) {
- this.theta0 = theta0 / a;
- this.theta1 = theta1 / a;
- return this;
- }
- }
-
- // *************************************************************************
- // USER FUNCTIONS
- // *************************************************************************
-
- /** Compute a single BGD type update for every parameters. */
- public static class SubUpdate extends RichMapFunction<Data, Tuple2<Params, Integer>> {
-
- private Collection<Params> parameters;
-
- private Params parameter;
-
- private int count = 1;
-
- /** Reads the parameters from a broadcast variable into a collection. */
- @Override
- public void open(Configuration parameters) throws Exception {
- this.parameters = getRuntimeContext().getBroadcastVariable("parameters");
- }
-
- @Override
- public Tuple2<Params, Integer> map(Data in) throws Exception {
-
- for (Params p : parameters) {
- this.parameter = p;
- }
-
- double theta0 =
- parameter.theta0
- - 0.01 * ((parameter.theta0 + (parameter.theta1 * in.x)) - in.y);
- double theta1 =
- parameter.theta1
- - 0.01
- * (((parameter.theta0 + (parameter.theta1 * in.x)) - in.y)
- * in.x);
-
- return new Tuple2<Params, Integer>(new Params(theta0, theta1), count);
- }
- }
-
- /** Accumulator all the update. */
- public static class UpdateAccumulator implements ReduceFunction<Tuple2<Params, Integer>> {
-
- @Override
- public Tuple2<Params, Integer> reduce(
- Tuple2<Params, Integer> val1, Tuple2<Params, Integer> val2) {
-
- double newTheta0 = val1.f0.theta0 + val2.f0.theta0;
- double newTheta1 = val1.f0.theta1 + val2.f0.theta1;
- Params result = new Params(newTheta0, newTheta1);
- return new Tuple2<Params, Integer>(result, val1.f1 + val2.f1);
- }
- }
-
- /** Compute the final update by average them. */
- public static class Update implements MapFunction<Tuple2<Params, Integer>, Params> {
-
- @Override
- public Params map(Tuple2<Params, Integer> arg0) throws Exception {
-
- return arg0.f0.div(arg0.f1);
- }
- }
-}
diff --git a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionData.java b/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionData.java
deleted file mode 100644
index 5cb1339..0000000
--- a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionData.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.examples.java.ml.util;
-
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.examples.java.ml.LinearRegression.Data;
-import org.apache.flink.examples.java.ml.LinearRegression.Params;
-
-import java.util.LinkedList;
-import java.util.List;
-
-/**
- * Provides the default data sets used for the Linear Regression example program. The default data
- * sets are used, if no parameters are given to the program.
- */
-public class LinearRegressionData {
-
- // We have the data as object arrays so that we can also generate Scala Data
- // Sources from it.
- public static final Object[][] PARAMS = new Object[][] {new Object[] {0.0, 0.0}};
-
- public static final Object[][] DATA =
- new Object[][] {
- new Object[] {0.5, 1.0}, new Object[] {1.0, 2.0},
- new Object[] {2.0, 4.0}, new Object[] {3.0, 6.0},
- new Object[] {4.0, 8.0}, new Object[] {5.0, 10.0},
- new Object[] {6.0, 12.0}, new Object[] {7.0, 14.0},
- new Object[] {8.0, 16.0}, new Object[] {9.0, 18.0},
- new Object[] {10.0, 20.0}, new Object[] {-0.08, -0.16},
- new Object[] {0.13, 0.26}, new Object[] {-1.17, -2.35},
- new Object[] {1.72, 3.45}, new Object[] {1.70, 3.41},
- new Object[] {1.20, 2.41}, new Object[] {-0.59, -1.18},
- new Object[] {0.28, 0.57}, new Object[] {1.65, 3.30},
- new Object[] {-0.55, -1.08}
- };
-
- public static DataSet<Params> getDefaultParamsDataSet(ExecutionEnvironment env) {
- List<Params> paramsList = new LinkedList<>();
- for (Object[] params : PARAMS) {
- paramsList.add(new Params((Double) params[0], (Double) params[1]));
- }
- return env.fromCollection(paramsList);
- }
-
- public static DataSet<Data> getDefaultDataDataSet(ExecutionEnvironment env) {
- List<Data> dataList = new LinkedList<>();
- for (Object[] data : DATA) {
- dataList.add(new Data((Double) data[0], (Double) data[1]));
- }
- return env.fromCollection(dataList);
- }
-}
diff --git a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionDataGenerator.java b/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionDataGenerator.java
deleted file mode 100644
index 52a912f..0000000
--- a/flink-ml-examples/examples-batch/src/main/java/org/apache/flink/examples/java/ml/util/LinearRegressionDataGenerator.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.examples.java.ml.util;
-
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileWriter;
-import java.io.IOException;
-import java.text.DecimalFormat;
-import java.util.Locale;
-import java.util.Random;
-
-/**
- * Generates data for the {@link org.apache.flink.examples.java.ml.LinearRegression} example
- * program.
- */
-public class LinearRegressionDataGenerator {
-
- static {
- Locale.setDefault(Locale.US);
- }
-
- private static final String POINTS_FILE = "data";
- private static final long DEFAULT_SEED = 4650285087650871364L;
- private static final int DIMENSIONALITY = 1;
- private static final DecimalFormat FORMAT = new DecimalFormat("#0.00");
- private static final char DELIMITER = ' ';
-
- /**
- * Main method to generate data for the {@link
- * org.apache.flink.examples.java.ml.LinearRegression} example program.
- *
- * <p>The generator creates to files:
- *
- * <ul>
- * <li><code>{tmp.dir}/data</code> for the data points
- * </ul>
- *
- * @param args
- * <ol>
- * <li>Int: Number of data points
- * <li><b>Optional</b> Long: Random seed
- * </ol>
- */
- public static void main(String[] args) throws IOException {
-
- // check parameter count
- if (args.length < 1) {
- System.out.println("LinearRegressionDataGenerator <numberOfDataPoints> [<seed>]");
- System.exit(1);
- }
-
- // parse parameters
- final int numDataPoints = Integer.parseInt(args[0]);
- final long firstSeed = args.length > 1 ? Long.parseLong(args[4]) : DEFAULT_SEED;
- final Random random = new Random(firstSeed);
- final String tmpDir = System.getProperty("java.io.tmpdir");
-
- // write the points out
- BufferedWriter pointsOut = null;
- try {
- pointsOut = new BufferedWriter(new FileWriter(new File(tmpDir + "/" + POINTS_FILE)));
- StringBuilder buffer = new StringBuilder();
-
- // DIMENSIONALITY + 1 means that the number of x(dimensionality) and target y
- double[] point = new double[DIMENSIONALITY + 1];
-
- for (int i = 1; i <= numDataPoints; i++) {
- point[0] = random.nextGaussian();
- point[1] = 2 * point[0] + 0.01 * random.nextGaussian();
- writePoint(point, buffer, pointsOut);
- }
-
- } finally {
- if (pointsOut != null) {
- pointsOut.close();
- }
- }
-
- System.out.println(
- "Wrote " + numDataPoints + " data points to " + tmpDir + "/" + POINTS_FILE);
- }
-
- private static void writePoint(double[] data, StringBuilder buffer, BufferedWriter out)
- throws IOException {
- buffer.setLength(0);
-
- // write coordinates
- for (int j = 0; j < data.length; j++) {
- buffer.append(FORMAT.format(data[j]));
- if (j < data.length - 1) {
- buffer.append(DELIMITER);
- }
- }
-
- out.write(buffer.toString());
- out.newLine();
- }
-}
diff --git a/flink-ml-examples/examples-batch/src/main/scala/org/apache/flink/examples/scala/ml/LinearRegression.scala b/flink-ml-examples/examples-batch/src/main/scala/org/apache/flink/examples/scala/ml/LinearRegression.scala
deleted file mode 100644
index 4663db6..0000000
--- a/flink-ml-examples/examples-batch/src/main/scala/org/apache/flink/examples/scala/ml/LinearRegression.scala
+++ /dev/null
@@ -1,159 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.examples.scala.ml
-
-import org.apache.flink.api.common.functions._
-import org.apache.flink.api.java.utils.ParameterTool
-import org.apache.flink.api.scala._
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.examples.java.ml.util.LinearRegressionData
-
-import scala.collection.JavaConverters._
-
-/**
- * This example implements a basic Linear Regression to solve the y = theta0 + theta1*x problem
- * using batch gradient descent algorithm.
- *
- * Linear Regression with BGD(batch gradient descent) algorithm is an iterative algorithm and
- * works as follows:
- *
- * Giving a data set and target set, the BGD try to find out the best parameters for the data set
- * to fit the target set.
- * In each iteration, the algorithm computes the gradient of the cost function and use it to
- * update all the parameters.
- * The algorithm terminates after a fixed number of iterations (as in this implementation).
- * With enough iteration, the algorithm can minimize the cost function and find the best parameters
- * This is the Wikipedia entry for the
- * [[http://en.wikipedia.org/wiki/Linear_regression Linear regression]] and
- * [[http://en.wikipedia.org/wiki/Gradient_descent Gradient descent algorithm]].
- *
- * This implementation works on one-dimensional data and finds the best two-dimensional theta to
- * fit the target.
- *
- * Input files are plain text files and must be formatted as follows:
- *
- * - Data points are represented as two double values separated by a blank character. The first
- * one represent the X(the training data) and the second represent the Y(target). Data points are
- * separated by newline characters.
- * For example `"-0.02 -0.04\n5.3 10.6\n"`gives two data points
- * (x=-0.02, y=-0.04) and (x=5.3, y=10.6).
- *
- * This example shows how to use:
- *
- * - Bulk iterations
- * - Broadcast variables in bulk iterations
- */
-object LinearRegression {
-
- def main(args: Array[String]) {
-
- val params: ParameterTool = ParameterTool.fromArgs(args)
-
- // set up execution environment
- val env = ExecutionEnvironment.getExecutionEnvironment
-
- // make parameters available in the web interface
- env.getConfig.setGlobalJobParameters(params)
-
- val parameters = env.fromCollection(LinearRegressionData.PARAMS map {
- case Array(x, y) => Params(x.asInstanceOf[Double], y.asInstanceOf[Double])
- })
-
- val data =
- if (params.has("input")) {
- env.readCsvFile[(Double, Double)](
- params.get("input"),
- fieldDelimiter = " ",
- includedFields = Array(0, 1))
- .map { t => new Data(t._1, t._2) }
- } else {
- println("Executing LinearRegression example with default input data set.")
- println("Use --input to specify file input.")
- val data = LinearRegressionData.DATA map {
- case Array(x, y) => Data(x.asInstanceOf[Double], y.asInstanceOf[Double])
- }
- env.fromCollection(data)
- }
-
- val numIterations = params.getInt("iterations", 10)
-
- val result = parameters.iterate(numIterations) { currentParameters =>
- val newParameters = data
- .map(new SubUpdate).withBroadcastSet(currentParameters, "parameters")
- .reduce { (p1, p2) =>
- val result = p1._1 + p2._1
- (result, p1._2 + p2._2)
- }
- .map { x => x._1.div(x._2) }
- newParameters
- }
-
- if (params.has("output")) {
- result.writeAsText(params.get("output"))
- env.execute("Scala Linear Regression example")
- } else {
- println("Printing result to stdout. Use --output to specify output path.")
- result.print()
- }
- }
-
- /**
- * A simple data sample, x means the input, and y means the target.
- */
- case class Data(var x: Double, var y: Double)
-
- /**
- * A set of parameters -- theta0, theta1.
- */
- case class Params(theta0: Double, theta1: Double) {
- def div(a: Int): Params = {
- Params(theta0 / a, theta1 / a)
- }
-
- def + (other: Params) = {
- Params(theta0 + other.theta0, theta1 + other.theta1)
- }
- }
-
- // *************************************************************************
- // USER FUNCTIONS
- // *************************************************************************
-
- /**
- * Compute a single BGD type update for every parameters.
- */
- class SubUpdate extends RichMapFunction[Data, (Params, Int)] {
-
- private var parameter: Params = null
-
- /** Reads the parameters from a broadcast variable into a collection. */
- override def open(parameters: Configuration) {
- val parameters = getRuntimeContext.getBroadcastVariable[Params]("parameters").asScala
- parameter = parameters.head
- }
-
- def map(in: Data): (Params, Int) = {
- val theta0 =
- parameter.theta0 - 0.01 * ((parameter.theta0 + (parameter.theta1 * in.x)) - in.y)
- val theta1 =
- parameter.theta1 - 0.01 * (((parameter.theta0 + (parameter.theta1 * in.x)) - in.y) * in.x)
- (Params(theta0, theta1), 1)
- }
- }
-}
diff --git a/flink-ml-examples/examples-streaming/pom.xml b/flink-ml-examples/examples-streaming/pom.xml
deleted file mode 100644
index 34e5399..0000000
--- a/flink-ml-examples/examples-streaming/pom.xml
+++ /dev/null
@@ -1,190 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<!--
-Licensed to the Apache Software Foundation (ASF) under one
-or more contributor license agreements. See the NOTICE file
-distributed with this work for additional information
-regarding copyright ownership. The ASF licenses this file
-to you under the Apache License, Version 2.0 (the
-"License"); you may not use this file except in compliance
-with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing,
-software distributed under the License is distributed on an
-"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, either express or implied. See the License for the
-specific language governing permissions and limitations
-under the License.
--->
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
-
- <modelVersion>4.0.0</modelVersion>
-
- <parent>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-ml-examples</artifactId>
- <version>0.1-SNAPSHOT</version>
- </parent>
-
- <artifactId>flink-ml-examples-streaming_${scala.binary.version}</artifactId>
- <name>Flink ML : Examples : Streaming</name>
-
- <packaging>jar</packaging>
-
- <dependencies>
- <!-- core dependencies -->
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-streaming-scala_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-clients_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-connector-twitter_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-connector-kafka_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-shaded-jackson</artifactId>
- </dependency>
-
- <!-- test dependencies -->
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-test-utils_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- <scope>test</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-statebackend-rocksdb_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- </dependency>
- </dependencies>
-
- <build>
- <plugins>
- <!-- Scala Code Style, most of the configuration done via plugin management -->
- <plugin>
- <groupId>org.scalastyle</groupId>
- <artifactId>scalastyle-maven-plugin</artifactId>
- <configuration>
- <configLocation>${project.basedir}/../../tools/maven/scalastyle-config.xml</configLocation>
- </configuration>
- </plugin>
-
- <!-- self-contained jars for each example -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-jar-plugin</artifactId>
- <version>2.4</version><!--$NO-MVN-MAN-VER$-->
- <executions>
- <!-- Default Execution -->
- <execution>
- <id>default</id>
- <phase>package</phase>
- <goals>
- <goal>test-jar</goal>
- </goals>
- </execution>
-
- <!-- IncrementalLearning -->
- <execution>
- <id>IncrementalLearning</id>
- <phase>package</phase>
- <goals>
- <goal>jar</goal>
- </goals>
- <configuration>
- <classifier>IncrementalLearning</classifier>
-
- <archive>
- <manifestEntries>
- <program-class>org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton</program-class>
- </manifestEntries>
- </archive>
-
- <includes>
- <include>org/apache/flink/streaming/examples/ml/*.class</include>
- <include>META-INF/LICENSE</include>
- <include>META-INF/NOTICE</include>
- </includes>
- </configuration>
- </execution>
- </executions>
- </plugin>
-
- <!-- Scala Compiler -->
- <plugin>
- <groupId>net.alchim31.maven</groupId>
- <artifactId>scala-maven-plugin</artifactId>
- <executions>
- <!-- Run scala compiler in the process-resources phase, so that dependencies on
- scala classes can be resolved later in the (Java) compile phase -->
- <execution>
- <id>scala-compile-first</id>
- <phase>process-resources</phase>
- <goals>
- <goal>compile</goal>
- </goals>
- </execution>
-
- <!-- Run scala compiler in the process-test-resources phase, so that dependencies on
- scala classes can be resolved later in the (Java) test-compile phase -->
- <execution>
- <id>scala-test-compile</id>
- <phase>process-test-resources</phase>
- <goals>
- <goal>testCompile</goal>
- </goals>
- </execution>
- </executions>
- <configuration>
- <jvmArgs>
- <jvmArg>-Xms128m</jvmArg>
- <jvmArg>-Xmx512m</jvmArg>
- </jvmArgs>
- </configuration>
- </plugin>
-
- <!--simplify the name of example JARs for build-target/examples -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-antrun-plugin</artifactId>
- <executions>
- <execution>
- <id>rename</id>
- <configuration>
- <target>
- <copy file="${project.basedir}/target/flink-ml-examples-streaming_${scala.binary.version}-${project.version}-IncrementalLearning.jar" tofile="${project.basedir}/target/IncrementalLearning.jar" />
- </target>
- </configuration>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
-
-</project>
diff --git a/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/IncrementalLearningSkeleton.java b/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/IncrementalLearningSkeleton.java
deleted file mode 100644
index 4f55f61..0000000
--- a/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/IncrementalLearningSkeleton.java
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.examples.ml;
-
-import org.apache.flink.api.java.utils.ParameterTool;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks;
-import org.apache.flink.streaming.api.functions.co.CoMapFunction;
-import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
-import org.apache.flink.streaming.api.watermark.Watermark;
-import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows;
-import org.apache.flink.streaming.api.windowing.time.Time;
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
-import org.apache.flink.util.Collector;
-
-/**
- * Skeleton for incremental machine learning algorithm consisting of a pre-computed model, which
- * gets updated for the new inputs and new input data for which the job provides predictions.
- *
- * <p>This may serve as a base of a number of algorithms, e.g. updating an incremental Alternating
- * Least Squares model while also providing the predictions.
- *
- * <p>This example shows how to use:
- *
- * <ul>
- * <li>Connected streams
- * <li>CoFunctions
- * <li>Tuple data types
- * </ul>
- */
-public class IncrementalLearningSkeleton {
-
- // *************************************************************************
- // PROGRAM
- // *************************************************************************
-
- public static void main(String[] args) throws Exception {
-
- // Checking input parameters
- final ParameterTool params = ParameterTool.fromArgs(args);
-
- StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
-
- DataStream<Integer> trainingData = env.addSource(new FiniteTrainingDataSource());
- DataStream<Integer> newData = env.addSource(new FiniteNewDataSource());
-
- // build new model on every second of new data
- DataStream<Double[]> model =
- trainingData
- .assignTimestampsAndWatermarks(new LinearTimestamp())
- .windowAll(TumblingEventTimeWindows.of(Time.milliseconds(5000)))
- .apply(new PartialModelBuilder());
-
- // use partial model for newData
- DataStream<Integer> prediction = newData.connect(model).map(new Predictor());
-
- // emit result
- if (params.has("output")) {
- prediction.writeAsText(params.get("output"));
- } else {
- System.out.println("Printing result to stdout. Use --output to specify output path.");
- prediction.print();
- }
-
- // execute program
- env.execute("Streaming Incremental Learning");
- }
-
- // *************************************************************************
- // USER FUNCTIONS
- // *************************************************************************
-
- /**
- * Feeds new data for newData. By default it is implemented as constantly emitting the Integer 1
- * in a loop.
- */
- public static class FiniteNewDataSource implements SourceFunction<Integer> {
- private static final long serialVersionUID = 1L;
- private int counter;
-
- @Override
- public void run(SourceContext<Integer> ctx) throws Exception {
- Thread.sleep(15);
- while (counter < 50) {
- ctx.collect(getNewData());
- }
- }
-
- @Override
- public void cancel() {
- // No cleanup needed
- }
-
- private Integer getNewData() throws InterruptedException {
- Thread.sleep(5);
- counter++;
- return 1;
- }
- }
-
- /**
- * Feeds new training data for the partial model builder. By default it is implemented as
- * constantly emitting the Integer 1 in a loop.
- */
- public static class FiniteTrainingDataSource implements SourceFunction<Integer> {
- private static final long serialVersionUID = 1L;
- private int counter = 0;
-
- @Override
- public void run(SourceContext<Integer> collector) throws Exception {
- while (counter < 8200) {
- collector.collect(getTrainingData());
- }
- }
-
- @Override
- public void cancel() {
- // No cleanup needed
- }
-
- private Integer getTrainingData() throws InterruptedException {
- counter++;
- return 1;
- }
- }
-
- private static class LinearTimestamp implements AssignerWithPunctuatedWatermarks<Integer> {
- private static final long serialVersionUID = 1L;
-
- private long counter = 0L;
-
- @Override
- public long extractTimestamp(Integer element, long previousElementTimestamp) {
- return counter += 10L;
- }
-
- @Override
- public Watermark checkAndGetNextWatermark(Integer lastElement, long extractedTimestamp) {
- return new Watermark(counter - 1);
- }
- }
-
- /** Builds up-to-date partial models on new training data. */
- public static class PartialModelBuilder
- implements AllWindowFunction<Integer, Double[], TimeWindow> {
- private static final long serialVersionUID = 1L;
-
- protected Double[] buildPartialModel(Iterable<Integer> values) {
- return new Double[] {1.};
- }
-
- @Override
- public void apply(TimeWindow window, Iterable<Integer> values, Collector<Double[]> out)
- throws Exception {
- out.collect(buildPartialModel(values));
- }
- }
-
- /**
- * Creates newData using the model produced in batch-processing and the up-to-date partial
- * model.
- *
- * <p>By default emits the Integer 0 for every newData and the Integer 1 for every model update.
- */
- public static class Predictor implements CoMapFunction<Integer, Double[], Integer> {
- private static final long serialVersionUID = 1L;
-
- Double[] batchModel = null;
- Double[] partialModel = null;
-
- @Override
- public Integer map1(Integer value) {
- // Return newData
- return predict(value);
- }
-
- @Override
- public Integer map2(Double[] value) {
- // Update model
- partialModel = value;
- batchModel = getBatchModel();
- return 1;
- }
-
- // pulls model built with batch-job on the old training data
- protected Double[] getBatchModel() {
- return new Double[] {0.};
- }
-
- // performs newData using the two models
- protected Integer predict(Integer inTuple) {
- return 0;
- }
- }
-}
diff --git a/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/util/IncrementalLearningSkeletonData.java b/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/util/IncrementalLearningSkeletonData.java
deleted file mode 100644
index eaf5bdb..0000000
--- a/flink-ml-examples/examples-streaming/src/main/java/org/apache/flink/streaming/examples/ml/util/IncrementalLearningSkeletonData.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.examples.ml.util;
-
-/** Data for IncrementalLearningSkeletonITCase. */
-public class IncrementalLearningSkeletonData {
-
- public static final String RESULTS =
- "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n"
- + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "1\n" + "0\n" + "0\n" + "0\n" + "0\n"
- + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n"
- + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n"
- + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n"
- + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n"
- + "0\n" + "0\n" + "0\n" + "0\n" + "0\n" + "0\n";
-
- private IncrementalLearningSkeletonData() {}
-}
diff --git a/flink-ml-examples/examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/ml/IncrementalLearningSkeleton.scala b/flink-ml-examples/examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/ml/IncrementalLearningSkeleton.scala
deleted file mode 100644
index 2f1e168..0000000
--- a/flink-ml-examples/examples-streaming/src/main/scala/org/apache/flink/streaming/scala/examples/ml/IncrementalLearningSkeleton.scala
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.scala.examples.ml
-
-import org.apache.flink.api.java.utils.ParameterTool
-import org.apache.flink.api.scala._
-import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks
-import org.apache.flink.streaming.api.functions.co.CoMapFunction
-import org.apache.flink.streaming.api.functions.source.SourceFunction
-import org.apache.flink.streaming.api.functions.source.SourceFunction.SourceContext
-import org.apache.flink.streaming.api.scala.function.AllWindowFunction
-import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
-import org.apache.flink.streaming.api.watermark.Watermark
-import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindows
-import org.apache.flink.streaming.api.windowing.time.Time
-import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.util.Collector
-
-/**
- * Skeleton for incremental machine learning algorithm consisting of a
- * pre-computed model, which gets updated for the new inputs and new input data
- * for which the job provides predictions.
- *
- * This may serve as a base of a number of algorithms, e.g. updating an
- * incremental Alternating Least Squares model while also providing the
- * predictions.
- *
- * This example shows how to use:
- *
- * - Connected streams
- * - CoFunctions
- * - Tuple data types
- *
- */
-object IncrementalLearningSkeleton {
-
- // *************************************************************************
- // PROGRAM
- // *************************************************************************
-
- def main(args: Array[String]): Unit = {
- // Checking input parameters
- val params = ParameterTool.fromArgs(args)
-
- // set up the execution environment
- val env = StreamExecutionEnvironment.getExecutionEnvironment
-
- // build new model on every second of new data
- val trainingData: DataStream[Integer] = env.addSource(new FiniteTrainingDataSource)
- val newData: DataStream[Integer] = env.addSource(new FiniteNewDataSource)
-
- val model: DataStream[Array[java.lang.Double]] = trainingData
- .assignTimestampsAndWatermarks(new LinearTimestamp)
- .windowAll(TumblingEventTimeWindows.of(Time.milliseconds(5000)))
- .apply(new PartialModelBuilder)
-
- // use partial model for newData
- val prediction: DataStream[Integer] = newData.connect(model).map(new Predictor)
-
- // emit result
- if (params.has("output")) {
- prediction.writeAsText(params.get("output"))
- } else {
- println("Printing result to stdout. Use --output to specify output path.")
- prediction.print()
- }
-
- // execute program
- env.execute("Streaming Incremental Learning")
- }
-
- // *************************************************************************
- // USER FUNCTIONS
- // *************************************************************************
-
- /**
- * Feeds new data for newData. By default it is implemented as constantly
- * emitting the Integer 1 in a loop.
- */
- private class FiniteNewDataSource extends SourceFunction[Integer] {
- override def run(ctx: SourceContext[Integer]) = {
- Thread.sleep(15)
- (0 until 50).foreach{ _ =>
- Thread.sleep(5)
- ctx.collect(1)
- }
- }
-
- override def cancel() = {
- // No cleanup needed
- }
- }
-
- /**
- * Feeds new training data for the partial model builder. By default it is
- * implemented as constantly emitting the Integer 1 in a loop.
- */
- private class FiniteTrainingDataSource extends SourceFunction[Integer] {
- override def run(ctx: SourceContext[Integer]) =
- (0 until 8200).foreach( _ => ctx.collect(1) )
-
- override def cancel() = {
- // No cleanup needed
- }
- }
-
- private class LinearTimestamp extends AssignerWithPunctuatedWatermarks[Integer] {
- var counter = 0L
-
- override def extractTimestamp(element: Integer, previousElementTimestamp: Long): Long = {
- counter += 10L
- counter
- }
-
- override def checkAndGetNextWatermark(lastElement: Integer, extractedTimestamp: Long) = {
- new Watermark(counter - 1)
- }
- }
-
- /**
- * Builds up-to-date partial models on new training data.
- */
- private class PartialModelBuilder
- extends AllWindowFunction[Integer, Array[java.lang.Double], TimeWindow] {
-
- protected def buildPartialModel(values: Iterable[Integer]): Array[java.lang.Double] =
- Array[java.lang.Double](1)
-
- override def apply(window: TimeWindow,
- values: Iterable[Integer],
- out: Collector[Array[java.lang.Double]]): Unit = {
- out.collect(buildPartialModel(values))
- }
- }
-
- /**
- * Creates newData using the model produced in batch-processing and the
- * up-to-date partial model.
- *
- * By default emits the Integer 0 for every newData and the Integer 1
- * for every model update.
- *
- */
- private class Predictor extends CoMapFunction[Integer, Array[java.lang.Double], Integer] {
-
- var batchModel: Array[java.lang.Double] = null
- var partialModel: Array[java.lang.Double] = null
-
- override def map1(value: Integer): Integer = {
- // Return newData
- predict(value)
- }
-
- override def map2(value: Array[java.lang.Double]): Integer = {
- // Update model
- partialModel = value
- batchModel = getBatchModel()
- 1
- }
-
- // pulls model built with batch-job on the old training data
- protected def getBatchModel(): Array[java.lang.Double] = Array[java.lang.Double](0)
-
- // performs newData using the two models
- protected def predict(inTuple: Int): Int = 0
- }
-
-}
diff --git a/flink-ml-examples/examples-streaming/src/test/java/org/apache/flink/streaming/test/StreamingExamplesITCase.java b/flink-ml-examples/examples-streaming/src/test/java/org/apache/flink/streaming/test/StreamingExamplesITCase.java
deleted file mode 100644
index 6f7bf5d..0000000
--- a/flink-ml-examples/examples-streaming/src/test/java/org/apache/flink/streaming/test/StreamingExamplesITCase.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.test;
-
-import org.apache.flink.streaming.examples.ml.util.IncrementalLearningSkeletonData;
-import org.apache.flink.test.util.AbstractTestBase;
-
-import org.junit.Test;
-
-/** Integration test for streaming programs in Java examples. */
-public class StreamingExamplesITCase extends AbstractTestBase {
-
- @Test
- public void testIncrementalLearningSkeleton() throws Exception {
- final String resultPath = getTempDirPath("result");
- org.apache.flink.streaming.examples.ml.IncrementalLearningSkeleton.main(
- new String[] {"--output", resultPath});
- compareResultsByLinesInMemory(IncrementalLearningSkeletonData.RESULTS, resultPath);
- }
-}
diff --git a/flink-ml-examples/examples-streaming/src/test/scala/org/apache/flink/streaming/scala/examples/StreamingExamplesITCase.scala b/flink-ml-examples/examples-streaming/src/test/scala/org/apache/flink/streaming/scala/examples/StreamingExamplesITCase.scala
deleted file mode 100644
index 7fcbe9d..0000000
--- a/flink-ml-examples/examples-streaming/src/test/scala/org/apache/flink/streaming/scala/examples/StreamingExamplesITCase.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.scala.examples
-
-import org.apache.flink.streaming.examples.ml.util.IncrementalLearningSkeletonData
-import org.apache.flink.streaming.scala.examples.ml.IncrementalLearningSkeleton
-import org.apache.flink.test.util.{AbstractTestBase, TestBaseUtils}
-import org.junit.Test
-
-/**
- * Integration test for streaming programs in Scala examples.
- */
-class StreamingExamplesITCase extends AbstractTestBase {
-
- @Test
- def testIncrementalLearningSkeleton(): Unit = {
- val resultPath = getTempDirPath("result")
- IncrementalLearningSkeleton.main(Array("--output", resultPath))
- TestBaseUtils.compareResultsByLinesInMemory(IncrementalLearningSkeletonData.RESULTS, resultPath)
- }
-}
diff --git a/flink-ml-examples/pom.xml b/flink-ml-examples/pom.xml
deleted file mode 100644
index e223fdd..0000000
--- a/flink-ml-examples/pom.xml
+++ /dev/null
@@ -1,111 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<!--
-Licensed to the Apache Software Foundation (ASF) under one
-or more contributor license agreements. See the NOTICE file
-distributed with this work for additional information
-regarding copyright ownership. The ASF licenses this file
-to you under the Apache License, Version 2.0 (the
-"License"); you may not use this file except in compliance
-with the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing,
-software distributed under the License is distributed on an
-"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-KIND, either express or implied. See the License for the
-specific language governing permissions and limitations
-under the License.
--->
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
-
- <modelVersion>4.0.0</modelVersion>
-
- <parent>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-ml-parent</artifactId>
- <version>0.1-SNAPSHOT</version>
- </parent>
-
- <artifactId>flink-ml-examples</artifactId>
- <name>Flink ML : Examples :</name>
- <packaging>pom</packaging>
-
- <modules>
- <module>examples-streaming</module>
- <module>examples-batch</module>
- </modules>
-
- <dependencies>
- <!-- Flink dependencies -->
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-core</artifactId>
- <version>${flink.version}</version>
- </dependency>
-
- <!-- Add a logging Framework, to make the examples produce -->
- <!-- logs when executing in the IDE -->
- <dependency>
- <groupId>org.apache.logging.log4j</groupId>
- <artifactId>log4j-slf4j-impl</artifactId>
- <scope>compile</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.logging.log4j</groupId>
- <artifactId>log4j-api</artifactId>
- <scope>compile</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.logging.log4j</groupId>
- <artifactId>log4j-core</artifactId>
- <scope>compile</scope>
- </dependency>
-
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-test-utils-junit</artifactId>
- </dependency>
- </dependencies>
-
- <build>
- <pluginManagement>
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-antrun-plugin</artifactId>
- <version>1.7</version>
- <executions>
- <execution>
- <id>rename</id>
- <phase>package</phase>
- <goals>
- <goal>run</goal>
- </goals>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </pluginManagement>
-
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-enforcer-plugin</artifactId>
- <executions>
- <execution>
- <id>dependency-convergence</id>
- <goals>
- <goal>enforce</goal>
- </goals>
- <configuration>
- <skip>true</skip>
- </configuration>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
-</project>
diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml
index bee7f25..883ae81 100644
--- a/flink-ml-lib/pom.xml
+++ b/flink-ml-lib/pom.xml
@@ -36,6 +36,7 @@ under the License.
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-ml-iteration</artifactId>
@@ -48,24 +49,21 @@ under the License.
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-api-java-bridge_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<scope>provided</scope>
</dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner_${scala.binary.version}</artifactId>
<version>${flink.version}</version>
<scope>test</scope>
</dependency>
- <dependency>
- <groupId>org.apache.flink</groupId>
- <artifactId>flink-clients_${scala.binary.version}</artifactId>
- <version>${flink.version}</version>
- <scope>provided</scope>
- </dependency>
+
<dependency>
<groupId>com.github.fommil.netlib</groupId>
<artifactId>core</artifactId>
@@ -106,20 +104,4 @@ under the License.
</dependency>
</dependencies>
- <build>
- <plugins>
- <!-- Because PyFlink uses it in tests -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-jar-plugin</artifactId>
- <executions>
- <execution>
- <goals>
- <goal>test-jar</goal>
- </goals>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
</project>
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironment.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironment.java
deleted file mode 100644
index aea7a67..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironment.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common;
-
-import org.apache.flink.api.java.ExecutionEnvironment;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.EnvironmentSettings;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-
-/**
- * The MLEnvironment stores the necessary context in Flink. Each MLEnvironment will be associated
- * with a unique ID. The operations associated with the same MLEnvironment ID will share the same
- * Flink job context.
- *
- * <p>Both MLEnvironment ID and MLEnvironment can only be retrieved from MLEnvironmentFactory.
- *
- * @see ExecutionEnvironment
- * @see StreamExecutionEnvironment
- * @see StreamTableEnvironment
- */
-public class MLEnvironment {
- private StreamExecutionEnvironment streamEnv;
- private StreamTableEnvironment streamTableEnv;
-
- /** Construct with null that the class can load the environment in the `get` method. */
- public MLEnvironment() {
- this(null, null);
- }
-
- /**
- * Construct with the stream environment and the the stream table environment given by user.
- *
- * <p>The env can be null which will be loaded in the `get` method.
- *
- * @param streamEnv the StreamExecutionEnvironment
- * @param streamTableEnv the StreamTableEnvironment
- */
- public MLEnvironment(
- StreamExecutionEnvironment streamEnv, StreamTableEnvironment streamTableEnv) {
- this.streamEnv = streamEnv;
- this.streamTableEnv = streamTableEnv;
- }
-
- /**
- * Get the StreamExecutionEnvironment. if the StreamExecutionEnvironment has not been set, it
- * initial the StreamExecutionEnvironment with default Configuration.
- *
- * @return the {@link StreamExecutionEnvironment}
- */
- public StreamExecutionEnvironment getStreamExecutionEnvironment() {
- if (null == streamEnv) {
- streamEnv = StreamExecutionEnvironment.getExecutionEnvironment();
- }
- return streamEnv;
- }
-
- /**
- * Get the StreamTableEnvironment. if the StreamTableEnvironment has not been set, it initial
- * the StreamTableEnvironment with default Configuration.
- *
- * @return the {@link StreamTableEnvironment}
- */
- public StreamTableEnvironment getStreamTableEnvironment() {
- if (null == streamTableEnv) {
- streamTableEnv =
- StreamTableEnvironment.create(
- getStreamExecutionEnvironment(),
- EnvironmentSettings.newInstance().build());
- }
- return streamTableEnv;
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironmentFactory.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironmentFactory.java
deleted file mode 100644
index 8d2526c..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/MLEnvironmentFactory.java
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common;
-
-import org.apache.flink.util.Preconditions;
-
-import java.util.HashMap;
-
-/**
- * Factory to get the MLEnvironment using a MLEnvironmentId.
- *
- * <p>The following code snippet shows how to interact with MLEnvironmentFactory.
- *
- * <pre>{@code
- * long mlEnvId = MLEnvironmentFactory.getNewMLEnvironmentId();
- * MLEnvironment mlEnv = MLEnvironmentFactory.get(mlEnvId);
- * }</pre>
- */
-public class MLEnvironmentFactory {
-
- /** The default MLEnvironmentId. */
- public static final Long DEFAULT_ML_ENVIRONMENT_ID = 0L;
-
- /**
- * A monotonically increasing id for the MLEnvironments. Each id uniquely identifies an
- * MLEnvironment.
- */
- private static Long nextId = 1L;
-
- /** Map that hold the MLEnvironment and use the MLEnvironmentId as its key. */
- private static final HashMap<Long, MLEnvironment> map = new HashMap<>();
-
- static {
- map.put(DEFAULT_ML_ENVIRONMENT_ID, new MLEnvironment());
- }
-
- /**
- * Get the MLEnvironment using a MLEnvironmentId.
- *
- * @param mlEnvId the MLEnvironmentId
- * @return the MLEnvironment
- */
- public static synchronized MLEnvironment get(Long mlEnvId) {
- if (!map.containsKey(mlEnvId)) {
- throw new IllegalArgumentException(
- String.format(
- "Cannot find MLEnvironment for MLEnvironmentId %s."
- + " Did you get the MLEnvironmentId by calling getNewMLEnvironmentId?",
- mlEnvId));
- }
-
- return map.get(mlEnvId);
- }
-
- /**
- * Get the MLEnvironment use the default MLEnvironmentId.
- *
- * @return the default MLEnvironment.
- */
- public static synchronized MLEnvironment getDefault() {
- return get(DEFAULT_ML_ENVIRONMENT_ID);
- }
-
- /**
- * Create a unique MLEnvironment id and register a new MLEnvironment in the factory.
- *
- * @return the MLEnvironment id.
- */
- public static synchronized Long getNewMLEnvironmentId() {
- return registerMLEnvironment(new MLEnvironment());
- }
-
- /**
- * Register a new MLEnvironment to the factory and return a new MLEnvironment id.
- *
- * @param env the MLEnvironment that will be stored in the factory.
- * @return the MLEnvironment id.
- */
- public static synchronized Long registerMLEnvironment(MLEnvironment env) {
- map.put(nextId, env);
- return nextId++;
- }
-
- /**
- * Remove the MLEnvironment using the MLEnvironmentId.
- *
- * @param mlEnvId the id.
- * @return the removed MLEnvironment
- */
- public static synchronized MLEnvironment remove(Long mlEnvId) {
- Preconditions.checkNotNull(mlEnvId, "The environment id cannot be null.");
- // Never remove the default MLEnvironment. Just return the default environment.
- if (DEFAULT_ML_ENVIRONMENT_ID.equals(mlEnvId)) {
- return getDefault();
- } else {
- return map.remove(mlEnvId);
- }
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java
deleted file mode 100644
index aca2f23..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java
+++ /dev/null
@@ -1,234 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.apache.flink.util.Preconditions;
-
-/** A utility class that provides BLAS routines over matrices and vectors. */
-public class BLAS {
-
- /** For level-1 routines, we use Java implementation. */
- private static final com.github.fommil.netlib.BLAS NATIVE_BLAS =
- com.github.fommil.netlib.BLAS.getInstance();
-
- /**
- * For level-2 and level-3 routines, we use the native BLAS.
- *
- * <p>The NATIVE_BLAS instance tries to load BLAS implementations in the order: 1) optimized
- * system libraries such as Intel MKL, 2) self-contained native builds using the reference
- * Fortran from netlib.org, 3) F2J implementation. If to use optimized system libraries, it is
- * important to turn of their multi-thread support. Otherwise, it will conflict with Flink's
- * executor and leads to performance loss.
- */
- private static final com.github.fommil.netlib.BLAS F2J_BLAS =
- com.github.fommil.netlib.F2jBLAS.getInstance();
-
- /** \sum_i |x_i| . */
- public static double asum(int n, double[] x, int offset) {
- return F2J_BLAS.dasum(n, x, offset, 1);
- }
-
- /** \sum_i |x_i| . */
- public static double asum(DenseVector x) {
- return asum(x.data.length, x.data, 0);
- }
-
- /** \sum_i |x_i| . */
- public static double asum(SparseVector x) {
- return asum(x.values.length, x.values, 0);
- }
-
- /** y += a * x . */
- public static void axpy(double a, double[] x, double[] y) {
- Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched.");
- F2J_BLAS.daxpy(x.length, a, x, 1, y, 1);
- }
-
- /** y += a * x . */
- public static void axpy(double a, DenseVector x, DenseVector y) {
- Preconditions.checkArgument(x.data.length == y.data.length, "Vector dimension mismatched.");
- F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1);
- }
-
- /** y += a * x . */
- public static void axpy(double a, SparseVector x, DenseVector y) {
- for (int i = 0; i < x.indices.length; i++) {
- y.data[x.indices[i]] += a * x.values[i];
- }
- }
-
- /** y += a * x . */
- public static void axpy(double a, DenseMatrix x, DenseMatrix y) {
- Preconditions.checkArgument(x.m == y.m && x.n == y.n, "Matrix dimension mismatched.");
- F2J_BLAS.daxpy(x.data.length, a, x.data, 1, y.data, 1);
- }
-
- /** y[yOffset:yOffset+n] += a * x[xOffset:xOffset+n] . */
- public static void axpy(int n, double a, double[] x, int xOffset, double[] y, int yOffset) {
- F2J_BLAS.daxpy(n, a, x, xOffset, 1, y, yOffset, 1);
- }
-
- /** x \cdot y . */
- public static double dot(double[] x, double[] y) {
- Preconditions.checkArgument(x.length == y.length, "Array dimension mismatched.");
- double s = 0.;
- for (int i = 0; i < x.length; i++) {
- s += x[i] * y[i];
- }
- return s;
- }
-
- /** x \cdot y . */
- public static double dot(DenseVector x, DenseVector y) {
- return dot(x.getData(), y.getData());
- }
-
- /** x = x * a . */
- public static void scal(double a, double[] x) {
- F2J_BLAS.dscal(x.length, a, x, 1);
- }
-
- /** x = x * a . */
- public static void scal(double a, DenseVector x) {
- F2J_BLAS.dscal(x.data.length, a, x.data, 1);
- }
-
- /** x = x * a . */
- public static void scal(double a, SparseVector x) {
- F2J_BLAS.dscal(x.values.length, a, x.values, 1);
- }
-
- /** x = x * a . */
- public static void scal(double a, DenseMatrix x) {
- F2J_BLAS.dscal(x.data.length, a, x.data, 1);
- }
-
- /** C := alpha * A * B + beta * C . */
- public static void gemm(
- double alpha,
- DenseMatrix matA,
- boolean transA,
- DenseMatrix matB,
- boolean transB,
- double beta,
- DenseMatrix matC) {
- if (transA) {
- Preconditions.checkArgument(
- matA.numCols() == matC.numRows(),
- "The columns of A does not match the rows of C");
- } else {
- Preconditions.checkArgument(
- matA.numRows() == matC.numRows(), "The rows of A does not match the rows of C");
- }
- if (transB) {
- Preconditions.checkArgument(
- matB.numRows() == matC.numCols(),
- "The rows of B does not match the columns of C");
- } else {
- Preconditions.checkArgument(
- matB.numCols() == matC.numCols(),
- "The columns of B does not match the columns of C");
- }
-
- final int m = matC.numRows();
- final int n = matC.numCols();
- final int k = transA ? matA.numRows() : matA.numCols();
- final int lda = matA.numRows();
- final int ldb = matB.numRows();
- final int ldc = matC.numRows();
- final String ta = transA ? "T" : "N";
- final String tb = transB ? "T" : "N";
- NATIVE_BLAS.dgemm(
- ta,
- tb,
- m,
- n,
- k,
- alpha,
- matA.getData(),
- lda,
- matB.getData(),
- ldb,
- beta,
- matC.getData(),
- ldc);
- }
-
- /** Check the compatibility of matrix and vector sizes in <code>gemv</code>. */
- private static void gemvDimensionCheck(DenseMatrix matA, boolean transA, Vector x, Vector y) {
- if (transA) {
- Preconditions.checkArgument(
- matA.numCols() == y.size() && matA.numRows() == x.size(),
- "Matrix and vector size mismatched.");
- } else {
- Preconditions.checkArgument(
- matA.numRows() == y.size() && matA.numCols() == x.size(),
- "Matrix and vector size mismatched.");
- }
- }
-
- /** y := alpha * A * x + beta * y . */
- public static void gemv(
- double alpha,
- DenseMatrix matA,
- boolean transA,
- DenseVector x,
- double beta,
- DenseVector y) {
- gemvDimensionCheck(matA, transA, x, y);
- final int m = matA.numRows();
- final int n = matA.numCols();
- final int lda = matA.numRows();
- final String ta = transA ? "T" : "N";
- NATIVE_BLAS.dgemv(
- ta, m, n, alpha, matA.getData(), lda, x.getData(), 1, beta, y.getData(), 1);
- }
-
- /** y := alpha * A * x + beta * y . */
- public static void gemv(
- double alpha,
- DenseMatrix matA,
- boolean transA,
- SparseVector x,
- double beta,
- DenseVector y) {
- gemvDimensionCheck(matA, transA, x, y);
- final int m = matA.numRows();
- final int n = matA.numCols();
- if (transA) {
- int start = 0;
- for (int i = 0; i < n; i++) {
- double s = 0.;
- for (int j = 0; j < x.indices.length; j++) {
- s += x.values[j] * matA.data[start + x.indices[j]];
- }
- y.data[i] = beta * y.data[i] + alpha * s;
- start += m;
- }
- } else {
- scal(beta, y);
- for (int i = 0; i < x.indices.length; i++) {
- int index = x.indices[i];
- double value = alpha * x.values[i];
- F2J_BLAS.daxpy(m, value, matA.data, index * m, 1, y.data, 0, 1);
- }
- }
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java
deleted file mode 100644
index 643416c..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseMatrix.java
+++ /dev/null
@@ -1,577 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.io.Serializable;
-import java.util.Arrays;
-
-/**
- * DenseMatrix stores dense matrix data and provides some methods to operate on the matrix it
- * represents.
- */
-public class DenseMatrix implements Serializable {
-
- /**
- * Row dimension.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- int m;
-
- /**
- * Column dimension.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- int n;
-
- /**
- * Array for internal storage of elements.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- *
- * <p>The matrix data is stored in column major format internally.
- */
- double[] data;
-
- /**
- * Construct an m-by-n matrix of zeros.
- *
- * @param m Number of rows.
- * @param n Number of columns.
- */
- public DenseMatrix(int m, int n) {
- this(m, n, new double[m * n], false);
- }
-
- /**
- * Construct a matrix from a 1-D array. The data in the array should organize in column major.
- *
- * @param m Number of rows.
- * @param n Number of cols.
- * @param data One-dimensional array of doubles.
- */
- public DenseMatrix(int m, int n, double[] data) {
- this(m, n, data, false);
- }
-
- /**
- * Construct a matrix from a 1-D array. The data in the array is organized in column major or in
- * row major, which is specified by parameter 'inRowMajor'
- *
- * @param m Number of rows.
- * @param n Number of cols.
- * @param data One-dimensional array of doubles.
- * @param inRowMajor Whether the matrix in 'data' is in row major format.
- */
- public DenseMatrix(int m, int n, double[] data, boolean inRowMajor) {
- assert (data.length == m * n);
- this.m = m;
- this.n = n;
- if (inRowMajor) {
- toColumnMajor(m, n, data);
- }
- this.data = data;
- }
-
- /**
- * Construct a matrix from a 2-D array.
- *
- * @param data Two-dimensional array of doubles.
- * @throws IllegalArgumentException All rows must have the same size
- */
- public DenseMatrix(double[][] data) {
- this.m = data.length;
- if (this.m == 0) {
- this.n = 0;
- this.data = new double[0];
- return;
- }
- this.n = data[0].length;
- for (int i = 0; i < m; i++) {
- if (data[i].length != n) {
- throw new IllegalArgumentException("All rows must have the same size.");
- }
- }
- this.data = new double[m * n];
- for (int i = 0; i < m; i++) {
- for (int j = 0; j < n; j++) {
- this.set(i, j, data[i][j]);
- }
- }
- }
-
- /**
- * Create an identity matrix.
- *
- * @param n the dimension of the eye matrix.
- * @return an identity matrix.
- */
- public static DenseMatrix eye(int n) {
- return eye(n, n);
- }
-
- /**
- * Create a m * n identity matrix.
- *
- * @param m the row dimension.
- * @param n the column dimension.e
- * @return the m * n identity matrix.
- */
- public static DenseMatrix eye(int m, int n) {
- DenseMatrix mat = new DenseMatrix(m, n);
- int k = Math.min(m, n);
- for (int i = 0; i < k; i++) {
- mat.data[i * m + i] = 1.0;
- }
- return mat;
- }
-
- /**
- * Create a zero matrix.
- *
- * @param m the row dimension.
- * @param n the column dimension.
- * @return a m * n zero matrix.
- */
- public static DenseMatrix zeros(int m, int n) {
- return new DenseMatrix(m, n);
- }
-
- /**
- * Create a matrix with all elements set to 1.
- *
- * @param m the row dimension
- * @param n the column dimension
- * @return the m * n matrix with all elements set to 1.
- */
- public static DenseMatrix ones(int m, int n) {
- DenseMatrix mat = new DenseMatrix(m, n);
- Arrays.fill(mat.data, 1.);
- return mat;
- }
-
- /**
- * Create a random matrix.
- *
- * @param m the row dimension
- * @param n the column dimension.
- * @return a m * n random matrix.
- */
- public static DenseMatrix rand(int m, int n) {
- DenseMatrix mat = new DenseMatrix(m, n);
- for (int i = 0; i < mat.data.length; i++) {
- mat.data[i] = Math.random();
- }
- return mat;
- }
-
- /**
- * Create a random symmetric matrix.
- *
- * @param n the dimension of the symmetric matrix.
- * @return a n * n random symmetric matrix.
- */
- public static DenseMatrix randSymmetric(int n) {
- DenseMatrix mat = new DenseMatrix(n, n);
- for (int i = 0; i < n; i++) {
- for (int j = i; j < n; j++) {
- double r = Math.random();
- mat.set(i, j, r);
- if (i != j) {
- mat.set(j, i, r);
- }
- }
- }
- return mat;
- }
-
- /**
- * Get a single element.
- *
- * @param i Row index.
- * @param j Column index.
- * @return matA(i, j)
- * @throws ArrayIndexOutOfBoundsException
- */
- public double get(int i, int j) {
- return data[j * m + i];
- }
-
- /**
- * Get the data array of this matrix.
- *
- * @return the data array of this matrix.
- */
- public double[] getData() {
- return this.data;
- }
-
- /**
- * Get all the matrix data, returned as a 2-D array.
- *
- * @return all matrix data, returned as a 2-D array.
- */
- public double[][] getArrayCopy2D() {
- double[][] arrayData = new double[m][n];
- for (int i = 0; i < m; i++) {
- for (int j = 0; j < n; j++) {
- arrayData[i][j] = this.get(i, j);
- }
- }
- return arrayData;
- }
-
- /**
- * Get all matrix data, returned as a 1-D array.
- *
- * @param inRowMajor Whether to return data in row major.
- * @return all matrix data, returned as a 1-D array.
- */
- public double[] getArrayCopy1D(boolean inRowMajor) {
- if (inRowMajor) {
- double[] arrayData = new double[m * n];
- for (int i = 0; i < m; i++) {
- for (int j = 0; j < n; j++) {
- arrayData[i * n + j] = this.get(i, j);
- }
- }
- return arrayData;
- } else {
- return this.data.clone();
- }
- }
-
- /**
- * Get one row.
- *
- * @param row the row index.
- * @return the row with the given index.
- */
- public double[] getRow(int row) {
- assert (row >= 0 && row < m) : "Invalid row index.";
- double[] r = new double[n];
- for (int i = 0; i < n; i++) {
- r[i] = this.get(row, i);
- }
- return r;
- }
-
- /**
- * Get one column.
- *
- * @param col the column index.
- * @return the column with the given index.
- */
- public double[] getColumn(int col) {
- assert (col >= 0 && col < n) : "Invalid column index.";
- double[] columnData = new double[m];
- System.arraycopy(this.data, col * m, columnData, 0, m);
- return columnData;
- }
-
- /** Clone the Matrix object. */
- @Override
- public DenseMatrix clone() {
- return new DenseMatrix(this.m, this.n, this.data.clone(), false);
- }
-
- /**
- * Create a new matrix by selecting some of the rows.
- *
- * @param rows the array of row indexes to select.
- * @return a new matrix by selecting some of the rows.
- */
- public DenseMatrix selectRows(int[] rows) {
- DenseMatrix sub = new DenseMatrix(rows.length, this.n);
- for (int i = 0; i < rows.length; i++) {
- for (int j = 0; j < this.n; j++) {
- sub.set(i, j, this.get(rows[i], j));
- }
- }
- return sub;
- }
-
- /**
- * Get sub matrix.
- *
- * @param m0 the starting row index (inclusive)
- * @param m1 the ending row index (exclusive)
- * @param n0 the starting column index (inclusive)
- * @param n1 the ending column index (exclusive)
- * @return the specified sub matrix.
- */
- public DenseMatrix getSubMatrix(int m0, int m1, int n0, int n1) {
- assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range.";
- DenseMatrix sub = new DenseMatrix(m1 - m0, n1 - n0);
- for (int i = 0; i < sub.m; i++) {
- for (int j = 0; j < sub.n; j++) {
- sub.set(i, j, this.get(m0 + i, n0 + j));
- }
- }
- return sub;
- }
-
- /**
- * Set part of the matrix values from the values of another matrix.
- *
- * @param sub the matrix whose element values will be assigned to the sub matrix of this matrix.
- * @param m0 the starting row index (inclusive)
- * @param m1 the ending row index (exclusive)
- * @param n0 the starting column index (inclusive)
- * @param n1 the ending column index (exclusive)
- */
- public void setSubMatrix(DenseMatrix sub, int m0, int m1, int n0, int n1) {
- assert (m0 >= 0 && m1 <= m) && (n0 >= 0 && n1 <= n) : "Invalid index range.";
- for (int i = 0; i < sub.m; i++) {
- for (int j = 0; j < sub.n; j++) {
- this.set(m0 + i, n0 + j, sub.get(i, j));
- }
- }
- }
-
- /**
- * Set a single element.
- *
- * @param i Row index.
- * @param j Column index.
- * @param s A(i,j).
- * @throws ArrayIndexOutOfBoundsException
- */
- public void set(int i, int j, double s) {
- data[j * m + i] = s;
- }
-
- /**
- * Add the given value to a single element.
- *
- * @param i Row index.
- * @param j Column index.
- * @param s A(i,j).
- * @throws ArrayIndexOutOfBoundsException
- */
- public void add(int i, int j, double s) {
- data[j * m + i] += s;
- }
-
- /**
- * Check whether the matrix is square matrix.
- *
- * @return <code>true</code> if this matrix is a square matrix, <code>false</code> otherwise.
- */
- public boolean isSquare() {
- return m == n;
- }
-
- /**
- * Check whether the matrix is symmetric matrix.
- *
- * @return <code>true</code> if this matrix is a symmetric matrix, <code>false</code> otherwise.
- */
- public boolean isSymmetric() {
- if (m != n) {
- return false;
- }
- for (int i = 0; i < n; i++) {
- for (int j = i + 1; j < n; j++) {
- if (this.get(i, j) != this.get(j, i)) {
- return false;
- }
- }
- }
- return true;
- }
-
- /**
- * Get the number of rows.
- *
- * @return the number of rows.
- */
- public int numRows() {
- return m;
- }
-
- /**
- * Get the number of columns.
- *
- * @return the number of columns.
- */
- public int numCols() {
- return n;
- }
-
- /** Sum of all elements of the matrix. */
- public double sum() {
- double s = 0.;
- for (int i = 0; i < this.data.length; i++) {
- s += this.data[i];
- }
- return s;
- }
-
- /** Scale the vector by value "v" and create a new matrix to store the result. */
- public DenseMatrix scale(double v) {
- DenseMatrix r = this.clone();
- BLAS.scal(v, r);
- return r;
- }
-
- /** Scale the matrix by value "v". */
- public void scaleEqual(double v) {
- BLAS.scal(v, this);
- }
-
- /** Create a new matrix by plussing another matrix. */
- public DenseMatrix plus(DenseMatrix mat) {
- DenseMatrix r = this.clone();
- BLAS.axpy(1.0, mat, r);
- return r;
- }
-
- /** Create a new matrix by plussing a constant. */
- public DenseMatrix plus(double alpha) {
- DenseMatrix r = this.clone();
- for (int i = 0; i < r.data.length; i++) {
- r.data[i] += alpha;
- }
- return r;
- }
-
- /** Plus with another matrix. */
- public void plusEquals(DenseMatrix mat) {
- BLAS.axpy(1.0, mat, this);
- }
-
- /** Plus with a constant. */
- public void plusEquals(double alpha) {
- for (int i = 0; i < this.data.length; i++) {
- this.data[i] += alpha;
- }
- }
-
- /** Create a new matrix by subtracting another matrix. */
- public DenseMatrix minus(DenseMatrix mat) {
- DenseMatrix r = this.clone();
- BLAS.axpy(-1.0, mat, r);
- return r;
- }
-
- /** Minus with another vector. */
- public void minusEquals(DenseMatrix mat) {
- BLAS.axpy(-1.0, mat, this);
- }
-
- /** Multiply with another matrix. */
- public DenseMatrix multiplies(DenseMatrix mat) {
- DenseMatrix r = new DenseMatrix(this.m, mat.n);
- BLAS.gemm(1.0, this, false, mat, false, 0., r);
- return r;
- }
-
- /** Multiply with a dense vector. */
- public DenseVector multiplies(DenseVector x) {
- DenseVector y = new DenseVector(this.numRows());
- BLAS.gemv(1.0, this, false, x, 0.0, y);
- return y;
- }
-
- /** Multiply with a sparse vector. */
- public DenseVector multiplies(SparseVector x) {
- DenseVector y = new DenseVector(this.numRows());
- for (int i = 0; i < this.numRows(); i++) {
- double s = 0.;
- int[] indices = x.getIndices();
- double[] values = x.getValues();
- for (int j = 0; j < indices.length; j++) {
- int index = indices[j];
- if (index >= this.numCols()) {
- throw new RuntimeException("Vector index out of bound:" + index);
- }
- s += this.get(i, index) * values[j];
- }
- y.set(i, s);
- }
- return y;
- }
-
- /**
- * Create a new matrix by transposing current matrix.
- *
- * <p>Use cache-oblivious matrix transpose algorithm.
- */
- public DenseMatrix transpose() {
- DenseMatrix mat = new DenseMatrix(n, m);
- int m0 = m;
- int n0 = n;
- int barrierSize = 16384;
- while (m0 * n0 > barrierSize) {
- if (m0 >= n0) {
- m0 /= 2;
- } else {
- n0 /= 2;
- }
- }
- for (int i0 = 0; i0 < m; i0 += m0) {
- for (int j0 = 0; j0 < n; j0 += n0) {
- for (int i = i0; i < i0 + m0 && i < m; i++) {
- for (int j = j0; j < j0 + n0 && j < n; j++) {
- mat.set(j, i, this.get(i, j));
- }
- }
- }
- }
- return mat;
- }
-
- /** Converts the data layout in "data" from row major to column major. */
- private static void toColumnMajor(int m, int n, double[] data) {
- if (m == n) {
- for (int i = 0; i < m; i++) {
- for (int j = i + 1; j < m; j++) {
- int pos0 = j * m + i;
- int pos1 = i * m + j;
- double t = data[pos0];
- data[pos0] = data[pos1];
- data[pos1] = t;
- }
- }
- } else {
- DenseMatrix temp = new DenseMatrix(n, m, data, false);
- System.arraycopy(temp.transpose().data, 0, data, 0, data.length);
- }
- }
-
- @Override
- public String toString() {
- StringBuilder sbd = new StringBuilder();
- sbd.append(String.format("mat[%d,%d]:\n", m, n));
- for (int i = 0; i < m; i++) {
- sbd.append(" ");
- for (int j = 0; j < n; j++) {
- if (j > 0) {
- sbd.append(",");
- }
- sbd.append(this.get(i, j));
- }
- sbd.append("\n");
- }
- return sbd.toString();
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java
deleted file mode 100644
index 951cb49..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/DenseVector.java
+++ /dev/null
@@ -1,379 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.util.Arrays;
-import java.util.Random;
-
-/** A dense vector represented by a values array. */
-public class DenseVector extends Vector {
- /**
- * The array holding the vector data.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- double[] data;
-
- /** Create a zero size vector. */
- public DenseVector() {
- this(0);
- }
-
- /**
- * Create a size n vector with all elements zero.
- *
- * @param n Size of the vector.
- */
- public DenseVector(int n) {
- this.data = new double[n];
- }
-
- /**
- * Create a dense vector with the user provided data.
- *
- * @param data The vector data.
- */
- public DenseVector(double[] data) {
- this.data = data;
- }
-
- /** Get the data array. */
- public double[] getData() {
- return this.data;
- }
-
- /** Set the data array. */
- public void setData(double[] data) {
- this.data = data;
- }
-
- /**
- * Create a dense vector with all elements one.
- *
- * @param n Size of the vector.
- * @return The newly created dense vector.
- */
- public static DenseVector ones(int n) {
- DenseVector r = new DenseVector(n);
- Arrays.fill(r.data, 1.0);
- return r;
- }
-
- /**
- * Create a dense vector with all elements zero.
- *
- * @param n Size of the vector.
- * @return The newly created dense vector.
- */
- public static DenseVector zeros(int n) {
- DenseVector r = new DenseVector(n);
- Arrays.fill(r.data, 0.0);
- return r;
- }
-
- /**
- * Create a dense vector with random values uniformly distributed in the range of [0.0, 1.0].
- *
- * @param n Size of the vector.
- * @return The newly created dense vector.
- */
- public static DenseVector rand(int n) {
- Random random = new Random();
- DenseVector v = new DenseVector(n);
- for (int i = 0; i < n; i++) {
- v.data[i] = random.nextDouble();
- }
- return v;
- }
-
- @Override
- public DenseVector clone() {
- return new DenseVector(this.data.clone());
- }
-
- @Override
- public String toString() {
- return VectorUtil.toString(this);
- }
-
- @Override
- public int size() {
- return data.length;
- }
-
- @Override
- public double get(int i) {
- return data[i];
- }
-
- @Override
- public void set(int i, double d) {
- data[i] = d;
- }
-
- @Override
- public void add(int i, double d) {
- data[i] += d;
- }
-
- @Override
- public double normL1() {
- double d = 0;
- for (double t : data) {
- d += Math.abs(t);
- }
- return d;
- }
-
- @Override
- public double normL2() {
- double d = 0;
- for (double t : data) {
- d += t * t;
- }
- return Math.sqrt(d);
- }
-
- @Override
- public double normL2Square() {
- double d = 0;
- for (double t : data) {
- d += t * t;
- }
- return d;
- }
-
- @Override
- public double normInf() {
- double d = 0;
- for (double t : data) {
- d = Math.max(Math.abs(t), d);
- }
- return d;
- }
-
- @Override
- public DenseVector slice(int[] indices) {
- double[] values = new double[indices.length];
- for (int i = 0; i < indices.length; ++i) {
- if (indices[i] >= data.length) {
- throw new RuntimeException("Index is larger than vector size.");
- }
- values[i] = data[indices[i]];
- }
- return new DenseVector(values);
- }
-
- @Override
- public DenseVector prefix(double d) {
- double[] data = new double[this.size() + 1];
- data[0] = d;
- System.arraycopy(this.data, 0, data, 1, this.data.length);
- return new DenseVector(data);
- }
-
- @Override
- public DenseVector append(double d) {
- double[] data = new double[this.size() + 1];
- System.arraycopy(this.data, 0, data, 0, this.data.length);
- data[this.size()] = d;
- return new DenseVector(data);
- }
-
- @Override
- public void scaleEqual(double d) {
- BLAS.scal(d, this);
- }
-
- @Override
- public DenseVector plus(Vector other) {
- DenseVector r = this.clone();
- if (other instanceof DenseVector) {
- BLAS.axpy(1.0, (DenseVector) other, r);
- } else {
- BLAS.axpy(1.0, (SparseVector) other, r);
- }
- return r;
- }
-
- @Override
- public DenseVector minus(Vector other) {
- DenseVector r = this.clone();
- if (other instanceof DenseVector) {
- BLAS.axpy(-1.0, (DenseVector) other, r);
- } else {
- BLAS.axpy(-1.0, (SparseVector) other, r);
- }
- return r;
- }
-
- @Override
- public DenseVector scale(double d) {
- DenseVector r = this.clone();
- BLAS.scal(d, r);
- return r;
- }
-
- @Override
- public double dot(Vector vec) {
- if (vec instanceof DenseVector) {
- return BLAS.dot(this, (DenseVector) vec);
- } else {
- return vec.dot(this);
- }
- }
-
- @Override
- public void standardizeEqual(double mean, double stdvar) {
- int size = data.length;
- for (int i = 0; i < size; i++) {
- data[i] -= mean;
- data[i] *= (1.0 / stdvar);
- }
- }
-
- @Override
- public void normalizeEqual(double p) {
- double norm = 0.0;
- if (Double.isInfinite(p)) {
- norm = normInf();
- } else if (p == 1.0) {
- norm = normL1();
- } else if (p == 2.0) {
- norm = normL2();
- } else {
- for (int i = 0; i < data.length; i++) {
- norm += Math.pow(Math.abs(data[i]), p);
- }
- norm = Math.pow(norm, 1 / p);
- }
- for (int i = 0; i < data.length; i++) {
- data[i] /= norm;
- }
- }
-
- /** Set the data of the vector the same as those of another vector. */
- public void setEqual(DenseVector other) {
- assert this.size() == other.size() : "Size of the two vectors mismatched.";
- System.arraycopy(other.data, 0, this.data, 0, this.size());
- }
-
- /** Plus with another vector. */
- public void plusEqual(Vector other) {
- if (other instanceof DenseVector) {
- BLAS.axpy(1.0, (DenseVector) other, this);
- } else {
- BLAS.axpy(1.0, (SparseVector) other, this);
- }
- }
-
- /** Minus with another vector. */
- public void minusEqual(Vector other) {
- if (other instanceof DenseVector) {
- BLAS.axpy(-1.0, (DenseVector) other, this);
- } else {
- BLAS.axpy(-1.0, (SparseVector) other, this);
- }
- }
-
- /** Plus with another vector scaled by "alpha". */
- public void plusScaleEqual(Vector other, double alpha) {
- if (other instanceof DenseVector) {
- BLAS.axpy(alpha, (DenseVector) other, this);
- } else {
- BLAS.axpy(alpha, (SparseVector) other, this);
- }
- }
-
- @Override
- public DenseMatrix outer() {
- return this.outer(this);
- }
-
- /**
- * Compute the outer product with another vector.
- *
- * @return The outer product matrix.
- */
- public DenseMatrix outer(DenseVector other) {
- int nrows = this.size();
- int ncols = other.size();
- double[] data = new double[nrows * ncols];
- int pos = 0;
- for (int j = 0; j < ncols; j++) {
- for (int i = 0; i < nrows; i++) {
- data[pos++] = this.data[i] * other.data[j];
- }
- }
- return new DenseMatrix(nrows, ncols, data, false);
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- DenseVector that = (DenseVector) o;
- return Arrays.equals(data, that.data);
- }
-
- @Override
- public int hashCode() {
- return Arrays.hashCode(data);
- }
-
- @Override
- public VectorIterator iterator() {
- return new DenseVectorIterator();
- }
-
- private class DenseVectorIterator implements VectorIterator {
- private int cursor = 0;
-
- @Override
- public boolean hasNext() {
- return cursor < data.length;
- }
-
- @Override
- public void next() {
- cursor++;
- }
-
- @Override
- public int getIndex() {
- if (cursor >= data.length) {
- throw new RuntimeException("Iterator out of bound.");
- }
- return cursor;
- }
-
- @Override
- public double getValue() {
- if (cursor >= data.length) {
- throw new RuntimeException("Iterator out of bound.");
- }
- return data[cursor];
- }
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java
deleted file mode 100644
index fafb6f2..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/MatVecOp.java
+++ /dev/null
@@ -1,307 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.util.function.BiFunction;
-import java.util.function.Function;
-
-/**
- * A utility class that provides operations over {@link DenseVector}, {@link SparseVector} and
- * {@link DenseMatrix}.
- */
-public class MatVecOp {
- /** compute vec1 + vec2 . */
- public static Vector plus(Vector vec1, Vector vec2) {
- return vec1.plus(vec2);
- }
-
- /** compute vec1 - vec2 . */
- public static Vector minus(Vector vec1, Vector vec2) {
- return vec1.minus(vec2);
- }
-
- /** Compute vec1 \cdot vec2 . */
- public static double dot(Vector vec1, Vector vec2) {
- return vec1.dot(vec2);
- }
-
- /** Compute || vec1 - vec2 ||_1 . */
- public static double sumAbsDiff(Vector vec1, Vector vec2) {
- if (vec1 instanceof DenseVector) {
- if (vec2 instanceof DenseVector) {
- return MatVecOp.applySum(
- (DenseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b));
- } else {
- return MatVecOp.applySum(
- (DenseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b));
- }
- } else {
- if (vec2 instanceof DenseVector) {
- return MatVecOp.applySum(
- (SparseVector) vec1, (DenseVector) vec2, (a, b) -> Math.abs(a - b));
- } else {
- return MatVecOp.applySum(
- (SparseVector) vec1, (SparseVector) vec2, (a, b) -> Math.abs(a - b));
- }
- }
- }
-
- /** Compute || vec1 - vec2 ||_2^2 . */
- public static double sumSquaredDiff(Vector vec1, Vector vec2) {
- if (vec1 instanceof DenseVector) {
- if (vec2 instanceof DenseVector) {
- return MatVecOp.applySum(
- (DenseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b));
- } else {
- return MatVecOp.applySum(
- (DenseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b));
- }
- } else {
- if (vec2 instanceof DenseVector) {
- return MatVecOp.applySum(
- (SparseVector) vec1, (DenseVector) vec2, (a, b) -> (a - b) * (a - b));
- } else {
- return MatVecOp.applySum(
- (SparseVector) vec1, (SparseVector) vec2, (a, b) -> (a - b) * (a - b));
- }
- }
- }
-
- /** y = func(x). */
- public static void apply(DenseMatrix x, DenseMatrix y, Function<Double, Double> func) {
- assert (x.m == y.m && x.n == y.n) : "x and y size mismatched.";
- double[] xdata = x.data;
- double[] ydata = y.data;
- for (int i = 0; i < xdata.length; i++) {
- ydata[i] = func.apply(xdata[i]);
- }
- }
-
- /** y = func(x1, x2). */
- public static void apply(
- DenseMatrix x1,
- DenseMatrix x2,
- DenseMatrix y,
- BiFunction<Double, Double, Double> func) {
-
- assert (x1.m == y.m && x1.n == y.n) : "x1 and y size mismatched.";
- assert (x2.m == y.m && x2.n == y.n) : "x2 and y size mismatched.";
- double[] x1data = x1.data;
- double[] x2data = x2.data;
- double[] ydata = y.data;
- for (int i = 0; i < ydata.length; i++) {
- ydata[i] = func.apply(x1data[i], x2data[i]);
- }
- }
-
- /** y = func(x). */
- public static void apply(DenseVector x, DenseVector y, Function<Double, Double> func) {
- assert (x.data.length == y.data.length) : "x and y size mismatched.";
- for (int i = 0; i < x.data.length; i++) {
- y.data[i] = func.apply(x.data[i]);
- }
- }
-
- /** y = func(x1, x2). */
- public static void apply(
- DenseVector x1,
- DenseVector x2,
- DenseVector y,
- BiFunction<Double, Double, Double> func) {
-
- assert (x1.data.length == y.data.length) : "x1 and y size mismatched.";
- assert (x2.data.length == y.data.length) : "x1 and y size mismatched.";
- for (int i = 0; i < y.data.length; i++) {
- y.data[i] = func.apply(x1.data[i], x2.data[i]);
- }
- }
-
- /**
- * Create a new {@link SparseVector} by element wise operation between two {@link
- * SparseVector}s. y = func(x1, x2).
- */
- public static SparseVector apply(
- SparseVector x1, SparseVector x2, BiFunction<Double, Double, Double> func) {
- assert (x1.size() == x2.size()) : "x1 and x2 size mismatched.";
-
- int totNnz = x1.values.length + x2.values.length;
- int p0 = 0;
- int p1 = 0;
- while (p0 < x1.values.length && p1 < x2.values.length) {
- if (x1.indices[p0] == x2.indices[p1]) {
- totNnz--;
- p0++;
- p1++;
- } else if (x1.indices[p0] < x2.indices[p1]) {
- p0++;
- } else {
- p1++;
- }
- }
-
- SparseVector r = new SparseVector(x1.size());
- r.indices = new int[totNnz];
- r.values = new double[totNnz];
- p0 = p1 = 0;
- int pos = 0;
- while (pos < totNnz) {
- if (p0 < x1.values.length && p1 < x2.values.length) {
- if (x1.indices[p0] == x2.indices[p1]) {
- r.indices[pos] = x1.indices[p0];
- r.values[pos] = func.apply(x1.values[p0], x2.values[p1]);
- p0++;
- p1++;
- } else if (x1.indices[p0] < x2.indices[p1]) {
- r.indices[pos] = x1.indices[p0];
- r.values[pos] = func.apply(x1.values[p0], 0.0);
- p0++;
- } else {
- r.indices[pos] = x2.indices[p1];
- r.values[pos] = func.apply(0.0, x2.values[p1]);
- p1++;
- }
- pos++;
- } else {
- if (p0 < x1.values.length) {
- r.indices[pos] = x1.indices[p0];
- r.values[pos] = func.apply(x1.values[p0], 0.0);
- p0++;
- pos++;
- continue;
- }
- if (p1 < x2.values.length) {
- r.indices[pos] = x2.indices[p1];
- r.values[pos] = func.apply(0.0, x2.values[p1]);
- p1++;
- pos++;
- continue;
- }
- }
- }
-
- return r;
- }
-
- /** \sum_i func(x1_i, x2_i) . */
- public static double applySum(
- DenseVector x1, DenseVector x2, BiFunction<Double, Double, Double> func) {
- assert x1.size() == x2.size() : "x1 and x2 size mismatched.";
- double[] x1data = x1.getData();
- double[] x2data = x2.getData();
- double s = 0.;
- for (int i = 0; i < x1data.length; i++) {
- s += func.apply(x1data[i], x2data[i]);
- }
- return s;
- }
-
- /** \sum_i func(x1_i, x2_i) . */
- public static double applySum(
- SparseVector x1, SparseVector x2, BiFunction<Double, Double, Double> func) {
- double s = 0.;
- int p1 = 0;
- int p2 = 0;
- int[] x1Indices = x1.getIndices();
- double[] x1Values = x1.getValues();
- int[] x2Indices = x2.getIndices();
- double[] x2Values = x2.getValues();
- int nnz1 = x1Indices.length;
- int nnz2 = x2Indices.length;
- while (p1 < nnz1 || p2 < nnz2) {
- if (p1 < nnz1 && p2 < nnz2) {
- if (x1Indices[p1] == x2Indices[p2]) {
- s += func.apply(x1Values[p1], x2Values[p2]);
- p1++;
- p2++;
- } else if (x1Indices[p1] < x2Indices[p2]) {
- s += func.apply(x1Values[p1], 0.);
- p1++;
- } else {
- s += func.apply(0., x2Values[p2]);
- p2++;
- }
- } else {
- if (p1 < nnz1) {
- s += func.apply(x1Values[p1], 0.);
- p1++;
- } else { // p2 < nnz2
- s += func.apply(0., x2Values[p2]);
- p2++;
- }
- }
- }
- return s;
- }
-
- /** \sum_i func(x1_i, x2_i) . */
- public static double applySum(
- DenseVector x1, SparseVector x2, BiFunction<Double, Double, Double> func) {
- assert x1.size() == x2.size() : "x1 and x2 size mismatched.";
- double s = 0.;
- int p2 = 0;
- int[] x2Indices = x2.getIndices();
- double[] x2Values = x2.getValues();
- int nnz2 = x2Indices.length;
- double[] x1data = x1.getData();
- for (int i = 0; i < x1data.length; i++) {
- if (p2 < nnz2 && x2Indices[p2] == i) {
- s += func.apply(x1data[i], x2Values[p2]);
- p2++;
- } else {
- s += func.apply(x1data[i], 0.);
- }
- }
- return s;
- }
-
- /** \sum_i func(x1_i, x2_i) . */
- public static double applySum(
- SparseVector x1, DenseVector x2, BiFunction<Double, Double, Double> func) {
- assert x1.size() == x2.size() : "x1 and x2 size mismatched.";
- double s = 0.;
- int p1 = 0;
- int[] x1Indices = x1.getIndices();
- double[] x1Values = x1.getValues();
- int nnz1 = x1Indices.length;
- double[] x2data = x2.getData();
- for (int i = 0; i < x2data.length; i++) {
- if (p1 < nnz1 && x1Indices[p1] == i) {
- s += func.apply(x1Values[p1], x2data[i]);
- p1++;
- } else {
- s += func.apply(0., x2data[i]);
- }
- }
- return s;
- }
-
- /** \sum_ij func(x1_ij, x2_ij) . */
- public static double applySum(
- DenseMatrix x1, DenseMatrix x2, BiFunction<Double, Double, Double> func) {
- assert (x1.m == x2.m && x1.n == x2.n) : "x1 and x2 size mismatched.";
- double[] x1data = x1.data;
- double[] x2data = x2.data;
- double s = 0.;
- for (int i = 0; i < x1data.length; i++) {
- s += func.apply(x1data[i], x2data[i]);
- }
- return s;
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java
deleted file mode 100644
index f11a188..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/SparseVector.java
+++ /dev/null
@@ -1,574 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.TreeMap;
-
-/** A sparse vector represented by an indices array and a values array. */
-public class SparseVector extends Vector {
-
- /**
- * Size of the vector. n = -1 indicates that the vector size is undetermined.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- int n;
-
- /**
- * Column indices.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- int[] indices;
-
- /**
- * Column values.
- *
- * <p>Package private to allow access from {@link MatVecOp} and {@link BLAS}.
- */
- double[] values;
-
- /** Construct an empty sparse vector with undetermined size. */
- public SparseVector() {
- this(-1);
- }
-
- /** Construct an empty sparse vector with determined size. */
- public SparseVector(int n) {
- this.n = n;
- this.indices = new int[0];
- this.values = new double[0];
- }
-
- /**
- * Construct a sparse vector with the given indices and values.
- *
- * @throws IllegalArgumentException If size of indices array and values array differ.
- * @throws IllegalArgumentException If n >= 0 and the indices are out of bound.
- */
- public SparseVector(int n, int[] indices, double[] values) {
- this.n = n;
- this.indices = indices;
- this.values = values;
- checkSizeAndIndicesRange();
- sortIndices();
- }
-
- /**
- * Construct a sparse vector with given indices to values map.
- *
- * @throws IllegalArgumentException If n >= 0 and the indices are out of bound.
- */
- public SparseVector(int n, Map<Integer, Double> kv) {
- this.n = n;
- int nnz = kv.size();
- int[] indices = new int[nnz];
- double[] values = new double[nnz];
-
- int pos = 0;
- for (Map.Entry<Integer, Double> entry : kv.entrySet()) {
- indices[pos] = entry.getKey();
- values[pos] = entry.getValue();
- pos++;
- }
-
- this.indices = indices;
- this.values = values;
- checkSizeAndIndicesRange();
-
- if (!(kv instanceof TreeMap)) {
- sortIndices();
- }
- }
-
- /**
- * Check whether the indices array and values array are of the same size, and whether vector
- * indices are in valid range.
- */
- private void checkSizeAndIndicesRange() {
- if (indices.length != values.length) {
- throw new IllegalArgumentException("Indices size and values size should be the same.");
- }
- for (int i = 0; i < indices.length; i++) {
- if (indices[i] < 0 || (n >= 0 && indices[i] >= n)) {
- throw new IllegalArgumentException("Index out of bound.");
- }
- }
- }
-
- /** Sort the indices and values using quick sort. */
- private static void sortImpl(int[] indices, double[] values, int low, int high) {
- int pivot = indices[high];
- int pos = low - 1;
- for (int i = low; i <= high; i++) {
- if (indices[i] <= pivot) {
- pos++;
- int tempI = indices[pos];
- double tempD = values[pos];
- indices[pos] = indices[i];
- values[pos] = values[i];
- indices[i] = tempI;
- values[i] = tempD;
- }
- }
- if (pos - 1 > low) {
- sortImpl(indices, values, low, pos - 1);
- }
- if (high > pos + 1) {
- sortImpl(indices, values, pos + 1, high);
- }
- }
-
- /** Sort the indices and values if the indices are out of order. */
- private void sortIndices() {
- boolean outOfOrder = false;
- for (int i = 0; i < this.indices.length - 1; i++) {
- if (this.indices[i] > this.indices[i + 1]) {
- outOfOrder = true;
- break;
- }
- }
- if (outOfOrder) {
- sortImpl(this.indices, this.values, 0, this.indices.length - 1);
- }
- }
-
- @Override
- public SparseVector clone() {
- SparseVector vec = new SparseVector(this.n);
- vec.indices = this.indices.clone();
- vec.values = this.values.clone();
- return vec;
- }
-
- @Override
- public SparseVector prefix(double d) {
- int[] indices = new int[this.indices.length + 1];
- double[] values = new double[this.values.length + 1];
- int n = (this.n >= 0) ? this.n + 1 : this.n;
-
- indices[0] = 0;
- values[0] = d;
-
- for (int i = 0; i < this.indices.length; i++) {
- indices[i + 1] = this.indices[i] + 1;
- values[i + 1] = this.values[i];
- }
-
- return new SparseVector(n, indices, values);
- }
-
- @Override
- public SparseVector append(double d) {
- int[] indices = new int[this.indices.length + 1];
- double[] values = new double[this.values.length + 1];
- int n = (this.n >= 0) ? this.n + 1 : this.n;
-
- System.arraycopy(this.indices, 0, indices, 0, this.indices.length);
- System.arraycopy(this.values, 0, values, 0, this.values.length);
-
- indices[this.indices.length] = n - 1;
- values[this.values.length] = d;
-
- return new SparseVector(n, indices, values);
- }
-
- /** Get the indices array. */
- public int[] getIndices() {
- return indices;
- }
-
- /** Get the values array. */
- public double[] getValues() {
- return values;
- }
-
- @Override
- public int size() {
- return n;
- }
-
- @Override
- public double get(int i) {
- int pos = Arrays.binarySearch(indices, i);
- if (pos >= 0) {
- return values[pos];
- }
- return 0.;
- }
-
- /** Set the size of the vector. */
- public void setSize(int n) {
- this.n = n;
- }
-
- /** Get number of values in this vector. */
- public int numberOfValues() {
- return this.values.length;
- }
-
- @Override
- public void set(int i, double val) {
- int pos = Arrays.binarySearch(indices, i);
- if (pos >= 0) {
- this.values[pos] = val;
- } else {
- pos = -(pos + 1);
- insert(pos, i, val);
- }
- }
-
- @Override
- public void add(int i, double val) {
- int pos = Arrays.binarySearch(indices, i);
- if (pos >= 0) {
- this.values[pos] += val;
- } else {
- pos = -(pos + 1);
- insert(pos, i, val);
- }
- }
-
- /** Insert value "val" in the position "pos" with index "index". */
- private void insert(int pos, int index, double val) {
- double[] newValues = new double[this.values.length + 1];
- int[] newIndices = new int[this.values.length + 1];
- System.arraycopy(this.values, 0, newValues, 0, pos);
- System.arraycopy(this.indices, 0, newIndices, 0, pos);
- newValues[pos] = val;
- newIndices[pos] = index;
- System.arraycopy(this.values, pos, newValues, pos + 1, this.values.length - pos);
- System.arraycopy(this.indices, pos, newIndices, pos + 1, this.values.length - pos);
- this.values = newValues;
- this.indices = newIndices;
- }
-
- @Override
- public String toString() {
- return VectorUtil.toString(this);
- }
-
- @Override
- public double normL2() {
- double d = 0;
- for (double t : values) {
- d += t * t;
- }
- return Math.sqrt(d);
- }
-
- @Override
- public double normL1() {
- double d = 0;
- for (double t : values) {
- d += Math.abs(t);
- }
- return d;
- }
-
- @Override
- public double normInf() {
- double d = 0;
- for (double t : values) {
- d = Math.max(Math.abs(t), d);
- }
- return d;
- }
-
- @Override
- public double normL2Square() {
- double d = 0;
- for (double t : values) {
- d += t * t;
- }
- return d;
- }
-
- @Override
- public SparseVector slice(int[] indices) {
- SparseVector sliced = new SparseVector(indices.length);
- int nnz = 0;
- sliced.indices = new int[indices.length];
- sliced.values = new double[indices.length];
-
- for (int i = 0; i < indices.length; i++) {
- int pos = Arrays.binarySearch(this.indices, indices[i]);
- if (pos >= 0) {
- sliced.indices[nnz] = i;
- sliced.values[nnz] = this.values[pos];
- nnz++;
- }
- }
-
- if (nnz < sliced.indices.length) {
- sliced.indices = Arrays.copyOf(sliced.indices, nnz);
- sliced.values = Arrays.copyOf(sliced.values, nnz);
- }
-
- return sliced;
- }
-
- @Override
- public Vector plus(Vector vec) {
- if (this.size() != vec.size()) {
- throw new IllegalArgumentException("The size of the two vectors are different.");
- }
-
- if (vec instanceof DenseVector) {
- DenseVector r = ((DenseVector) vec).clone();
- for (int i = 0; i < this.indices.length; i++) {
- r.add(this.indices[i], this.values[i]);
- }
- return r;
- } else {
- return MatVecOp.apply(this, (SparseVector) vec, ((a, b) -> a + b));
- }
- }
-
- @Override
- public Vector minus(Vector vec) {
- if (this.size() != vec.size()) {
- throw new IllegalArgumentException("The size of the two vectors are different.");
- }
-
- if (vec instanceof DenseVector) {
- DenseVector r = ((DenseVector) vec).scale(-1.0);
- for (int i = 0; i < this.indices.length; i++) {
- r.add(this.indices[i], this.values[i]);
- }
- return r;
- } else {
- return MatVecOp.apply(this, (SparseVector) vec, ((a, b) -> a - b));
- }
- }
-
- @Override
- public SparseVector scale(double d) {
- SparseVector r = this.clone();
- BLAS.scal(d, r);
- return r;
- }
-
- @Override
- public void scaleEqual(double d) {
- BLAS.scal(d, this);
- }
-
- /** Remove all zero values away from this vector. */
- public void removeZeroValues() {
- if (this.values.length != 0) {
- List<Integer> idxs = new ArrayList<>();
- for (int i = 0; i < values.length; i++) {
- if (0 != values[i]) {
- idxs.add(i);
- }
- }
- int[] newIndices = new int[idxs.size()];
- double[] newValues = new double[newIndices.length];
- for (int i = 0; i < newIndices.length; i++) {
- newIndices[i] = indices[idxs.get(i)];
- newValues[i] = values[idxs.get(i)];
- }
- this.indices = newIndices;
- this.values = newValues;
- }
- }
-
- private double dot(SparseVector other) {
- if (this.size() != other.size()) {
- throw new RuntimeException("the size of the two vectors are different");
- }
-
- double d = 0;
- int p0 = 0;
- int p1 = 0;
- while (p0 < this.values.length && p1 < other.values.length) {
- if (this.indices[p0] == other.indices[p1]) {
- d += this.values[p0] * other.values[p1];
- p0++;
- p1++;
- } else if (this.indices[p0] < other.indices[p1]) {
- p0++;
- } else {
- p1++;
- }
- }
- return d;
- }
-
- private double dot(DenseVector other) {
- if (this.size() != other.size()) {
- throw new RuntimeException(
- "The size of the two vectors are different: "
- + this.size()
- + " vs "
- + other.size());
- }
- double s = 0.;
- for (int i = 0; i < this.indices.length; i++) {
- s += this.values[i] * other.get(this.indices[i]);
- }
- return s;
- }
-
- @Override
- public double dot(Vector other) {
- if (other instanceof DenseVector) {
- return dot((DenseVector) other);
- } else {
- return dot((SparseVector) other);
- }
- }
-
- @Override
- public DenseMatrix outer() {
- return this.outer(this);
- }
-
- /**
- * Compute the outer product with another vector.
- *
- * @return The outer product matrix.
- */
- public DenseMatrix outer(SparseVector other) {
- int nrows = this.size();
- int ncols = other.size();
- double[] data = new double[ncols * nrows];
- for (int i = 0; i < this.values.length; i++) {
- for (int j = 0; j < other.values.length; j++) {
- data[this.indices[i] + other.indices[j] * nrows] = this.values[i] * other.values[j];
- }
- }
- return new DenseMatrix(nrows, ncols, data);
- }
-
- /** Convert to a dense vector. */
- public DenseVector toDenseVector() {
- if (n >= 0) {
- DenseVector r = new DenseVector(n);
- for (int i = 0; i < this.indices.length; i++) {
- r.set(this.indices[i], this.values[i]);
- }
- return r;
- } else {
- if (this.indices.length == 0) {
- return new DenseVector();
- } else {
- int n = this.indices[this.indices.length - 1] + 1;
- DenseVector r = new DenseVector(n);
- for (int i = 0; i < this.indices.length; i++) {
- r.set(this.indices[i], this.values[i]);
- }
- return r;
- }
- }
- }
-
- @Override
- public void standardizeEqual(double mean, double stdvar) {
- for (int i = 0; i < indices.length; i++) {
- values[i] -= mean;
- values[i] *= (1.0 / stdvar);
- }
- }
-
- @Override
- public void normalizeEqual(double p) {
- double norm = 0.0;
- if (Double.isInfinite(p)) {
- norm = normInf();
- } else if (p == 1.0) {
- norm = normL1();
- } else if (p == 2.0) {
- norm = normL2();
- } else {
- for (int i = 0; i < indices.length; i++) {
- norm += Math.pow(values[i], p);
- }
- norm = Math.pow(norm, 1 / p);
- }
-
- for (int i = 0; i < indices.length; i++) {
- values[i] /= norm;
- }
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- SparseVector that = (SparseVector) o;
- return n == that.n
- && Arrays.equals(indices, that.indices)
- && Arrays.equals(values, that.values);
- }
-
- @Override
- public int hashCode() {
- int result = Objects.hash(n);
- result = 31 * result + Arrays.hashCode(indices);
- result = 31 * result + Arrays.hashCode(values);
- return result;
- }
-
- @Override
- public VectorIterator iterator() {
- return new SparseVectorVectorIterator();
- }
-
- private class SparseVectorVectorIterator implements VectorIterator {
- private int cursor = 0;
-
- @Override
- public boolean hasNext() {
- return cursor < values.length;
- }
-
- @Override
- public void next() {
- cursor++;
- }
-
- @Override
- public int getIndex() {
- if (cursor >= values.length) {
- throw new RuntimeException("Iterator out of bound.");
- }
- return indices[cursor];
- }
-
- @Override
- public double getValue() {
- if (cursor >= values.length) {
- throw new RuntimeException("Iterator out of bound.");
- }
- return values[cursor];
- }
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java
deleted file mode 100644
index bf4e329..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/Vector.java
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.io.Serializable;
-
-/** The Vector class defines some common methods for both DenseVector and SparseVector. */
-public abstract class Vector implements Serializable {
- /** Get the size of the vector. */
- public abstract int size();
-
- /** Get the i-th element of the vector. */
- public abstract double get(int i);
-
- /** Set the i-th element of the vector to value "val". */
- public abstract void set(int i, double val);
-
- /** Add the i-th element of the vector by value "val". */
- public abstract void add(int i, double val);
-
- /** Return the L1 norm of the vector. */
- public abstract double normL1();
-
- /** Return the Inf norm of the vector. */
- public abstract double normInf();
-
- /** Return the L2 norm of the vector. */
- public abstract double normL2();
-
- /** Return the square of L2 norm of the vector. */
- public abstract double normL2Square();
-
- /** Scale the vector by value "v" and create a new vector to store the result. */
- public abstract Vector scale(double v);
-
- /** Scale the vector by value "v". */
- public abstract void scaleEqual(double v);
-
- /** Normalize the vector. */
- public abstract void normalizeEqual(double p);
-
- /** Standardize the vector. */
- public abstract void standardizeEqual(double mean, double stdvar);
-
- /** Create a new vector by adding an element to the head of the vector. */
- public abstract Vector prefix(double v);
-
- /** Create a new vector by adding an element to the end of the vector. */
- public abstract Vector append(double v);
-
- /** Create a new vector by plussing another vector. */
- public abstract Vector plus(Vector vec);
-
- /** Create a new vector by subtracting another vector. */
- public abstract Vector minus(Vector vec);
-
- /** Compute the dot product with another vector. */
- public abstract double dot(Vector vec);
-
- /** Get the iterator of the vector. */
- public abstract VectorIterator iterator();
-
- /** Slice the vector. */
- public abstract Vector slice(int[] indexes);
-
- /**
- * Compute the outer product with itself.
- *
- * @return The outer product matrix.
- */
- public abstract DenseMatrix outer();
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java
deleted file mode 100644
index 422b7a4..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorIterator.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import java.io.Serializable;
-import java.util.Iterator;
-
-/**
- * An iterator over the elements of a vector.
- *
- * <p>Usage: <code>
- * Vector vector = ...;
- * VectorIterator iterator = vector.iterator();
- *
- * while(iterator.hasNext()) {
- * int index = iterator.getIndex();
- * double value = iterator.getValue();
- * iterator.next();
- * }
- * </code>
- */
-public interface VectorIterator extends Serializable {
-
- /**
- * Returns {@code true} if the iteration has more elements. Otherwise, {@code false} will be
- * returned.
- *
- * @return {@code true} if the iteration has more elements
- */
- boolean hasNext();
-
- /**
- * Trigger the cursor points to the next element of the vector.
- *
- * <p>The {@link #getIndex()} while returns the index of the element which the cursor points.
- * The {@link #getValue()} ()} while returns the value of the element which the cursor points.
- *
- * <p>The difference to the {@link Iterator#next()} is that this can avoid the return of boxed
- * type.
- */
- void next();
-
- /**
- * Returns the index of the element which the cursor points.
- *
- * @returnthe the index of the element which the cursor points.
- */
- int getIndex();
-
- /**
- * Returns the value of the element which the cursor points.
- *
- * @returnthe the value of the element which the cursor points.
- */
- double getValue();
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java
deleted file mode 100644
index 7e328ca..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/VectorUtil.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.apache.commons.lang3.StringUtils;
-
-/** Utility class for the operations on {@link Vector} and its subclasses. */
-public class VectorUtil {
- /** Delimiter between elements. */
- private static final char ELEMENT_DELIMITER = ' ';
- /** Delimiter between vector size and vector data. */
- private static final char HEADER_DELIMITER = '$';
- /** Delimiter between index and value. */
- private static final char INDEX_VALUE_DELIMITER = ':';
-
- /**
- * Parse either a {@link SparseVector} or a {@link DenseVector} from a formatted string.
- *
- * <p>The format of a dense vector is space separated values such as "1 2 3 4". The format of a
- * sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4". If the sparse
- * vector has determined vector size, the size is prepended to the head. For example, the string
- * "$4$0:1 2:3 3:4" represents a sparse vector with size 4.
- *
- * @param str A formatted string representing a vector.
- * @return The parsed vector.
- */
- public static Vector parse(String str) {
- boolean isSparse =
- org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)
- || StringUtils.indexOf(str, INDEX_VALUE_DELIMITER) != -1
- || StringUtils.indexOf(str, HEADER_DELIMITER) != -1;
- if (isSparse) {
- return parseSparse(str);
- } else {
- return parseDense(str);
- }
- }
-
- /**
- * Parse the dense vector from a formatted string.
- *
- * <p>The format of a dense vector is space separated values such as "1 2 3 4".
- *
- * @param str A string of space separated values.
- * @return The parsed vector.
- */
- public static DenseVector parseDense(String str) {
- if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) {
- return new DenseVector();
- }
-
- int len = str.length();
-
- int inDataBuffPos = 0;
- boolean isInBuff = false;
-
- for (int i = 0; i < len; ++i) {
- char c = str.charAt(i);
-
- if (c == ELEMENT_DELIMITER
- // to be compatible with previous delimiter
- || c == ',') {
- if (isInBuff) {
- inDataBuffPos++;
- }
-
- isInBuff = false;
- } else {
- isInBuff = true;
- }
- }
-
- if (isInBuff) {
- inDataBuffPos++;
- }
-
- double[] data = new double[inDataBuffPos];
- int lastestInCharBuffPos = 0;
-
- inDataBuffPos = 0;
- isInBuff = false;
-
- for (int i = 0; i < len; ++i) {
- char c = str.charAt(i);
-
- if (c == ELEMENT_DELIMITER) {
- if (isInBuff) {
- data[inDataBuffPos++] =
- Double.parseDouble(
- StringUtils.substring(str, lastestInCharBuffPos, i).trim());
-
- lastestInCharBuffPos = i + 1;
- }
-
- isInBuff = false;
- } else {
- isInBuff = true;
- }
- }
-
- if (isInBuff) {
- data[inDataBuffPos] =
- Double.valueOf(StringUtils.substring(str, lastestInCharBuffPos).trim());
- }
-
- return new DenseVector(data);
- }
-
- /**
- * Parse the sparse vector from a formatted string.
- *
- * <p>The format of a sparse vector is space separated index-value pairs, such as "0:1 2:3 3:4".
- * If the sparse vector has determined vector size, the size is prepended to the head. For
- * example, the string "$4$0:1 2:3 3:4" represents a sparse vector with size 4.
- *
- * @param str A formatted string representing a sparse vector.
- * @throws IllegalArgumentException If the string is of invalid format.
- */
- public static SparseVector parseSparse(String str) {
- try {
- if (org.apache.flink.util.StringUtils.isNullOrWhitespaceOnly(str)) {
- return new SparseVector();
- }
-
- int n = -1;
- int firstDollarPos = str.indexOf(HEADER_DELIMITER);
- int lastDollarPos = -1;
- if (firstDollarPos >= 0) {
- lastDollarPos = StringUtils.lastIndexOf(str, HEADER_DELIMITER);
- String sizeStr = StringUtils.substring(str, firstDollarPos + 1, lastDollarPos);
- n = Integer.valueOf(sizeStr);
- if (lastDollarPos == str.length() - 1) {
- return new SparseVector(n);
- }
- }
-
- int numValues = StringUtils.countMatches(str, String.valueOf(INDEX_VALUE_DELIMITER));
- double[] data = new double[numValues];
- int[] indices = new int[numValues];
- int startPos = lastDollarPos + 1;
- int endPos;
- for (int i = 0; i < numValues; i++) {
- int colonPos = StringUtils.indexOf(str, INDEX_VALUE_DELIMITER, startPos);
- if (colonPos < 0) {
- throw new IllegalArgumentException("Format error.");
- }
- endPos = StringUtils.indexOf(str, ELEMENT_DELIMITER, colonPos);
-
- if (endPos == -1) {
- endPos = str.length();
- }
- indices[i] = Integer.valueOf(str.substring(startPos, colonPos).trim());
- data[i] = Double.valueOf(str.substring(colonPos + 1, endPos).trim());
- startPos = endPos + 1;
- }
- return new SparseVector(n, indices, data);
- } catch (Exception e) {
- throw new IllegalArgumentException(
- String.format("Fail to getVector sparse vector from string: \"%s\".", str), e);
- }
- }
-
- /**
- * Serialize the vector to a string.
- *
- * @param vector The vector to serialize.
- * @see #toString(DenseVector)
- * @see #toString(SparseVector)
- */
- public static String toString(Vector vector) {
- if (vector instanceof SparseVector) {
- return toString((SparseVector) vector);
- }
- return toString((DenseVector) vector);
- }
-
- /**
- * Serialize the SparseVector to string.
- *
- * <p>The format of the returned is described at {@link #parseSparse(String)}
- *
- * @param sparseVector The sparse vector to serialize.
- */
- public static String toString(SparseVector sparseVector) {
- StringBuilder sbd = new StringBuilder();
- if (sparseVector.n > 0) {
- sbd.append(HEADER_DELIMITER);
- sbd.append(sparseVector.n);
- sbd.append(HEADER_DELIMITER);
- }
- if (null != sparseVector.indices) {
- for (int i = 0; i < sparseVector.indices.length; i++) {
- sbd.append(sparseVector.indices[i]);
- sbd.append(INDEX_VALUE_DELIMITER);
- sbd.append(sparseVector.values[i]);
- if (i < sparseVector.indices.length - 1) {
- sbd.append(ELEMENT_DELIMITER);
- }
- }
- }
-
- return sbd.toString();
- }
-
- /**
- * Serialize the DenseVector to String.
- *
- * <p>The format of the returned is described at {@link #parseDense(String)}
- *
- * @param denseVector The DenseVector to serialize.
- */
- public static String toString(DenseVector denseVector) {
- StringBuilder sbd = new StringBuilder();
-
- for (int i = 0; i < denseVector.data.length; i++) {
- sbd.append(denseVector.data[i]);
- if (i < denseVector.data.length - 1) {
- sbd.append(ELEMENT_DELIMITER);
- }
- }
- return sbd.toString();
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/Mapper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/Mapper.java
deleted file mode 100644
index 4cce10b..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/Mapper.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.mapper;
-
-import org.apache.flink.ml.api.misc.param.Params;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.table.types.DataType;
-import org.apache.flink.types.Row;
-
-import java.io.Serializable;
-
-/**
- * Abstract class for mappers. A mapper takes one row as input and transform it into another row.
- */
-public abstract class Mapper implements Serializable {
-
- /** Schema of the input rows. */
- private final String[] dataFieldNames;
-
- private final DataType[] dataFieldTypes;
-
- /** Parameters for the Mapper. Users can set the params before the Mapper is executed. */
- protected final Params params;
-
- /**
- * Construct a Mapper.
- *
- * @param dataSchema The schema of input rows.
- * @param params The parameters for this mapper.
- */
- public Mapper(TableSchema dataSchema, Params params) {
- this.dataFieldNames = dataSchema.getFieldNames();
- this.dataFieldTypes = dataSchema.getFieldDataTypes();
- this.params = (null == params) ? new Params() : params.clone();
- }
-
- /**
- * Get the schema of input rows.
- *
- * @return The schema of input rows.
- */
- protected TableSchema getDataSchema() {
- return TableSchema.builder().fields(dataFieldNames, dataFieldTypes).build();
- }
-
- /**
- * Map a row to a new row.
- *
- * @param row The input row.
- * @return A new row.
- * @throws Exception This method may throw exceptions. Throwing an exception will cause the
- * operation to fail.
- */
- public abstract Row map(Row row) throws Exception;
-
- /**
- * Get the schema of the output rows of {@link #map(Row)} method.
- *
- * @return The table schema of the output rows of {@link #map(Row)} method.
- */
- public abstract TableSchema getOutputSchema();
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/MapperAdapter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/MapperAdapter.java
deleted file mode 100644
index d804302..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/MapperAdapter.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.mapper;
-
-import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.types.Row;
-
-/**
- * A class that helps adapt a {@link Mapper} to a {@link MapFunction} so that the mapper can run in
- * Flink.
- */
-public class MapperAdapter implements MapFunction<Row, Row> {
-
- private final Mapper mapper;
-
- /**
- * Construct a MapperAdapter with the given mapper.
- *
- * @param mapper The {@link Mapper} to adapt.
- */
- public MapperAdapter(Mapper mapper) {
- this.mapper = mapper;
- }
-
- @Override
- public Row map(Row row) throws Exception {
- return this.mapper.map(row);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapper.java
deleted file mode 100644
index 772e6d6..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapper.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.mapper;
-
-import org.apache.flink.ml.api.misc.param.Params;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.table.types.DataType;
-import org.apache.flink.types.Row;
-
-import java.util.List;
-
-/** An abstract class for {@link Mapper Mappers} with a model. */
-public abstract class ModelMapper extends Mapper {
-
- /** Field names of the model rows. */
- private final String[] modelFieldNames;
-
- /** Field types of the model rows. */
- private final DataType[] modelFieldTypes;
-
- /**
- * Constructs a ModelMapper.
- *
- * @param modelSchema The schema of the model rows passed to {@link #loadModel(List)}.
- * @param dataSchema The schema of the input data rows.
- * @param params The parameters of this ModelMapper.
- */
- public ModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
- super(dataSchema, params);
- this.modelFieldNames = modelSchema.getFieldNames();
- this.modelFieldTypes = modelSchema.getFieldDataTypes();
- }
-
- /**
- * Get the schema of the model rows that are passed to {@link #loadModel(List)}.
- *
- * @return The schema of the model rows.
- */
- protected TableSchema getModelSchema() {
- return TableSchema.builder().fields(this.modelFieldNames, this.modelFieldTypes).build();
- }
-
- /**
- * Load the model from the list of rows.
- *
- * @param modelRows The list of rows that containing the model.
- */
- public abstract void loadModel(List<Row> modelRows);
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapperAdapter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapperAdapter.java
deleted file mode 100644
index 821b383..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/mapper/ModelMapperAdapter.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.mapper;
-
-import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.ml.common.model.ModelSource;
-import org.apache.flink.types.Row;
-
-import java.util.List;
-
-/**
- * A class that adapts a {@link ModelMapper} to a Flink {@link RichMapFunction} so the model can be
- * loaded in a Flink job.
- *
- * <p>This adapter class hold the target {@link ModelMapper} and it's {@link ModelSource}. Upon
- * open(), it will load model rows from {@link ModelSource} into {@link ModelMapper}.
- */
-public class ModelMapperAdapter extends RichMapFunction<Row, Row> {
-
- private final ModelMapper mapper;
- private final ModelSource modelSource;
-
- /**
- * Construct a ModelMapperAdapter with the given ModelMapper and ModelSource.
- *
- * @param mapper The {@link ModelMapper} to adapt.
- * @param modelSource The {@link ModelSource} to load the model from.
- */
- public ModelMapperAdapter(ModelMapper mapper, ModelSource modelSource) {
- this.mapper = mapper;
- this.modelSource = modelSource;
- }
-
- @Override
- public void open(Configuration parameters) throws Exception {
- List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
- this.mapper.loadModel(modelRows);
- }
-
- @Override
- public Row map(Row row) throws Exception {
- return this.mapper.map(row);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/BroadcastVariableModelSource.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/BroadcastVariableModelSource.java
deleted file mode 100644
index 0e552ae..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/BroadcastVariableModelSource.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.model;
-
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.types.Row;
-
-import java.util.List;
-
-/** A {@link ModelSource} implementation that reads the model from the broadcast variable. */
-public class BroadcastVariableModelSource implements ModelSource {
-
- /** The name of the broadcast variable that hosts the model. */
- private final String modelVariableName;
-
- /**
- * Construct a BroadcastVariableModelSource.
- *
- * @param modelVariableName The name of the broadcast variable that hosts a
- * BroadcastVariableModelSource.
- */
- public BroadcastVariableModelSource(String modelVariableName) {
- this.modelVariableName = modelVariableName;
- }
-
- @Override
- public List<Row> getModelRows(RuntimeContext runtimeContext) {
- return runtimeContext.getBroadcastVariable(modelVariableName);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/ModelSource.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/ModelSource.java
deleted file mode 100644
index 960f9aa..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/ModelSource.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.model;
-
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.types.Row;
-
-import java.io.Serializable;
-import java.util.List;
-
-/**
- * An interface that load the model from different sources. E.g. broadcast variables, list of rows,
- * etc.
- */
-public interface ModelSource extends Serializable {
-
- /**
- * Get the rows that containing the model.
- *
- * @return the rows that containing the model.
- */
- List<Row> getModelRows(RuntimeContext runtimeContext);
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/RowsModelSource.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/RowsModelSource.java
deleted file mode 100644
index d314a70..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/model/RowsModelSource.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.model;
-
-import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.types.Row;
-
-import java.util.List;
-
-/** A {@link ModelSource} implementation that reads the model from the memory. */
-public class RowsModelSource implements ModelSource {
-
- /** The rows that hosts the model. */
- private final List<Row> modelRows;
-
- /**
- * Construct a RowsModelSource with the given rows containing a model.
- *
- * @param modelRows The rows that contains a model.
- */
- public RowsModelSource(List<Row> modelRows) {
- this.modelRows = modelRows;
- }
-
- @Override
- public List<Row> getModelRows(RuntimeContext runtimeContext) {
- return modelRows;
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussian.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussian.java
deleted file mode 100644
index bf3a7b9..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussian.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.statistics.basicstatistic;
-
-import org.apache.flink.ml.common.linalg.BLAS;
-import org.apache.flink.ml.common.linalg.DenseMatrix;
-import org.apache.flink.ml.common.linalg.DenseVector;
-import org.apache.flink.ml.common.linalg.SparseVector;
-import org.apache.flink.ml.common.linalg.Vector;
-
-import com.github.fommil.netlib.LAPACK;
-import org.netlib.util.intW;
-
-/** This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. */
-public class MultivariateGaussian {
-
- private static final LAPACK LAPACK_INST = LAPACK.getInstance();
- private static final com.github.fommil.netlib.BLAS F2J_BLAS_INST =
- com.github.fommil.netlib.F2jBLAS.getInstance();
- private static final double EPSILON;
-
- static {
- double eps = 1.0;
- while ((1.0 + (eps / 2.0)) != 1.0) {
- eps /= 2.0;
- }
- EPSILON = eps;
- }
-
- private final DenseVector mean;
- private final DenseMatrix cov;
-
- private DenseMatrix rootSigmaInv;
- private double u;
-
- // data buffers for computing pdf
- private DenseVector delta;
- private DenseVector v;
-
- /**
- * The constructor.
- *
- * @param mean The mean vector of the distribution.
- * @param cov The covariance matrix of the distribution.
- */
- public MultivariateGaussian(DenseVector mean, DenseMatrix cov) {
- this.mean = mean;
- this.cov = cov;
- this.delta = DenseVector.zeros(mean.size());
- this.v = DenseVector.zeros(mean.size());
- calculateCovarianceConstants();
- }
-
- /** Returns density of this multivariate Gaussian at given point x . */
- public double pdf(Vector x) {
- return Math.exp(logpdf(x));
- }
-
- /** Returns the log-density of this multivariate Gaussian at given point x . */
- public double logpdf(Vector x) {
- int n = mean.size();
- System.arraycopy(mean.getData(), 0, delta.getData(), 0, n);
- BLAS.scal(-1.0, delta);
- if (x instanceof DenseVector) {
- BLAS.axpy(1., (DenseVector) x, delta);
- } else if (x instanceof SparseVector) {
- BLAS.axpy(1., (SparseVector) x, delta);
- }
- BLAS.gemv(1.0, rootSigmaInv, true, delta, 0., v);
- return u - 0.5 * BLAS.dot(v, v);
- }
-
- /**
- * Compute distribution dependent constants.
- *
- * <p>The probability density function is calculated as: pdf(x) = (2*pi)^(-k/2)^ *
- * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu))
- *
- * <p>Here we compute the following distribution dependent constants that can be reused in each
- * pdf computation: A) u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) B) rootSigmaInv =
- * sqrt(inv(sigma)) = U * D^(-1/2)^
- *
- * <ul>
- * <li>sigma = U * D * U.t
- * <li>inv(sigma) = U * inv(D) * U.t = (U * D^(-1/2)^) * (U * D^(-1/2)^).t
- * <li>sqrt(inv(sigma)) = U * D^(-1/2)^
- * </ul>
- */
- private void calculateCovarianceConstants() {
- int k = this.mean.size();
- int lwork = 3 * k - 1;
- double[] matU = new double[k * k];
- double[] work = new double[lwork];
- double[] evs = new double[k];
- intW info = new intW(0);
-
- System.arraycopy(cov.getData(), 0, matU, 0, k * k);
- LAPACK_INST.dsyev("V", "U", k, matU, k, evs, work, lwork, info);
-
- double maxEv = Double.MIN_VALUE;
- for (double ev : evs) {
- maxEv = Math.max(maxEv, ev);
- }
- double tol = EPSILON * k * maxEv;
-
- // log(pseudo-determinant) is sum of the logs of all non-zero singular values
- double logPseudoDetSigma = 0.;
- for (double ev : evs) {
- if (ev > tol) {
- logPseudoDetSigma += Math.log(ev);
- }
- }
-
- for (int i = 0; i < k; i++) {
- double invEv = evs[i] > tol ? Math.sqrt(1.0 / evs[i]) : 0.;
- F2J_BLAS_INST.dscal(k, invEv, matU, i * k, 1);
- }
- this.rootSigmaInv = new DenseMatrix(k, k, matU);
- this.u = -0.5 * (k * Math.log(2.0 * Math.PI) + logPseudoDetSigma);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/DataStreamConversionUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/DataStreamConversionUtil.java
deleted file mode 100644
index 112e249..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/DataStreamConversionUtil.java
+++ /dev/null
@@ -1,167 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.common.MLEnvironment;
-import org.apache.flink.ml.common.MLEnvironmentFactory;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
-import org.apache.flink.table.api.Expressions;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.table.api.ValidationException;
-import org.apache.flink.table.expressions.Expression;
-import org.apache.flink.types.Row;
-
-import java.util.Arrays;
-
-/** Provide functions of conversions between DataStream and Table. */
-public class DataStreamConversionUtil {
- /**
- * Convert the given Table to {@link DataStream}<{@link Row}>.
- *
- * @param sessionId the sessionId of {@link MLEnvironmentFactory}
- * @param table the Table to convert.
- * @return the converted DataStream.
- */
- public static DataStream<Row> fromTable(Long sessionId, Table table) {
- return MLEnvironmentFactory.get(sessionId)
- .getStreamTableEnvironment()
- .toAppendStream(table, Row.class);
- }
-
- /**
- * Convert the given DataStream to Table with specified TableSchema.
- *
- * @param sessionId the sessionId of {@link MLEnvironmentFactory}
- * @param data the DataStream to convert.
- * @param schema the specified TableSchema.
- * @return the converted Table.
- */
- public static Table toTable(Long sessionId, DataStream<Row> data, TableSchema schema) {
- // TableSchema.getFieldTypes() is deprecated, this should be improved once FLIP-65 is fully
- // merged.
- return toTable(sessionId, data, schema.getFieldNames(), schema.getFieldTypes());
- }
-
- /**
- * Convert the given DataStream to Table with specified colNames.
- *
- * @param sessionId sessionId the sessionId of {@link MLEnvironmentFactory}.
- * @param data the DataStream to convert.
- * @param colNames the specified colNames.
- * @return the converted Table.
- */
- public static Table toTable(Long sessionId, DataStream<Row> data, String[] colNames) {
- return toTable(MLEnvironmentFactory.get(sessionId), data, colNames);
- }
-
- /**
- * Convert the given DataStream to Table with specified colNames and colTypes.
- *
- * @param sessionId sessionId the sessionId of {@link MLEnvironmentFactory}.
- * @param data the DataStream to convert.
- * @param colNames the specified colNames.
- * @param colTypes the specified colTypes. This variable is used only when the DataStream is
- * produced by a function and Flink cannot determine automatically what the produced type
- * is.
- * @return the converted Table.
- */
- public static Table toTable(
- Long sessionId,
- DataStream<Row> data,
- String[] colNames,
- TypeInformation<?>[] colTypes) {
- return toTable(MLEnvironmentFactory.get(sessionId), data, colNames, colTypes);
- }
-
- /**
- * Convert the given DataStream to Table with specified colNames.
- *
- * @param session the MLEnvironment using to convert DataStream to Table.
- * @param data the DataStream to convert.
- * @param colNames the specified colNames.
- * @return the converted Table.
- */
- public static Table toTable(MLEnvironment session, DataStream<Row> data, String[] colNames) {
- if (null == colNames || colNames.length == 0) {
- return session.getStreamTableEnvironment().fromDataStream(data);
- } else {
- return session.getStreamTableEnvironment()
- .fromDataStream(
- data,
- Arrays.stream(colNames).map(Expressions::$).toArray(Expression[]::new));
- }
- }
-
- /**
- * Convert the given DataStream to Table with specified colNames and colTypes.
- *
- * @param session the MLEnvironment using to convert DataStream to Table.
- * @param data the DataStream to convert.
- * @param colNames the specified colNames.
- * @param colTypes the specified colTypes. This variable is used only when the DataStream is
- * produced by a function and Flink cannot determine automatically what the produced type
- * is.
- * @return the converted Table.
- */
- public static Table toTable(
- MLEnvironment session,
- DataStream<Row> data,
- String[] colNames,
- TypeInformation<?>[] colTypes) {
- try {
- if (null != colTypes) {
- // Try to add row type information for the datastream to be converted.
- // In most case, this keeps us from the rolling back logic in the catch block,
- // which adds an unnecessary map function just in order to add row type information.
- if (data instanceof SingleOutputStreamOperator) {
- ((SingleOutputStreamOperator) data)
- .returns(new RowTypeInfo(colTypes, colNames));
- }
- }
- return toTable(session, data, colNames);
- } catch (ValidationException ex) {
- if (null == colTypes) {
- throw ex;
- } else {
- DataStream<Row> t = fallbackToExplicitTypeDefine(data, colNames, colTypes);
- return toTable(session, t, colNames);
- }
- }
- }
-
- /**
- * Adds a type information hint about the colTypes with the Row to the DataStream.
- *
- * @param data the DataStream to add type information.
- * @param colNames the specified colNames
- * @param colTypes the specified colTypes
- * @return the DataStream with type information hint.
- */
- private static DataStream<Row> fallbackToExplicitTypeDefine(
- DataStream<Row> data, String[] colNames, TypeInformation<?>[] colTypes) {
- return data.map((MapFunction<Row, Row>) t -> t)
- .returns(new RowTypeInfo(colTypes, colNames));
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/OutputColsHelper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/OutputColsHelper.java
deleted file mode 100644
index 18ca429..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/OutputColsHelper.java
+++ /dev/null
@@ -1,211 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.types.Row;
-
-import org.apache.commons.lang3.ArrayUtils;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashSet;
-
-/**
- * Utils for merging input data with output data.
- *
- * <p>Input: 1) Schema of input data being predicted or transformed. 2) Output column names of the
- * prediction/transformation operator. 3) Output column types of the prediction/transformation
- * operator. 4) Reserved column names, which is a subset of input data's column names that we want
- * to preserve.
- *
- * <p>Output: 1)The result data schema. The result data is a combination of the preserved columns
- * and the operator's output columns.
- *
- * <p>Several rules are followed:
- *
- * <ul>
- * <li>If reserved columns are not given, then all columns of input data is reserved.
- * <li>The reserved columns are arranged ahead of the operator's output columns in the final
- * output.
- * <li>If some of the reserved column names overlap with those of operator's output columns, then
- * the operator's output columns override the conflicting reserved columns.
- * <li>The reserved columns in the result table preserve their orders as in the input table.
- * </ul>
- *
- * <p>For example, if we have input data schema of ["id":INT, "f1":FLOAT, "f2":DOUBLE], and the
- * operator outputs a column "label" with type STRING, and we want to preserve the column "id", then
- * we get the result schema of ["id":INT, "label":STRING].
- *
- * <p>end user should not directly interact with this helper class. instead it will be indirectly
- * used via concrete algorithms.
- */
-public class OutputColsHelper implements Serializable {
- private String[] inputColNames;
- private TypeInformation<?>[] inputColTypes;
- private String[] outputColNames;
- private TypeInformation<?>[] outputColTypes;
-
- /** Column indices in the input data that would be forward to the result. */
- private int[] reservedCols;
-
- /** The positions of reserved columns in the result. */
- private int[] reservedColsPosInResult;
-
- /** The positions of output columns in the result. */
- private int[] outputColsPosInResult;
-
- public OutputColsHelper(
- TableSchema inputSchema, String outputColName, TypeInformation<?> outputColType) {
- this(inputSchema, outputColName, outputColType, inputSchema.getFieldNames());
- }
-
- public OutputColsHelper(
- TableSchema inputSchema,
- String outputColName,
- TypeInformation<?> outputColType,
- String[] reservedColNames) {
- this(
- inputSchema,
- new String[] {outputColName},
- new TypeInformation<?>[] {outputColType},
- reservedColNames);
- }
-
- public OutputColsHelper(
- TableSchema inputSchema, String[] outputColNames, TypeInformation<?>[] outputColTypes) {
- this(inputSchema, outputColNames, outputColTypes, inputSchema.getFieldNames());
- }
-
- /**
- * The constructor.
- *
- * @param inputSchema Schema of input data being predicted or transformed.
- * @param outputColNames Output column names of the prediction/transformation operator.
- * @param outputColTypes Output column types of the prediction/transformation operator.
- * @param reservedColNames Reserved column names, which is a subset of input data's column names
- * that we want to preserve.
- */
- public OutputColsHelper(
- TableSchema inputSchema,
- String[] outputColNames,
- TypeInformation<?>[] outputColTypes,
- String[] reservedColNames) {
- this.inputColNames = inputSchema.getFieldNames();
- this.inputColTypes = inputSchema.getFieldTypes();
- this.outputColNames = outputColNames;
- this.outputColTypes = outputColTypes;
-
- HashSet<String> toReservedCols =
- new HashSet<>(
- Arrays.asList(
- reservedColNames == null ? this.inputColNames : reservedColNames));
- // the indices of the columns which need to be reserved.
- ArrayList<Integer> reservedColIndices = new ArrayList<>(toReservedCols.size());
- ArrayList<Integer> reservedColToResultIndex = new ArrayList<>(toReservedCols.size());
- outputColsPosInResult = new int[outputColNames.length];
- Arrays.fill(outputColsPosInResult, -1);
- int index = 0;
- for (int i = 0; i < inputColNames.length; i++) {
- int key = ArrayUtils.indexOf(outputColNames, inputColNames[i]);
- if (key >= 0) {
- outputColsPosInResult[key] = index++;
- continue;
- }
- // add these interested column.
- if (toReservedCols.contains(inputColNames[i])) {
- reservedColIndices.add(i);
- reservedColToResultIndex.add(index++);
- }
- }
- for (int i = 0; i < outputColsPosInResult.length; i++) {
- if (outputColsPosInResult[i] == -1) {
- outputColsPosInResult[i] = index++;
- }
- }
- // write reversed column information in array.
- this.reservedCols = new int[reservedColIndices.size()];
- this.reservedColsPosInResult = new int[reservedColIndices.size()];
- for (int i = 0; i < this.reservedCols.length; i++) {
- this.reservedCols[i] = reservedColIndices.get(i);
- this.reservedColsPosInResult[i] = reservedColToResultIndex.get(i);
- }
- }
-
- /**
- * Get the reserved columns' names.
- *
- * @return the reserved colNames.
- */
- public String[] getReservedColumns() {
- String[] passThroughColNames = new String[reservedCols.length];
- for (int i = 0; i < reservedCols.length; i++) {
- passThroughColNames[i] = inputColNames[reservedCols[i]];
- }
- return passThroughColNames;
- }
-
- /**
- * Get the result table schema. The result data is a combination of the preserved columns and
- * the operator's output columns.
- *
- * @return The result table schema.
- */
- public TableSchema getResultSchema() {
- int resultLength = reservedCols.length + outputColNames.length;
- String[] resultColNames = new String[resultLength];
- TypeInformation<?>[] resultColTypes = new TypeInformation[resultLength];
- for (int i = 0; i < reservedCols.length; i++) {
- resultColNames[reservedColsPosInResult[i]] = inputColNames[reservedCols[i]];
- resultColTypes[reservedColsPosInResult[i]] = inputColTypes[reservedCols[i]];
- }
- for (int i = 0; i < outputColsPosInResult.length; i++) {
- resultColNames[outputColsPosInResult[i]] = outputColNames[i];
- resultColTypes[outputColsPosInResult[i]] = outputColTypes[i];
- }
- return new TableSchema(resultColNames, resultColTypes);
- }
-
- /**
- * Merge the input row and the output row.
- *
- * @param input The input row being predicted or transformed.
- * @param output The output row of the prediction/transformation operator.
- * @return The result row, which is a combination of preserved columns and the operator's output
- * columns.
- */
- public Row getResultRow(Row input, Row output) {
- int numOutputs = outputColsPosInResult.length;
- if (output.getArity() != numOutputs) {
- throw new IllegalArgumentException("Invalid output size");
- }
- int resultLength = reservedCols.length + outputColNames.length;
- Row result = new Row(resultLength);
- for (int i = 0; i < reservedCols.length; i++) {
- result.setField(reservedColsPosInResult[i], input.getField(reservedCols[i]));
- }
- for (int i = 0; i < numOutputs; i++) {
- result.setField(outputColsPosInResult[i], output.getField(i));
- }
- return result;
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/TableUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/TableUtil.java
deleted file mode 100644
index 3f2261e..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/TableUtil.java
+++ /dev/null
@@ -1,424 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.UUID;
-
-/** Utility to operator to interact with Table contents, such as rows and columns. */
-public class TableUtil {
- /**
- * Return a temp table named with prefix `temp_`, follow by a random UUID.
- *
- * <p>UUID hyphens ("-") will be replaced by underscores ("_").
- *
- * @return tableName
- */
- public static synchronized String getTempTableName() {
- return ("temp_" + UUID.randomUUID().toString().replaceAll("-", "_")).toLowerCase();
- }
-
- /**
- * Find the index of <code>targetCol</code> in string array <code>tableCols</code>. It will
- * ignore the case of the tableCols.
- *
- * @param tableCols a string array among which to find the targetCol.
- * @param targetCol the targetCol to find.
- * @return the index of the targetCol, if not found, returns -1.
- */
- public static int findColIndex(String[] tableCols, String targetCol) {
- Preconditions.checkNotNull(targetCol, "targetCol is null!");
- for (int i = 0; i < tableCols.length; i++) {
- if (targetCol.equalsIgnoreCase(tableCols[i])) {
- return i;
- }
- }
- return -1;
- }
-
- /**
- * Find the index of <code>targetCol</code> from the <code>tableSchema</code>.
- *
- * @param tableSchema the TableSchema among which to find the targetCol.
- * @param targetCol the targetCols to find.
- * @return the index of the targetCol.
- */
- public static int findColIndex(TableSchema tableSchema, String targetCol) {
- return findColIndex(tableSchema.getFieldNames(), targetCol);
- }
-
- /**
- * Find the indices of <code>targetCols</code> in string array <code>tableCols</code>. If <code>
- * targetCols</code> is null, it will be replaced by the <code>tableCols</code>
- *
- * @param tableCols a string array among which to find the targetCols.
- * @param targetCols the targetCols to find.
- * @return the indices of the targetCols.
- */
- public static int[] findColIndices(String[] tableCols, String[] targetCols) {
- if (targetCols == null) {
- int[] indices = new int[tableCols.length];
- for (int i = 0; i < tableCols.length; i++) {
- indices[i] = i;
- }
- return indices;
- }
- int[] indices = new int[targetCols.length];
- for (int i = 0; i < indices.length; i++) {
- indices[i] = findColIndex(tableCols, targetCols[i]);
- }
- return indices;
- }
-
- /**
- * Find the indices of <code>targetCols</code> from the <code>tableSchema</code>.
- *
- * @param tableSchema the TableSchema among which to find the targetCols.
- * @param targetCols the targetCols to find.
- * @return the indices of the targetCols.
- */
- public static int[] findColIndices(TableSchema tableSchema, String[] targetCols) {
- return findColIndices(tableSchema.getFieldNames(), targetCols);
- }
-
- /**
- * Find the types of the <code>targetCols</code>. If the targetCol not exist, return null.
- *
- * @param tableSchema TableSchema.
- * @param targetCols the targetCols to find.
- * @return the corresponding types.
- */
- public static TypeInformation<?>[] findColTypes(TableSchema tableSchema, String[] targetCols) {
- if (targetCols == null) {
- return tableSchema.getFieldTypes();
- }
- TypeInformation<?>[] types = new TypeInformation[targetCols.length];
- for (int i = 0; i < types.length; i++) {
- types[i] = findColType(tableSchema, targetCols[i]);
- }
- return types;
- }
-
- /**
- * Find the type of the <code>targetCol</code>. If the targetCol not exist, return null.
- *
- * @param tableSchema TableSchema
- * @param targetCol the targetCol to find.
- * @return the corresponding type.
- */
- public static TypeInformation<?> findColType(TableSchema tableSchema, String targetCol) {
- int index = findColIndex(tableSchema.getFieldNames(), targetCol);
-
- return index == -1 ? null : tableSchema.getFieldTypes()[index];
- }
-
- /**
- * Determine whether it is number type, number type includes double, long, byte, int, float and
- * short.
- *
- * @param dataType the dataType to determine.
- * @return whether it is number type
- */
- public static boolean isSupportedNumericType(TypeInformation<?> dataType) {
- return Types.DOUBLE == dataType
- || Types.LONG == dataType
- || Types.BYTE == dataType
- || Types.INT == dataType
- || Types.FLOAT == dataType
- || Types.SHORT == dataType;
- }
-
- /**
- * Determine whether it is a string type.
- *
- * @param dataType the dataType to determine.
- * @return whether it is string type
- */
- public static boolean isString(TypeInformation<?> dataType) {
- return Types.STRING == dataType;
- }
-
- /**
- * Determine whether it is a vector type.
- *
- * @param dataType the dataType to determine.
- * @return whether it is vector type
- */
- public static boolean isVector(TypeInformation<?> dataType) {
- return VectorTypes.VECTOR.equals(dataType)
- || VectorTypes.DENSE_VECTOR.equals(dataType)
- || VectorTypes.SPARSE_VECTOR.equals(dataType);
- }
-
- /**
- * Check whether <code>selectedCols</code> exist or not, if not exist, throw exception.
- *
- * @param tableCols a string array among which to find the target selectedCols.
- * @param selectedCols the selectedCols to assert.
- */
- public static void assertSelectedColExist(String[] tableCols, String... selectedCols) {
- if (null != selectedCols) {
- for (String selectedCol : selectedCols) {
- if (null != selectedCol) {
- if (-1 == findColIndex(tableCols, selectedCol)) {
- throw new IllegalArgumentException(" col is not exist " + selectedCol);
- }
- }
- }
- }
- }
-
- /**
- * Check whether colTypes of the <code>selectedCols</code> is numerical, if not, throw
- * exception.
- *
- * @param tableSchema TableSchema
- * @param selectedCols the selectedCols to assert.
- */
- public static void assertNumericalCols(TableSchema tableSchema, String... selectedCols) {
- if (selectedCols != null && selectedCols.length != 0) {
- for (String selectedCol : selectedCols) {
- if (null != selectedCol) {
- if (!isSupportedNumericType(findColType(tableSchema, selectedCol))) {
- throw new IllegalArgumentException(
- "col type must be number " + selectedCol);
- }
- }
- }
- }
- }
-
- /**
- * Check whether colTypes of the <code>selectedCols</code> is string, if not, throw exception.
- *
- * @param tableSchema TableSchema
- * @param selectedCols the selectedCol to assert.
- */
- public static void assertStringCols(TableSchema tableSchema, String... selectedCols) {
- if (selectedCols != null && selectedCols.length != 0) {
- for (String selectedCol : selectedCols) {
- if (null != selectedCol) {
- if (!isString(findColType(tableSchema, selectedCol))) {
- throw new IllegalArgumentException(
- "col type must be string " + selectedCol);
- }
- }
- }
- }
- }
-
- /**
- * Check whether colTypes of the <code>selectedCols</code> is vector, if not, throw exception.
- *
- * @param tableSchema TableSchema
- * @param selectedCols the selectedCol to assert.
- * @see #isVector(TypeInformation)
- */
- public static void assertVectorCols(TableSchema tableSchema, String... selectedCols) {
- if (selectedCols != null && selectedCols.length != 0) {
- for (String selectedCol : selectedCols) {
- if (null != selectedCol) {
- if (!isVector(findColType(tableSchema, selectedCol))) {
- throw new IllegalArgumentException(
- "col type must be string " + selectedCol);
- }
- }
- }
- }
- }
-
- /**
- * Return the columns in the table whose types are string.
- *
- * @param tableSchema TableSchema
- * @return String columns.
- */
- public static String[] getStringCols(TableSchema tableSchema) {
- return getStringCols(tableSchema, null);
- }
-
- /**
- * Return the columns in the table whose types are string and are not included in the
- * excludeCols.
- *
- * <p>If <code>excludeCols</code> is null, return all the string columns.
- *
- * @param tableSchema TableSchema.
- * @param excludeCols The columns who are not considered.
- * @return string columns.
- */
- public static String[] getStringCols(TableSchema tableSchema, String[] excludeCols) {
- ArrayList<String> numericCols = new ArrayList<>();
- List<String> excludeColsList = null == excludeCols ? null : Arrays.asList(excludeCols);
- String[] inColNames = tableSchema.getFieldNames();
- TypeInformation<?>[] inColTypes = tableSchema.getFieldTypes();
-
- for (int i = 0; i < inColNames.length; i++) {
- if (isString(inColTypes[i])) {
- if (null == excludeColsList || !excludeColsList.contains(inColNames[i])) {
- numericCols.add(inColNames[i]);
- }
- }
- }
-
- return numericCols.toArray(new String[0]);
- }
-
- /**
- * Return the columns in the table whose types are numeric.
- *
- * @param tableSchema TableSchema
- * @return numeric columns.
- */
- public static String[] getNumericCols(TableSchema tableSchema) {
- return getNumericCols(tableSchema, null);
- }
-
- /**
- * Return the columns in the table whose types are numeric and are not included in the
- * excludeCols.
- *
- * <p>If <code>excludeCols</code> is null, return all the numeric columns.
- *
- * @param tableSchema TableSchema.
- * @param excludeCols the columns who are not considered.
- * @return numeric columns.
- */
- public static String[] getNumericCols(TableSchema tableSchema, String[] excludeCols) {
- ArrayList<String> numericCols = new ArrayList<>();
- List<String> excludeColsList = (null == excludeCols ? null : Arrays.asList(excludeCols));
- String[] inColNames = tableSchema.getFieldNames();
- TypeInformation<?>[] inColTypes = tableSchema.getFieldTypes();
-
- for (int i = 0; i < inColNames.length; i++) {
- if (isSupportedNumericType(inColTypes[i])) {
- if (null == excludeColsList || !excludeColsList.contains(inColNames[i])) {
- numericCols.add(inColNames[i]);
- }
- }
- }
-
- return numericCols.toArray(new String[0]);
- }
-
- /**
- * Get the columns from featureCols who are included in the <code>categoricalCols</code>, and
- * the columns whose types are string or boolean.
- *
- * <p>If <code>categoricalCols</code> is null, return all the categorical columns.
- *
- * <p>for example: In FeatureHasher which projects a number of categorical or numerical features
- * into a feature vector of a specified dimension needs to identify the categorical features.
- * And the column which is the string or boolean must be categorical. We need to find these
- * columns as categorical when user do not specify the types(categorical or numerical).
- *
- * @param tableSchema TableSchema.
- * @param featureCols the columns to chosen from.
- * @param categoricalCols the columns which are included in the final result whatever the types
- * of them are. And it must be a subset of featureCols.
- * @return the categoricalCols.
- */
- public static String[] getCategoricalCols(
- TableSchema tableSchema, String[] featureCols, String[] categoricalCols) {
- if (null == featureCols) {
- return categoricalCols;
- }
- List<String> categoricalList =
- null == categoricalCols ? null : Arrays.asList(categoricalCols);
- List<String> featureList = Arrays.asList(featureCols);
- if (null != categoricalCols && !featureList.containsAll(categoricalList)) {
- throw new IllegalArgumentException("CategoricalCols must be included in featureCols!");
- }
-
- TypeInformation<?>[] featureColTypes = findColTypes(tableSchema, featureCols);
- List<String> res = new ArrayList<>();
- for (int i = 0; i < featureCols.length; i++) {
- boolean included = null != categoricalList && categoricalList.contains(featureCols[i]);
- if (included
- || Types.BOOLEAN == featureColTypes[i]
- || Types.STRING == featureColTypes[i]) {
- res.add(featureCols[i]);
- }
- }
-
- return res.toArray(new String[0]);
- }
-
- /** format the column names as header of markdown. */
- public static String formatTitle(String[] colNames) {
- StringBuilder sbd = new StringBuilder();
- StringBuilder sbdSplitter = new StringBuilder();
-
- for (int i = 0; i < colNames.length; ++i) {
- if (i > 0) {
- sbd.append("|");
- sbdSplitter.append("|");
- }
-
- sbd.append(colNames[i]);
-
- int t = null == colNames[i] ? 4 : colNames[i].length();
- for (int j = 0; j < t; j++) {
- sbdSplitter.append("-");
- }
- }
-
- return sbd.toString() + "\r\n" + sbdSplitter.toString();
- }
-
- /** format the row as body of markdown. */
- public static String formatRows(Row row) {
- StringBuilder sbd = new StringBuilder();
-
- for (int i = 0; i < row.getArity(); ++i) {
- if (i > 0) {
- sbd.append("|");
- }
- Object obj = row.getField(i);
- if (obj instanceof Double || obj instanceof Float) {
- sbd.append(String.format("%.4f", (double) obj));
- } else {
- sbd.append(obj);
- }
- }
-
- return sbd.toString();
- }
-
- /** format the column names and rows in table as markdown. */
- public static String format(String[] colNames, List<Row> data) {
- StringBuilder sbd = new StringBuilder();
- sbd.append(formatTitle(colNames));
-
- for (Row row : data) {
- sbd.append("\n").append(formatRows(row));
- }
-
- return sbd.toString();
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/VectorTypes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/VectorTypes.java
deleted file mode 100644
index ce49609..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/utils/VectorTypes.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.ml.common.linalg.DenseVector;
-import org.apache.flink.ml.common.linalg.SparseVector;
-import org.apache.flink.ml.common.linalg.Vector;
-
-/** Built-in vector types. */
-public class VectorTypes {
- /** <code>DenseVector</code> type information. */
- public static final TypeInformation<DenseVector> DENSE_VECTOR =
- TypeInformation.of(DenseVector.class);
-
- /** <code>SparseVector</code> type information. */
- public static final TypeInformation<SparseVector> SPARSE_VECTOR =
- TypeInformation.of(SparseVector.class);
-
- /**
- * <code>Vector</code> type information. For efficiency, use type information of sub-class
- * <code>DenseVector</code> and <code>SparseVector</code> as much as possible. When an operator
- * output both sub-class type of vectors, use this one.
- */
- public static final TypeInformation<Vector> VECTOR = TypeInformation.of(Vector.class);
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/AlgoOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/AlgoOperator.java
deleted file mode 100644
index 2efaca7..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/AlgoOperator.java
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.operator;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.ml.api.misc.param.Params;
-import org.apache.flink.ml.api.misc.param.WithParams;
-import org.apache.flink.ml.params.shared.HasMLEnvironmentId;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.util.Preconditions;
-
-import java.io.Serializable;
-
-/**
- * Base class for algorithm operators.
- *
- * <p>Base class for the algorithm operators. It hosts the parameters and output tables of an
- * algorithm operator. Each AlgoOperator may have one or more output tables. One of the output table
- * is the primary output table which can be obtained by calling {@link #getOutput}. The other output
- * tables are side output tables that can be obtained by calling {@link #getSideOutputs()}.
- *
- * <p>The input of an AlgoOperator is defined in the subclasses of the AlgoOperator.
- *
- * @param <T> The class type of the {@link AlgoOperator} implementation itself
- */
-public abstract class AlgoOperator<T extends AlgoOperator<T>>
- implements WithParams<T>, HasMLEnvironmentId<T>, Serializable {
-
- /** Params for algorithms. */
- private Params params;
-
- /** The table held by operator. */
- private Table output = null;
-
- /** The side outputs of operator that be similar to the stream's side outputs. */
- private Table[] sideOutputs = null;
-
- /**
- * Construct the operator with empty Params.
- *
- * <p>This constructor is especially useful when users want to set parameters for the algorithm
- * operators. For example: SplitBatchOp is widely used in ML data pre-processing, which splits
- * one dataset into two dataset: training set and validation set. It is very convenient for us
- * to write code like this:
- *
- * <pre>{@code
- * new SplitBatchOp().setSplitRatio(0.9)
- * }</pre>
- */
- protected AlgoOperator() {
- this(null);
- }
-
- /** Construct the operator with the initial Params. */
- protected AlgoOperator(Params params) {
- if (null == params) {
- this.params = new Params();
- } else {
- this.params = params.clone();
- }
- }
-
- @Override
- public Params getParams() {
- return this.params;
- }
-
- /** Returns the table held by operator. */
- public Table getOutput() {
- return this.output;
- }
-
- /** Returns the side outputs. */
- public Table[] getSideOutputs() {
- return this.sideOutputs;
- }
-
- /**
- * Set the side outputs.
- *
- * @param sideOutputs the side outputs set the operator.
- */
- protected void setSideOutputs(Table[] sideOutputs) {
- this.sideOutputs = sideOutputs;
- }
-
- /**
- * Set the table held by operator.
- *
- * @param output the output table.
- */
- protected void setOutput(Table output) {
- this.output = output;
- }
-
- /** Returns the column names of the output table. */
- public String[] getColNames() {
- return getSchema().getFieldNames();
- }
-
- /** Returns the column types of the output table. */
- public TypeInformation<?>[] getColTypes() {
- return getSchema().getFieldTypes();
- }
-
- /**
- * Get the column names of the specified side-output table.
- *
- * @param index the index of the table.
- * @return the column types of the table.
- */
- public String[] getSideOutputColNames(int index) {
- checkSideOutputAccessibility(index);
-
- return sideOutputs[index].getSchema().getFieldNames();
- }
-
- /**
- * Get the column types of the specified side-output table.
- *
- * @param index the index of the table.
- * @return the column types of the table.
- */
- public TypeInformation<?>[] getSideOutputColTypes(int index) {
- checkSideOutputAccessibility(index);
-
- return sideOutputs[index].getSchema().getFieldTypes();
- }
-
- /** Returns the schema of the output table. */
- public TableSchema getSchema() {
- return this.getOutput().getSchema();
- }
-
- @Override
- public String toString() {
- return getOutput().toString();
- }
-
- protected static void checkOpSize(int size, AlgoOperator<?>... inputs) {
- Preconditions.checkNotNull(inputs, "Operators should not be null.");
- Preconditions.checkState(
- inputs.length == size,
- "The size of operators should be equal to " + size + ", current: " + inputs.length);
- }
-
- protected static void checkMinOpSize(int size, AlgoOperator<?>... inputs) {
- Preconditions.checkNotNull(inputs, "Operators should not be null.");
- Preconditions.checkState(
- inputs.length >= size,
- "The size of operators should be equal or greater than "
- + size
- + ", current: "
- + inputs.length);
- }
-
- private void checkSideOutputAccessibility(int index) {
- Preconditions.checkNotNull(sideOutputs, "There is not side-outputs in this AlgoOperator.");
- Preconditions.checkState(
- index >= 0 && index < sideOutputs.length,
- String.format("The index(%s) of side-outputs is out of bound.", index));
- Preconditions.checkNotNull(
- sideOutputs[index],
- String.format(
- "The %snd of side-outputs is null. Maybe the operator has not been linked.",
- index));
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/BatchOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/BatchOperator.java
deleted file mode 100644
index 1ddfdb3..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/BatchOperator.java
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.operator.batch;
-
-import org.apache.flink.ml.api.misc.param.Params;
-import org.apache.flink.ml.operator.AlgoOperator;
-import org.apache.flink.ml.operator.batch.source.TableSourceBatchOp;
-import org.apache.flink.table.api.Table;
-
-/**
- * Base class of batch algorithm operators.
- *
- * <p>This class extends {@link AlgoOperator} to support data transmission between BatchOperators.
- */
-public abstract class BatchOperator<T extends BatchOperator<T>> extends AlgoOperator<T> {
-
- public BatchOperator() {
- super();
- }
-
- /**
- * The constructor of BatchOperator with {@link Params}.
- *
- * @param params the initial Params.
- */
- public BatchOperator(Params params) {
- super(params);
- }
-
- /**
- * Link to another {@link BatchOperator}.
- *
- * <p>Link the <code>next</code> BatchOperator using this BatchOperator as its input.
- *
- * <p>For example:
- *
- * <pre>{@code
- * BatchOperator a = ...;
- * BatchOperator b = ...;
- * BatchOperator c = a.link(b)
- * }</pre>
- *
- * <p>The BatchOperator <code>c</code> in the above code is the same instance as <code>b</code>
- * which takes <code>a</code> as its input. Note that BatchOperator <code>b</code> will be
- * changed to link from BatchOperator <code>a</code>.
- *
- * @param next The operator that will be modified to add this operator to its input.
- * @param <B> type of BatchOperator returned
- * @return the linked next
- * @see #linkFrom(BatchOperator[])
- */
- public <B extends BatchOperator<?>> B link(B next) {
- next.linkFrom(this);
- return next;
- }
-
- /**
- * Link from others {@link BatchOperator}.
- *
- * <p>Link this object to BatchOperator using the BatchOperators as its input.
- *
- * <p>For example:
- *
- * <pre>{@code
- * BatchOperator a = ...;
- * BatchOperator b = ...;
- * BatchOperator c = ...;
- *
- * BatchOperator d = c.linkFrom(a, b)
- * }</pre>
- *
- * <p>The <code>d</code> in the above code is the same instance as BatchOperator <code>c</code>
- * which takes both <code>a</code> and <code>b</code> as its input.
- *
- * <p>note: It is not recommended to linkFrom itself or linkFrom the same group inputs twice.
- *
- * @param inputs the linked inputs
- * @return the linked this object
- */
- public abstract T linkFrom(BatchOperator<?>... inputs);
-
- /**
- * create a new BatchOperator from table.
- *
- * @param table the input table
- * @return the new BatchOperator
- */
- public static BatchOperator<?> fromTable(Table table) {
- return new TableSourceBatchOp(table);
- }
-
- protected static BatchOperator<?> checkAndGetFirst(BatchOperator<?>... inputs) {
- checkOpSize(1, inputs);
- return inputs[0];
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/source/TableSourceBatchOp.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/source/TableSourceBatchOp.java
deleted file mode 100644
index f53ae12..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/batch/source/TableSourceBatchOp.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.operator.batch.source;
-
-import org.apache.flink.ml.operator.batch.BatchOperator;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.util.Preconditions;
-
-/** Transform the Table to SourceBatchOp. */
-public final class TableSourceBatchOp extends BatchOperator<TableSourceBatchOp> {
-
- public TableSourceBatchOp(Table table) {
- super(null);
- Preconditions.checkArgument(table != null, "The source table cannot be null.");
- this.setOutput(table);
- }
-
- @Override
- public TableSourceBatchOp linkFrom(BatchOperator<?>... inputs) {
- throw new UnsupportedOperationException(
- "Table source operator should not have any upstream to link from.");
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/StreamOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/StreamOperator.java
deleted file mode 100644
index 373556c..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/StreamOperator.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.operator.stream;
-
-import org.apache.flink.ml.api.misc.param.Params;
-import org.apache.flink.ml.operator.AlgoOperator;
-import org.apache.flink.ml.operator.stream.source.TableSourceStreamOp;
-import org.apache.flink.table.api.Table;
-
-/**
- * Base class of stream algorithm operators.
- *
- * <p>This class extends {@link AlgoOperator} to support data transmission between StreamOperator.
- */
-public abstract class StreamOperator<T extends StreamOperator<T>> extends AlgoOperator<T> {
-
- public StreamOperator() {
- super();
- }
-
- /**
- * The constructor of StreamOperator with {@link Params}.
- *
- * @param params the initial Params.
- */
- public StreamOperator(Params params) {
- super(params);
- }
-
- /**
- * Link to another {@link StreamOperator}.
- *
- * <p>Link the <code>next</code> StreamOperator using this StreamOperator as its input.
- *
- * <p>For example:
- *
- * <pre>{@code
- * StreamOperator a = ...;
- * StreamOperator b = ...;
- *
- * StreamOperator c = a.link(b)
- * }</pre>
- *
- * <p>The StreamOperator <code>c</code> in the above code is the same instance as <code>b</code>
- * which takes <code>a</code> as its input. Note that StreamOperator <code>b</code> will be
- * changed to link from StreamOperator <code>a</code>.
- *
- * @param next the linked StreamOperator
- * @param <S> type of StreamOperator returned
- * @return the linked next
- * @see #linkFrom(StreamOperator[])
- */
- public <S extends StreamOperator<?>> S link(S next) {
- next.linkFrom(this);
- return next;
- }
-
- /**
- * Link from others {@link StreamOperator}.
- *
- * <p>Link this object to StreamOperator using the StreamOperators as its input.
- *
- * <p>For example:
- *
- * <pre>{@code
- * StreamOperator a = ...;
- * StreamOperator b = ...;
- * StreamOperator c = ...;
- *
- * StreamOperator d = c.linkFrom(a, b)
- * }</pre>
- *
- * <p>The <code>d</code> in the above code is the same instance as StreamOperator <code>c</code>
- * which takes both <code>a</code> and <code>b</code> as its input.
- *
- * <p>note: It is not recommended to linkFrom itself or linkFrom the same group inputs twice.
- *
- * @param inputs the linked inputs
- * @return the linked this object
- */
- public abstract T linkFrom(StreamOperator<?>... inputs);
-
- /**
- * create a new StreamOperator from table.
- *
- * @param table the input table
- * @return the new StreamOperator
- */
- public static StreamOperator<?> fromTable(Table table) {
- return new TableSourceStreamOp(table);
- }
-
- protected static StreamOperator<?> checkAndGetFirst(StreamOperator<?>... inputs) {
- checkOpSize(1, inputs);
- return inputs[0];
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/source/TableSourceStreamOp.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/source/TableSourceStreamOp.java
deleted file mode 100644
index 90a284e..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/operator/stream/source/TableSourceStreamOp.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.operator.stream.source;
-
-import org.apache.flink.ml.operator.stream.StreamOperator;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.util.Preconditions;
-
-/** Transform the Table to SourceStreamOp. */
-public final class TableSourceStreamOp extends StreamOperator<TableSourceStreamOp> {
-
- public TableSourceStreamOp(Table table) {
- super(null);
- Preconditions.checkArgument(table != null, "The source table cannot be null.");
- this.setOutput(table);
- }
-
- @Override
- public TableSourceStreamOp linkFrom(StreamOperator<?>... inputs) {
- throw new UnsupportedOperationException(
- "Table source operator should not have any upstream to link from.");
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/HasMLEnvironmentId.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/HasMLEnvironmentId.java
deleted file mode 100644
index 4ea1b6c..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/HasMLEnvironmentId.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-import org.apache.flink.ml.common.MLEnvironmentFactory;
-
-/** An interface for classes with a parameter specifying the id of MLEnvironment. */
-public interface HasMLEnvironmentId<T> extends WithParams<T> {
-
- ParamInfo<Long> ML_ENVIRONMENT_ID =
- ParamInfoFactory.createParamInfo("MLEnvironmentId", Long.class)
- .setDescription("ID of ML environment.")
- .setHasDefaultValue(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID)
- .build();
-
- default Long getMLEnvironmentId() {
- return get(ML_ENVIRONMENT_ID);
- }
-
- default T setMLEnvironmentId(Long value) {
- return set(ML_ENVIRONMENT_ID, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCol.java
deleted file mode 100644
index e731b8e..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCol.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the name of the output column.
- *
- * @see HasOutputCols
- * @see HasOutputColDefaultAsNull
- * @see HasOutputColsDefaultAsNull
- */
-public interface HasOutputCol<T> extends WithParams<T> {
-
- ParamInfo<String> OUTPUT_COL =
- ParamInfoFactory.createParamInfo("outputCol", String.class)
- .setDescription("Name of the output column")
- .setRequired()
- .build();
-
- default String getOutputCol() {
- return get(OUTPUT_COL);
- }
-
- default T setOutputCol(String value) {
- return set(OUTPUT_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColDefaultAsNull.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColDefaultAsNull.java
deleted file mode 100644
index 1846f89..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColDefaultAsNull.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying name of the output column with a null
- * default value.
- *
- * @see HasOutputCol
- * @see HasOutputCols
- * @see HasOutputColsDefaultAsNull
- */
-public interface HasOutputColDefaultAsNull<T> extends WithParams<T> {
-
- ParamInfo<String> OUTPUT_COL =
- ParamInfoFactory.createParamInfo("outputCol", String.class)
- .setDescription("Name of the output column")
- .setHasDefaultValue(null)
- .build();
-
- default String getOutputCol() {
- return get(OUTPUT_COL);
- }
-
- default T setOutputCol(String value) {
- return set(OUTPUT_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCols.java
deleted file mode 100644
index 7bccc73..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputCols.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying names of multiple output columns.
- *
- * @see HasOutputCol
- * @see HasOutputColDefaultAsNull
- * @see HasOutputColsDefaultAsNull
- */
-public interface HasOutputCols<T> extends WithParams<T> {
-
- ParamInfo<String[]> OUTPUT_COLS =
- ParamInfoFactory.createParamInfo("outputCols", String[].class)
- .setDescription("Names of the output columns")
- .setRequired()
- .build();
-
- default String[] getOutputCols() {
- return get(OUTPUT_COLS);
- }
-
- default T setOutputCols(String... value) {
- return set(OUTPUT_COLS, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColsDefaultAsNull.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColsDefaultAsNull.java
deleted file mode 100644
index 2b9eacd..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasOutputColsDefaultAsNull.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying names of multiple output columns. The
- * default parameter value is null.
- *
- * @see HasOutputCol
- * @see HasOutputColDefaultAsNull
- * @see HasOutputCols
- */
-public interface HasOutputColsDefaultAsNull<T> extends WithParams<T> {
-
- ParamInfo<String[]> OUTPUT_COLS =
- ParamInfoFactory.createParamInfo("outputCols", String[].class)
- .setDescription("Names of the output columns")
- .setHasDefaultValue(null)
- .build();
-
- default String[] getOutputCols() {
- return get(OUTPUT_COLS);
- }
-
- default T setOutputCols(String... value) {
- return set(OUTPUT_COLS, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionCol.java
deleted file mode 100644
index 1bfd493..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionCol.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/** An interface for classes with a parameter specifying the column name of the prediction. */
-public interface HasPredictionCol<T> extends WithParams<T> {
-
- ParamInfo<String> PREDICTION_COL =
- ParamInfoFactory.createParamInfo("predictionCol", String.class)
- .setDescription("Column name of prediction.")
- .setRequired()
- .build();
-
- default String getPredictionCol() {
- return get(PREDICTION_COL);
- }
-
- default T setPredictionCol(String value) {
- return set(PREDICTION_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionDetailCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionDetailCol.java
deleted file mode 100644
index 4e409fa..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasPredictionDetailCol.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the column name of prediction detail.
- *
- * <p>The detail is the information of prediction result, such as the probability of each label in
- * classifier.
- */
-public interface HasPredictionDetailCol<T> extends WithParams<T> {
-
- ParamInfo<String> PREDICTION_DETAIL_COL =
- ParamInfoFactory.createParamInfo("predictionDetailCol", String.class)
- .setDescription(
- "Column name of prediction result, it will include detailed info.")
- .build();
-
- default String getPredictionDetailCol() {
- return get(PREDICTION_DETAIL_COL);
- }
-
- default T setPredictionDetailCol(String value) {
- return set(PREDICTION_DETAIL_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasReservedCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasReservedCols.java
deleted file mode 100644
index f06689e..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasReservedCols.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the names of the columns to be retained in
- * the output table.
- */
-public interface HasReservedCols<T> extends WithParams<T> {
-
- ParamInfo<String[]> RESERVED_COLS =
- ParamInfoFactory.createParamInfo("reservedCols", String[].class)
- .setDescription("Names of the columns to be retained in the output table")
- .setHasDefaultValue(null)
- .build();
-
- default String[] getReservedCols() {
- return get(RESERVED_COLS);
- }
-
- default T setReservedCols(String... value) {
- return set(RESERVED_COLS, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCol.java
deleted file mode 100644
index 88560dc..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCol.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the name of the table column.
- *
- * @see HasSelectedColDefaultAsNull
- * @see HasSelectedCols
- * @see HasSelectedColsDefaultAsNull
- */
-public interface HasSelectedCol<T> extends WithParams<T> {
-
- ParamInfo<String> SELECTED_COL =
- ParamInfoFactory.createParamInfo("selectedCol", String.class)
- .setDescription("Name of the selected column used for processing")
- .setRequired()
- .build();
-
- default String getSelectedCol() {
- return get(SELECTED_COL);
- }
-
- default T setSelectedCol(String value) {
- return set(SELECTED_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColDefaultAsNull.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColDefaultAsNull.java
deleted file mode 100644
index 72a1dd7..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColDefaultAsNull.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the name of the table column with null
- * default value.
- *
- * @see HasSelectedCol
- * @see HasSelectedCols
- * @see HasSelectedColsDefaultAsNull
- */
-public interface HasSelectedColDefaultAsNull<T> extends WithParams<T> {
-
- ParamInfo<String> SELECTED_COL =
- ParamInfoFactory.createParamInfo("selectedCol", String.class)
- .setDescription("Name of the selected column used for processing")
- .setHasDefaultValue(null)
- .build();
-
- default String getSelectedCol() {
- return get(SELECTED_COL);
- }
-
- default T setSelectedCol(String value) {
- return set(SELECTED_COL, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCols.java
deleted file mode 100644
index 68b5bba..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedCols.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the name of multiple table columns.
- *
- * @see HasSelectedCol
- * @see HasSelectedColDefaultAsNull
- * @see HasSelectedColsDefaultAsNull
- */
-public interface HasSelectedCols<T> extends WithParams<T> {
-
- ParamInfo<String[]> SELECTED_COLS =
- ParamInfoFactory.createParamInfo("selectedCols", String[].class)
- .setDescription("Names of the columns used for processing")
- .setRequired()
- .build();
-
- default String[] getSelectedCols() {
- return get(SELECTED_COLS);
- }
-
- default T setSelectedCols(String... value) {
- return set(SELECTED_COLS, value);
- }
-}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColsDefaultAsNull.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColsDefaultAsNull.java
deleted file mode 100644
index 5e2801f..0000000
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/params/shared/colname/HasSelectedColsDefaultAsNull.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.params.shared.colname;
-
-import org.apache.flink.ml.api.misc.param.ParamInfo;
-import org.apache.flink.ml.api.misc.param.ParamInfoFactory;
-import org.apache.flink.ml.api.misc.param.WithParams;
-
-/**
- * An interface for classes with a parameter specifying the name of multiple table columns with null
- * default value.
- *
- * @see HasSelectedCol
- * @see HasSelectedColDefaultAsNull
- * @see HasSelectedCols
- */
-public interface HasSelectedColsDefaultAsNull<T> extends WithParams<T> {
-
- ParamInfo<String[]> SELECTED_COLS =
- ParamInfoFactory.createParamInfo("selectedCols", String[].class)
- .setDescription("Names of the columns used for processing")
- .setHasDefaultValue(null)
- .build();
-
- default String[] getSelectedCols() {
- return get(SELECTED_COLS);
- }
-
- default T setSelectedCols(String... value) {
- return set(SELECTED_COLS, value);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/MLEnvironmentTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/MLEnvironmentTest.java
deleted file mode 100644
index db51622..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/MLEnvironmentTest.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common;
-
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.EnvironmentSettings;
-import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-/** Test cases for MLEnvironment. */
-public class MLEnvironmentTest {
- @Test
- public void testDefaultConstructor() {
- MLEnvironment mlEnvironment = new MLEnvironment();
- Assert.assertNotNull(mlEnvironment.getStreamExecutionEnvironment());
- Assert.assertNotNull(mlEnvironment.getStreamTableEnvironment());
- }
-
- @Test
- public void testConstructWithStreamEnv() {
- StreamExecutionEnvironment streamExecutionEnvironment =
- StreamExecutionEnvironment.getExecutionEnvironment();
- StreamTableEnvironment streamTableEnvironment =
- StreamTableEnvironment.create(
- streamExecutionEnvironment, EnvironmentSettings.newInstance().build());
-
- MLEnvironment mlEnvironment =
- new MLEnvironment(streamExecutionEnvironment, streamTableEnvironment);
-
- Assert.assertSame(
- mlEnvironment.getStreamExecutionEnvironment(), streamExecutionEnvironment);
- Assert.assertSame(mlEnvironment.getStreamTableEnvironment(), streamTableEnvironment);
- }
-
- @Test
- public void testRemoveDefaultMLEnvironment() {
- MLEnvironment defaultEnv = MLEnvironmentFactory.getDefault();
- MLEnvironmentFactory.remove(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID);
- assertEquals(
- "The default MLEnvironment should not have been removed",
- defaultEnv,
- MLEnvironmentFactory.get(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID));
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java
deleted file mode 100644
index 13e67ff..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/BLASTest.java
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-
-/** The test cases for {@link BLAS}. */
-public class BLASTest {
- private static final double TOL = 1.0e-8;
- private DenseMatrix mat = new DenseMatrix(2, 3, new double[] {1, 4, 2, 5, 3, 6});
- private DenseVector dv1 = new DenseVector(new double[] {1, 2});
- private DenseVector dv2 = new DenseVector(new double[] {1, 2, 3});
- private SparseVector spv1 = new SparseVector(2, new int[] {0, 1}, new double[] {1, 2});
- private SparseVector spv2 = new SparseVector(3, new int[] {0, 2}, new double[] {1, 3});
-
- @Rule public ExpectedException thrown = ExpectedException.none();
-
- @Test
- public void testAsum() throws Exception {
- Assert.assertEquals(BLAS.asum(dv1), 3.0, TOL);
- Assert.assertEquals(BLAS.asum(spv1), 3.0, TOL);
- }
-
- @Test
- public void testScal() throws Exception {
- DenseVector v1 = dv1.clone();
- BLAS.scal(0.5, v1);
- Assert.assertArrayEquals(v1.getData(), new double[] {0.5, 1.0}, TOL);
-
- SparseVector v2 = spv1.clone();
- BLAS.scal(0.5, v2);
- Assert.assertArrayEquals(v2.getIndices(), spv1.getIndices());
- Assert.assertArrayEquals(v2.getValues(), new double[] {0.5, 1.0}, TOL);
- }
-
- @Test
- public void testDot() throws Exception {
- DenseVector v = DenseVector.ones(2);
- Assert.assertEquals(BLAS.dot(dv1, v), 3.0, TOL);
- }
-
- @Test
- public void testAxpy() throws Exception {
- DenseVector v = DenseVector.ones(2);
- BLAS.axpy(1.0, dv1, v);
- Assert.assertArrayEquals(v.getData(), new double[] {2, 3}, TOL);
- BLAS.axpy(1.0, spv1, v);
- Assert.assertArrayEquals(v.getData(), new double[] {3, 5}, TOL);
- BLAS.axpy(1, 1.0, new double[] {1}, 0, v.getData(), 1);
- Assert.assertArrayEquals(v.getData(), new double[] {3, 6}, TOL);
- }
-
- private DenseMatrix simpleMM(DenseMatrix m1, DenseMatrix m2) {
- DenseMatrix mm = new DenseMatrix(m1.numRows(), m2.numCols());
- for (int i = 0; i < m1.numRows(); i++) {
- for (int j = 0; j < m2.numCols(); j++) {
- double s = 0.;
- for (int k = 0; k < m1.numCols(); k++) {
- s += m1.get(i, k) * m2.get(k, j);
- }
- mm.set(i, j, s);
- }
- }
- return mm;
- }
-
- @Test
- public void testGemm() throws Exception {
- DenseMatrix m32 = DenseMatrix.rand(3, 2);
- DenseMatrix m24 = DenseMatrix.rand(2, 4);
- DenseMatrix m34 = DenseMatrix.rand(3, 4);
- DenseMatrix m42 = DenseMatrix.rand(4, 2);
- DenseMatrix m43 = DenseMatrix.rand(4, 3);
-
- DenseMatrix a34 = DenseMatrix.zeros(3, 4);
- BLAS.gemm(1.0, m32, false, m24, false, 0., a34);
- Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m24).getData(), TOL);
-
- BLAS.gemm(1.0, m32, false, m42, true, 0., a34);
- Assert.assertArrayEquals(a34.getData(), simpleMM(m32, m42.transpose()).getData(), TOL);
-
- DenseMatrix a24 = DenseMatrix.zeros(2, 4);
- BLAS.gemm(1.0, m32, true, m34, false, 0., a24);
- Assert.assertArrayEquals(a24.getData(), simpleMM(m32.transpose(), m34).getData(), TOL);
-
- BLAS.gemm(1.0, m32, true, m43, true, 0., a24);
- Assert.assertArrayEquals(
- a24.getData(), simpleMM(m32.transpose(), m43.transpose()).getData(), TOL);
- }
-
- @Test
- public void testGemmSizeCheck() throws Exception {
- thrown.expect(IllegalArgumentException.class);
- DenseMatrix m32 = DenseMatrix.rand(3, 2);
- DenseMatrix m42 = DenseMatrix.rand(4, 2);
- DenseMatrix a34 = DenseMatrix.zeros(3, 4);
- BLAS.gemm(1.0, m32, false, m42, false, 0., a34);
- }
-
- @Test
- public void testGemmTransposeSizeCheck() throws Exception {
- thrown.expect(IllegalArgumentException.class);
- DenseMatrix m32 = DenseMatrix.rand(3, 2);
- DenseMatrix m42 = DenseMatrix.rand(4, 2);
- DenseMatrix a34 = DenseMatrix.zeros(3, 4);
- BLAS.gemm(1.0, m32, true, m42, true, 0., a34);
- }
-
- @Test
- public void testGemvDense() throws Exception {
- DenseVector y1 = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, false, dv2, 0., y1);
- Assert.assertArrayEquals(new double[] {28, 64}, y1.data, TOL);
-
- DenseVector y2 = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, false, dv2, 1., y2);
- Assert.assertArrayEquals(new double[] {29, 65}, y2.data, TOL);
- }
-
- @Test
- public void testGemvDenseTranspose() throws Exception {
- DenseVector y1 = DenseVector.ones(3);
- BLAS.gemv(1.0, mat, true, dv1, 0., y1);
- Assert.assertArrayEquals(new double[] {9, 12, 15}, y1.data, TOL);
-
- DenseVector y2 = DenseVector.ones(3);
- BLAS.gemv(1.0, mat, true, dv1, 1., y2);
- Assert.assertArrayEquals(new double[] {10, 13, 16}, y2.data, TOL);
- }
-
- @Test
- public void testGemvSparse() throws Exception {
- DenseVector y1 = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, false, spv2, 0., y1);
- Assert.assertArrayEquals(new double[] {20, 44}, y1.data, TOL);
-
- DenseVector y2 = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, false, spv2, 1., y2);
- Assert.assertArrayEquals(new double[] {21, 45}, y2.data, TOL);
- }
-
- @Test
- public void testGemvSparseTranspose() throws Exception {
- DenseVector y1 = DenseVector.ones(3);
- BLAS.gemv(2.0, mat, true, spv1, 0., y1);
- Assert.assertArrayEquals(new double[] {18, 24, 30}, y1.data, TOL);
-
- DenseVector y2 = DenseVector.ones(3);
- BLAS.gemv(2.0, mat, true, spv1, 1., y2);
- Assert.assertArrayEquals(new double[] {19, 25, 31}, y2.data, TOL);
- }
-
- @Test
- public void testGemvSizeCheck() throws Exception {
- thrown.expect(IllegalArgumentException.class);
- DenseVector y = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, false, dv1, 0., y);
- }
-
- @Test
- public void testGemvTransposeSizeCheck() throws Exception {
- thrown.expect(IllegalArgumentException.class);
- DenseVector y = DenseVector.ones(2);
- BLAS.gemv(2.0, mat, true, dv1, 0., y);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java
deleted file mode 100644
index 40cb9a4..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseMatrixTest.java
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-/** Test cases for DenseMatrix. */
-public class DenseMatrixTest {
-
- private static final double TOL = 1.0e-6;
-
- private static void assertEqual2D(double[][] matA, double[][] matB) {
- assert (matA.length == matB.length);
- assert (matA[0].length == matB[0].length);
- int m = matA.length;
- int n = matA[0].length;
- for (int i = 0; i < m; i++) {
- for (int j = 0; j < n; j++) {
- Assert.assertEquals(matA[i][j], matB[i][j], TOL);
- }
- }
- }
-
- private static double[][] simpleMM(double[][] matA, double[][] matB) {
- int m = matA.length;
- int n = matB[0].length;
- int k = matA[0].length;
- double[][] matC = new double[m][n];
- for (int i = 0; i < m; i++) {
- for (int j = 0; j < n; j++) {
- matC[i][j] = 0.;
- for (int l = 0; l < k; l++) {
- matC[i][j] += matA[i][l] * matB[l][j];
- }
- }
- }
- return matC;
- }
-
- private static double[] simpleMV(double[][] matA, double[] x) {
- int m = matA.length;
- int n = matA[0].length;
- assert (n == x.length);
- double[] y = new double[m];
- for (int i = 0; i < m; i++) {
- y[i] = 0.;
- for (int j = 0; j < n; j++) {
- y[i] += matA[i][j] * x[j];
- }
- }
- return y;
- }
-
- @Test
- public void testPlusEquals() throws Exception {
- DenseMatrix matA =
- new DenseMatrix(
- new double[][] {
- new double[] {1, 3, 5},
- new double[] {2, 4, 6},
- });
- DenseMatrix matB = DenseMatrix.ones(2, 3);
- matA.plusEquals(matB);
- Assert.assertArrayEquals(matA.getData(), new double[] {2, 3, 4, 5, 6, 7}, TOL);
- matA.plusEquals(1.0);
- Assert.assertArrayEquals(matA.getData(), new double[] {3, 4, 5, 6, 7, 8}, TOL);
- }
-
- @Test
- public void testMinusEquals() throws Exception {
- DenseMatrix matA =
- new DenseMatrix(
- new double[][] {
- new double[] {1, 3, 5},
- new double[] {2, 4, 6},
- });
- DenseMatrix matB = DenseMatrix.ones(2, 3);
- matA.minusEquals(matB);
- Assert.assertArrayEquals(matA.getData(), new double[] {0, 1, 2, 3, 4, 5}, TOL);
- }
-
- @Test
- public void testPlus() throws Exception {
- DenseMatrix matA =
- new DenseMatrix(
- new double[][] {
- new double[] {1, 3, 5},
- new double[] {2, 4, 6},
- });
- DenseMatrix matB = DenseMatrix.ones(2, 3);
- DenseMatrix matC = matA.plus(matB);
- Assert.assertArrayEquals(matC.getData(), new double[] {2, 3, 4, 5, 6, 7}, TOL);
- DenseMatrix matD = matA.plus(1.0);
- Assert.assertArrayEquals(matD.getData(), new double[] {2, 3, 4, 5, 6, 7}, TOL);
- }
-
- @Test
- public void testMinus() throws Exception {
- DenseMatrix matA =
- new DenseMatrix(
- new double[][] {
- new double[] {1, 3, 5},
- new double[] {2, 4, 6},
- });
- DenseMatrix matB = DenseMatrix.ones(2, 3);
- DenseMatrix matC = matA.minus(matB);
- Assert.assertArrayEquals(matC.getData(), new double[] {0, 1, 2, 3, 4, 5}, TOL);
- }
-
- @Test
- public void testMM() throws Exception {
- DenseMatrix matA = DenseMatrix.rand(4, 3);
- DenseMatrix matB = DenseMatrix.rand(3, 5);
- DenseMatrix matC = matA.multiplies(matB);
- assertEqual2D(
- matC.getArrayCopy2D(), simpleMM(matA.getArrayCopy2D(), matB.getArrayCopy2D()));
-
- DenseMatrix matD = new DenseMatrix(5, 4);
- BLAS.gemm(1., matB, true, matA, true, 0., matD);
- Assert.assertArrayEquals(matD.transpose().getData(), matC.data, TOL);
- }
-
- @Test
- public void testMV() throws Exception {
- DenseMatrix matA = DenseMatrix.rand(4, 3);
- DenseVector x = DenseVector.ones(3);
- DenseVector y = matA.multiplies(x);
- Assert.assertArrayEquals(y.getData(), simpleMV(matA.getArrayCopy2D(), x.getData()), TOL);
-
- SparseVector x2 = new SparseVector(3, new int[] {0, 1, 2}, new double[] {1, 1, 1});
- DenseVector y2 = matA.multiplies(x2);
- Assert.assertArrayEquals(y2.getData(), y.getData(), TOL);
- }
-
- @Test
- public void testDataSelection() throws Exception {
- DenseMatrix mat =
- new DenseMatrix(
- new double[][] {
- new double[] {1, 2, 3},
- new double[] {4, 5, 6},
- new double[] {7, 8, 9},
- });
- DenseMatrix sub1 = mat.selectRows(new int[] {1});
- DenseMatrix sub2 = mat.getSubMatrix(1, 2, 1, 2);
- Assert.assertEquals(sub1.numRows(), 1);
- Assert.assertEquals(sub1.numCols(), 3);
- Assert.assertEquals(sub2.numRows(), 1);
- Assert.assertEquals(sub2.numCols(), 1);
- Assert.assertArrayEquals(sub1.getData(), new double[] {4, 5, 6}, TOL);
- Assert.assertArrayEquals(sub2.getData(), new double[] {5}, TOL);
-
- double[] row = mat.getRow(1);
- double[] col = mat.getColumn(1);
- Assert.assertArrayEquals(row, new double[] {4, 5, 6}, 0.);
- Assert.assertArrayEquals(col, new double[] {2, 5, 8}, 0.);
- }
-
- @Test
- public void testSum() throws Exception {
- DenseMatrix matA = DenseMatrix.ones(3, 2);
- Assert.assertEquals(matA.sum(), 6.0, TOL);
- }
-
- @Test
- public void testRowMajorFormat() throws Exception {
- double[] data = new double[] {1, 2, 3, 4, 5, 6};
- DenseMatrix matA = new DenseMatrix(2, 3, data, true);
- Assert.assertArrayEquals(data, new double[] {1, 4, 2, 5, 3, 6}, 0.);
- Assert.assertArrayEquals(matA.getData(), new double[] {1, 4, 2, 5, 3, 6}, 0.);
-
- data = new double[] {1, 2, 3, 4};
- matA = new DenseMatrix(2, 2, data, true);
- Assert.assertArrayEquals(data, new double[] {1, 3, 2, 4}, 0.);
- Assert.assertArrayEquals(matA.getData(), new double[] {1, 3, 2, 4}, 0.);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java
deleted file mode 100644
index 7859317..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/DenseVectorTest.java
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-/** Test cases for DenseVector. */
-public class DenseVectorTest {
- private static final double TOL = 1.0e-6;
-
- @Test
- public void testSize() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertEquals(vec.size(), 3);
- }
-
- @Test
- public void testNormL1() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertEquals(vec.normL1(), 6, 0);
- }
-
- @Test
- public void testNormMax() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertEquals(vec.normInf(), 3, 0);
- }
-
- @Test
- public void testNormL2() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertEquals(vec.normL2(), Math.sqrt(1 + 4 + 9), TOL);
- }
-
- @Test
- public void testNormL2Square() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertEquals(vec.normL2Square(), 1 + 4 + 9, TOL);
- }
-
- @Test
- public void testSlice() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- DenseVector sliced = vec.slice(new int[] {0, 2});
- Assert.assertArrayEquals(new double[] {1, -3}, sliced.getData(), 0);
- }
-
- @Test
- public void testMinus() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- DenseVector d = new DenseVector(new double[] {1, 2, 1});
- DenseVector vec2 = vec.minus(d);
- Assert.assertArrayEquals(vec.getData(), new double[] {1, 2, -3}, 0);
- Assert.assertArrayEquals(vec2.getData(), new double[] {0, 0, -4}, TOL);
- vec.minusEqual(d);
- Assert.assertArrayEquals(vec.getData(), new double[] {0, 0, -4}, TOL);
- }
-
- @Test
- public void testPlus() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- DenseVector d = new DenseVector(new double[] {1, 2, 1});
- DenseVector vec2 = vec.plus(d);
- Assert.assertArrayEquals(vec.getData(), new double[] {1, 2, -3}, 0);
- Assert.assertArrayEquals(vec2.getData(), new double[] {2, 4, -2}, TOL);
- vec.plusEqual(d);
- Assert.assertArrayEquals(vec.getData(), new double[] {2, 4, -2}, TOL);
- }
-
- @Test
- public void testPlusScaleEqual() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- DenseVector vec2 = new DenseVector(new double[] {1, 0, 2});
- vec.plusScaleEqual(vec2, 2.);
- Assert.assertArrayEquals(vec.getData(), new double[] {3, 2, 1}, TOL);
- }
-
- @Test
- public void testDot() throws Exception {
- DenseVector vec1 = new DenseVector(new double[] {1, 2, -3});
- DenseVector vec2 = new DenseVector(new double[] {3, 2, 1});
- Assert.assertEquals(vec1.dot(vec2), 3 + 4 - 3, TOL);
- }
-
- @Test
- public void testPrefix() throws Exception {
- DenseVector vec1 = new DenseVector(new double[] {1, 2, -3});
- DenseVector vec2 = vec1.prefix(0);
- Assert.assertArrayEquals(vec2.getData(), new double[] {0, 1, 2, -3}, 0);
- }
-
- @Test
- public void testAppend() throws Exception {
- DenseVector vec1 = new DenseVector(new double[] {1, 2, -3});
- DenseVector vec2 = vec1.append(0);
- Assert.assertArrayEquals(vec2.getData(), new double[] {1, 2, -3, 0}, 0);
- }
-
- @Test
- public void testOuter() throws Exception {
- DenseVector vec1 = new DenseVector(new double[] {1, 2, -3});
- DenseVector vec2 = new DenseVector(new double[] {3, 2, 1});
- DenseMatrix outer = vec1.outer(vec2);
- Assert.assertArrayEquals(
- outer.getArrayCopy1D(true), new double[] {3, 2, 1, 6, 4, 2, -9, -6, -3}, TOL);
- }
-
- @Test
- public void testNormalize() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- vec.normalizeEqual(1.0);
- Assert.assertArrayEquals(vec.getData(), new double[] {1. / 6, 2. / 6, -3. / 6}, TOL);
- }
-
- @Test
- public void testStandardize() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- vec.standardizeEqual(1.0, 1.0);
- Assert.assertArrayEquals(vec.getData(), new double[] {0, 1, -4}, TOL);
- }
-
- @Test
- public void testIterator() throws Exception {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- VectorIterator iterator = vec.iterator();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 0);
- Assert.assertEquals(iterator.getValue(), 1, 0);
- iterator.next();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 1);
- Assert.assertEquals(iterator.getValue(), 2, 0);
- iterator.next();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 2);
- Assert.assertEquals(iterator.getValue(), -3, 0);
- iterator.next();
- Assert.assertFalse(iterator.hasNext());
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java
deleted file mode 100644
index bbcd6d8..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/MatVecOpTest.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-
-/** Test cases for {@link MatVecOp}. */
-public class MatVecOpTest {
- private static final double TOL = 1.0e-6;
- private DenseVector dv;
- private SparseVector sv;
-
- @Before
- public void setUp() throws Exception {
- dv = new DenseVector(new double[] {1, 2, 3, 4});
- sv = new SparseVector(4, new int[] {0, 2}, new double[] {1., 1.});
- }
-
- @Test
- public void testPlus() throws Exception {
- Vector plusResult1 = MatVecOp.plus(dv, sv);
- Vector plusResult2 = MatVecOp.plus(sv, dv);
- Vector plusResult3 = MatVecOp.plus(sv, sv);
- Vector plusResult4 = MatVecOp.plus(dv, dv);
- Assert.assertTrue(plusResult1 instanceof DenseVector);
- Assert.assertTrue(plusResult2 instanceof DenseVector);
- Assert.assertTrue(plusResult3 instanceof SparseVector);
- Assert.assertTrue(plusResult4 instanceof DenseVector);
- Assert.assertArrayEquals(
- ((DenseVector) plusResult1).getData(), new double[] {2, 2, 4, 4}, TOL);
- Assert.assertArrayEquals(
- ((DenseVector) plusResult2).getData(), new double[] {2, 2, 4, 4}, TOL);
- Assert.assertArrayEquals(((SparseVector) plusResult3).getIndices(), new int[] {0, 2});
- Assert.assertArrayEquals(
- ((SparseVector) plusResult3).getValues(), new double[] {2., 2.}, TOL);
- Assert.assertArrayEquals(
- ((DenseVector) plusResult4).getData(), new double[] {2, 4, 6, 8}, TOL);
- }
-
- @Test
- public void testMinus() throws Exception {
- Vector minusResult1 = MatVecOp.minus(dv, sv);
- Vector minusResult2 = MatVecOp.minus(sv, dv);
- Vector minusResult3 = MatVecOp.minus(sv, sv);
- Vector minusResult4 = MatVecOp.minus(dv, dv);
- Assert.assertTrue(minusResult1 instanceof DenseVector);
- Assert.assertTrue(minusResult2 instanceof DenseVector);
- Assert.assertTrue(minusResult3 instanceof SparseVector);
- Assert.assertTrue(minusResult4 instanceof DenseVector);
- Assert.assertArrayEquals(
- ((DenseVector) minusResult1).getData(), new double[] {0, 2, 2, 4}, TOL);
- Assert.assertArrayEquals(
- ((DenseVector) minusResult2).getData(), new double[] {0, -2, -2, -4}, TOL);
- Assert.assertArrayEquals(((SparseVector) minusResult3).getIndices(), new int[] {0, 2});
- Assert.assertArrayEquals(
- ((SparseVector) minusResult3).getValues(), new double[] {0., 0.}, TOL);
- Assert.assertArrayEquals(
- ((DenseVector) minusResult4).getData(), new double[] {0, 0, 0, 0}, TOL);
- }
-
- @Test
- public void testDot() throws Exception {
- Assert.assertEquals(MatVecOp.dot(dv, sv), 4.0, TOL);
- Assert.assertEquals(MatVecOp.dot(sv, dv), 4.0, TOL);
- Assert.assertEquals(MatVecOp.dot(sv, sv), 2.0, TOL);
- Assert.assertEquals(MatVecOp.dot(dv, dv), 30.0, TOL);
- }
-
- @Test
- public void testSumAbsDiff() throws Exception {
- Assert.assertEquals(MatVecOp.sumAbsDiff(dv, sv), 8.0, TOL);
- Assert.assertEquals(MatVecOp.sumAbsDiff(sv, dv), 8.0, TOL);
- Assert.assertEquals(MatVecOp.sumAbsDiff(sv, sv), 0.0, TOL);
- Assert.assertEquals(MatVecOp.sumAbsDiff(dv, dv), 0.0, TOL);
- }
-
- @Test
- public void testSumSquaredDiff() throws Exception {
- Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, sv), 24.0, TOL);
- Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, dv), 24.0, TOL);
- Assert.assertEquals(MatVecOp.sumSquaredDiff(sv, sv), 0.0, TOL);
- Assert.assertEquals(MatVecOp.sumSquaredDiff(dv, dv), 0.0, TOL);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java
deleted file mode 100644
index 0f68102..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/SparseVectorTest.java
+++ /dev/null
@@ -1,232 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import java.util.Map;
-import java.util.TreeMap;
-
-/** Test cases for SparseVector. */
-public class SparseVectorTest {
- private static final double TOL = 1.0e-6;
- private SparseVector v1 =
- new SparseVector(8, new int[] {1, 3, 5, 7}, new double[] {2.0, 2.0, 2.0, 2.0});
- private SparseVector v2 =
- new SparseVector(8, new int[] {3, 4, 5}, new double[] {1.0, 1.0, 1.0});
-
- @Test
- public void testConstructor() throws Exception {
- int[] indices = new int[] {3, 7, 2, 1};
- double[] values = new double[] {3.0, 7.0, 2.0, 1.0};
- Map<Integer, Double> map = new TreeMap<>();
- for (int i = 0; i < indices.length; i++) {
- map.put(indices[i], values[i]);
- }
- SparseVector v = new SparseVector(8, map);
- Assert.assertArrayEquals(v.getIndices(), new int[] {1, 2, 3, 7});
- Assert.assertArrayEquals(v.getValues(), new double[] {1, 2, 3, 7}, TOL);
- }
-
- @Test
- public void testSize() throws Exception {
- Assert.assertEquals(v1.size(), 8);
- }
-
- @Test
- public void testSet() throws Exception {
- SparseVector v = v1.clone();
- v.set(2, 2.0);
- v.set(3, 3.0);
- Assert.assertEquals(v.get(2), 2.0, TOL);
- Assert.assertEquals(v.get(3), 3.0, TOL);
- }
-
- @Test
- public void testAdd() throws Exception {
- SparseVector v = v1.clone();
- v.add(2, 2.0);
- v.add(3, 3.0);
- Assert.assertEquals(v.get(2), 2.0, TOL);
- Assert.assertEquals(v.get(3), 5.0, TOL);
- }
-
- @Test
- public void testPrefix() throws Exception {
- SparseVector prefixed = v1.prefix(0.2);
- Assert.assertArrayEquals(prefixed.getIndices(), new int[] {0, 2, 4, 6, 8});
- Assert.assertArrayEquals(prefixed.getValues(), new double[] {0.2, 2, 2, 2, 2}, 0);
- }
-
- @Test
- public void testAppend() throws Exception {
- SparseVector prefixed = v1.append(0.2);
- Assert.assertArrayEquals(prefixed.getIndices(), new int[] {1, 3, 5, 7, 8});
- Assert.assertArrayEquals(prefixed.getValues(), new double[] {2, 2, 2, 2, 0.2}, 0);
- }
-
- @Test
- public void testSortIndices() throws Exception {
- int n = 8;
- int[] indices = new int[] {7, 5, 3, 1};
- double[] values = new double[] {7, 5, 3, 1};
- v1 = new SparseVector(n, indices, values);
- Assert.assertArrayEquals(values, new double[] {1, 3, 5, 7}, 0.);
- Assert.assertArrayEquals(v1.getValues(), new double[] {1, 3, 5, 7}, 0.);
- Assert.assertArrayEquals(indices, new int[] {1, 3, 5, 7});
- Assert.assertArrayEquals(v1.getIndices(), new int[] {1, 3, 5, 7});
- }
-
- @Test
- public void testNormL2Square() throws Exception {
- Assert.assertEquals(v2.normL2Square(), 3.0, TOL);
- }
-
- @Test
- public void testMinus() throws Exception {
- Vector d = v2.minus(v1);
- Assert.assertEquals(d.get(0), 0.0, TOL);
- Assert.assertEquals(d.get(1), -2.0, TOL);
- Assert.assertEquals(d.get(2), 0.0, TOL);
- Assert.assertEquals(d.get(3), -1.0, TOL);
- Assert.assertEquals(d.get(4), 1.0, TOL);
- }
-
- @Test
- public void testPlus() throws Exception {
- Vector d = v1.plus(v2);
- Assert.assertEquals(d.get(0), 0.0, TOL);
- Assert.assertEquals(d.get(1), 2.0, TOL);
- Assert.assertEquals(d.get(2), 0.0, TOL);
- Assert.assertEquals(d.get(3), 3.0, TOL);
-
- DenseVector dv = DenseVector.ones(8);
- dv = dv.plus(v2);
- Assert.assertArrayEquals(dv.getData(), new double[] {1, 1, 1, 2, 2, 2, 1, 1}, TOL);
- }
-
- @Test
- public void testDot() throws Exception {
- Assert.assertEquals(v1.dot(v2), 4.0, TOL);
- }
-
- @Test
- public void testGet() throws Exception {
- Assert.assertEquals(v1.get(5), 2.0, TOL);
- Assert.assertEquals(v1.get(6), 0.0, TOL);
- }
-
- @Test
- public void testSlice() throws Exception {
- int n = 8;
- int[] indices = new int[] {1, 3, 5, 7};
- double[] values = new double[] {2.0, 3.0, 4.0, 5.0};
- SparseVector v = new SparseVector(n, indices, values);
-
- int[] indices1 = new int[] {5, 4, 3};
- SparseVector vec1 = v.slice(indices1);
- Assert.assertEquals(vec1.size(), 3);
- Assert.assertArrayEquals(vec1.getIndices(), new int[] {0, 2});
- Assert.assertArrayEquals(vec1.getValues(), new double[] {4.0, 3.0}, 0.0);
-
- int[] indices2 = new int[] {3, 5};
- SparseVector vec2 = v.slice(indices2);
- Assert.assertArrayEquals(vec2.getIndices(), new int[] {0, 1});
- Assert.assertArrayEquals(vec2.getValues(), new double[] {3.0, 4.0}, 0.0);
-
- int[] indices3 = new int[] {2, 4};
- SparseVector vec3 = v.slice(indices3);
- Assert.assertEquals(vec3.size(), 2);
- Assert.assertArrayEquals(vec3.getIndices(), new int[] {});
- Assert.assertArrayEquals(vec3.getValues(), new double[] {}, 0.0);
-
- int[] indices4 = new int[] {2, 2, 4, 4};
- SparseVector vec4 = v.slice(indices4);
- Assert.assertEquals(vec4.size(), 4);
- Assert.assertArrayEquals(vec4.getIndices(), new int[] {});
- Assert.assertArrayEquals(vec4.getValues(), new double[] {}, 0.0);
- }
-
- @Test
- public void testToDenseVector() throws Exception {
- int[] indices = new int[] {1, 3, 5};
- double[] values = new double[] {1.0, 3.0, 5.0};
- SparseVector v = new SparseVector(-1, indices, values);
- DenseVector dv = v.toDenseVector();
- Assert.assertEquals(dv.size(), 6);
- Assert.assertArrayEquals(dv.getData(), new double[] {0, 1, 0, 3, 0, 5}, TOL);
- }
-
- @Test
- public void testRemoveZeroValues() throws Exception {
- int[] indices = new int[] {1, 3, 5};
- double[] values = new double[] {0.0, 3.0, 0.0};
- SparseVector v = new SparseVector(6, indices, values);
- v.removeZeroValues();
- Assert.assertArrayEquals(v.getIndices(), new int[] {3});
- Assert.assertArrayEquals(v.getValues(), new double[] {3}, TOL);
- }
-
- @Test
- public void testOuter() throws Exception {
- DenseMatrix outerProduct = v1.outer(v2);
- Assert.assertEquals(outerProduct.numRows(), 8);
- Assert.assertEquals(outerProduct.numCols(), 8);
- Assert.assertArrayEquals(
- outerProduct.getRow(0), new double[] {0, 0, 0, 0, 0, 0, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(1), new double[] {0, 0, 0, 2, 2, 2, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(2), new double[] {0, 0, 0, 0, 0, 0, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(3), new double[] {0, 0, 0, 2, 2, 2, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(4), new double[] {0, 0, 0, 0, 0, 0, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(5), new double[] {0, 0, 0, 2, 2, 2, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(6), new double[] {0, 0, 0, 0, 0, 0, 0, 0}, TOL);
- Assert.assertArrayEquals(
- outerProduct.getRow(7), new double[] {0, 0, 0, 2, 2, 2, 0, 0}, TOL);
- }
-
- @Test
- public void testIterator() throws Exception {
- VectorIterator iterator = v1.iterator();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 1);
- Assert.assertEquals(iterator.getValue(), 2, 0);
- iterator.next();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 3);
- Assert.assertEquals(iterator.getValue(), 2, 0);
- iterator.next();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 5);
- Assert.assertEquals(iterator.getValue(), 2, 0);
- iterator.next();
- Assert.assertTrue(iterator.hasNext());
- Assert.assertEquals(iterator.getIndex(), 7);
- Assert.assertEquals(iterator.getValue(), 2, 0);
- iterator.next();
- Assert.assertFalse(iterator.hasNext());
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java
deleted file mode 100644
index 26d0ac4..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/linalg/VectorUtilTest.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.linalg;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-/** Test cases for VectorUtil. */
-public class VectorUtilTest {
- @Test
- public void testParseDenseAndToString() {
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- String str = VectorUtil.toString(vec);
- Assert.assertEquals(str, "1.0 2.0 -3.0");
- Assert.assertArrayEquals(vec.getData(), VectorUtil.parseDense(str).getData(), 0);
- }
-
- @Test
- public void testParseDenseWithSpace() {
- DenseVector vec1 = VectorUtil.parseDense("1 2 -3");
- DenseVector vec2 = VectorUtil.parseDense(" 1 2 -3 ");
- DenseVector vec = new DenseVector(new double[] {1, 2, -3});
- Assert.assertArrayEquals(vec1.getData(), vec.getData(), 0);
- Assert.assertArrayEquals(vec2.getData(), vec.getData(), 0);
- }
-
- @Test
- public void testSparseToString() {
- SparseVector v1 =
- new SparseVector(8, new int[] {1, 3, 5, 7}, new double[] {2.0, 2.0, 2.0, 2.0});
- Assert.assertEquals(VectorUtil.toString(v1), "$8$1:2.0 3:2.0 5:2.0 7:2.0");
- }
-
- @Test
- public void testParseSparse() {
- SparseVector vec1 = VectorUtil.parseSparse("0:1 2:-3");
- SparseVector vec3 = VectorUtil.parseSparse("$4$0:1 2:-3");
- SparseVector vec4 = VectorUtil.parseSparse("$4$");
- SparseVector vec5 = VectorUtil.parseSparse("");
- Assert.assertEquals(vec1.get(0), 1., 0.);
- Assert.assertEquals(vec1.get(2), -3., 0.);
- Assert.assertArrayEquals(vec3.toDenseVector().getData(), new double[] {1, 0, -3, 0}, 0);
- Assert.assertEquals(vec3.size(), 4);
- Assert.assertArrayEquals(vec4.toDenseVector().getData(), new double[] {0, 0, 0, 0}, 0);
- Assert.assertEquals(vec4.size(), 4);
- Assert.assertEquals(vec5.size(), -1);
- }
-
- @Test
- public void testParseAndToStringOfVector() {
- Vector sparse = VectorUtil.parseSparse("0:1 2:-3");
- Vector dense = VectorUtil.parseDense("1 0 -3");
-
- Assert.assertEquals(VectorUtil.toString(sparse), "0:1.0 2:-3.0");
- Assert.assertEquals(VectorUtil.toString(dense), "1.0 0.0 -3.0");
- Assert.assertTrue(VectorUtil.parse("$4$0:1 2:-3") instanceof SparseVector);
- Assert.assertTrue(VectorUtil.parse("1 0 -3") instanceof DenseVector);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussianTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussianTest.java
deleted file mode 100644
index 1bda854..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussianTest.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.statistics.basicstatistic;
-
-import org.apache.flink.ml.common.linalg.DenseMatrix;
-import org.apache.flink.ml.common.linalg.DenseVector;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-/** Test cases for {@link MultivariateGaussian}. */
-public class MultivariateGaussianTest {
- private static final double TOL = 1.0e-5;
-
- @Test
- public void testUnivariate() throws Exception {
- DenseVector x1 = new DenseVector(new double[] {0.0});
- DenseVector x2 = new DenseVector(new double[] {1.5});
- DenseVector mu = DenseVector.zeros(1);
- DenseMatrix sigma1 = DenseMatrix.ones(1, 1);
- MultivariateGaussian dist1 = new MultivariateGaussian(mu, sigma1);
- Assert.assertEquals(dist1.pdf(x1), 0.39894, TOL);
- Assert.assertEquals(dist1.pdf(x2), 0.12952, TOL);
-
- DenseMatrix sigma2 = DenseMatrix.ones(1, 1);
- sigma2.scaleEqual(4.0);
- MultivariateGaussian dist2 = new MultivariateGaussian(mu, sigma2);
- Assert.assertEquals(dist2.pdf(x1), 0.19947, TOL);
- Assert.assertEquals(dist2.pdf(x2), 0.15057, TOL);
- }
-
- @Test
- public void testMultivariate() throws Exception {
- DenseVector mu = DenseVector.zeros(2);
-
- DenseMatrix sigma1 = DenseMatrix.eye(2);
- MultivariateGaussian mg1 = new MultivariateGaussian(mu, sigma1);
- Assert.assertEquals(mg1.pdf(DenseVector.zeros(2)), 0.15915, TOL);
- Assert.assertEquals(mg1.pdf(DenseVector.ones(2)), 0.05855, TOL);
-
- DenseMatrix sigma2 = new DenseMatrix(2, 2, new double[] {4.0, -1.0, -1.0, 2.0});
- MultivariateGaussian mg2 = new MultivariateGaussian(mu, sigma2);
- Assert.assertEquals(mg2.pdf(DenseVector.zeros(2)), 0.060155, TOL);
- Assert.assertEquals(mg2.pdf(DenseVector.ones(2)), 0.033971, TOL);
- }
-
- @Test
- public void testMultivariateDegenerate() throws Exception {
- DenseVector mu = DenseVector.zeros(2);
- DenseMatrix sigma = new DenseMatrix(2, 2, new double[] {1.0, 1.0, 1.0, 1.0});
- MultivariateGaussian mg = new MultivariateGaussian(mu, sigma);
- Assert.assertEquals(mg.pdf(DenseVector.zeros(2)), 0.11254, TOL);
- Assert.assertEquals(mg.pdf(DenseVector.ones(2)), 0.068259, TOL);
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/DataStreamConversionUtilTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/DataStreamConversionUtilTest.java
deleted file mode 100644
index 60edbc1..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/DataStreamConversionUtilTest.java
+++ /dev/null
@@ -1,208 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.ml.common.MLEnvironmentFactory;
-import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.datastream.DataStreamUtils;
-import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Table;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.table.api.ValidationException;
-import org.apache.flink.types.Row;
-
-import org.junit.Assert;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-
-import java.util.Iterator;
-import java.util.concurrent.ExecutionException;
-
-/** Unit Test for DataStreamConversionUtil. */
-public class DataStreamConversionUtilTest {
- @Rule public ExpectedException thrown = ExpectedException.none();
-
- @Test
- public void testForceTypeSchema() {
- StreamExecutionEnvironment env =
- MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
-
- DataStream<Row> input =
- env.fromElements(Row.of("s1"))
- .map(new DataStreamConversionUtilTest.GenericTypeMap());
- Table table2 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- input,
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)});
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}),
- table2.getSchema());
- }
-
- @Test
- public void testForceTypeWithTableSchema() {
- StreamExecutionEnvironment env =
- MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
-
- DataStream<Row> input =
- env.fromElements(Row.of("s1"))
- .map(new DataStreamConversionUtilTest.GenericTypeMap());
- Table table2 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- input,
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}));
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}),
- table2.getSchema());
- }
-
- @Test
- public void testExceptionWithoutTypeSchema() {
- thrown.expect(ValidationException.class);
- StreamExecutionEnvironment env =
- MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
- DataStream<Row> input =
- env.fromElements(Row.of("s1"))
- .map(new DataStreamConversionUtilTest.GenericTypeMap());
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, input, new String[] {"f0"});
- }
-
- @Test
- public void testBasicConvert() throws Exception {
- StreamExecutionEnvironment env =
- MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
- DataStream<Row> input = env.fromElements(Row.of("a"));
- Table table1 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- input,
- new String[] {"word"});
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(String.class)}),
- table1.getSchema());
- DataStream<Row> rowDataStream =
- DataStreamConversionUtil.fromTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, table1);
- Iterator<Row> result = DataStreamUtils.collect(rowDataStream);
- Assert.assertEquals(Row.of("a"), result.next());
- Assert.assertFalse(result.hasNext());
- }
-
- @Test
- public void testE2E() throws Exception {
- StreamExecutionEnvironment env =
- MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment();
-
- DataStream<Row> input = env.fromElements(Row.of("a"));
-
- Table table1 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- input,
- new String[] {"word"});
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(String.class)}),
- table1.getSchema());
-
- DataStream<Row> genericInput1 = input.map(new GenericTypeMap());
-
- // Force type should go through with explicit type info on generic type input.
- Table table2 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- genericInput1,
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)});
-
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}),
- table2.getSchema());
-
- DataStream<Row> genericInput2 = input.map(new GenericTypeMap());
-
- // Force type should go through with table schema on generic type input.
- Table table3 =
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- genericInput2,
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}));
-
- Assert.assertEquals(
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}),
- table3.getSchema());
-
- // applying toTable again on the same input should fail
- thrown.expect(IllegalStateException.class);
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- genericInput2,
- new TableSchema(
- new String[] {"word"},
- new TypeInformation[] {TypeInformation.of(Integer.class)}));
-
- // Validation should fail due to type inference error.
- DataStream<Row> genericInput3 = input.map(new GenericTypeMap());
- thrown.expect(ValidationException.class);
- DataStreamConversionUtil.toTable(
- MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID,
- genericInput3,
- new String[] {"word"});
-
- // Output should go through when using correct type to output.
- DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, table1)
- .print();
-
- // Output should NOT go through when using incorrect type forcing.
- thrown.expect(ExecutionException.class);
- DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, table2)
- .print();
- }
-
- private static class GenericTypeMap implements MapFunction<Row, Row> {
-
- @Override
- public Row map(Row value) throws Exception {
- return value;
- }
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/OutputColsHelperTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/OutputColsHelperTest.java
deleted file mode 100644
index 8b89e1e..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/OutputColsHelperTest.java
+++ /dev/null
@@ -1,249 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.types.Row;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-/** Unit test for OutputColsHelper. */
-public class OutputColsHelperTest {
-
- private TableSchema tableSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class)
- });
- private String[] reservedColNames = new String[] {"f0"};
- private Row row = Row.of("a", 1L, 1);
-
- @Test
- public void testResultSchema() {
- TableSchema expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class),
- TypeInformation.of(String.class)
- });
- OutputColsHelper helper =
- new OutputColsHelper(tableSchema, "res", TypeInformation.of(String.class));
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class), TypeInformation.of(String.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema, "res", TypeInformation.of(String.class), reservedColNames);
- Assert.assertEquals(expectSchema, helper.getResultSchema());
- Assert.assertArrayEquals(reservedColNames, helper.getReservedColumns());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "res1", "res2"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(String.class),
- TypeInformation.of(Integer.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res1", "res2"},
- new TypeInformation[] {
- TypeInformation.of(String.class), TypeInformation.of(Integer.class)
- },
- reservedColNames);
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class),
- TypeInformation.of(String.class)
- });
- helper = new OutputColsHelper(tableSchema, "res", TypeInformation.of(String.class));
- Assert.assertEquals(expectSchema, helper.getResultSchema());
- Assert.assertArrayEquals(tableSchema.getFieldNames(), helper.getReservedColumns());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class)
- });
- helper = new OutputColsHelper(tableSchema, "f0", TypeInformation.of(Integer.class));
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(String.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"f0", "f2"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- });
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class), TypeInformation.of(Integer.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res", "f0"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- },
- reservedColNames);
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- },
- new String[] {"f1", "f0"});
- Assert.assertEquals(expectSchema, helper.getResultSchema());
- }
-
- @Test
- public void testResultRow() {
- OutputColsHelper helper =
- new OutputColsHelper(tableSchema, "res", TypeInformation.of(String.class));
- Row expectRow = Row.of("a", 1L, 1, "b");
- Assert.assertEquals(helper.getResultRow(row, Row.of("b")), expectRow);
-
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res1", "res2"},
- new TypeInformation[] {
- TypeInformation.of(String.class), TypeInformation.of(Integer.class)
- });
- expectRow = Row.of("a", 1L, 1, "b", 2);
- Assert.assertEquals(expectRow, helper.getResultRow(row, Row.of("b", 2)));
-
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res", "f0"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- },
- reservedColNames);
- expectRow = Row.of("b", 2);
- Assert.assertEquals(expectRow, helper.getResultRow(row, Row.of(2, "b")));
- }
-
- @Test
- public void testExceptionCase() {
- TableSchema expectSchema =
- new TableSchema(
- new String[] {"f0", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class), TypeInformation.of(Integer.class)
- });
- OutputColsHelper helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res", "f0"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- },
- new String[] {"res", "res2"});
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "res"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Long.class),
- TypeInformation.of(Integer.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"res", "f0"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class), TypeInformation.of(String.class)
- },
- new String[] {"f1", "res"});
- Assert.assertEquals(expectSchema, helper.getResultSchema());
-
- expectSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2"},
- new TypeInformation[] {
- TypeInformation.of(String.class),
- TypeInformation.of(Integer.class),
- TypeInformation.of(Double.class)
- });
- helper =
- new OutputColsHelper(
- tableSchema,
- new String[] {"f1", "f0", "f2"},
- new TypeInformation[] {
- TypeInformation.of(Integer.class),
- TypeInformation.of(String.class),
- TypeInformation.of(Double.class)
- },
- new String[] {"f1", "res"});
- Assert.assertEquals(expectSchema, helper.getResultSchema());
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/TableUtilTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/TableUtilTest.java
deleted file mode 100644
index e3733c3..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/TableUtilTest.java
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.table.api.TableSchema;
-import org.apache.flink.types.Row;
-
-import org.junit.Assert;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-
-import java.util.Collections;
-
-/** Unit test for TableUtil. */
-public class TableUtilTest {
- @Rule public ExpectedException thrown = ExpectedException.none();
- private String[] colNames = new String[] {"f0", "f1", "f2"};
- private TableSchema tableSchema =
- new TableSchema(colNames, new TypeInformation[] {Types.INT, Types.LONG, Types.STRING});
-
- @Test
- public void testFindIndexFromName() {
- String[] colNames = new String[] {"f0", "f1", "F2"};
- Assert.assertEquals(0, TableUtil.findColIndex(colNames, "f0"));
- Assert.assertEquals(1, TableUtil.findColIndex(colNames, "F1"));
- Assert.assertEquals(-1, TableUtil.findColIndex(colNames, "f3"));
- Assert.assertEquals(0, TableUtil.findColIndex(tableSchema, "f0"));
-
- Assert.assertArrayEquals(
- new int[] {1, 2}, TableUtil.findColIndices(colNames, new String[] {"f1", "F2"}));
- Assert.assertArrayEquals(
- new int[] {1, 2}, TableUtil.findColIndices(tableSchema, new String[] {"f1", "F2"}));
- Assert.assertArrayEquals(
- new int[] {-1, 2},
- TableUtil.findColIndices(tableSchema, new String[] {"f3", "F2"}));
- Assert.assertArrayEquals(new int[] {0, 1, 2}, TableUtil.findColIndices(colNames, null));
- }
-
- @Test
- public void testFindTypeFromTable() {
- Assert.assertArrayEquals(
- new TypeInformation[] {TypeInformation.of(Integer.class), Types.LONG},
- TableUtil.findColTypes(tableSchema, new String[] {"f0", "f1"}));
- Assert.assertArrayEquals(
- new TypeInformation[] {Types.LONG, null},
- TableUtil.findColTypes(tableSchema, new String[] {"f1", "f3"}));
- Assert.assertArrayEquals(
- new TypeInformation[] {Types.INT, Types.LONG, Types.STRING},
- TableUtil.findColTypes(tableSchema, null));
-
- Assert.assertEquals(
- TypeInformation.of(Integer.class), TableUtil.findColType(tableSchema, "f0"));
- Assert.assertNull(TableUtil.findColType(tableSchema, "f3"));
- }
-
- @Test
- public void isNumberIsStringTest() {
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.INT));
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.DOUBLE));
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.LONG));
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.BYTE));
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.FLOAT));
- Assert.assertTrue(TableUtil.isSupportedNumericType(Types.SHORT));
- Assert.assertFalse(TableUtil.isSupportedNumericType(Types.STRING));
- Assert.assertTrue(TableUtil.isString(Types.STRING));
- }
-
- @Test
- public void assertColExistOrTypeTest() {
- String[] colNames = new String[] {"f0", "f1", "f2"};
- TableUtil.assertSelectedColExist(colNames, null);
- TableUtil.assertSelectedColExist(colNames, "f0");
- TableUtil.assertSelectedColExist(colNames, "f0", "f1");
-
- TableUtil.assertNumericalCols(tableSchema, null);
- TableUtil.assertNumericalCols(tableSchema, "f1");
- TableUtil.assertNumericalCols(tableSchema, "f0", "f1");
-
- TableUtil.assertStringCols(tableSchema, null);
- TableUtil.assertStringCols(tableSchema, "f2");
-
- TableUtil.assertVectorCols(tableSchema, null);
- }
-
- @Test
- public void assertColExistOrTypeExceptionTest() {
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage(" col is not exist f3");
- TableUtil.assertSelectedColExist(colNames, "f3");
-
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage(" col is not exist f3");
- TableUtil.assertSelectedColExist(colNames, "f0", "f3");
-
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("col type must be number f2");
- TableUtil.assertNumericalCols(tableSchema, "f2");
-
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("col type must be number f2");
- TableUtil.assertNumericalCols(tableSchema, "f2", "f0");
-
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("col type must be string f2");
- TableUtil.assertStringCols(tableSchema, "f2");
-
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("col type must be string f0");
- TableUtil.assertStringCols(tableSchema, "f0", "f3");
- }
-
- @Test
- public void getNumericColsTest() {
- TableSchema tableSchema =
- new TableSchema(
- new String[] {"f0", "f1", "F2", "f3"},
- new TypeInformation[] {Types.INT, Types.LONG, Types.STRING, Types.BOOLEAN});
-
- Assert.assertArrayEquals(new String[] {"f0", "f1"}, TableUtil.getNumericCols(tableSchema));
- Assert.assertArrayEquals(
- new String[] {"f1"}, TableUtil.getNumericCols(tableSchema, new String[] {"f0"}));
- Assert.assertArrayEquals(
- new String[] {"f0", "f1"},
- TableUtil.getNumericCols(tableSchema, new String[] {"f2"}));
- }
-
- @Test
- public void getCategoricalColsTest() {
- TableSchema tableSchema =
- new TableSchema(
- new String[] {"f0", "f1", "f2", "f3"},
- new TypeInformation[] {Types.INT, Types.LONG, Types.STRING, Types.BOOLEAN});
-
- Assert.assertArrayEquals(
- new String[] {"f2", "f3"},
- TableUtil.getCategoricalCols(tableSchema, tableSchema.getFieldNames(), null));
- Assert.assertArrayEquals(
- new String[] {"f2", "f0", "f3"},
- TableUtil.getCategoricalCols(
- tableSchema, new String[] {"f2", "f1", "f0", "f3"}, new String[] {"f0"}));
-
- thrown.expect(IllegalArgumentException.class);
- Assert.assertArrayEquals(
- new String[] {"f3", "f2"},
- TableUtil.getCategoricalCols(
- tableSchema, new String[] {"f3", "f0"}, new String[] {"f2"}));
- }
-
- @Test
- public void getStringColsTest() {
- TableSchema tableSchema =
- new TableSchema(
- new String[] {"f0", "f1", "F2", "f3"},
- new TypeInformation[] {Types.INT, Types.LONG, Types.STRING, Types.BOOLEAN});
-
- Assert.assertArrayEquals(new String[] {"F2"}, TableUtil.getStringCols(tableSchema));
- Assert.assertArrayEquals(
- new String[] {}, TableUtil.getStringCols(tableSchema, new String[] {"F2"}));
- }
-
- @Test
- public void formatTest() {
- TableSchema tableSchema =
- new TableSchema(
- new String[] {"f0", "f1", "F2", "f3"},
- new TypeInformation[] {Types.INT, Types.LONG, Types.STRING, Types.BOOLEAN});
- Row row = Row.of(1, 2L, "3", true);
-
- String format =
- TableUtil.format(tableSchema.getFieldNames(), Collections.singletonList(row));
- Assert.assertTrue(
- ("f0|f1|F2|f3\r\n" + "--|--|--|--\n" + "1|2|3|true").equalsIgnoreCase(format));
- }
-
- @Test
- public void testTempTable() {
- Assert.assertTrue(TableUtil.getTempTableName().startsWith("temp_"));
- Assert.assertFalse(TableUtil.getTempTableName().contains("-"));
- }
-}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/VectorTypesTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/VectorTypesTest.java
deleted file mode 100644
index 054956d..0000000
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/utils/VectorTypesTest.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.flink.ml.common.utils;
-
-import org.apache.flink.api.common.ExecutionConfig;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.memory.DataInputDeserializer;
-import org.apache.flink.core.memory.DataOutputSerializer;
-import org.apache.flink.ml.common.linalg.DenseVector;
-import org.apache.flink.ml.common.linalg.SparseVector;
-import org.apache.flink.ml.common.linalg.Vector;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.concurrent.ThreadLocalRandom;
-
-/** Test cases for VectorTypes. */
-public class VectorTypesTest {
- @SuppressWarnings("unchecked")
- private static <V extends Vector> void doVectorSerDeserTest(TypeSerializer ser, V vector)
- throws IOException {
- DataOutputSerializer out = new DataOutputSerializer(1024);
- ser.serialize(vector, out);
- DataInputDeserializer in = new DataInputDeserializer(out.getCopyOfBuffer());
- Vector deserialize = (Vector) ser.deserialize(in);
- Assert.assertEquals(vector.getClass(), deserialize.getClass());
- Assert.assertEquals(vector, deserialize);
- }
-
- @Test
- public void testVectorsSerDeser() throws IOException {
- // Prepare data
- SparseVector sparseVector =
- new SparseVector(
- 10,
- new HashMap<Integer, Double>() {
- {
- ThreadLocalRandom rand = ThreadLocalRandom.current();
- for (int i = 0; i < 10; i += 2) {
- this.put(i, rand.nextDouble());
- }
- }
- });
- DenseVector denseVector = DenseVector.rand(10);
-
- // Prepare serializer
- ExecutionConfig config = new ExecutionConfig();
- TypeSerializer<Vector> vecSer = VectorTypes.VECTOR.createSerializer(config);
- TypeSerializer<SparseVector> sparseSer = VectorTypes.SPARSE_VECTOR.createSerializer(config);
- TypeSerializer<DenseVector> denseSer = VectorTypes.DENSE_VECTOR.createSerializer(config);
-
- // Do tests.
- doVectorSerDeserTest(vecSer, sparseVector);
- doVectorSerDeserTest(vecSer, denseVector);
- doVectorSerDeserTest(sparseSer, sparseVector);
- doVectorSerDeserTest(denseSer, denseVector);
- }
-}
diff --git a/pom.xml b/pom.xml
index ee15d3d..5eb1805 100644
--- a/pom.xml
+++ b/pom.xml
@@ -57,7 +57,6 @@ under the License.
<module>flink-ml-uber</module>
<module>flink-ml-iteration</module>
<module>flink-ml-tests</module>
- <module>flink-ml-examples</module>
</modules>
<properties>