You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "WeichenXu123 (via GitHub)" <gi...@apache.org> on 2023/03/06 13:33:05 UTC

[GitHub] [spark] WeichenXu123 opened a new pull request, #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

WeichenXu123 opened a new pull request, #40297:
URL: https://github.com/apache/spark/pull/40297

   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a faster review.
     7. If you want to add a new configuration, please read the guideline first for naming configurations in
        'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
     8. If you want to add or modify an error type or message, please read the guideline first in
        'core/src/main/resources/error/README.md'.
   -->
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   Design doc:
   https://docs.google.com/document/d/1V5rOgksmOnA8AsJFZ_rasSYDQuP06_vrcfp3RY_22o8/edit#
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133521772


##########
mllib/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable
  *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
-class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
-  extends Serializable {
+class Param[T: ClassTag](

Review Comment:
   Sounds good.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127788041


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {

Review Comment:
   maybe we don't need the `paramType`? The Literal already carries the type.
   
   I guess we can use `LiteralValueProtoConverter.{toCatalystValue, toArrayData}` here



##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -81,13 +82,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;
+    Params params = 3;
+  }
+  message FeatureTransform {

Review Comment:
   happen to find a special case: `Bucketizer`
   
   it is the model of estimator `QuantileDiscretizer`;
   itself can be treated as a normal transformer with coefficients set via `setSplits`



##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -81,13 +82,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;
+    Params params = 3;
+  }
+  message FeatureTransform {

Review Comment:
   for a trained `Bucketizer ` we must copy its params from the server, since its param `splits` contains the trained coefficients



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   Finally, I think I understand it.
   
   `ModelAttr` may return a `MlCommandResponse` or `DataFrame`.
   
   `ModelAttr` is used in `MlCommand` to get a scalar result, or `MlRelation` to get a logical plan.
   
   maybe here we can split `getModelAttr` (/ `getModelSummaryAttr `) into two methods `getModelScalarAttr` and `getModelDataFrameAttr` to make it a bit easier to understand?



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala:
##########
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.linalg.{Matrix, Vector}
+
+object Serializer {

Review Comment:
   I guess we can enhance `LiteralValueProtoConverter.toConnectProtoValue` and use it here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127841115


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   > maybe here we can split getModelAttr (/ getModelSummaryAttr ) into two methods getModelScalarAttr and getModelDataFrameAttr to make it a bit easier to understand?
   
   This is what I did in initial code, but then I changed it to current code. Because:
   
   in ML code part, returning either MlCommandResponse or DataFrame either , the logic for them is the same,
   if we split them into 2 methods, the model summary, we have complex hierarchy, then each super class we need to define 2 methods similarly, it makes code bloating and hard to maintain.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129073881


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   For estimator, we need define a new LoadEstimator command



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129180218


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {
+    if (paramType == classOf[Int]) {
+      assert (paramValueProto.hasInteger || paramValueProto.hasLong)
+      if (paramValueProto.hasInteger) {
+        paramValueProto.getInteger
+      } else {
+        paramValueProto.getLong
+      }
+    } else if (paramType == classOf[Long]) {
+      assert(paramValueProto.hasLong)
+      paramValueProto.getLong
+    } else if (paramType == classOf[Float]) {
+      assert (paramValueProto.hasFloat || paramValueProto.hasDouble)
+      if (paramValueProto.hasFloat) {
+        paramValueProto.getFloat
+      } else {
+        paramValueProto.getDouble

Review Comment:
   Good catch!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133520106


##########
mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java:
##########
@@ -108,9 +108,6 @@ private void init() {
     myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
     myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
       ParamValidators.inRange(0.0, 1.0));
-    List<String> validStrings = Arrays.asList("a", "b");

Review Comment:
   When we have `T: ClassTag`, then java code cannot define the class like `Param[XXX]`, java compiler cannot generate class tag.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136540170


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLHandler {
+
+  def handleMlCommand(
+      sessionHolder: SessionHolder,
+      mlCommand: proto.MlCommand): proto.MlCommandResponse = {
+    mlCommand.getMlCommandTypeCase match {
+      case proto.MlCommand.MlCommandTypeCase.FIT =>
+        val fitCommandProto = mlCommand.getFit
+        val estimatorProto = fitCommandProto.getEstimator
+        assert(estimatorProto.getType == proto.MlStage.StageType.ESTIMATOR)
+
+        val algoName = fitCommandProto.getEstimator.getName
+        val algo = AlgorithmRegistry.get(algoName)
+
+        val estimator = algo.initiateEstimator(estimatorProto.getUid)
+        MLUtils.setInstanceParams(estimator, estimatorProto.getParams)
+        val dataset = MLUtils.parseRelationProto(fitCommandProto.getDataset, sessionHolder)
+        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR =>
+        val getModelAttrProto = mlCommand.getFetchModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.getModelAttr(model, getModelAttrProto.getName).left.get
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR =>
+        val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr
+        val modelEntry =
+          sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams)
+
+        val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) {
+          val evalDF = MLUtils.parseRelationProto(
+            getModelSummaryAttrProto.getEvaluationDataset,
+            sessionHolder)
+          Some(evalDF)
+        } else None
+
+        algo
+          .getModelSummaryAttr(copiedModel, getModelSummaryAttrProto.getName, datasetOpt)
+          .left
+          .get
+
+      case proto.MlCommand.MlCommandTypeCase.LOAD_MODEL =>
+        val loadModelProto = mlCommand.getLoadModel
+        val algo = AlgorithmRegistry.get(loadModelProto.getName)
+        val model = algo.loadModel(loadModelProto.getPath)
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid)
+              .setParams(MLUtils.convertInstanceParamsToProto(model)))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.SAVE_MODEL =>
+        val saveModelProto = mlCommand.getSaveModel
+        val modelEntry = sessionHolder.mlCache.modelCache.get(saveModelProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.saveModel(
+          model,
+          saveModelProto.getPath,
+          saveModelProto.getOverwrite,
+          saveModelProto.getOptionsMap.asScala.toMap)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.LOAD_STAGE =>
+        val loadStageProto = mlCommand.getLoadStage
+        val name = loadStageProto.getName
+        loadStageProto.getType match {
+          case proto.MlStage.StageType.ESTIMATOR =>
+            val algo = AlgorithmRegistry.get(name)
+            val estimator = algo.loadEstimator(loadStageProto.getPath)
+
+            proto.MlCommandResponse
+              .newBuilder()
+              .setStage(
+                proto.MlStage
+                  .newBuilder()
+                  .setName(name)
+                  .setType(proto.MlStage.StageType.ESTIMATOR)
+                  .setUid(estimator.uid)
+                  .setParams(MLUtils.convertInstanceParamsToProto(estimator)))
+              .build()
+          case _ =>
+            throw new UnsupportedOperationException()
+        }
+
+      case proto.MlCommand.MlCommandTypeCase.SAVE_STAGE =>
+        val saveStageProto = mlCommand.getSaveStage
+        val stageProto = saveStageProto.getStage
+
+        stageProto.getType match {
+          case proto.MlStage.StageType.ESTIMATOR =>
+            val name = stageProto.getName
+            val algo = AlgorithmRegistry.get(name)
+            val estimator = algo.initiateEstimator(stageProto.getUid)
+            MLUtils.setInstanceParams(estimator, stageProto.getParams)
+            algo.saveEstimator(
+              estimator,
+              saveStageProto.getPath,
+              saveStageProto.getOverwrite,
+              saveStageProto.getOptionsMap.asScala.toMap)
+            proto.MlCommandResponse
+              .newBuilder()
+              .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+              .build()
+
+          case _ =>
+            throw new UnsupportedOperationException()
+        }
+
+      case proto.MlCommand.MlCommandTypeCase.COPY_MODEL =>
+        val copyModelProto = mlCommand.getCopyModel
+        val modelEntry = sessionHolder.mlCache.modelCache.get(copyModelProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(copiedModel, algo)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(proto.Expression.Literal.newBuilder().setLong(refId))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.DELETE_MODEL =>
+        val modelRefId = mlCommand.getDeleteModel.getModelRefId
+        sessionHolder.mlCache.modelCache.remove(modelRefId)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+          .build()
+
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+  def transformMLRelation(
+      mlRelationProto: proto.MlRelation,
+      sessionHolder: SessionHolder): DataFrame = {
+    mlRelationProto.getMlRelationTypeCase match {
+      case proto.MlRelation.MlRelationTypeCase.MODEL_TRANSFORM =>
+        val modelTransformRelationProto = mlRelationProto.getModelTransform
+        val (model, _) =
+          sessionHolder.mlCache.modelCache.get(modelTransformRelationProto.getModelRefId)
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, modelTransformRelationProto.getParams)
+        val inputDF =
+          MLUtils.parseRelationProto(modelTransformRelationProto.getInput, sessionHolder)
+        copiedModel.transform(inputDF)
+
+      case proto.MlRelation.MlRelationTypeCase.MODEL_ATTR =>
+        val modelAttrProto = mlRelationProto.getModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(modelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2

Review Comment:
   ```suggestion
           val (model, algo) = sessionHolder.mlCache.modelCache.get(modelAttrProto.getModelRefId)
   ```
   
   Can we use unpacking?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136535013


##########
python/pyspark/sql/connect/ml/base.py:
##########
@@ -0,0 +1,327 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.ml import Estimator, Model, Predictor, PredictionModel
+from pyspark.ml.wrapper import _PredictorParams
+from pyspark.ml.util import MLWritable, MLWriter, MLReadable, MLReader
+import pyspark.sql.connect.proto as pb2
+import pyspark.sql.connect.proto.ml_pb2 as ml_pb2
+import pyspark.sql.connect.proto.ml_common_pb2 as ml_common_pb2
+from pyspark.sql.connect.ml.serializer import deserialize, serialize_ml_params
+from pyspark.sql.connect import session as pyspark_session
+from pyspark.sql.connect.plan import LogicalPlan
+
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import HasTrainingSummary as PySparkHasTrainingSummary
+
+
+@inherit_doc
+class ClientEstimator(Estimator, metaclass=ABCMeta):
+
+    @classmethod
+    def _algo_name(cls):
+        raise NotImplementedError()
+
+    @classmethod
+    def _model_class(cls):
+        raise NotImplementedError()
+
+    def _fit(self, dataset: DataFrame) -> Model:
+        client = dataset.sparkSession.client
+        dataset_relation = dataset._plan.plan(client)
+        estimator_proto = ml_common_pb2.MlStage(
+            name=self._algo_name(),
+            params=serialize_ml_params(self, client),
+            uid=self.uid,
+            type=ml_common_pb2.MlStage.ESTIMATOR,
+        )
+        fit_command_proto = ml_pb2.MlCommand.Fit(
+            estimator=estimator_proto,
+            dataset=dataset_relation,
+        )
+        req = client._execute_plan_request_with_metadata()
+        req.plan.ml_command.fit.CopyFrom(fit_command_proto)
+
+        resp = client._execute_ml(req)
+        return deserialize(resp, client, clazz=self._model_class())
+
+
+@inherit_doc
+class ClientPredictor(Predictor, ClientEstimator, _PredictorParams, metaclass=ABCMeta):
+    pass
+
+
+@inherit_doc
+class ClientModel(Model, metaclass=ABCMeta):
+
+    ref_id: str = None
+
+    def __del__(self):
+        client = pyspark_session._active_spark_session.client
+        del_model_proto = ml_pb2.MlCommand.DeleteModel(
+            model_ref_id=self.ref_id,
+        )
+        req = client._execute_plan_request_with_metadata()
+        req.plan.ml_command.delete_model.CopyFrom(del_model_proto)
+        client._execute_ml(req)
+
+    @classmethod

Review Comment:
   Can we use `asbstractmethod`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135529105


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   If we want to support 3rd-party algorithm without registry, then inevitably we have to use java reflection to invoke methods (e.g. We need to invoke `XXXModel.load` to load model, which is unsafe.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] github-actions[bot] commented on pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "github-actions[bot] (via GitHub)" <gi...@apache.org>.
github-actions[bot] commented on PR #40297:
URL: https://github.com/apache/spark/pull/40297#issuecomment-1615303795

   We're closing this PR because it hasn't been updated in a while. This isn't a judgement on the merit of the PR in any way. It's just a way of keeping the PR queue manageable.
   If you'd like to revive this PR, please reopen it and ask a committer to remove the Stale tag!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1138619277


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.ml.util.MLWriter
+import org.apache.spark.sql.DataFrame
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]
+
+  def getModelSummaryAttr(
+      model: Model[_],
+      name: String,
+      datasetOpt: Option[DataFrame]): Either[proto.MlCommandResponse, DataFrame]
+
+  def loadModel(path: String): Model[_]
+
+  def loadEstimator(path: String): Estimator[_]
+
+  protected def getEstimatorWriter(estimator: Estimator[_]): MLWriter
+
+  protected def getModelWriter(model: Model[_]): MLWriter
+
+  def _save(
+      writer: MLWriter,
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    if (overwrite) {
+      writer.overwrite()
+    }
+    options.map { case (k, v) => writer.option(k, v) }
+    writer.save(path)
+  }
+
+  def saveModel(
+      model: Model[_],
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    _save(getModelWriter(model), path, overwrite, options)
+  }
+
+  def saveEstimator(
+      estimator: Estimator[_],
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    _save(getEstimatorWriter(estimator), path, overwrite, options)
+  }
+}
+
+class LogisticRegressionAlgorithm extends Algorithm {

Review Comment:
   If we can use java reflection to invoke methods, we don't need the registry class, we just need some configuration data for registry.
   
   If we plan to mandatorily enable spark connect mode since spark 4 for DBR, then we'd better use  java reflection invocation. Otherwise it is hard to support huge number of 3rd-party estimators.
   
   CC @grundprinzip  



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1138372032


##########
mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -793,6 +800,10 @@ trait Params extends Identifiable with Serializable {
     this
   }
 
+  private[spark] def _setDefault(paramPairs: ParamPair[_]*): this.type = {
+    setDefault(paramPairs: _*)
+  }

Review Comment:
   > I think we can simply change setDefault to protected[spark] ?
   
   This should be a breaking change.
   
   Some 3rd-party estimator might override this method, if they are not under "org.apach" package, then compiling will fail.



##########
mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -793,6 +800,10 @@ trait Params extends Identifiable with Serializable {
     this
   }
 
+  private[spark] def _setDefault(paramPairs: ParamPair[_]*): this.type = {
+    setDefault(paramPairs: _*)
+  }

Review Comment:
   > I think we can simply change setDefault to protected[spark] ?
   
   This should be a breaking change.
   
   Some 3rd-party estimator might override this method, if they are not under "org.apache" package, then compiling will fail.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129072923


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {

Review Comment:
   > for int/long case, we can specify the DataType in Python Client
   
   This does not help.
   Because the pyspark estimator Params it does not record the type attribute, so we cannot distinguish int/long, double/float for python side param values.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129159284


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   yes, you are right. The 3-rd lib may directly use Param[Int]



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133534997


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveModel {
+    int64 model_ref_id = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadStage {
+    string name = 1;
+    string path = 2;
+    MlStage.StageType type = 3;
+  }
+
+  message SaveStage {
+    MlStage stage = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadEvaluator {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveEvaluator {
+    MlEvaluator evaluator = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message FetchModelAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+  }
+
+  message FetchModelSummaryAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+    MlParams params = 3;
+
+    // Evaluation dataset that it uses to computes
+    // the summary attribute
+    // If not set, get attributes from
+    // model.summary (i.e. the summary on training dataset)
+    optional Relation evaluation_dataset = 4;
+  }
+
+  message CopyModel {
+    int64 model_ref_id = 1;
+  }
+}
+
+
+message MlCommandResponse {
+  oneof ml_command_response_type {
+    Expression.Literal literal = 1;
+    ModelInfo model_info = 2;
+    Vector vector = 3;

Review Comment:
   I think we should only abstract them as Vector or Matrix in protobuf message level.
   Different model have different wrappers for model coefficients.
   
   But we can support large vector / matrix by converting them to spark dataframe. We can do it in following PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135521876


##########
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlParams {
+  map<string, Expression.Literal> params = 1;
+  map<string, Expression.Literal> default_params = 2;
+}
+
+message MlStage {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+  StageType type = 4;
+  enum StageType {

Review Comment:
   Yes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127834899


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala:
##########
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.linalg.{Matrix, Vector}
+
+object Serializer {

Review Comment:
   Good idea.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1128994906


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   ok. I am neutral on it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129073675


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   ~~I will rename it to `LoadStage` to support estimator too.~~
   Transformer saving we need to define a new message `SaveTransformer`, because for transformers we don't have `model_ref_id`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136925391


##########
python/pyspark/sql/connect/ml/utils.py:
##########
@@ -0,0 +1,55 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.sql.utils import is_remote
+import os
+
+
+def _get_remote_ml_class(cls):
+    remote_module = "pyspark.sql.connect.ml." + cls.__module__[len("pyspark.ml."):]
+    cls_name = cls.__name__
+    m = __import__(remote_module, fromlist=[cls_name])
+    remote_cls = getattr(m, cls_name)
+    return remote_cls
+
+
+def try_remote_ml_class(x):

Review Comment:
   I feel we can also simplify the `pyspark.sql` side by only using this annotation to the a few key classes
   
   cc @HyukjinKwon 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136928670


##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   It seem this `@typing_extensions.final` is gone, just curious what is the cause?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127288752


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   `DataFrame` case cannot be put in `MlCommandResponse`, because the `DataFrame` case we directly generate the dataframe plan in client side, we don't send request to server .



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129010791


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   will `LodeModel` and `SaveModel` also support load/save estimator/transformer?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129143557


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   > instance.set(p, v.toLong)
   
   If p is `Param[Int]` type, then it will cause error. `set` function would succeed but when we get the param type conversion error will occur.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135526116


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   For current PR, it does not support third-party estimators.
   We need to register related class for 3rd-party algorithm to `AlgorithmRegistry` class. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135522872


##########
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlParams {
+  map<string, Expression.Literal> params = 1;
+  map<string, Expression.Literal> default_params = 2;
+}
+
+message MlStage {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+  StageType type = 4;
+  enum StageType {

Review Comment:
   Or we can make server side infer the stage type from stage name,
   but let client fill the stage type is easier for code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135600960


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   Btw, supporting 3rd-party estimators is risky, because in shared cluster we will [binpack the spark workers across different customers](https://docs.google.com/document/d/1sJVjan44XagM48PEqdkg6KWctpPcz54Urf0i_-dGesA/edit?disco=AAAArl9hpF8)
   But 3rd-party estimators implementation might use RDD operations that we cannot isolate them by container. So it is risky if we allow user uses 3rd-party estimators on shared cluster.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "HyukjinKwon (via GitHub)" <gi...@apache.org>.
HyukjinKwon commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135031137


##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   It's autogenerated anyway so I suspect it's fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1141916038


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala:
##########
@@ -210,7 +211,12 @@ class SparkConnectService(debug: Boolean)
  * @param userId
  * @param session
  */
-case class SessionHolder(userId: String, sessionId: String, session: SparkSession)
+case class SessionHolder(
+  userId: String,
+  sessionId: String,
+  session: SparkSession,
+  mlCache: MLCache = MLCache()

Review Comment:
   https://github.com/apache/spark/pull/40485



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129079073


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {

Review Comment:
   I see, we don't have `IntParam` in PySpark



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133669879


##########
mllib/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable
  *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
-class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
-  extends Serializable {
+class Param[T: ClassTag](

Review Comment:
   then that is fine



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133521621


##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   What's this annotation for and why we need to remove it ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135526116


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   For current PR, it does not support third-party estimators.
   We need to register related class for 3rd algorithm to `AlgorithmRegistry` class. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135523794


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";

Review Comment:
   Sure!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135523486


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -82,13 +83,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;

Review Comment:
   message ModelRef {
     int64 id = 1;
   }
   
   This sounds good.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1138313771


##########
mllib/common/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -793,6 +800,10 @@ trait Params extends Identifiable with Serializable {
     this
   }
 
+  private[spark] def _setDefault(paramPairs: ParamPair[_]*): this.type = {
+    setDefault(paramPairs: _*)
+  }

Review Comment:
   I think we can simply change `setDefault` to `protected[spark]` ?



##########
python/pyspark/ml/base.py:
##########
@@ -17,6 +17,7 @@
 
 from abc import ABCMeta, abstractmethod
 
+import os

Review Comment:
   is this import needed?



##########
python/pyspark/sql/connect/session.py:
##########
@@ -463,7 +463,7 @@ def stop(self) -> None:
 
     @classmethod
     def getActiveSession(cls) -> Any:
-        raise NotImplementedError("getActiveSession() is not implemented.")

Review Comment:
   do we need this change ? I thought we can use the newly added `getOrCreate`



##########
python/pyspark/ml/connect/classification.py:
##########
@@ -0,0 +1,510 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import (
+    Any,
+    Dict,
+    Generic,
+    Iterable,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+    TYPE_CHECKING,
+)
+
+from pyspark.sql import DataFrame
+from pyspark.ml.classification import (
+    Classifier,
+    ProbabilisticClassifier,
+    ProbabilisticClassifier,
+    _LogisticRegressionParams,
+    _LogisticRegressionCommon,
+    ClassificationModel,
+    ProbabilisticClassificationModel
+)
+from pyspark.ml.linalg import (Matrix, Vector)
+from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
+from pyspark.ml.connect.base import (
+    ClientEstimator,
+    ClientModel,
+    HasTrainingSummary,
+    ClientModelSummary,
+    ClientPredictor,
+    ClientPredictionModel,
+    ClientMLWritable,
+    ClientMLReadable,
+)
+from abc import ABCMeta, abstractmethod
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.classification import (
+    LogisticRegression as PySparkLogisticRegression,
+    LogisticRegressionModel as PySparkLogisticRegressionModel,
+    _ClassificationSummary as _PySparkClassificationSummary,
+    _TrainingSummary as _PySparkTrainingSummary,
+    _BinaryClassificationSummary as _PySparkBinaryClassificationSummary,
+    LogisticRegressionSummary as PySparkLogisticRegressionSummary,
+    LogisticRegressionTrainingSummary as PySparkLogisticRegressionTrainingSummary,
+    BinaryLogisticRegressionSummary as PySparkBinaryLogisticRegressionSummary,
+    BinaryLogisticRegressionTrainingSummary as PySparkBinaryLogisticRegressionTrainingSummary,
+)
+
+
+@inherit_doc
+class _ClientClassifier(Classifier, ClientPredictor, metaclass=ABCMeta):
+    pass
+
+
+@inherit_doc
+class _ClientProbabilisticClassifier(
+    ProbabilisticClassifier, _ClientClassifier, metaclass=ABCMeta
+):
+    pass
+
+
+@inherit_doc
+class _ClientClassificationModel(ClassificationModel, ClientPredictionModel):
+    @property  # type: ignore[misc]
+    def numClasses(self) -> int:
+        return self._get_model_attr("numClasses")
+
+    def predictRaw(self, value: Vector) -> Vector:
+        # TODO: support this.
+        raise NotImplementedError()
+
+
+@inherit_doc
+class _ClientProbabilisticClassificationModel(
+    ProbabilisticClassificationModel, _ClientClassificationModel
+):
+    def predictProbability(self, value: Vector) -> Vector:
+        # TODO: support this.
+        raise NotImplementedError()
+
+
+@inherit_doc
+class LogisticRegression(
+    _ClientProbabilisticClassifier,
+    _LogisticRegressionCommon,
+    ClientMLWritable,
+    ClientMLReadable,
+):
+    _input_kwargs: Dict[str, Any]
+
+    @overload
+    def __init__(
+            self,
+            *,
+            featuresCol: str = ...,
+            labelCol: str = ...,
+            predictionCol: str = ...,
+            maxIter: int = ...,
+            regParam: float = ...,
+            elasticNetParam: float = ...,
+            tol: float = ...,
+            fitIntercept: bool = ...,
+            threshold: float = ...,
+            probabilityCol: str = ...,
+            rawPredictionCol: str = ...,
+            standardization: bool = ...,
+            weightCol: Optional[str] = ...,
+            aggregationDepth: int = ...,
+            family: str = ...,
+            lowerBoundsOnCoefficients: Optional[Matrix] = ...,
+            upperBoundsOnCoefficients: Optional[Matrix] = ...,
+            lowerBoundsOnIntercepts: Optional[Vector] = ...,
+            upperBoundsOnIntercepts: Optional[Vector] = ...,
+            maxBlockSizeInMB: float = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+            self,
+            *,
+            featuresCol: str = ...,
+            labelCol: str = ...,
+            predictionCol: str = ...,
+            maxIter: int = ...,
+            regParam: float = ...,
+            elasticNetParam: float = ...,
+            tol: float = ...,
+            fitIntercept: bool = ...,
+            thresholds: Optional[List[float]] = ...,
+            probabilityCol: str = ...,
+            rawPredictionCol: str = ...,
+            standardization: bool = ...,
+            weightCol: Optional[str] = ...,
+            aggregationDepth: int = ...,
+            family: str = ...,
+            lowerBoundsOnCoefficients: Optional[Matrix] = ...,
+            upperBoundsOnCoefficients: Optional[Matrix] = ...,
+            lowerBoundsOnIntercepts: Optional[Vector] = ...,
+            upperBoundsOnIntercepts: Optional[Vector] = ...,
+            maxBlockSizeInMB: float = ...,
+    ):
+        ...
+
+    @keyword_only
+    def __init__(
+            self,
+            *,
+            featuresCol: str = "features",
+            labelCol: str = "label",
+            predictionCol: str = "prediction",
+            maxIter: int = 100,
+            regParam: float = 0.0,
+            elasticNetParam: float = 0.0,
+            tol: float = 1e-6,
+            fitIntercept: bool = True,
+            threshold: float = 0.5,
+            thresholds: Optional[List[float]] = None,
+            probabilityCol: str = "probability",
+            rawPredictionCol: str = "rawPrediction",
+            standardization: bool = True,
+            weightCol: Optional[str] = None,
+            aggregationDepth: int = 2,
+            family: str = "auto",
+            lowerBoundsOnCoefficients: Optional[Matrix] = None,
+            upperBoundsOnCoefficients: Optional[Matrix] = None,
+            lowerBoundsOnIntercepts: Optional[Vector] = None,
+            upperBoundsOnIntercepts: Optional[Vector] = None,
+            maxBlockSizeInMB: float = 0.0,
+    ):
+        """
+        __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
+                 threshold=0.5, thresholds=None, probabilityCol="probability", \
+                 rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+                 aggregationDepth=2, family="auto", \
+                 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
+                 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \
+                 maxBlockSizeInMB=0.0):
+        If the threshold and thresholds Params are both set, they must be equivalent.
+        """
+        super(LogisticRegression, self).__init__()
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+        self._checkThresholdConsistency()
+
+    @classmethod
+    def _algo_name(cls):
+        return "LogisticRegression"
+
+    @classmethod
+    def _model_class(cls):
+        return LogisticRegressionModel
+
+
+LogisticRegression.__doc__ = PySparkLogisticRegression.__doc__
+
+
+@inherit_doc
+class LogisticRegressionModel(
+    _ClientProbabilisticClassificationModel,
+    _LogisticRegressionParams,
+    HasTrainingSummary,
+    ClientMLWritable,
+    ClientMLReadable,
+):
+    @classmethod
+    def _algo_name(cls):
+        return "LogisticRegression"
+
+    @property  # type: ignore[misc]
+    def coefficients(self) -> Vector:
+        return self._get_model_attr("coefficients")
+
+    @property  # type: ignore[misc]
+    def intercept(self) -> float:
+        return self._get_model_attr("intercept")
+
+    @property  # type: ignore[misc]
+    def coefficientMatrix(self) -> Matrix:
+        return self._get_model_attr("coefficientMatrix")
+
+    @property  # type: ignore[misc]
+    def interceptVector(self) -> Vector:
+        return self._get_model_attr("interceptVector")
+
+    def evaluate(self, dataset):
+        if self.numClasses <= 2:
+            return BinaryLogisticRegressionSummary(self, dataset)
+        else:
+            return LogisticRegressionSummary(self, dataset)
+
+    # TODO: Move this method to common interface shared by connect code and legacy code
+    @property  # type: ignore[misc]
+    def summary(self) -> "LogisticRegressionTrainingSummary":
+        if self.hasSummary:
+            if self.numClasses <= 2:
+                return BinaryLogisticRegressionTrainingSummary(self, None)
+            else:
+                return LogisticRegressionTrainingSummary(self, None)
+        else:
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
+
+
+LogisticRegressionModel.__doc__ = PySparkLogisticRegressionModel.__doc__
+
+
+@inherit_doc
+class _ClassificationSummary(ClientModelSummary):
+
+    @property  # type: ignore[misc]
+    def predictions(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("predictions")
+
+    predictions.__doc__ = _PySparkClassificationSummary.predictions.__doc__
+
+    @property  # type: ignore[misc]
+    def predictionCol(self) -> str:
+        return self._get_summary_attr("predictionCol")
+
+    predictionCol.__doc__ = _PySparkClassificationSummary.predictionCol.__doc__
+
+    @property  # type: ignore[misc]
+    def labelCol(self) -> str:
+        return self._get_summary_attr("labelCol")
+
+    labelCol.__doc__ = _PySparkClassificationSummary.labelCol.__doc__
+
+    @property  # type: ignore[misc]
+    def weightCol(self) -> str:
+        return self._get_summary_attr("weightCol")
+
+    weightCol.__doc__ = _PySparkClassificationSummary.weightCol.__doc__
+
+    @property
+    def labels(self) -> List[str]:
+        return self._get_summary_attr("labels")
+
+    labels.__doc__ = _PySparkClassificationSummary.labels.__doc__
+
+    @property  # type: ignore[misc]
+    def truePositiveRateByLabel(self) -> List[float]:
+        return self._get_summary_attr("truePositiveRateByLabel")
+
+    truePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.truePositiveRateByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def falsePositiveRateByLabel(self) -> List[float]:
+        return self._get_summary_attr("falsePositiveRateByLabel")
+
+    falsePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.falsePositiveRateByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def precisionByLabel(self) -> List[float]:
+        return self._get_summary_attr("precisionByLabel")
+
+    precisionByLabel.__doc__ = _PySparkClassificationSummary.precisionByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def recallByLabel(self) -> List[float]:
+        return self._get_summary_attr("recallByLabel")
+
+    recallByLabel.__doc__ = _PySparkClassificationSummary.recallByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def fMeasureByLabel(self, beta: float = 1.0) -> List[float]:
+        # TODO: support this.
+        raise NotImplementedError()
+
+    fMeasureByLabel.__doc__ = _PySparkClassificationSummary.fMeasureByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def accuracy(self) -> float:
+        return self._get_summary_attr("accuracy")
+
+    accuracy.__doc__ = _PySparkClassificationSummary.accuracy.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedTruePositiveRate(self) -> float:
+        return self._get_summary_attr("weightedTruePositiveRate")
+
+    weightedTruePositiveRate.__doc__ = _PySparkClassificationSummary.weightedTruePositiveRate.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedFalsePositiveRate(self) -> float:
+        return self._get_summary_attr("weightedFalsePositiveRate")
+
+    weightedFalsePositiveRate.__doc__ = _PySparkClassificationSummary.weightedFalsePositiveRate.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedRecall(self) -> float:
+        return self._get_summary_attr("weightedRecall")
+
+    weightedRecall.__doc__ = _PySparkClassificationSummary.weightedRecall.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedPrecision(self) -> float:
+        return self._get_summary_attr("weightedPrecision")
+
+    weightedPrecision.__doc__ = _PySparkClassificationSummary.weightedPrecision.__doc__
+
+    def weightedFMeasure(self, beta: float = 1.0) -> float:
+        # TODO: support this.
+        raise NotImplementedError()
+
+    weightedFMeasure.__doc__ = _PySparkClassificationSummary.weightedFMeasure.__doc__
+
+
+@inherit_doc
+class _TrainingSummary(ClientModelSummary):
+
+    @property  # type: ignore[misc]
+    def objectiveHistory(self) -> List[float]:
+        return self._get_summary_attr("objectiveHistory")
+
+    objectiveHistory.__doc__ = _PySparkTrainingSummary.objectiveHistory.__doc__
+
+    @property  # type: ignore[misc]
+    def totalIterations(self) -> int:
+        return self._get_summary_attr("totalIterations")
+
+    totalIterations.__doc__ = _PySparkTrainingSummary.totalIterations.__doc__
+
+
+@inherit_doc
+class _BinaryClassificationSummary(_ClassificationSummary):
+
+    @property  # type: ignore[misc]
+    def scoreCol(self) -> str:
+        return self._get_summary_attr("scoreCol")
+
+    scoreCol.__doc__ = _PySparkBinaryClassificationSummary.scoreCol.__doc__
+
+    @property
+    def roc(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("roc")
+
+    roc.__doc__ = _PySparkBinaryClassificationSummary.roc.__doc__
+
+    @property  # type: ignore[misc]
+    def areaUnderROC(self) -> float:
+        return self._get_summary_attr("areaUnderROC")
+
+    areaUnderROC.__doc__ = _PySparkBinaryClassificationSummary.areaUnderROC.__doc__
+
+    @property  # type: ignore[misc]
+    def pr(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("pr")
+
+    pr.__doc__ = _PySparkBinaryClassificationSummary.pr.__doc__
+
+    @property  # type: ignore[misc]
+    def fMeasureByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("fMeasureByThreshold")
+
+    fMeasureByThreshold.__doc__ = _PySparkBinaryClassificationSummary.fMeasureByThreshold.__doc__
+
+    @property  # type: ignore[misc]
+    def precisionByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("precisionByThreshold")
+
+    precisionByThreshold.__doc__ = _PySparkBinaryClassificationSummary.precisionByThreshold.__doc__
+
+    @property  # type: ignore[misc]
+    def recallByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("recallByThreshold")
+
+    recallByThreshold.__doc__ = _PySparkBinaryClassificationSummary.recallByThreshold.__doc__
+
+
+@inherit_doc
+class LogisticRegressionSummary(_ClassificationSummary):
+
+    @property  # type: ignore[misc]
+    def probabilityCol(self) -> str:
+        return self._get_summary_attr("probabilityCol")
+
+    probabilityCol.__doc__ = PySparkLogisticRegressionSummary.probabilityCol.__doc__
+
+    @property  # type: ignore[misc]
+    def featuresCol(self) -> str:
+        return self._get_summary_attr("featuresCol")
+
+    featuresCol.__doc__ = PySparkLogisticRegressionSummary.featuresCol.__doc__
+
+
+LogisticRegressionSummary.__doc__ = PySparkLogisticRegressionSummary.__doc__
+
+
+@inherit_doc
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary):
+    pass
+
+
+LogisticRegressionTrainingSummary.__doc__ = PySparkLogisticRegressionTrainingSummary.__doc__
+
+
+@inherit_doc
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary):
+    pass
+
+
+BinaryLogisticRegressionSummary.__doc__ = PySparkBinaryLogisticRegressionSummary.__doc__
+
+
+@inherit_doc
+class BinaryLogisticRegressionTrainingSummary(
+    BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
+):
+    pass
+
+
+BinaryLogisticRegressionTrainingSummary.__doc__ = PySparkBinaryLogisticRegressionTrainingSummary.__doc__
+
+
+def _test() -> None:
+    import os
+    import sys
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.sql.connect.ml.classification
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.connect.dataframe.__dict__.copy()
+
+    globs["spark"] = (
+        PySparkSession.builder.appName("sql.connect.ml.classification tests")

Review Comment:
   doctest should be added in `sparktestsupport/modules.py`



##########
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+// MlParams stores param settings for
+// ML Estimator / Transformer / Model / Evaluator
+message MlParams {
+  // user-supplied params
+  map<string, Expression.Literal> params = 1;
+  // default params
+  map<string, Expression.Literal> default_params = 2;
+}
+
+// MlStage stores ML stage data (Estimator or Transformer)
+message MlStage {
+  // The name of the stage in the registry
+  string name = 1;
+  // param settings for the stage
+  MlParams params = 2;
+  // unique id of the stage
+  string uid = 3;
+  StageType type = 4;
+  enum StageType {
+    UNSPECIFIED = 0;
+    ESTIMATOR = 1;
+    TRANSFORMER = 2;

Review Comment:
   we normally name enums like this
   ```
       STAGE_TYPE_UNSPECIFIED = 0;
       STAGE_TYPE_ESTIMATOR = 1;
       STAGE_TYPE_TRANSFORMER = 2;
   ```



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala:
##########
@@ -0,0 +1,61 @@
+/*

Review Comment:
   will we move these ml files to `connector/connect/server/src/main/scala/org/apache/spark/ml/connect` ?



##########
python/pyspark/ml/connect/classification.py:
##########
@@ -0,0 +1,510 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import (
+    Any,
+    Dict,
+    Generic,
+    Iterable,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    overload,
+    TYPE_CHECKING,
+)
+
+from pyspark.sql import DataFrame
+from pyspark.ml.classification import (
+    Classifier,
+    ProbabilisticClassifier,
+    ProbabilisticClassifier,
+    _LogisticRegressionParams,
+    _LogisticRegressionCommon,
+    ClassificationModel,
+    ProbabilisticClassificationModel
+)
+from pyspark.ml.linalg import (Matrix, Vector)
+from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
+from pyspark.ml.connect.base import (
+    ClientEstimator,
+    ClientModel,
+    HasTrainingSummary,
+    ClientModelSummary,
+    ClientPredictor,
+    ClientPredictionModel,
+    ClientMLWritable,
+    ClientMLReadable,
+)
+from abc import ABCMeta, abstractmethod
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.classification import (
+    LogisticRegression as PySparkLogisticRegression,
+    LogisticRegressionModel as PySparkLogisticRegressionModel,
+    _ClassificationSummary as _PySparkClassificationSummary,
+    _TrainingSummary as _PySparkTrainingSummary,
+    _BinaryClassificationSummary as _PySparkBinaryClassificationSummary,
+    LogisticRegressionSummary as PySparkLogisticRegressionSummary,
+    LogisticRegressionTrainingSummary as PySparkLogisticRegressionTrainingSummary,
+    BinaryLogisticRegressionSummary as PySparkBinaryLogisticRegressionSummary,
+    BinaryLogisticRegressionTrainingSummary as PySparkBinaryLogisticRegressionTrainingSummary,
+)
+
+
+@inherit_doc
+class _ClientClassifier(Classifier, ClientPredictor, metaclass=ABCMeta):
+    pass
+
+
+@inherit_doc
+class _ClientProbabilisticClassifier(
+    ProbabilisticClassifier, _ClientClassifier, metaclass=ABCMeta
+):
+    pass
+
+
+@inherit_doc
+class _ClientClassificationModel(ClassificationModel, ClientPredictionModel):
+    @property  # type: ignore[misc]
+    def numClasses(self) -> int:
+        return self._get_model_attr("numClasses")
+
+    def predictRaw(self, value: Vector) -> Vector:
+        # TODO: support this.
+        raise NotImplementedError()
+
+
+@inherit_doc
+class _ClientProbabilisticClassificationModel(
+    ProbabilisticClassificationModel, _ClientClassificationModel
+):
+    def predictProbability(self, value: Vector) -> Vector:
+        # TODO: support this.
+        raise NotImplementedError()
+
+
+@inherit_doc
+class LogisticRegression(
+    _ClientProbabilisticClassifier,
+    _LogisticRegressionCommon,
+    ClientMLWritable,
+    ClientMLReadable,
+):
+    _input_kwargs: Dict[str, Any]
+
+    @overload
+    def __init__(
+            self,
+            *,
+            featuresCol: str = ...,
+            labelCol: str = ...,
+            predictionCol: str = ...,
+            maxIter: int = ...,
+            regParam: float = ...,
+            elasticNetParam: float = ...,
+            tol: float = ...,
+            fitIntercept: bool = ...,
+            threshold: float = ...,
+            probabilityCol: str = ...,
+            rawPredictionCol: str = ...,
+            standardization: bool = ...,
+            weightCol: Optional[str] = ...,
+            aggregationDepth: int = ...,
+            family: str = ...,
+            lowerBoundsOnCoefficients: Optional[Matrix] = ...,
+            upperBoundsOnCoefficients: Optional[Matrix] = ...,
+            lowerBoundsOnIntercepts: Optional[Vector] = ...,
+            upperBoundsOnIntercepts: Optional[Vector] = ...,
+            maxBlockSizeInMB: float = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+            self,
+            *,
+            featuresCol: str = ...,
+            labelCol: str = ...,
+            predictionCol: str = ...,
+            maxIter: int = ...,
+            regParam: float = ...,
+            elasticNetParam: float = ...,
+            tol: float = ...,
+            fitIntercept: bool = ...,
+            thresholds: Optional[List[float]] = ...,
+            probabilityCol: str = ...,
+            rawPredictionCol: str = ...,
+            standardization: bool = ...,
+            weightCol: Optional[str] = ...,
+            aggregationDepth: int = ...,
+            family: str = ...,
+            lowerBoundsOnCoefficients: Optional[Matrix] = ...,
+            upperBoundsOnCoefficients: Optional[Matrix] = ...,
+            lowerBoundsOnIntercepts: Optional[Vector] = ...,
+            upperBoundsOnIntercepts: Optional[Vector] = ...,
+            maxBlockSizeInMB: float = ...,
+    ):
+        ...
+
+    @keyword_only
+    def __init__(
+            self,
+            *,
+            featuresCol: str = "features",
+            labelCol: str = "label",
+            predictionCol: str = "prediction",
+            maxIter: int = 100,
+            regParam: float = 0.0,
+            elasticNetParam: float = 0.0,
+            tol: float = 1e-6,
+            fitIntercept: bool = True,
+            threshold: float = 0.5,
+            thresholds: Optional[List[float]] = None,
+            probabilityCol: str = "probability",
+            rawPredictionCol: str = "rawPrediction",
+            standardization: bool = True,
+            weightCol: Optional[str] = None,
+            aggregationDepth: int = 2,
+            family: str = "auto",
+            lowerBoundsOnCoefficients: Optional[Matrix] = None,
+            upperBoundsOnCoefficients: Optional[Matrix] = None,
+            lowerBoundsOnIntercepts: Optional[Vector] = None,
+            upperBoundsOnIntercepts: Optional[Vector] = None,
+            maxBlockSizeInMB: float = 0.0,
+    ):
+        """
+        __init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                 maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
+                 threshold=0.5, thresholds=None, probabilityCol="probability", \
+                 rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
+                 aggregationDepth=2, family="auto", \
+                 lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \
+                 lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \
+                 maxBlockSizeInMB=0.0):
+        If the threshold and thresholds Params are both set, they must be equivalent.
+        """
+        super(LogisticRegression, self).__init__()
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+        self._checkThresholdConsistency()
+
+    @classmethod
+    def _algo_name(cls):
+        return "LogisticRegression"
+
+    @classmethod
+    def _model_class(cls):
+        return LogisticRegressionModel
+
+
+LogisticRegression.__doc__ = PySparkLogisticRegression.__doc__
+
+
+@inherit_doc
+class LogisticRegressionModel(
+    _ClientProbabilisticClassificationModel,
+    _LogisticRegressionParams,
+    HasTrainingSummary,
+    ClientMLWritable,
+    ClientMLReadable,
+):
+    @classmethod
+    def _algo_name(cls):
+        return "LogisticRegression"
+
+    @property  # type: ignore[misc]
+    def coefficients(self) -> Vector:
+        return self._get_model_attr("coefficients")
+
+    @property  # type: ignore[misc]
+    def intercept(self) -> float:
+        return self._get_model_attr("intercept")
+
+    @property  # type: ignore[misc]
+    def coefficientMatrix(self) -> Matrix:
+        return self._get_model_attr("coefficientMatrix")
+
+    @property  # type: ignore[misc]
+    def interceptVector(self) -> Vector:
+        return self._get_model_attr("interceptVector")
+
+    def evaluate(self, dataset):
+        if self.numClasses <= 2:
+            return BinaryLogisticRegressionSummary(self, dataset)
+        else:
+            return LogisticRegressionSummary(self, dataset)
+
+    # TODO: Move this method to common interface shared by connect code and legacy code
+    @property  # type: ignore[misc]
+    def summary(self) -> "LogisticRegressionTrainingSummary":
+        if self.hasSummary:
+            if self.numClasses <= 2:
+                return BinaryLogisticRegressionTrainingSummary(self, None)
+            else:
+                return LogisticRegressionTrainingSummary(self, None)
+        else:
+            raise RuntimeError(
+                "No training summary available for this %s" % self.__class__.__name__
+            )
+
+
+LogisticRegressionModel.__doc__ = PySparkLogisticRegressionModel.__doc__
+
+
+@inherit_doc
+class _ClassificationSummary(ClientModelSummary):
+
+    @property  # type: ignore[misc]
+    def predictions(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("predictions")
+
+    predictions.__doc__ = _PySparkClassificationSummary.predictions.__doc__
+
+    @property  # type: ignore[misc]
+    def predictionCol(self) -> str:
+        return self._get_summary_attr("predictionCol")
+
+    predictionCol.__doc__ = _PySparkClassificationSummary.predictionCol.__doc__
+
+    @property  # type: ignore[misc]
+    def labelCol(self) -> str:
+        return self._get_summary_attr("labelCol")
+
+    labelCol.__doc__ = _PySparkClassificationSummary.labelCol.__doc__
+
+    @property  # type: ignore[misc]
+    def weightCol(self) -> str:
+        return self._get_summary_attr("weightCol")
+
+    weightCol.__doc__ = _PySparkClassificationSummary.weightCol.__doc__
+
+    @property
+    def labels(self) -> List[str]:
+        return self._get_summary_attr("labels")
+
+    labels.__doc__ = _PySparkClassificationSummary.labels.__doc__
+
+    @property  # type: ignore[misc]
+    def truePositiveRateByLabel(self) -> List[float]:
+        return self._get_summary_attr("truePositiveRateByLabel")
+
+    truePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.truePositiveRateByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def falsePositiveRateByLabel(self) -> List[float]:
+        return self._get_summary_attr("falsePositiveRateByLabel")
+
+    falsePositiveRateByLabel.__doc__ = _PySparkClassificationSummary.falsePositiveRateByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def precisionByLabel(self) -> List[float]:
+        return self._get_summary_attr("precisionByLabel")
+
+    precisionByLabel.__doc__ = _PySparkClassificationSummary.precisionByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def recallByLabel(self) -> List[float]:
+        return self._get_summary_attr("recallByLabel")
+
+    recallByLabel.__doc__ = _PySparkClassificationSummary.recallByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def fMeasureByLabel(self, beta: float = 1.0) -> List[float]:
+        # TODO: support this.
+        raise NotImplementedError()
+
+    fMeasureByLabel.__doc__ = _PySparkClassificationSummary.fMeasureByLabel.__doc__
+
+    @property  # type: ignore[misc]
+    def accuracy(self) -> float:
+        return self._get_summary_attr("accuracy")
+
+    accuracy.__doc__ = _PySparkClassificationSummary.accuracy.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedTruePositiveRate(self) -> float:
+        return self._get_summary_attr("weightedTruePositiveRate")
+
+    weightedTruePositiveRate.__doc__ = _PySparkClassificationSummary.weightedTruePositiveRate.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedFalsePositiveRate(self) -> float:
+        return self._get_summary_attr("weightedFalsePositiveRate")
+
+    weightedFalsePositiveRate.__doc__ = _PySparkClassificationSummary.weightedFalsePositiveRate.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedRecall(self) -> float:
+        return self._get_summary_attr("weightedRecall")
+
+    weightedRecall.__doc__ = _PySparkClassificationSummary.weightedRecall.__doc__
+
+    @property  # type: ignore[misc]
+    def weightedPrecision(self) -> float:
+        return self._get_summary_attr("weightedPrecision")
+
+    weightedPrecision.__doc__ = _PySparkClassificationSummary.weightedPrecision.__doc__
+
+    def weightedFMeasure(self, beta: float = 1.0) -> float:
+        # TODO: support this.
+        raise NotImplementedError()
+
+    weightedFMeasure.__doc__ = _PySparkClassificationSummary.weightedFMeasure.__doc__
+
+
+@inherit_doc
+class _TrainingSummary(ClientModelSummary):
+
+    @property  # type: ignore[misc]
+    def objectiveHistory(self) -> List[float]:
+        return self._get_summary_attr("objectiveHistory")
+
+    objectiveHistory.__doc__ = _PySparkTrainingSummary.objectiveHistory.__doc__
+
+    @property  # type: ignore[misc]
+    def totalIterations(self) -> int:
+        return self._get_summary_attr("totalIterations")
+
+    totalIterations.__doc__ = _PySparkTrainingSummary.totalIterations.__doc__
+
+
+@inherit_doc
+class _BinaryClassificationSummary(_ClassificationSummary):
+
+    @property  # type: ignore[misc]
+    def scoreCol(self) -> str:
+        return self._get_summary_attr("scoreCol")
+
+    scoreCol.__doc__ = _PySparkBinaryClassificationSummary.scoreCol.__doc__
+
+    @property
+    def roc(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("roc")
+
+    roc.__doc__ = _PySparkBinaryClassificationSummary.roc.__doc__
+
+    @property  # type: ignore[misc]
+    def areaUnderROC(self) -> float:
+        return self._get_summary_attr("areaUnderROC")
+
+    areaUnderROC.__doc__ = _PySparkBinaryClassificationSummary.areaUnderROC.__doc__
+
+    @property  # type: ignore[misc]
+    def pr(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("pr")
+
+    pr.__doc__ = _PySparkBinaryClassificationSummary.pr.__doc__
+
+    @property  # type: ignore[misc]
+    def fMeasureByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("fMeasureByThreshold")
+
+    fMeasureByThreshold.__doc__ = _PySparkBinaryClassificationSummary.fMeasureByThreshold.__doc__
+
+    @property  # type: ignore[misc]
+    def precisionByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("precisionByThreshold")
+
+    precisionByThreshold.__doc__ = _PySparkBinaryClassificationSummary.precisionByThreshold.__doc__
+
+    @property  # type: ignore[misc]
+    def recallByThreshold(self) -> DataFrame:
+        return self._get_summary_attr_dataframe("recallByThreshold")
+
+    recallByThreshold.__doc__ = _PySparkBinaryClassificationSummary.recallByThreshold.__doc__
+
+
+@inherit_doc
+class LogisticRegressionSummary(_ClassificationSummary):
+
+    @property  # type: ignore[misc]
+    def probabilityCol(self) -> str:
+        return self._get_summary_attr("probabilityCol")
+
+    probabilityCol.__doc__ = PySparkLogisticRegressionSummary.probabilityCol.__doc__
+
+    @property  # type: ignore[misc]
+    def featuresCol(self) -> str:
+        return self._get_summary_attr("featuresCol")
+
+    featuresCol.__doc__ = PySparkLogisticRegressionSummary.featuresCol.__doc__
+
+
+LogisticRegressionSummary.__doc__ = PySparkLogisticRegressionSummary.__doc__
+
+
+@inherit_doc
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary, _TrainingSummary):
+    pass
+
+
+LogisticRegressionTrainingSummary.__doc__ = PySparkLogisticRegressionTrainingSummary.__doc__
+
+
+@inherit_doc
+class BinaryLogisticRegressionSummary(_BinaryClassificationSummary, LogisticRegressionSummary):
+    pass
+
+
+BinaryLogisticRegressionSummary.__doc__ = PySparkBinaryLogisticRegressionSummary.__doc__
+
+
+@inherit_doc
+class BinaryLogisticRegressionTrainingSummary(
+    BinaryLogisticRegressionSummary, LogisticRegressionTrainingSummary
+):
+    pass
+
+
+BinaryLogisticRegressionTrainingSummary.__doc__ = PySparkBinaryLogisticRegressionTrainingSummary.__doc__
+
+
+def _test() -> None:
+    import os
+    import sys
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.sql.connect.ml.classification
+
+    os.chdir(os.environ["SPARK_HOME"])
+
+    globs = pyspark.sql.connect.dataframe.__dict__.copy()
+
+    globs["spark"] = (
+        PySparkSession.builder.appName("sql.connect.ml.classification tests")

Review Comment:
   ```suggestion
           PySparkSession.builder.appName("ml.connect.classification tests")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136451429


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -82,13 +83,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;

Review Comment:
   > The ID is generated from a increamental counter. 
   
   Using random UUID might be a better idea , if we want to support server failover in future (we need to persist  status and restore it,  random UUID can help avoiding reusing ID that is generated before.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133521342


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala:
##########
@@ -210,7 +211,12 @@ class SparkConnectService(debug: Boolean)
  * @param userId
  * @param session
  */
-case class SessionHolder(userId: String, sessionId: String, session: SparkSession)
+case class SessionHolder(
+  userId: String,
+  sessionId: String,
+  session: SparkSession,
+  mlCache: MLCache = MLCache()

Review Comment:
   But the modelCache has ml specific attribute, like `modelToHandlerMap`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133455888


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveModel {
+    int64 model_ref_id = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadStage {
+    string name = 1;
+    string path = 2;
+    MlStage.StageType type = 3;
+  }
+
+  message SaveStage {
+    MlStage stage = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadEvaluator {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveEvaluator {
+    MlEvaluator evaluator = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message FetchModelAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+  }
+
+  message FetchModelSummaryAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+    MlParams params = 3;
+
+    // Evaluation dataset that it uses to computes
+    // the summary attribute
+    // If not set, get attributes from
+    // model.summary (i.e. the summary on training dataset)
+    optional Relation evaluation_dataset = 4;
+  }
+
+  message CopyModel {
+    int64 model_ref_id = 1;
+  }
+}
+
+
+message MlCommandResponse {
+  oneof ml_command_response_type {
+    Expression.Literal literal = 1;
+    ModelInfo model_info = 2;
+    Vector vector = 3;

Review Comment:
   do we need an abstraction for model coefficients?
   
   for example, 
   
   ```
   class GaussianMixtureModel private[ml] (
       @Since("2.0.0") override val uid: String,
       @Since("2.0.0") val weights: Array[Double],
       @Since("2.0.0") val gaussians: Array[MultivariateGaussian])
   
   class MultivariateGaussian @Since("2.0.0") (
       @Since("2.0.0") val mean: Vector,
       @Since("2.0.0") val cov: Matrix) extends Serializable
   ```
   
   
   Or maybe always return a DF?
   



##########
mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java:
##########
@@ -108,9 +108,6 @@ private void init() {
     myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
     myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
       ParamValidators.inRange(0.0, 1.0));
-    List<String> validStrings = Arrays.asList("a", "b");

Review Comment:
   why remove?



##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   need to figure out how to remove such annotation



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala:
##########
@@ -210,7 +211,12 @@ class SparkConnectService(debug: Boolean)
  * @param userId
  * @param session
  */
-case class SessionHolder(userId: String, sessionId: String, session: SparkSession)
+case class SessionHolder(
+  userId: String,
+  sessionId: String,
+  session: SparkSession,
+  mlCache: MLCache = MLCache()

Review Comment:
   I feel we'd better let `SessionHolder` hold a generalized cache (`ObjectCache`) instead of a ML-specific cache



##########
mllib/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable
  *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
-class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
-  extends Serializable {
+class Param[T: ClassTag](

Review Comment:
   maybe a litter simpler
   
   ```suggestion
   class Param[T](
       val parent: String, val name: String, val doc: String, val isValid: T => Boolean
   )(implicit paramValueClassTag: ClassTag[T]) extends Serializable
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135600960


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   Btw, supporting 3rd-party estimators is risky, because in shared cluster we will [binpack the spark workers across different customers](https://docs.google.com/document/d/1sJVjan44XagM48PEqdkg6KWctpPcz54Urf0i_-dGesA/edit?disco=AAAArl9hpF8)
   But 3rd-party estimators implementation might invoke RDD transformation (e.g. RDD.map) that we cannot isolate them by container. So it is risky if we allow user uses 3rd-party estimators on shared cluster.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] github-actions[bot] closed pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "github-actions[bot] (via GitHub)" <gi...@apache.org>.
github-actions[bot] closed pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML
URL: https://github.com/apache/spark/pull/40297


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127830764


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -81,13 +82,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;
+    Params params = 3;
+  }
+  message FeatureTransform {

Review Comment:
   Got it. Then we can treat it as normal transformer because it only contains splits param



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127834377


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {

Review Comment:
   No. The purpose of `paramType` is:
   in pyspark, we cannot distinguish int/long, double/float, so when send it to server, int becomes long, and float becomes double, we need to get the accurate param type otherwise JVM will raise error in runtime.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129073675


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   I will rename it to `LoadStage` to support estimator too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1128955568


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {

Review Comment:
   for int/long case, we can specify the DataType in Python Client
   
   https://github.com/apache/spark/blob/c99a632fea74136964b27b28563115fe2d7667b3/python/pyspark/sql/connect/expressions.py#L165-L171



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129109936


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   what if we make `instance.set` accept different values types here?
   
   ```
   (paramDef, paramValue) match {
      case (p: DoubleParam, v: Float) => instance.set(p, v.toDouble)
      case (p: DoubleParam, v: Double) => instance.set(p, v)
      case (p: LongParam, v: Short) => instance.set(p, v.toLong)
      case (p: LongParam, v: Int) => instance.set(p, v.toLong)
      case (p: LongParam, v: Long) => instance.set(p, v)
      ...
      case _ => instance.set(paramDef, paramValue)
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133663732


##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   CC @HyukjinKwon Any ideas ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133666221


##########
mllib/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable
  *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
-class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
-  extends Serializable {
+class Param[T: ClassTag](

Review Comment:
   @zhengruifeng 
   
   I recall I tried this approach `(implicit paramValueClassTag: ClassTag[T])` before , but it makes us hard to get the classTag object. So I prefer current approach.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133605056


##########
python/pyspark/sql/connect/proto/catalog_pb2.pyi:
##########
@@ -49,6 +49,7 @@ else:
 
 DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
 
+@typing_extensions.final

Review Comment:
   I don't know what is this annotation, but I think it should not affect other generated files



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129158322


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   This way we cannot support param that defined as Param[XXX] class.



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   This way we cannot support param that is defined as Param[XXX] class.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129158322


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   This way we cannot support param that is defined as Param[XXX] class.( when compiling, the XXX type info is erased)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129143557


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)

Review Comment:
   > instance.set(p, v.toLong)
   
   If p is `Param[Int]` type, then it will cause error. `set` function would succeed but when we get the param type conversion error will occur.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127841115


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   > maybe here we can split getModelAttr (/ getModelSummaryAttr ) into two methods getModelScalarAttr and getModelDataFrameAttr to make it a bit easier to understand?
   
   This is what I did in initial code, but then I changed it to current code. Because:
   
   in ML code part, returning either MlCommandResponse or DataFrame either , the logics for them are the same,
   if we split them into 2 methods, the model summary, we have complex hierarchy, then each super class we need to define 2 methods similarly, it makes code bloating and hard to maintain.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127232443


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "google/protobuf/any.proto";
+import "spark/connect/expressions.proto";
+import "spark/connect/types.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveModel {
+    int64 model_ref_id = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message ModelAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+  }
+
+  message ModelSummaryAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+    Params params = 3;
+
+    // Evaluation dataset that it uses to computes
+    // the summary attribute
+    // If not set, get attributes from
+    // model.summary (i.e. the summary on training dataset)
+    optional Relation evaluation_dataset = 4;
+  }
+}
+
+
+message MlCommandResponse {
+  oneof ml_command_response_type {
+    Expression.Literal literal = 1;
+    ModelInfo model_info = 2;
+    Vector vector = 3;
+    Matrix matrix = 4;
+  }
+  message ModelInfo {
+    int64 model_ref_id = 1;
+    string model_uid = 2;
+  }
+}
+
+
+message Vector {
+  oneof one_of {
+    Dense dense = 1;
+    Sparse sparse = 2;
+  }
+  message Dense {
+    repeated double values = 1;
+  }
+  message Sparse {
+    int32 size = 1;
+    repeated double indices = 2;
+    repeated double values = 3;
+  }
+}
+
+message Matrix {
+  oneof one_of {
+    Dense dense = 1;
+    Sparse sparse = 2;
+  }
+  message Dense {
+    int32 num_rows = 1;
+    int32 num_cols = 2;
+    repeated double values = 3;

Review Comment:
   nit, `DenseMatrix` also has `isTransposed`



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.sql.DataFrame
+
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]

Review Comment:
   why not making `DataFrame`/`Relation` an `oneof` item in `MlCommandResponse` ?



##########
connector/connect/common/src/main/protobuf/spark/connect/expressions.proto:
##########
@@ -172,6 +172,12 @@ message Expression {
       CalendarInterval calendar_interval = 19;
       int32 year_month_interval = 20;
       int64 day_time_interval = 21;
+
+      List list = 99;

Review Comment:
   https://github.com/apache/spark/commit/0def3de6ed1000efe72c8bbdd3b3804bb34ce620 just introduced the literal array type, we can reuse it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135519116


##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -82,13 +83,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;

Review Comment:
   The ID is generated from a increamental counter. So I think int64 type should be fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136544713


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLHandler {
+
+  def handleMlCommand(
+      sessionHolder: SessionHolder,
+      mlCommand: proto.MlCommand): proto.MlCommandResponse = {
+    mlCommand.getMlCommandTypeCase match {
+      case proto.MlCommand.MlCommandTypeCase.FIT =>
+        val fitCommandProto = mlCommand.getFit
+        val estimatorProto = fitCommandProto.getEstimator
+        assert(estimatorProto.getType == proto.MlStage.StageType.ESTIMATOR)
+
+        val algoName = fitCommandProto.getEstimator.getName
+        val algo = AlgorithmRegistry.get(algoName)
+
+        val estimator = algo.initiateEstimator(estimatorProto.getUid)
+        MLUtils.setInstanceParams(estimator, estimatorProto.getParams)
+        val dataset = MLUtils.parseRelationProto(fitCommandProto.getDataset, sessionHolder)
+        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR =>
+        val getModelAttrProto = mlCommand.getFetchModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.getModelAttr(model, getModelAttrProto.getName).left.get
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR =>
+        val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr
+        val modelEntry =
+          sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams)
+
+        val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) {
+          val evalDF = MLUtils.parseRelationProto(
+            getModelSummaryAttrProto.getEvaluationDataset,
+            sessionHolder)
+          Some(evalDF)
+        } else None

Review Comment:
   ```suggestion
           val datasetOpt = getModelSummaryAttrProto.evaluationDataset.map(ds => MLUtils.parseRelationProto(ds, sessionHolder))
   ```
   
   Can we use `map` to simplify the code here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136643875


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLHandler {
+
+  def handleMlCommand(
+      sessionHolder: SessionHolder,
+      mlCommand: proto.MlCommand): proto.MlCommandResponse = {
+    mlCommand.getMlCommandTypeCase match {
+      case proto.MlCommand.MlCommandTypeCase.FIT =>
+        val fitCommandProto = mlCommand.getFit
+        val estimatorProto = fitCommandProto.getEstimator
+        assert(estimatorProto.getType == proto.MlStage.StageType.ESTIMATOR)
+
+        val algoName = fitCommandProto.getEstimator.getName
+        val algo = AlgorithmRegistry.get(algoName)
+
+        val estimator = algo.initiateEstimator(estimatorProto.getUid)
+        MLUtils.setInstanceParams(estimator, estimatorProto.getParams)
+        val dataset = MLUtils.parseRelationProto(fitCommandProto.getDataset, sessionHolder)
+        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR =>
+        val getModelAttrProto = mlCommand.getFetchModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.getModelAttr(model, getModelAttrProto.getName).left.get
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR =>
+        val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr
+        val modelEntry =
+          sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams)
+
+        val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) {
+          val evalDF = MLUtils.parseRelationProto(
+            getModelSummaryAttrProto.getEvaluationDataset,
+            sessionHolder)
+          Some(evalDF)
+        } else None
+
+        algo
+          .getModelSummaryAttr(copiedModel, getModelSummaryAttrProto.getName, datasetOpt)
+          .left
+          .get
+
+      case proto.MlCommand.MlCommandTypeCase.LOAD_MODEL =>
+        val loadModelProto = mlCommand.getLoadModel
+        val algo = AlgorithmRegistry.get(loadModelProto.getName)
+        val model = algo.loadModel(loadModelProto.getPath)
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid)
+              .setParams(MLUtils.convertInstanceParamsToProto(model)))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.SAVE_MODEL =>
+        val saveModelProto = mlCommand.getSaveModel
+        val modelEntry = sessionHolder.mlCache.modelCache.get(saveModelProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.saveModel(
+          model,
+          saveModelProto.getPath,
+          saveModelProto.getOverwrite,
+          saveModelProto.getOptionsMap.asScala.toMap)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.LOAD_STAGE =>
+        val loadStageProto = mlCommand.getLoadStage
+        val name = loadStageProto.getName
+        loadStageProto.getType match {
+          case proto.MlStage.StageType.ESTIMATOR =>
+            val algo = AlgorithmRegistry.get(name)
+            val estimator = algo.loadEstimator(loadStageProto.getPath)
+
+            proto.MlCommandResponse
+              .newBuilder()
+              .setStage(
+                proto.MlStage
+                  .newBuilder()
+                  .setName(name)
+                  .setType(proto.MlStage.StageType.ESTIMATOR)
+                  .setUid(estimator.uid)
+                  .setParams(MLUtils.convertInstanceParamsToProto(estimator)))
+              .build()
+          case _ =>
+            throw new UnsupportedOperationException()
+        }
+
+      case proto.MlCommand.MlCommandTypeCase.SAVE_STAGE =>
+        val saveStageProto = mlCommand.getSaveStage
+        val stageProto = saveStageProto.getStage
+
+        stageProto.getType match {
+          case proto.MlStage.StageType.ESTIMATOR =>
+            val name = stageProto.getName
+            val algo = AlgorithmRegistry.get(name)
+            val estimator = algo.initiateEstimator(stageProto.getUid)
+            MLUtils.setInstanceParams(estimator, stageProto.getParams)
+            algo.saveEstimator(
+              estimator,
+              saveStageProto.getPath,
+              saveStageProto.getOverwrite,
+              saveStageProto.getOptionsMap.asScala.toMap)
+            proto.MlCommandResponse
+              .newBuilder()
+              .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+              .build()
+
+          case _ =>
+            throw new UnsupportedOperationException()
+        }
+
+      case proto.MlCommand.MlCommandTypeCase.COPY_MODEL =>
+        val copyModelProto = mlCommand.getCopyModel
+        val modelEntry = sessionHolder.mlCache.modelCache.get(copyModelProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(copiedModel, algo)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(proto.Expression.Literal.newBuilder().setLong(refId))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.DELETE_MODEL =>
+        val modelRefId = mlCommand.getDeleteModel.getModelRefId
+        sessionHolder.mlCache.modelCache.remove(modelRefId)
+        proto.MlCommandResponse
+          .newBuilder()
+          .setLiteral(LiteralValueProtoConverter.toLiteralProto(null))
+          .build()
+
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+  def transformMLRelation(
+      mlRelationProto: proto.MlRelation,
+      sessionHolder: SessionHolder): DataFrame = {
+    mlRelationProto.getMlRelationTypeCase match {
+      case proto.MlRelation.MlRelationTypeCase.MODEL_TRANSFORM =>
+        val modelTransformRelationProto = mlRelationProto.getModelTransform
+        val (model, _) =
+          sessionHolder.mlCache.modelCache.get(modelTransformRelationProto.getModelRefId)
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, modelTransformRelationProto.getParams)
+        val inputDF =
+          MLUtils.parseRelationProto(modelTransformRelationProto.getInput, sessionHolder)
+        copiedModel.transform(inputDF)
+
+      case proto.MlRelation.MlRelationTypeCase.MODEL_ATTR =>
+        val modelAttrProto = mlRelationProto.getModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(modelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2

Review Comment:
   I tried unpacking but for this case compiling failed. Error is like:
   ```
   inferred existential type (org.apache.spark.ml.Model[_$4], org.apache.spark.sql.connect.ml.Algorithm)( forSome { type _$4 }), which cannot be expressed by wildcards,  should be enabled
   by making the implicit value scala.language.existentials visible.
   ```
   
   But after adding `import scala.language.existentials` it works.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136648435


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLHandler {
+
+  def handleMlCommand(
+      sessionHolder: SessionHolder,
+      mlCommand: proto.MlCommand): proto.MlCommandResponse = {
+    mlCommand.getMlCommandTypeCase match {
+      case proto.MlCommand.MlCommandTypeCase.FIT =>
+        val fitCommandProto = mlCommand.getFit
+        val estimatorProto = fitCommandProto.getEstimator
+        assert(estimatorProto.getType == proto.MlStage.StageType.ESTIMATOR)
+
+        val algoName = fitCommandProto.getEstimator.getName
+        val algo = AlgorithmRegistry.get(algoName)
+
+        val estimator = algo.initiateEstimator(estimatorProto.getUid)
+        MLUtils.setInstanceParams(estimator, estimatorProto.getParams)
+        val dataset = MLUtils.parseRelationProto(fitCommandProto.getDataset, sessionHolder)
+        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR =>
+        val getModelAttrProto = mlCommand.getFetchModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.getModelAttr(model, getModelAttrProto.getName).left.get
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR =>
+        val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr
+        val modelEntry =
+          sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams)
+
+        val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) {
+          val evalDF = MLUtils.parseRelationProto(
+            getModelSummaryAttrProto.getEvaluationDataset,
+            sessionHolder)
+          Some(evalDF)
+        } else None

Review Comment:
   But `getModelSummaryAttrProto` does not have a method `evaluationDataset` that returns Option[X] type value. The proto interfaces are generated in java style.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] grundprinzip commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "grundprinzip (via GitHub)" <gi...@apache.org>.
grundprinzip commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135241667


##########
connector/connect/common/src/main/protobuf/spark/connect/base.proto:
##########
@@ -261,6 +263,9 @@ message ExecutePlanResponse {
     // Special case for executing SQL commands.
     SqlCommandResult sql_command_result = 5;
 
+    // ML command response

Review Comment:
   ```suggestion
       // ML command response.
   ```



##########
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlParams {
+  map<string, Expression.Literal> params = 1;
+  map<string, Expression.Literal> default_params = 2;
+}
+
+message MlStage {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+  StageType type = 4;
+  enum StageType {

Review Comment:
   Is this knowledge actually required on the client?



##########
connector/connect/common/src/main/protobuf/spark/connect/relations.proto:
##########
@@ -82,13 +83,50 @@ message Relation {
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
 
+    // ML relation
+    MlRelation ml_relation = 300;
+
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // relations they can add them here. During the planning the correct resolution is done.
     google.protobuf.Any extension = 998;
     Unknown unknown = 999;
   }
 }
 
+message MlRelation {
+  oneof ml_relation_type {
+    ModelTransform model_transform = 1;
+    FeatureTransform feature_transform = 2;
+    ModelAttr model_attr = 3;
+    ModelSummaryAttr model_summary_attr = 4;
+  }
+  message ModelTransform {
+    Relation input = 1;
+    int64 model_ref_id = 2;

Review Comment:
   My suggestion here is to maybe wrap the `moddel_ref_id` into an extra message object that becomes easier to extend.
   
   ```
   message ModelRef {
     int64 id = 1;
   }
   ```
   
   That said, is there a reason the ID is numeric vs a string?



##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";

Review Comment:
   it would be great to give all of the protos more documentation once we start getting them in.



##########
connector/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlParams {
+  map<string, Expression.Literal> params = 1;
+  map<string, Expression.Literal> default_params = 2;
+}
+
+message MlStage {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+  StageType type = 4;
+  enum StageType {
+    ESTIMATOR = 0;
+    TRANSFORMER = 1;
+  }

Review Comment:
   In proto the first parameter should allways be `unspecified`. Please follow the style guide https://protobuf.dev/programming-guides/style/#enums



##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   Would this work with arbitrary model for example provided by Spark NLP?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136539321


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/AlgorithmRegisty.scala:
##########
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.classification.TrainingSummary
+import org.apache.spark.ml.util.MLWriter
+import org.apache.spark.sql.DataFrame
+
+object AlgorithmRegistry {
+
+  def get(name: String): Algorithm = {
+    name match {
+      case "LogisticRegression" => new LogisticRegressionAlgorithm
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+}
+
+abstract class Algorithm {
+
+  def initiateEstimator(uid: String): Estimator[_]
+
+  def getModelAttr(model: Model[_], name: String): Either[proto.MlCommandResponse, DataFrame]
+
+  def getModelSummaryAttr(
+      model: Model[_],
+      name: String,
+      datasetOpt: Option[DataFrame]): Either[proto.MlCommandResponse, DataFrame]
+
+  def loadModel(path: String): Model[_]
+
+  def loadEstimator(path: String): Estimator[_]
+
+  protected def getEstimatorWriter(estimator: Estimator[_]): MLWriter
+
+  protected def getModelWriter(model: Model[_]): MLWriter
+
+  def _save(
+      writer: MLWriter,
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    if (overwrite) {
+      writer.overwrite()
+    }
+    options.map { case (k, v) => writer.option(k, v) }
+    writer.save(path)
+  }
+
+  def saveModel(
+      model: Model[_],
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    _save(getModelWriter(model), path, overwrite, options)
+  }
+
+  def saveEstimator(
+      estimator: Estimator[_],
+      path: String,
+      overwrite: Boolean,
+      options: Map[String, String]): Unit = {
+    _save(getEstimatorWriter(estimator), path, overwrite, options)
+  }
+}
+
+class LogisticRegressionAlgorithm extends Algorithm {
+
+  override def initiateEstimator(uid: String): Estimator[_] = {
+    new ml.classification.LogisticRegression(uid)
+  }
+
+  override def loadModel(path: String): Model[_] = {
+    ml.classification.LogisticRegressionModel.load(path)
+  }
+
+  override def loadEstimator(path: String): Estimator[_] = {
+    ml.classification.LogisticRegression.load(path)
+  }
+
+  protected override def getModelWriter(model: Model[_]): MLWriter = {
+    model.asInstanceOf[ml.classification.LogisticRegressionModel].write
+  }
+
+  protected override def getEstimatorWriter(estimator: Estimator[_]): MLWriter = {
+    estimator.asInstanceOf[ml.classification.LogisticRegression].write
+  }
+
+  override def getModelAttr(
+      model: Model[_],
+      name: String): Either[proto.MlCommandResponse, DataFrame] = {
+    val lorModel = model.asInstanceOf[ml.classification.LogisticRegressionModel]
+    // TODO: hasSummary
+    name match {
+      case "hasSummary" => Left(Serializer.serialize(lorModel.hasSummary))
+      case "numClasses" => Left(Serializer.serialize(lorModel.numClasses))
+      case "numFeatures" => Left(Serializer.serialize(lorModel.numFeatures))
+      case "intercept" => Left(Serializer.serialize(lorModel.intercept))
+      case "interceptVector" => Left(Serializer.serialize(lorModel.interceptVector))
+      case "coefficients" => Left(Serializer.serialize(lorModel.coefficients))
+      case "coefficientMatrix" => Left(Serializer.serialize(lorModel.coefficientMatrix))
+      case _ =>
+        throw new IllegalArgumentException()
+    }
+  }
+
+  override def getModelSummaryAttr(
+      model: Model[_],
+      name: String,
+      datasetOpt: Option[DataFrame]): Either[proto.MlCommandResponse, DataFrame] = {
+    val lorModel = model.asInstanceOf[ml.classification.LogisticRegressionModel]
+    val summary = if (datasetOpt.isDefined) {
+      lorModel.evaluate(datasetOpt.get)
+    } else {
+      lorModel.summary
+    }

Review Comment:
   ```suggestion
       val summary = datasetOpt match {
         case Some(dataset) => lorModel.evaluate(dataset)
         case None => lorModelSummary
       }
   ```
   
   Can we use pattern maching here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133666221


##########
mllib/src/main/scala/org/apache/spark/ml/param/params.scala:
##########
@@ -44,8 +45,14 @@ import org.apache.spark.ml.util.Identifiable
  *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
-class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
-  extends Serializable {
+class Param[T: ClassTag](

Review Comment:
   @zhengruifeng 
   
   I recall I tried this approach `(implicit paramValueClassTag: ClassTag[T])` before , but it make us hard to get the classTag object. So I prefer current approach.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1133771185


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveModel {
+    int64 model_ref_id = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadStage {
+    string name = 1;
+    string path = 2;
+    MlStage.StageType type = 3;
+  }
+
+  message SaveStage {
+    MlStage stage = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message LoadEvaluator {
+    string name = 1;
+    string path = 2;
+  }
+
+  message SaveEvaluator {
+    MlEvaluator evaluator = 1;
+    string path = 2; // saving path
+    bool overwrite = 3;
+    map<string, string> options = 4; // saving options
+  }
+
+  message FetchModelAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+  }
+
+  message FetchModelSummaryAttr {
+    int64 model_ref_id = 1;
+    string name = 2;
+    MlParams params = 3;
+
+    // Evaluation dataset that it uses to computes
+    // the summary attribute
+    // If not set, get attributes from
+    // model.summary (i.e. the summary on training dataset)
+    optional Relation evaluation_dataset = 4;
+  }
+
+  message CopyModel {
+    int64 model_ref_id = 1;
+  }
+}
+
+
+message MlCommandResponse {
+  oneof ml_command_response_type {
+    Expression.Literal literal = 1;
+    ModelInfo model_info = 2;
+    Vector vector = 3;

Review Comment:
   looks fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1135600960


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message MlEvaluator {
+  string name = 1;
+  MlParams params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    FetchModelAttr fetch_model_attr = 2;
+    FetchModelSummaryAttr fetch_model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+    SaveStage save_stage = 7;
+    LoadStage load_stage = 8;
+    SaveEvaluator save_evaluator = 9;
+    LoadEvaluator load_evaluator = 10;
+    CopyModel copy_model = 11;
+    DeleteModel delete_model = 12;
+  }
+
+  message Fit {
+    MlStage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    MlEvaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   Btw, supporting 3rd-party estimators is risky, because in shared cluster we will [binpack the spark workers across different customers](https://docs.google.com/document/d/1sJVjan44XagM48PEqdkg6KWctpPcz54Urf0i_-dGesA/edit?disco=AAAArl9hpF8) (according to @mengxr 's explanation)
   But 3rd-party estimators implementation might invoke RDD transformation (e.g. RDD.map) that we cannot isolate them by container. So it is risky if we allow user uses 3rd-party estimators on shared cluster.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136535275


##########
python/pyspark/sql/connect/ml/base.py:
##########
@@ -0,0 +1,327 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.ml import Estimator, Model, Predictor, PredictionModel
+from pyspark.ml.wrapper import _PredictorParams
+from pyspark.ml.util import MLWritable, MLWriter, MLReadable, MLReader
+import pyspark.sql.connect.proto as pb2
+import pyspark.sql.connect.proto.ml_pb2 as ml_pb2
+import pyspark.sql.connect.proto.ml_common_pb2 as ml_common_pb2
+from pyspark.sql.connect.ml.serializer import deserialize, serialize_ml_params
+from pyspark.sql.connect import session as pyspark_session
+from pyspark.sql.connect.plan import LogicalPlan
+
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import HasTrainingSummary as PySparkHasTrainingSummary
+
+
+@inherit_doc
+class ClientEstimator(Estimator, metaclass=ABCMeta):
+
+    @classmethod
+    def _algo_name(cls):
+        raise NotImplementedError()
+
+    @classmethod
+    def _model_class(cls):
+        raise NotImplementedError()
+
+    def _fit(self, dataset: DataFrame) -> Model:
+        client = dataset.sparkSession.client
+        dataset_relation = dataset._plan.plan(client)
+        estimator_proto = ml_common_pb2.MlStage(
+            name=self._algo_name(),
+            params=serialize_ml_params(self, client),
+            uid=self.uid,
+            type=ml_common_pb2.MlStage.ESTIMATOR,
+        )
+        fit_command_proto = ml_pb2.MlCommand.Fit(
+            estimator=estimator_proto,
+            dataset=dataset_relation,
+        )
+        req = client._execute_plan_request_with_metadata()
+        req.plan.ml_command.fit.CopyFrom(fit_command_proto)
+
+        resp = client._execute_ml(req)
+        return deserialize(resp, client, clazz=self._model_class())
+
+
+@inherit_doc
+class ClientPredictor(Predictor, ClientEstimator, _PredictorParams, metaclass=ABCMeta):
+    pass
+
+
+@inherit_doc
+class ClientModel(Model, metaclass=ABCMeta):
+
+    ref_id: str = None
+
+    def __del__(self):
+        client = pyspark_session._active_spark_session.client
+        del_model_proto = ml_pb2.MlCommand.DeleteModel(
+            model_ref_id=self.ref_id,
+        )
+        req = client._execute_plan_request_with_metadata()
+        req.plan.ml_command.delete_model.CopyFrom(del_model_proto)
+        client._execute_ml(req)
+
+    @classmethod

Review Comment:
   ```suggestion
       @asbstractmethod
   ```
   
   
   
   Can we use `asbstractmethod`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] harupy commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "harupy (via GitHub)" <gi...@apache.org>.
harupy commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1136754416


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLHandler {
+
+  def handleMlCommand(
+      sessionHolder: SessionHolder,
+      mlCommand: proto.MlCommand): proto.MlCommandResponse = {
+    mlCommand.getMlCommandTypeCase match {
+      case proto.MlCommand.MlCommandTypeCase.FIT =>
+        val fitCommandProto = mlCommand.getFit
+        val estimatorProto = fitCommandProto.getEstimator
+        assert(estimatorProto.getType == proto.MlStage.StageType.ESTIMATOR)
+
+        val algoName = fitCommandProto.getEstimator.getName
+        val algo = AlgorithmRegistry.get(algoName)
+
+        val estimator = algo.initiateEstimator(estimatorProto.getUid)
+        MLUtils.setInstanceParams(estimator, estimatorProto.getParams)
+        val dataset = MLUtils.parseRelationProto(fitCommandProto.getDataset, sessionHolder)
+        val model = estimator.fit(dataset).asInstanceOf[Model[_]]
+        val refId = sessionHolder.mlCache.modelCache.register(model, algo)
+
+        proto.MlCommandResponse
+          .newBuilder()
+          .setModelInfo(
+            proto.MlCommandResponse.ModelInfo.newBuilder
+              .setModelRefId(refId)
+              .setModelUid(model.uid))
+          .build()
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_ATTR =>
+        val getModelAttrProto = mlCommand.getFetchModelAttr
+        val modelEntry = sessionHolder.mlCache.modelCache.get(getModelAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        algo.getModelAttr(model, getModelAttrProto.getName).left.get
+
+      case proto.MlCommand.MlCommandTypeCase.FETCH_MODEL_SUMMARY_ATTR =>
+        val getModelSummaryAttrProto = mlCommand.getFetchModelSummaryAttr
+        val modelEntry =
+          sessionHolder.mlCache.modelCache.get(getModelSummaryAttrProto.getModelRefId)
+        val model = modelEntry._1
+        val algo = modelEntry._2
+        // Create a copied model to avoid concurrently modify model params.
+        val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+        MLUtils.setInstanceParams(copiedModel, getModelSummaryAttrProto.getParams)
+
+        val datasetOpt = if (getModelSummaryAttrProto.hasEvaluationDataset) {
+          val evalDF = MLUtils.parseRelationProto(
+            getModelSummaryAttrProto.getEvaluationDataset,
+            sessionHolder)
+          Some(evalDF)
+        } else None

Review Comment:
   I didn't know that, thansk for the explanation



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1138369409


##########
python/pyspark/sql/connect/session.py:
##########
@@ -463,7 +463,7 @@ def stop(self) -> None:
 
     @classmethod
     def getActiveSession(cls) -> Any:
-        raise NotImplementedError("getActiveSession() is not implemented.")

Review Comment:
   oh. I will revert this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1127834899


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala:
##########
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.linalg.{Matrix, Vector}
+
+object Serializer {

Review Comment:
   ~~Good idea.~~
   We need address this issue first https://github.com/apache/spark/pull/40297#discussion_r1129072923



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129073881


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   For Evaluator, we need define a new LoadEstimator command



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129073121


##########
connector/connect/common/src/main/protobuf/spark/connect/ml.proto:
##########
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+import "spark/connect/relations.proto";
+import "spark/connect/ml_common.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+
+message Evaluator {
+  string name = 1;
+  Params params = 2;
+  string uid = 3;
+}
+
+
+message MlCommand {
+  oneof ml_command_type {
+    Fit fit = 1;
+    ModelAttr model_attr = 2;
+    ModelSummaryAttr model_summary_attr = 3;
+    LoadModel load_model = 4;
+    SaveModel save_model = 5;
+    Evaluate evaluate = 6;
+  }
+
+  message Fit {
+    Stage estimator = 1;
+    Relation dataset = 2;
+  }
+
+  message Evaluate {
+    Evaluator evaluator = 1;
+  }
+
+  message LoadModel {

Review Comment:
   Sure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129171568


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {
+    if (paramType == classOf[Int]) {
+      assert (paramValueProto.hasInteger || paramValueProto.hasLong)
+      if (paramValueProto.hasInteger) {
+        paramValueProto.getInteger
+      } else {
+        paramValueProto.getLong
+      }
+    } else if (paramType == classOf[Long]) {
+      assert(paramValueProto.hasLong)
+      paramValueProto.getLong
+    } else if (paramType == classOf[Float]) {
+      assert (paramValueProto.hasFloat || paramValueProto.hasDouble)
+      if (paramValueProto.hasFloat) {
+        paramValueProto.getFloat
+      } else {
+        paramValueProto.getDouble

Review Comment:
   ```suggestion
           paramValueProto.getDouble.toFloat
   ```
   
   I guess we need to covert the value to float here, otherwise the value type is different from param type



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40297: [SPARK-42412][WIP] Initial PR of Spark connect ML

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40297:
URL: https://github.com/apache/spark/pull/40297#discussion_r1129188765


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object MLUtils {
+
+  def setInstanceParams(instance: Params, paramsProto: proto.Params): Unit = {
+    import scala.collection.JavaConverters._
+    paramsProto.getParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance.set(paramDef, paramValue)
+    }
+    paramsProto.getDefaultParamsMap.asScala.foreach { case (paramName, paramValueProto) =>
+      val paramDef = instance.getParam(paramName)
+      val paramValue = parseParamValue(paramDef.paramValueClassTag.runtimeClass, paramValueProto)
+      instance._setDefault(paramDef -> paramValue)
+    }
+  }
+
+  def parseParamValue(paramType: Class[_], paramValueProto: proto.Expression.Literal): Any = {
+    if (paramType == classOf[Int]) {
+      assert (paramValueProto.hasInteger || paramValueProto.hasLong)
+      if (paramValueProto.hasInteger) {
+        paramValueProto.getInteger
+      } else {
+        paramValueProto.getLong

Review Comment:
   ```suggestion
           paramValueProto.getLong.toInt
   ```
   
   another one, I think others are fine



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org