You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@rocketmq.apache.org by GitBox <gi...@apache.org> on 2018/12/23 08:45:01 UTC

[GitHub] ifplusor closed pull request #19: Fixed deadlock in #18, and add args field for python callback.

ifplusor closed pull request #19: Fixed deadlock in #18, and add args field for python callback. 
URL: https://github.com/apache/rocketmq-client-python/pull/19
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/PythonWrapper.cpp b/src/PythonWrapper.cpp
index 8c7c1e2..04a3cf2 100644
--- a/src/PythonWrapper.cpp
+++ b/src/PythonWrapper.cpp
@@ -22,6 +22,7 @@
 #include "CPushConsumer.h"
 #include "PythonWrapper.h"
 #include <boost/python.hpp>
+#include <boost/thread.hpp>
 #include <map>
 
 using namespace boost::python;
@@ -34,20 +35,33 @@ map<CPushConsumer *, PyObject *> g_CallBackMap;
 
 class PyThreadStateLock {
 public:
-    PyThreadStateLock(void) {
+    PyThreadStateLock() {
         state = PyGILState_Ensure();
     }
 
-    ~PyThreadStateLock(void) {
-        if (state == PyGILState_LOCKED) {
-            PyGILState_Release(state);
-        }
+    ~PyThreadStateLock() {
+        // NOTE: 必须跟 PyGILState_Ensure 成对出现,否则可能出现死锁!!!
+        PyGILState_Release(state);
     }
 
 private:
     PyGILState_STATE state;
 };
 
+class PyThreadStateUnlock {
+public:
+    PyThreadStateUnlock() : _save(NULL) {
+        Py_UNBLOCK_THREADS
+    }
+
+    ~PyThreadStateUnlock() {
+        Py_BLOCK_THREADS
+    }
+
+private:
+    PyThreadState *_save;
+};
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -146,22 +160,25 @@ const char *PyGetSendResultMsgID(CSendResult &sendResult) {
 }
 //consumer
 void *PyCreatePushConsumer(const char *groupId) {
-    //Py_Initialize();
-    PyEval_InitThreads();
-//    PyEval_ReleaseThread(PyThreadState_Get());
+    PyEval_InitThreads();  // 因为从 C 中发起对 python 的回调,确保初始化对多线程的支持(主要是创建GIL)
     return (void *) CreatePushConsumer(groupId);
 }
 int PyDestroyPushConsumer(void *consumer) {
-    return DestroyPushConsumer((CPushConsumer *) consumer);
+    CPushConsumer *consumerInner = (CPushConsumer *) consumer;
+    map<CPushConsumer *, pair<PyObject *, object>>::iterator iter;
+    iter = g_CallBackMap.find(consumerInner);
+    if (iter != g_CallBackMap.end()) {
+        UnregisterMessageCallback(consumerInner);
+        g_CallBackMap.erase(iter);
+    }
+    return DestroyPushConsumer(consumerInner);
 }
 int PyStartPushConsumer(void *consumer) {
     return StartPushConsumer((CPushConsumer *) consumer);
 }
 int PyShutdownPushConsumer(void *consumer) {
-    int ret = ShutdownPushConsumer((CPushConsumer *) consumer);
-    //PyGILState_Ensure();
-    //Py_Finalize();
-    return ret;
+    PyThreadStateUnlock PyThreadUnlock;  // 存在阻塞调用,确保线程不持有 GIL
+    return ShutdownPushConsumer((CPushConsumer *) consumer);
 }
 int PySetPushConsumerNameServerAddress(void *consumer, const char *namesrv) {
     return SetPushConsumerNameServerAddress((CPushConsumer *) consumer, namesrv);
@@ -172,29 +189,28 @@ int PySetPushConsumerNameServerDomain(void *consumer, const char *domain){
 int PySubscribe(void *consumer, const char *topic, const char *expression) {
     return Subscribe((CPushConsumer *) consumer, topic, expression);
 }
-int PyRegisterMessageCallback(void *consumer, PyObject *pCallback) {
+int PyRegisterMessageCallback(void *consumer, PyObject *pCallback, object args) {
     CPushConsumer *consumerInner = (CPushConsumer *) consumer;
-    g_CallBackMap[consumerInner] = pCallback;
+    g_CallBackMap[consumerInner] = make_pair(pCallback, std::move(args));
     return RegisterMessageCallback(consumerInner, &PythonMessageCallBackInner);
 }
 
 int PythonMessageCallBackInner(CPushConsumer *consumer, CMessageExt *msg) {
-
-    class PyThreadStateLock PyThreadLock;
-    PyMessageExt message;
-    message.pMessageExt = msg;
-    map<CPushConsumer *, PyObject *>::iterator iter;
+    boost::this_thread::disable_interruption di;  // 调用 python 回调,线程不应被中断
+    PyThreadStateLock PyThreadLock;  // 调用 python 回调,确保线程持有 GIL
+    PyMessageExt message = { .pMessageExt = msg };
+    map<CPushConsumer *, pair<PyObject *, object>>::iterator iter;
     iter = g_CallBackMap.find(consumer);
     if (iter != g_CallBackMap.end()) {
-        PyObject * pCallback = iter->second;
+        pair<PyObject *, object> callback = iter->second;
+        PyObject * pCallback = callback.first;
+        object& args = callback.second;
         if (pCallback != NULL) {
-            int status =
-                    boost::python::call<int>(pCallback, message);
+            int status = boost::python::call<int>(pCallback, message, args);
             return status;
         }
     }
     return 1;
-
 }
 
 int PySetPushConsumerThreadCount(void *consumer, int threadCount) {
@@ -212,7 +228,7 @@ int PySetPushConsumerSessionCredentials(void *consumer, const char *accessKey, c
 }
 
 //push consumer
-int PySetPullConsumerNameServerDomain(void *consumer, const char *domain){
+int PySetPullConsumerNameServerDomain(void *consumer, const char *domain) {
     return SetPullConsumerNameServerDomain((CPullConsumer *) consumer, domain);
 }
 //version
diff --git a/src/PythonWrapper.h b/src/PythonWrapper.h
index 324676b..2ad8255 100644
--- a/src/PythonWrapper.h
+++ b/src/PythonWrapper.h
@@ -89,7 +89,7 @@ int PyShutdownPushConsumer(void *consumer);
 int PySetPushConsumerNameServerAddress(void *consumer, const char *namesrv);
 int PySetPushConsumerNameServerDomain(void *consumer, const char *domain);
 int PySubscribe(void *consumer, const char *topic, const char *expression);
-int PyRegisterMessageCallback(void *consumer, PyObject *pCallback);
+int PyRegisterMessageCallback(void *consumer, PyObject *pCallback, object args);
 int PythonMessageCallBackInner(CPushConsumer *consumer, CMessageExt *msg);
 int PySetPushConsumerThreadCount(void *consumer, int threadCount);
 int PySetPushConsumerMessageBatchMaxSize(void *consumer, int batchSize);


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services