You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/27 01:04:27 UTC

[iotdb] branch master updated: [IOTDB-5696] Implement client to connect ConfigNode/DataNode (#9365)

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new c2870c6f34 [IOTDB-5696] Implement client to connect ConfigNode/DataNode (#9365)
c2870c6f34 is described below

commit c2870c6f345389e3d6d3164b3773a8b97191f9f9
Author: Yong Liu <li...@gmail.com>
AuthorDate: Mon Mar 27 09:04:20 2023 +0800

    [IOTDB-5696] Implement client to connect ConfigNode/DataNode (#9365)
    
    Co-authored-by: zhouhang <11...@qq.com>
---
 mlnode/.gitignore                           |   6 +-
 mlnode/iotdb/mlnode/client.py               | 247 +++++++++++++++++++++++++++-
 mlnode/requirements.txt                     |   4 +
 mlnode/requirements_dev.txt                 |   7 +-
 thrift-mlnode/src/main/thrift/mlnode.thrift |   1 +
 5 files changed, 250 insertions(+), 15 deletions(-)

diff --git a/mlnode/.gitignore b/mlnode/.gitignore
index 9ba0ff6df8..94606bf62f 100644
--- a/mlnode/.gitignore
+++ b/mlnode/.gitignore
@@ -1,6 +1,4 @@
 /iotdb/thrift/
 
-# generated by Pypi
-/build/
-/dist/
-/*.egg-info/
\ No newline at end of file
+# generated by Poetry
+/dist/
\ No newline at end of file
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 08a6b925c2..244b6975c9 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -15,12 +15,46 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from thrift.protocol import TCompactProtocol
+import time
+
+from thrift.protocol import TBinaryProtocol, TCompactProtocol
+from thrift.Thrift import TException
 from thrift.transport import TSocket, TTransport
 
+from iotdb.mlnode.config import config
 from iotdb.mlnode.log import logger
+from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
+from iotdb.thrift.confignode import IConfigNodeRPCService
+from iotdb.thrift.confignode.ttypes import TUpdateModelInfoReq
+from iotdb.thrift.datanode import IDataNodeRPCService
+from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
+                                          TFetchTimeseriesResp,
+                                          TRecordModelMetricsReq)
 from iotdb.thrift.mlnode import IMLNodeRPCService
-from iotdb.thrift.mlnode.ttypes import TDeleteModelReq
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
+
+# status code
+SUCCESS_STATUS = 200
+REDIRECTION_RECOMMEND = 400
+
+
+def verify_success(status: TSStatus, err_msg: str) -> None:
+    if status.code != SUCCESS_STATUS:
+        logger.warn(err_msg + ", error status is ", status)
+        raise RuntimeError(str(status.code) + ": " + status.message)
+
+
+class ClientManager(object):
+    def __init__(self):
+        self.__data_node_endpoint = config.get_mn_target_data_node()
+        self.__config_node_endpoint = config.get_mn_target_config_node()
+
+    def borrow_data_node_client(self):
+        return DataNodeClient(host=self.__data_node_endpoint.ip,
+                              port=self.__data_node_endpoint.port)
+
+    def borrow_config_node_client(self):
+        return ConfigNodeClient(config_leader=self.__config_node_endpoint)
 
 
 class MLNodeClient(object):
@@ -40,12 +74,207 @@ class MLNodeClient(object):
         protocol = TCompactProtocol.TCompactProtocol(transport)
         self.__client = IMLNodeRPCService.Client(protocol)
 
-    def delete_model(self, model_path: str):
-        req = TDeleteModelReq(model_path)
-        return self.__client.deleteModel(req)
+    def create_training_task(self,
+                             model_id: str,
+                             is_auto: bool,
+                             model_configs: dict,
+                             query_expressions: list[str],
+                             query_filter: str = None) -> None:
+        req = TCreateTrainingTaskReq(
+            modelId=model_id,
+            isAuto=is_auto,
+            modelConfigs={k: str(v) for k, v in model_configs.items()},
+            queryExpressions=[str(query) for query in query_expressions],
+            queryFilter=query_filter,
+        )
+        try:
+            status = self.__client.createTrainingTask(req)
+            verify_success(status, "An error occurs when calling create_training_task()")
+        except TTransport.TException as e:
+            raise e
+
+    def create_forecast_task(self) -> None:
+        # TODO
+        pass
+
+    def delete_model(self,
+                     model_id: str,
+                     trial_id: str = None) -> None:
+        req = TDeleteModelReq(modelId=model_id, trialId=trial_id)
+        try:
+            status = self.__client.deleteModel(req)
+            verify_success(status, "An error occurs when calling delete_model()")
+        except TTransport.TException as e:
+            raise e
+
+
+class DataNodeClient(object):
+    DEFAULT_FETCH_SIZE = 10000
+    DEFAULT_TIMEOUT = 60000
+
+    def __init__(self, host, port):
+        self.__host = host
+        self.__port = port
+
+        transport = TTransport.TFramedTransport(
+            TSocket.TSocket(self.__host, self.__port)
+        )
+        if not transport.isOpen():
+            try:
+                transport.open()
+            except TTransport.TTransportException as e:
+                logger.exception("TTransportException!", exc_info=e)
+
+        protocol = TBinaryProtocol.TBinaryProtocol(transport)
+        self.__client = IDataNodeRPCService.Client(protocol)
+
+    def fetch_timeseries(self,
+                         session_id: int,
+                         statement_id: int,
+                         query_expressions: list[str],
+                         query_filter: str = None,
+                         fetch_size: int = DEFAULT_FETCH_SIZE,
+                         timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp:
+        req = TFetchTimeseriesReq(
+            sessionId=session_id,
+            statementId=statement_id,
+            queryExpressions=query_expressions,
+            queryFilter=query_filter,
+            fetchSize=fetch_size,
+            timeout=timeout
+        )
+        try:
+            resp = self.__client.fetchTimeseries(req)
+            verify_success(resp.status, "An error occurs when calling fetch_timeseries()")
+            return resp
+        except TTransport.TException as e:
+            raise e
+
+    def record_model_metrics(self,
+                             model_id: str,
+                             trial_id: str,
+                             metrics: list[str],
+                             values: list[float]) -> None:
+        req = TRecordModelMetricsReq(
+            modelId=model_id,
+            trialId=trial_id,
+            metrics=metrics,
+            timestamp=int(round(time.time() * 1000)),
+            values=values
+        )
+        try:
+            status = self.__client.recordModelMetrics(req)
+            verify_success(status, "An error occurs when calling record_model_metrics()")
+        except TTransport.TException as e:
+            raise e
+
+
+class ConfigNodeClient(object):
+    def __init__(self, config_leader: TEndPoint):
+        self.__config_leader = config_leader
+        self.__config_nodes = []
+        self.__cursor = 0
+        self.__transport = None
+        self.__client = None
+
+        self.__MSG_RECONNECTION_FAIL = "Fail to connect to any config node. Please check status of ConfigNodes"
+        self.__RETRY_NUM = 5
+        self.__RETRY_INTERVAL_MS = 1000
+
+        try:
+            self.__try_to_connect()
+        except TException:
+            # can not connect to each config node
+            self.__sync_latest_config_node_list()
+            self.__try_to_connect()
+
+    def __try_to_connect(self) -> None:
+        if self.__config_leader is not None:
+            try:
+                self.__connect(self.__config_leader)
+            except TException:
+                logger.warn("The current node {} may have been down, try next node", self.__config_leader)
+                self.__config_leader = None
+
+        if self.__transport is not None:
+            self.__transport.close()
+
+        try_host_num = 0
+        while try_host_num < len(self.__config_nodes):
+            self.__cursor = (self.__cursor + 1) % len(self.__config_nodes)
+
+            try_endpoint = self.__config_nodes[self.__cursor]
+            try:
+                self.__connect(try_endpoint)
+            except TException:
+                logger.warn("The current node {} may have been down, try next node", try_endpoint)
+
+            try_host_num = try_host_num + 1
+
+        raise TException(self.__MSG_RECONNECTION_FAIL)
+
+    def __connect(self, target_config_node: TEndPoint) -> None:
+        transport = TTransport.TFramedTransport(
+            TSocket.TSocket(target_config_node.ip, target_config_node.port)
+        )
+        if not transport.isOpen():
+            try:
+                transport.open()
+            except TTransport.TTransportException as e:
+                logger.exception("TTransportException!", exc_info=e)
+
+        protocol = TCompactProtocol.TBinaryProtocol(transport)
+        self.__client = IConfigNodeRPCService.Client(protocol)
+
+    def __wait_and_reconnect(self) -> None:
+        # wait to start the next try
+        time.sleep(self.__RETRY_INTERVAL_MS)
+
+        try:
+            self.__try_to_connect()
+        except TException:
+            # can not connect to each config node
+            self.__sync_latest_config_node_list()
+            self.__try_to_connect()
+
+    def __sync_latest_config_node_list(self) -> None:
+        # TODO
+        pass
+
+    def __update_config_node_leader(self, status: TSStatus) -> bool:
+        if status.code == REDIRECTION_RECOMMEND:
+            if status.redirectNode is not None:
+                self.__config_leader = status.redirectNode
+            else:
+                self.__config_leader = None
+            return True
+        return False
+
+    def update_model_info(self,
+                          model_id: str,
+                          trial_id: str,
+                          model_info: dict) -> None:
+        if model_info is None:
+            model_info = {}
+        req = TUpdateModelInfoReq(
+            modelId=model_id,
+            trialId=trial_id,
+            modelInfo={k: str(v) for k, v in model_info.items()},
+        )
+
+        for i in range(0, self.__RETRY_NUM):
+            try:
+                status = self.__client.updateModelInfo(req)
+                if not self.__update_config_node_leader(status):
+                    verify_success(status, "An error occurs when calling update_model_info()")
+                    return
+            except TTransport.TException:
+                logger.warn("Failed to connect to ConfigNode {} from MLNode when executing update_model_info()",
+                            self.__config_leader)
+                self.__config_leader = None
+            self.__wait_and_reconnect()
+
+        raise TException(self.__MSG_RECONNECTION_FAIL)
 
 
-if __name__ == "__main__":
-    # test rpc service
-    client = MLNodeClient(host="127.0.0.1", port=10810)
-    print(client.delete_model("test_model_path"))
+client_manager = ClientManager()
diff --git a/mlnode/requirements.txt b/mlnode/requirements.txt
index 05397f0df5..edd85701ab 100644
--- a/mlnode/requirements.txt
+++ b/mlnode/requirements.txt
@@ -20,4 +20,8 @@ pandas>=1.3.5
 numpy>=1.21.4
 apache-iotdb
 poetry
+torch
 pylru
+
+thrift~=0.13.0
+dynaconf~=3.1.12
\ No newline at end of file
diff --git a/mlnode/requirements_dev.txt b/mlnode/requirements_dev.txt
index f3e9ad3cf6..75a78abfb1 100644
--- a/mlnode/requirements_dev.txt
+++ b/mlnode/requirements_dev.txt
@@ -19,7 +19,10 @@
 -r requirements.txt
 # Pytest to run tests
 
+pandas>=1.3.5
+numpy>=1.21.4
+torch
+pylru
 pytest
 thrift
-dynaconf
-torch
\ No newline at end of file
+dynaconf
\ No newline at end of file
diff --git a/thrift-mlnode/src/main/thrift/mlnode.thrift b/thrift-mlnode/src/main/thrift/mlnode.thrift
index 2210fcf9d6..916022e973 100644
--- a/thrift-mlnode/src/main/thrift/mlnode.thrift
+++ b/thrift-mlnode/src/main/thrift/mlnode.thrift
@@ -31,6 +31,7 @@ struct TCreateTrainingTaskReq {
 
 struct TDeleteModelReq {
   1: required string modelId
+  2: optional string trailId
 }
 
 struct TForecastReq {