You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2023/07/19 01:52:46 UTC

[iotdb] branch rel/1.2 updated: [IOTDB-6069] Support SessionPool in client-py

This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch rel/1.2
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/rel/1.2 by this push:
     new 5af06886660 [IOTDB-6069] Support SessionPool in client-py
5af06886660 is described below

commit 5af068866603d217285e1bf909fda620da84b29b
Author: YangCaiyin <yc...@gmail.com>
AuthorDate: Wed Jul 19 09:52:39 2023 +0800

    [IOTDB-6069] Support SessionPool in client-py
---
 iotdb-client/client-py/iotdb/Session.py           |   1 +
 iotdb-client/client-py/iotdb/SessionPool.py       | 135 ++++++++++++++++++++++
 iotdb-client/client-py/tests/test_session.py      |  68 ++++++-----
 iotdb-client/client-py/tests/test_session_pool.py |  66 +++++++++++
 4 files changed, 244 insertions(+), 26 deletions(-)

diff --git a/iotdb-client/client-py/iotdb/Session.py b/iotdb-client/client-py/iotdb/Session.py
index 310db405679..90dcc33c08f 100644
--- a/iotdb-client/client-py/iotdb/Session.py
+++ b/iotdb-client/client-py/iotdb/Session.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
 import logging
 import random
 import struct
diff --git a/iotdb-client/client-py/iotdb/SessionPool.py b/iotdb-client/client-py/iotdb/SessionPool.py
new file mode 100644
index 00000000000..d815e0a0218
--- /dev/null
+++ b/iotdb-client/client-py/iotdb/SessionPool.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import logging
+import multiprocessing
+import time
+from multiprocessing import Queue
+from threading import Lock
+
+from iotdb.Session import Session
+
+DEFAULT_MULTIPIE = 5
+DEFAULT_FETCH_SIZE = 5000
+DEFAULT_MAX_RETRY = 3
+DEFAULT_TIME_ZONE = "Asia/Shanghai"
+logger = logging.getLogger("IoTDB")
+
+
+class PoolConfig(object):
+    def __init__(self, host: str, ip: str, user_name: str, password: str, node_urls: list = None,
+                 fetch_size: int = DEFAULT_FETCH_SIZE, time_zone: str = DEFAULT_TIME_ZONE,
+                 max_retry: int = DEFAULT_MAX_RETRY):
+        self.host = host
+        self.ip = ip
+        if node_urls is None:
+            node_urls = []
+        self.node_urls = node_urls
+        self.user_name = user_name
+        self.password = password
+        self.fetch_size = fetch_size
+        self.time_zone = time_zone
+        self.max_retry = max_retry
+
+
+class SessionPool(object):
+
+    def __init__(self, pool_config: PoolConfig, max_pool_size: int, wait_timeout_in_ms: int):
+        self.__pool_config = pool_config
+        self.__max_pool_size = max_pool_size
+        self.__wait_timeout_in_ms = wait_timeout_in_ms / 1000
+        self.__pool_size = 0
+        self.__queue = Queue(max_pool_size)
+        self.__lock = Lock()
+        self.__closed = False
+
+    def __construct_session(self) -> Session:
+        if len(self.__pool_config.node_urls) > 0:
+            return Session.init_from_node_urls(self.__pool_config.node_urls, self.__pool_config.user_name,
+                                               self.__pool_config.password, self.__pool_config.fetch_size,
+                                               self.__pool_config.time_zone)
+
+        else:
+            return Session(self.__pool_config.host, self.__pool_config.ip, self.__pool_config.user_name,
+                           self.__pool_config.password, self.__pool_config.fetch_size, self.__pool_config.time_zone)
+
+    def __poll_session(self) -> Session | None:
+        if self.__queue.empty():
+            return None
+        return self.__queue.get(block=False)
+
+    def get_session(self) -> Session:
+
+        if self.__closed:
+            raise ConnectionError("SessionPool has already been closed.")
+
+        should_create = False
+        start = time.time()
+
+        session = self.__poll_session()
+        while session is None:
+            self.__lock.acquire()
+            if self.__pool_size < self.__max_pool_size:
+                self.__pool_size += 1
+                should_create = True
+                self.__lock.release()
+                break
+            else:
+                if time.time() - start > self.__wait_timeout_in_ms:
+                    self.__lock.release()
+                    raise TimeoutError("Wait to get session timeout in SessionPool, current pool size: {0}"
+                                       .format(self.__max_pool_size))
+                time.sleep(1)
+            session = self.__poll_session()
+            self.__lock.release()
+
+        if should_create:
+            try:
+                session = self.__construct_session()
+            except Exception as e:
+                self.__lock.acquire()
+                self.__pool_size -= 1
+                self.__lock.release()
+                raise e
+
+        return session
+
+    def put_back(self, session: Session):
+
+        if self.__closed:
+            raise ConnectionError("SessionPool has already been closed, please close the session manually.")
+
+        if session.is_open():
+            self.__queue.put(session)
+        else:
+            self.__lock.acquire()
+            self.__pool_size -= 1
+            self.__lock.release()
+
+    def close(self):
+        while not self.__queue.empty():
+            session = self.__queue.get(block=False)
+            session.close()
+            self.__pool_size -= 1
+        self.__closed = True
+        logger.info("SessionPool has been closed successfully.")
+
+
+def create_session_pool(pool_config: PoolConfig, max_pool_size: int, wait_timeout_in_ms: int) -> SessionPool:
+    if max_pool_size <= 0:
+        max_pool_size = multiprocessing.cpu_count() * DEFAULT_MULTIPIE
+    return SessionPool(pool_config, max_pool_size, wait_timeout_in_ms)
diff --git a/iotdb-client/client-py/tests/test_session.py b/iotdb-client/client-py/tests/test_session.py
index f611bb374a0..5118fe90658 100644
--- a/iotdb-client/client-py/tests/test_session.py
+++ b/iotdb-client/client-py/tests/test_session.py
@@ -20,6 +20,7 @@
 import numpy as np
 
 from iotdb.Session import Session
+from iotdb.SessionPool import PoolConfig, create_session_pool
 from iotdb.utils.BitMap import BitMap
 from iotdb.utils.IoTDBConstants import TSDataType, TSEncoding, Compressor
 from iotdb.utils.NumpyTablet import NumpyTablet
@@ -46,9 +47,24 @@ def print_message(message):
 
 
 def test_session():
+    session_test()
+
+
+def test_session_pool():
+    session_test(True)
+
+
+def session_test(use_session_pool=False):
     with IoTDBContainer("iotdb:dev") as db:
         db: IoTDBContainer
-        session = Session(db.get_container_host_ip(), db.get_exposed_port(6667))
+
+        if use_session_pool:
+            pool_config = PoolConfig(db.get_container_host_ip(), db.get_exposed_port(6667), "root", "root", None, 1024,
+                                     "Asia/Shanghai", 3)
+            session_pool = create_session_pool(pool_config, 1, 3000)
+            session = session_pool.get_session()
+        else:
+            session = Session(db.get_container_host_ip(), db.get_exposed_port(6667))
         session.open(False)
 
         if not session.is_open():
@@ -154,14 +170,14 @@ def test_session():
 
         # delete time series
         if (
-            session.delete_time_series(
-                [
-                    "root.sg_test_01.d_01.s_07",
-                    "root.sg_test_01.d_01.s_08",
-                    "root.sg_test_01.d_01.s_09",
-                ]
-            )
-            < 0
+                session.delete_time_series(
+                    [
+                        "root.sg_test_01.d_01.s_07",
+                        "root.sg_test_01.d_01.s_08",
+                        "root.sg_test_01.d_01.s_09",
+                    ]
+                )
+                < 0
         ):
             test_fail()
             print_message("delete time series failed")
@@ -197,10 +213,10 @@ def test_session():
             TSDataType.TEXT,
         ]
         if (
-            session.insert_record(
-                "root.sg_test_01.d_01", 1, measurements_, data_types_, values_
-            )
-            < 0
+                session.insert_record(
+                    "root.sg_test_01.d_01", 1, measurements_, data_types_, values_
+                )
+                < 0
         ):
             test_fail()
             print_message("insert record failed")
@@ -217,10 +233,10 @@ def test_session():
         data_type_list_ = [data_types_, data_types_]
         device_ids_ = ["root.sg_test_01.d_01", "root.sg_test_01.d_02"]
         if (
-            session.insert_records(
-                device_ids_, [2, 3], measurements_list_, data_type_list_, values_list_
-            )
-            < 0
+                session.insert_records(
+                    device_ids_, [2, 3], measurements_list_, data_type_list_, values_list_
+                )
+                < 0
         ):
             test_fail()
             print_message("insert records failed")
@@ -303,7 +319,7 @@ def test_session():
         ]
         np_timestamps_ = np.array([30, 31, 32, 33], np.dtype(">i8"))
         np_bitmaps_ = []
-        for i in range(len(measurements_)):
+        for _ in range(len(measurements_)):
             np_bitmaps_.append(BitMap(len(np_timestamps_)))
         np_bitmaps_[0].mark(0)
         np_bitmaps_[1].mark(1)
@@ -337,14 +353,14 @@ def test_session():
         values_list = [[False, 22, 33], [True, 1, 23], [False, 15, 26]]
 
         if (
-            session.insert_records_of_one_device(
-                "root.sg_test_01.d_01",
-                time_list,
-                measurements_list,
-                data_types_list,
-                values_list,
-            )
-            < 0
+                session.insert_records_of_one_device(
+                    "root.sg_test_01.d_01",
+                    time_list,
+                    measurements_list,
+                    data_types_list,
+                    values_list,
+                )
+                < 0
         ):
             test_fail()
             print_message("insert records of one device failed")
diff --git a/iotdb-client/client-py/tests/test_session_pool.py b/iotdb-client/client-py/tests/test_session_pool.py
new file mode 100644
index 00000000000..5efbfe45b81
--- /dev/null
+++ b/iotdb-client/client-py/tests/test_session_pool.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from threading import Thread
+
+from iotdb.IoTDBContainer import IoTDBContainer
+from iotdb.SessionPool import create_session_pool, PoolConfig
+
+
+def test_session_pool():
+    with IoTDBContainer("iotdb:dev") as db:
+        db: IoTDBContainer
+        max_pool_size = 2
+        pool_config = PoolConfig(db.get_container_host_ip(), db.get_exposed_port(6667), "root", "root",
+                                 [], 1024, "Asia/Shanghai", 3)
+        session_pool = create_session_pool(pool_config, max_pool_size, 3000)
+        session = session_pool.get_session()
+        session.open(False)
+        assert session.is_open() is True
+
+        session2 = session_pool.get_session()
+        assert session2 is not None
+
+        timeout = False
+        try:
+            session_pool.get_session()
+        except TimeoutError as e:
+            timeout = True
+            assert str(e) == "Wait to get session timeout in SessionPool, current pool size: " + str(max_pool_size)
+        assert timeout is True
+
+        Thread(target=lambda: session_pool.put_back(session2)).start()
+        session3 = session_pool.get_session()
+        assert session3 is not None
+
+        session_pool.close()
+
+        is_closed = False
+        try:
+            session_pool.get_session()
+        except ConnectionError as e:
+            is_closed = True
+            assert str(e) == "SessionPool has already been closed."
+        assert is_closed is True
+
+        is_closed = False
+        try:
+            session_pool.put_back(session3)
+        except ConnectionError as e:
+            is_closed = True
+            assert str(e) == "SessionPool has already been closed, please close the session manually."
+        assert is_closed is True