You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2022/06/27 03:53:52 UTC

[flink-ml] 02/02: [FLINK-28237][python][ml] Fix the package error in flink ml python

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 1f94a6f4017647d9f12492317e22b6246c12f0be
Author: huangxingbo <hx...@apache.org>
AuthorDate: Fri Jun 24 11:43:25 2022 +0800

    [FLINK-28237][python][ml] Fix the package error in flink ml python
---
 flink-ml-python/pyflink/ml/__init__.py | 61 ++++++++++++++++++++++++++--------
 flink-ml-python/setup.py               |  5 +++
 2 files changed, 52 insertions(+), 14 deletions(-)

diff --git a/flink-ml-python/pyflink/ml/__init__.py b/flink-ml-python/pyflink/ml/__init__.py
index 07037b4..5bcf4b1 100644
--- a/flink-ml-python/pyflink/ml/__init__.py
+++ b/flink-ml-python/pyflink/ml/__init__.py
@@ -15,6 +15,7 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
+from py4j.java_gateway import JavaClass, get_java_class, JavaObject
 from pyflink.java_gateway import get_gateway
 from pyflink.util import java_utils
 from pyflink.util.java_utils import to_jarray, load_java_class
@@ -40,22 +41,54 @@ def add_jars_to_context_class_loader(jar_urls):
     if all([url.toString() in existing_urls for url in jar_urls]):
         # if urls all existed, no need to create new class loader.
         return
+
     URLClassLoaderClass = load_java_class("java.net.URLClassLoader")
-    addURL = URLClassLoaderClass.getDeclaredMethod(
-        "addURL",
-        to_jarray(
-            gateway.jvm.Class,
-            [load_java_class("java.net.URL")]))
-    addURL.setAccessible(True)
-    if class_loader_name == "org.apache.flink.runtime.execution.librarycache." \
-                            "FlinkUserCodeClassLoaders$SafetyNetWrapperClassLoader":
-        ensureInner = context_classloader.getClass().getDeclaredMethod("ensureInner", None)
-        ensureInner.setAccessible(True)
-        loader = ensureInner.invoke(context_classloader, None)
+    if is_instance_of(context_classloader, URLClassLoaderClass):
+        if class_loader_name == "org.apache.flink.runtime.execution.librarycache." \
+                                "FlinkUserCodeClassLoaders$SafetyNetWrapperClassLoader":
+            ensureInner = context_classloader.getClass().getDeclaredMethod("ensureInner", None)
+            ensureInner.setAccessible(True)
+            context_classloader = ensureInner.invoke(context_classloader, None)
+
+        addURL = URLClassLoaderClass.getDeclaredMethod(
+            "addURL",
+            to_jarray(
+                gateway.jvm.Class,
+                [load_java_class("java.net.URL")]))
+        addURL.setAccessible(True)
+
+        for url in jar_urls:
+            addURL.invoke(context_classloader, to_jarray(get_gateway().jvm.Object, [url]))
+
+    else:
+        context_classloader = create_url_class_loader(jar_urls, context_classloader)
+        gateway.jvm.Thread.currentThread().setContextClassLoader(context_classloader)
+
+
+def is_instance_of(java_object, java_class):
+    gateway = get_gateway()
+    if isinstance(java_class, str):
+        param = java_class
+    elif isinstance(java_class, JavaClass):
+        param = get_java_class(java_class)
+    elif isinstance(java_class, JavaObject):
+        if not is_instance_of(java_class, gateway.jvm.Class):
+            param = java_class.getClass()
+        else:
+            param = java_class
     else:
-        loader = context_classloader
-    for url in jar_urls:
-        addURL.invoke(loader, to_jarray(get_gateway().jvm.Object, [url]))
+        raise TypeError(
+            "java_class must be a string, a JavaClass, or a JavaObject")
+
+    return gateway.jvm.org.apache.flink.api.python.shaded.py4j.reflection.TypeUtil.isInstanceOf(
+        param, java_object)
+
+
+def create_url_class_loader(urls, parent_class_loader):
+    gateway = get_gateway()
+    url_class_loader = gateway.jvm.java.net.URLClassLoader(
+        to_jarray(gateway.jvm.java.net.URL, urls), parent_class_loader)
+    return url_class_loader
 
 
 java_utils.add_jars_to_context_class_loader = add_jars_to_context_class_loader
diff --git a/flink-ml-python/setup.py b/flink-ml-python/setup.py
index 55b617e..e584904 100644
--- a/flink-ml-python/setup.py
+++ b/flink-ml-python/setup.py
@@ -82,6 +82,11 @@ try:
                 'pyflink.ml',
                 'pyflink.ml.core',
                 'pyflink.ml.lib',
+                'pyflink.ml.lib.classification',
+                'pyflink.ml.lib.clustering',
+                'pyflink.ml.lib.evaluation',
+                'pyflink.ml.lib.feature',
+                'pyflink.ml.lib',
                 'pyflink.ml.util',
                 'pyflink.examples']