You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by iv...@apache.org on 2021/07/22 06:19:56 UTC
[ignite-python-thin-client] branch master updated: IGNITE-15118
Implement handshake timeout - Fixes #47.
This is an automated email from the ASF dual-hosted git repository.
ivandasch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite-python-thin-client.git
The following commit(s) were added to refs/heads/master by this push:
new de07126 IGNITE-15118 Implement handshake timeout - Fixes #47.
de07126 is described below
commit de07126cc4af51a04c12f6033609755a92da6d53
Author: Ivan Daschinsky <iv...@apache.org>
AuthorDate: Thu Jul 22 09:18:47 2021 +0300
IGNITE-15118 Implement handshake timeout - Fixes #47.
---
.travis.yml | 2 +-
examples/transactions.py | 6 +
pyignite/aio_client.py | 35 +++++-
pyignite/client.py | 39 +++++-
pyignite/connection/aio_connection.py | 48 ++++----
pyignite/connection/connection.py | 59 +++++----
pyignite/connection/protocol_context.py | 3 +
pyignite/transaction.py | 2 +-
tests/common/test_query_listener.py | 18 +--
tests/conftest.py | 5 -
tests/custom/test_cluster.py | 2 +-
tests/custom/test_connection_events.py | 31 +++--
tests/custom/test_handshake_timeout.py | 212 ++++++++++++++++++++++++++++++++
tests/security/test_auth.py | 3 +-
tox.ini | 7 ++
15 files changed, 394 insertions(+), 78 deletions(-)
diff --git a/.travis.yml b/.travis.yml
index 74909b8..2cd3e2b 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -51,4 +51,4 @@ jobs:
env: TOXENV=py39
install: pip install tox
-script: tox
\ No newline at end of file
+script: tox
diff --git a/examples/transactions.py b/examples/transactions.py
index 53da05f..53e9c30 100644
--- a/examples/transactions.py
+++ b/examples/transactions.py
@@ -130,6 +130,12 @@ def sync_example():
if __name__ == '__main__':
+ client = Client()
+ with client.connect('127.0.0.1', 10800):
+ if not client.protocol_context.is_transactions_supported():
+ print("'Transactions' API is not supported by cluster. Finishing...")
+ exit(0)
+
print("Starting sync example")
sync_example()
diff --git a/pyignite/aio_client.py b/pyignite/aio_client.py
index 083c964..b6ded74 100644
--- a/pyignite/aio_client.py
+++ b/pyignite/aio_client.py
@@ -65,6 +65,9 @@ class AioClient(BaseClient):
"""
Initialize client.
+ For the use of the SSL-related parameters see
+ https://docs.python.org/3/library/ssl.html#ssl-certificates.
+
:param compact_footer: (optional) use compact (True, recommended) or
full (False) schema approach when serializing Complex objects.
Default is to use the same approach the server is using (None).
@@ -73,7 +76,37 @@ class AioClient(BaseClient):
:param partition_aware: (optional) try to calculate the exact data
placement from the key before to issue the key operation to the
server node, `True` by default,
- :param event_listeners: (optional) event listeners.
+ :param event_listeners: (optional) event listeners,
+ :param handshake_timeout: (optional) sets timeout (in seconds) for performing handshake (connection)
+ with node. Default is 10.0 seconds,
+ :param use_ssl: (optional) set to True if Ignite server uses SSL
+ on its binary connector. Defaults to use SSL when username
+ and password has been supplied, not to use SSL otherwise,
+ :param ssl_version: (optional) SSL version constant from standard
+ `ssl` module. Defaults to TLS v1.2,
+ :param ssl_ciphers: (optional) ciphers to use. If not provided,
+ `ssl` default ciphers are used,
+ :param ssl_cert_reqs: (optional) determines how the remote side
+ certificate is treated:
+
+ * `ssl.CERT_NONE` − remote certificate is ignored (default),
+ * `ssl.CERT_OPTIONAL` − remote certificate will be validated,
+ if provided,
+ * `ssl.CERT_REQUIRED` − valid remote certificate is required,
+
+ :param ssl_keyfile: (optional) a path to SSL key file to identify
+ local (client) party,
+ :param ssl_keyfile_password: (optional) password for SSL key file,
+ can be provided when key file is encrypted to prevent OpenSSL
+ password prompt,
+ :param ssl_certfile: (optional) a path to ssl certificate file
+ to identify local (client) party,
+ :param ssl_ca_certfile: (optional) a path to a trusted certificate
+ or a certificate chain. Required to check the validity of the remote
+ (server-side) certificate,
+ :param username: (optional) user name to authenticate to Ignite
+ cluster,
+ :param password: (optional) password to authenticate to Ignite cluster.
"""
super().__init__(compact_footer, partition_aware, event_listeners, **kwargs)
self._registry_mux = asyncio.Lock()
diff --git a/pyignite/client.py b/pyignite/client.py
index e3dd71b..397c52e 100644
--- a/pyignite/client.py
+++ b/pyignite/client.py
@@ -346,6 +346,9 @@ class Client(BaseClient):
"""
Initialize client.
+ For the use of the SSL-related parameters see
+ https://docs.python.org/3/library/ssl.html#ssl-certificates.
+
:param compact_footer: (optional) use compact (True, recommended) or
full (False) schema approach when serializing Complex objects.
Default is to use the same approach the server is using (None).
@@ -354,7 +357,41 @@ class Client(BaseClient):
:param partition_aware: (optional) try to calculate the exact data
placement from the key before to issue the key operation to the
server node, `True` by default,
- :param event_listeners: (optional) event listeners.
+ :param event_listeners: (optional) event listeners,
+ :param timeout: (optional) sets timeout (in seconds) for each socket
+ operation including `connect`. 0 means non-blocking mode, which is
+ virtually guaranteed to fail. Can accept integer or float value.
+ Default is None (blocking mode),
+ :param handshake_timeout: (optional) sets timeout (in seconds) for performing handshake (connection)
+ with node. Default is 10.0 seconds,
+ :param use_ssl: (optional) set to True if Ignite server uses SSL
+ on its binary connector. Defaults to use SSL when username
+ and password has been supplied, not to use SSL otherwise,
+ :param ssl_version: (optional) SSL version constant from standard
+ `ssl` module. Defaults to TLS v1.2,
+ :param ssl_ciphers: (optional) ciphers to use. If not provided,
+ `ssl` default ciphers are used,
+ :param ssl_cert_reqs: (optional) determines how the remote side
+ certificate is treated:
+
+ * `ssl.CERT_NONE` − remote certificate is ignored (default),
+ * `ssl.CERT_OPTIONAL` − remote certificate will be validated,
+ if provided,
+ * `ssl.CERT_REQUIRED` − valid remote certificate is required,
+
+ :param ssl_keyfile: (optional) a path to SSL key file to identify
+ local (client) party,
+ :param ssl_keyfile_password: (optional) password for SSL key file,
+ can be provided when key file is encrypted to prevent OpenSSL
+ password prompt,
+ :param ssl_certfile: (optional) a path to ssl certificate file
+ to identify local (client) party,
+ :param ssl_ca_certfile: (optional) a path to a trusted certificate
+ or a certificate chain. Required to check the validity of the remote
+ (server-side) certificate,
+ :param username: (optional) user name to authenticate to Ignite
+ cluster,
+ :param password: (optional) password to authenticate to Ignite cluster.
"""
super().__init__(compact_footer, partition_aware, event_listeners, **kwargs)
diff --git a/pyignite/connection/aio_connection.py b/pyignite/connection/aio_connection.py
index 89de49d..4d13d6e 100644
--- a/pyignite/connection/aio_connection.py
+++ b/pyignite/connection/aio_connection.py
@@ -118,11 +118,13 @@ class AioConnection(BaseConnection):
:param client: Ignite client object,
:param host: Ignite server node's host name or IP,
:param port: Ignite server node's port number,
+ :param handshake_timeout: (optional) sets timeout (in seconds) for performing handshake (connection)
+ with node. Default is 10.0 seconds,
:param use_ssl: (optional) set to True if Ignite server uses SSL
on its binary connector. Defaults to use SSL when username
and password has been supplied, not to use SSL otherwise,
:param ssl_version: (optional) SSL version constant from standard
- `ssl` module. Defaults to TLS v1.1, as in Ignite 2.5,
+ `ssl` module. Defaults to TLS v1.2,
:param ssl_ciphers: (optional) ciphers to use. If not provided,
`ssl` default ciphers are used,
:param ssl_cert_reqs: (optional) determines how the remote side
@@ -165,7 +167,6 @@ class AioConnection(BaseConnection):
"""
if self.alive:
return
- self._closed = False
await self._connect()
async def _connect(self):
@@ -176,27 +177,28 @@ class AioConnection(BaseConnection):
detecting_protocol = True
self.client.protocol_context = ProtocolContext(max(PROTOCOLS), BitmaskFeature.all_supported())
- try:
- self._on_handshake_start()
- result = await self._connect_version()
- except HandshakeError as e:
- if e.expected_version in PROTOCOLS:
- self.client.protocol_context.version = e.expected_version
+ while True:
+ try:
+ self._on_handshake_start()
result = await self._connect_version()
- else:
+ self._on_handshake_success(result)
+ return
+ except HandshakeError as e:
+ if e.expected_version in PROTOCOLS:
+ self.client.protocol_context.version = e.expected_version
+ continue
+ else:
+ self._on_handshake_fail(e)
+ raise e
+ except AuthenticationError as e:
self._on_handshake_fail(e)
raise e
- except AuthenticationError as e:
- self._on_handshake_fail(e)
- raise e
- except Exception as e:
- self._on_handshake_fail(e)
- # restore undefined protocol version
- if detecting_protocol:
- self.client.protocol_context = None
- raise e
-
- self._on_handshake_success(result)
+ except Exception as e:
+ self._on_handshake_fail(e)
+ # restore undefined protocol version
+ if detecting_protocol:
+ self.client.protocol_context = None
+ raise e
def process_connection_lost(self, err, reconnect=False):
self.failed = True
@@ -225,9 +227,13 @@ class AioConnection(BaseConnection):
ssl_context = create_ssl_context(self.ssl_params)
handshake_fut = self._loop.create_future()
+ self._closed = False
self._transport, _ = await self._loop.create_connection(lambda: BaseProtocol(self, handshake_fut),
host=self.host, port=self.port, ssl=ssl_context)
- hs_response = await handshake_fut
+ try:
+ hs_response = await asyncio.wait_for(handshake_fut, self.handshake_timeout)
+ except asyncio.TimeoutError:
+ raise ConnectionError('timed out')
if hs_response.op_code == 0:
await self.close()
diff --git a/pyignite/connection/connection.py b/pyignite/connection/connection.py
index 2b9970a..98ba7e0 100644
--- a/pyignite/connection/connection.py
+++ b/pyignite/connection/connection.py
@@ -19,7 +19,7 @@ import socket
from typing import Union
from pyignite.constants import PROTOCOLS, IGNITE_DEFAULT_HOST, IGNITE_DEFAULT_PORT, PROTOCOL_BYTE_ORDER
-from pyignite.exceptions import HandshakeError, SocketError, connection_errors, AuthenticationError
+from pyignite.exceptions import HandshakeError, SocketError, connection_errors, AuthenticationError, ParameterError
from .bitmask_feature import BitmaskFeature
from .handshake import HandshakeRequest, HandshakeResponse
@@ -34,14 +34,18 @@ logger = logging.getLogger('.'.join(__name__.split('.')[:-1]))
class BaseConnection:
def __init__(self, client, host: str = None, port: int = None, username: str = None, password: str = None,
- **ssl_params):
+ handshake_timeout: float = 10.0, **ssl_params):
self.client = client
+ self.handshake_timeout = handshake_timeout
self.host = host if host else IGNITE_DEFAULT_HOST
self.port = port if port else IGNITE_DEFAULT_PORT
self.username = username
self.password = password
self.uuid = None
+ if handshake_timeout <= 0.0:
+ raise ParameterError("handshake_timeout should be positive")
+
check_ssl_params(ssl_params)
if self.username and self.password and 'use_ssl' not in ssl_params:
@@ -162,8 +166,9 @@ class Connection(BaseConnection):
* binary protocol connector. Encapsulates handshake and failover reconnection.
"""
- def __init__(self, client: 'Client', host: str, port: int, timeout: float = None,
- username: str = None, password: str = None, **ssl_params):
+ def __init__(self, client: 'Client', host: str, port: int, username: str = None, password: str = None,
+ timeout: float = None, handshake_timeout: float = 10.0,
+ **ssl_params):
"""
Initialize connection.
@@ -177,11 +182,13 @@ class Connection(BaseConnection):
operation including `connect`. 0 means non-blocking mode, which is
virtually guaranteed to fail. Can accept integer or float value.
Default is None (blocking mode),
+ :param handshake_timeout: (optional) sets timeout (in seconds) for performing handshake (connection)
+ with node. Default is 10.0.
:param use_ssl: (optional) set to True if Ignite server uses SSL
on its binary connector. Defaults to use SSL when username
and password has been supplied, not to use SSL otherwise,
:param ssl_version: (optional) SSL version constant from standard
- `ssl` module. Defaults to TLS v1.1, as in Ignite 2.5,
+ `ssl` module. Defaults to TLS v1.2,
:param ssl_ciphers: (optional) ciphers to use. If not provided,
`ssl` default ciphers are used,
:param ssl_cert_reqs: (optional) determines how the remote side
@@ -206,7 +213,7 @@ class Connection(BaseConnection):
cluster,
:param password: (optional) password to authenticate to Ignite cluster.
"""
- super().__init__(client, host, port, username, password, **ssl_params)
+ super().__init__(client, host, port, username, password, handshake_timeout, **ssl_params)
self.timeout = timeout
self._socket = None
@@ -225,27 +232,29 @@ class Connection(BaseConnection):
detecting_protocol = True
self.client.protocol_context = ProtocolContext(max(PROTOCOLS), BitmaskFeature.all_supported())
- try:
- self._on_handshake_start()
- result = self._connect_version()
- except HandshakeError as e:
- if e.expected_version in PROTOCOLS:
- self.client.protocol_context.version = e.expected_version
+ while True:
+ try:
+ self._on_handshake_start()
result = self._connect_version()
- else:
+ self._socket.settimeout(self.timeout)
+ self._on_handshake_success(result)
+ return
+ except HandshakeError as e:
+ if e.expected_version in PROTOCOLS:
+ self.client.protocol_context.version = e.expected_version
+ continue
+ else:
+ self._on_handshake_fail(e)
+ raise e
+ except AuthenticationError as e:
self._on_handshake_fail(e)
raise e
- except AuthenticationError as e:
- self._on_handshake_fail(e)
- raise e
- except Exception as e:
- self._on_handshake_fail(e)
- # restore undefined protocol version
- if detecting_protocol:
- self.client.protocol_context = None
- raise e
-
- self._on_handshake_success(result)
+ except Exception as e:
+ self._on_handshake_fail(e)
+ # restore undefined protocol version
+ if detecting_protocol:
+ self.client.protocol_context = None
+ raise e
def _connect_version(self) -> Union[dict, OrderedDict]:
"""
@@ -254,7 +263,7 @@ class Connection(BaseConnection):
"""
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._socket.settimeout(self.timeout)
+ self._socket.settimeout(self.handshake_timeout)
self._socket = wrap(self._socket, self.ssl_params)
self._socket.connect((self.host, self.port))
diff --git a/pyignite/connection/protocol_context.py b/pyignite/connection/protocol_context.py
index ba6d9e4..f60d45b 100644
--- a/pyignite/connection/protocol_context.py
+++ b/pyignite/connection/protocol_context.py
@@ -40,6 +40,9 @@ class ProtocolContext:
def __str__(self):
return f'ProtocolContext(version={self._version}, features={self._features})'
+ def __repr__(self):
+ return self.__str__()
+
def _ensure_consistency(self):
if not self.is_feature_flags_supported():
self._features = None
diff --git a/pyignite/transaction.py b/pyignite/transaction.py
index eb77f8d..3003eb6 100644
--- a/pyignite/transaction.py
+++ b/pyignite/transaction.py
@@ -23,7 +23,7 @@ from pyignite.utils import status_to_exception
def _validate_int_enum_param(value: Union[int, IntEnum], cls: Type[IntEnum]):
- if value not in cls:
+ if value not in set(v.value for v in cls): # Use this trick to disable warning on python 3.7
raise ValueError(f'{value} not in {cls}')
return value
diff --git a/tests/common/test_query_listener.py b/tests/common/test_query_listener.py
index afff542..8310117 100644
--- a/tests/common/test_query_listener.py
+++ b/tests/common/test_query_listener.py
@@ -17,7 +17,7 @@ import pytest
from pyignite import Client, AioClient
from pyignite.exceptions import CacheError
from pyignite.monitoring import QueryEventListener, QueryStartEvent, QueryFailEvent, QuerySuccessEvent
-from pyignite.queries.op_codes import OP_CACHE_PUT, OP_CACHE_PARTITIONS, OP_CLUSTER_GET_STATE
+from pyignite.queries.op_codes import OP_CACHE_PUT, OP_CACHE_PARTITIONS, OP_CACHE_GET_NAMES
events = []
@@ -93,17 +93,17 @@ def __assert_fail_events(client):
assert ev.port == conn.port
assert ev.node_uuid == str(conn.uuid if conn.uuid else '')
assert 'Cache does not exist' in ev.err_msg
- assert ev.duration > 0
+ assert ev.duration >= 0
def test_query_success_events(client):
- client.get_cluster().get_state()
+ client.get_cache_names()
__assert_success_events(client)
@pytest.mark.asyncio
async def test_query_success_events_async(async_client):
- await async_client.get_cluster().get_state()
+ await async_client.get_cache_names()
__assert_success_events(async_client)
@@ -112,16 +112,16 @@ def __assert_success_events(client):
conn = client._nodes[0]
for ev in events:
if isinstance(ev, QueryStartEvent):
- assert ev.op_code == OP_CLUSTER_GET_STATE
- assert ev.op_name == 'OP_CLUSTER_GET_STATE'
+ assert ev.op_code == OP_CACHE_GET_NAMES
+ assert ev.op_name == 'OP_CACHE_GET_NAMES'
assert ev.host == conn.host
assert ev.port == conn.port
assert ev.node_uuid == str(conn.uuid if conn.uuid else '')
if isinstance(ev, QuerySuccessEvent):
- assert ev.op_code == OP_CLUSTER_GET_STATE
- assert ev.op_name == 'OP_CLUSTER_GET_STATE'
+ assert ev.op_code == OP_CACHE_GET_NAMES
+ assert ev.op_name == 'OP_CACHE_GET_NAMES'
assert ev.host == conn.host
assert ev.port == conn.port
assert ev.node_uuid == str(conn.uuid if conn.uuid else '')
- assert ev.duration > 0
+ assert ev.duration >= 0
diff --git a/tests/conftest.py b/tests/conftest.py
index 70995a2..6f92f0c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -20,11 +20,6 @@ import pytest
logger = logging.getLogger('pyignite')
logger.setLevel(logging.DEBUG)
-handler = logging.StreamHandler(stream=sys.stdout)
-handler.setFormatter(
- logging.Formatter('%(asctime)s %(name)s %(levelname)s %(message)s')
-)
-logger.addHandler(handler)
@pytest.fixture(autouse=True)
diff --git a/tests/custom/test_cluster.py b/tests/custom/test_cluster.py
index e94853a..ae83ecd 100644
--- a/tests/custom/test_cluster.py
+++ b/tests/custom/test_cluster.py
@@ -49,7 +49,7 @@ def cluster_api_supported(request, server1):
client = Client()
with client.connect('127.0.0.1', 10801):
if not client.protocol_context.is_cluster_api_supported():
- pytest.skip(f'skipped {request.node.name}, ExpiryPolicy APIis not supported.')
+ pytest.skip(f'skipped {request.node.name}, Cluster API is not supported.')
def test_cluster_set_active(with_persistence):
diff --git a/tests/custom/test_connection_events.py b/tests/custom/test_connection_events.py
index bee9395..f49ad61 100644
--- a/tests/custom/test_connection_events.py
+++ b/tests/custom/test_connection_events.py
@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import random
-
import pytest
from pyignite import Client, AioClient
+from pyignite.datatypes.cache_config import CacheMode
+from pyignite.datatypes.prop_codes import PROP_NAME, PROP_CACHE_MODE
from pyignite.monitoring import ConnectionEventListener, ConnectionLostEvent, ConnectionClosedEvent, \
HandshakeSuccessEvent, HandshakeFailedEvent, HandshakeStartEvent
@@ -65,12 +65,16 @@ def test_events(request, server2):
with client.connect([('127.0.0.1', 10800 + idx) for idx in range(1, 3)]):
protocol_context = client.protocol_context
nodes = {conn.port: conn for conn in client._nodes}
- cache = client.get_or_create_cache(request.node.name)
+ cache = client.get_or_create_cache({
+ PROP_NAME: request.node.name,
+ PROP_CACHE_MODE: CacheMode.REPLICATED,
+ })
+
kill_process_tree(server2.pid)
- while True:
+ for _ in range(0, 100):
try:
- cache.put(random.randint(0, 1000), 1)
+ cache.put(1, 1)
except: # noqa 13
pass
@@ -86,12 +90,15 @@ async def test_events_async(request, server2):
async with client.connect([('127.0.0.1', 10800 + idx) for idx in range(1, 3)]):
protocol_context = client.protocol_context
nodes = {conn.port: conn for conn in client._nodes}
- cache = await client.get_or_create_cache(request.node.name)
+ cache = await client.get_or_create_cache({
+ PROP_NAME: request.node.name,
+ PROP_CACHE_MODE: CacheMode.REPLICATED,
+ })
kill_process_tree(server2.pid)
- while True:
+ for _ in range(0, 100):
try:
- await cache.put(random.randint(0, 1000), 1)
+ await cache.put(1, 1)
except: # noqa 13
pass
@@ -104,7 +111,7 @@ async def test_events_async(request, server2):
def __assert_events(nodes, protocol_context):
assert len([e for e in events if isinstance(e, ConnectionLostEvent)]) == 1
# ConnectionLostEvent is a subclass of ConnectionClosedEvent
- assert len([e for e in events if type(e) == ConnectionClosedEvent]) == 1
+ assert 1 <= len([e for e in events if type(e) == ConnectionClosedEvent and e.node_uuid]) <= 2
assert len([e for e in events if isinstance(e, HandshakeSuccessEvent)]) == 2
for ev in events:
@@ -114,7 +121,6 @@ def __assert_events(nodes, protocol_context):
assert ev.node_uuid == str(nodes[ev.port].uuid)
assert ev.error_msg
elif isinstance(ev, HandshakeStartEvent):
- assert ev.protocol_context == protocol_context
assert ev.port in {10801, 10802}
elif isinstance(ev, HandshakeFailedEvent):
assert ev.port == 10802
@@ -125,5 +131,6 @@ def __assert_events(nodes, protocol_context):
assert ev.node_uuid == str(nodes[ev.port].uuid)
assert ev.protocol_context == protocol_context
elif isinstance(ev, ConnectionClosedEvent):
- assert ev.port == 10801
- assert ev.node_uuid == str(nodes[ev.port].uuid)
+ assert ev.port in {10801, 10802}
+ if ev.node_uuid: # Possible if protocol negotiation occurred.
+ assert ev.node_uuid == str(nodes[ev.port].uuid)
diff --git a/tests/custom/test_handshake_timeout.py b/tests/custom/test_handshake_timeout.py
new file mode 100644
index 0000000..bae184d
--- /dev/null
+++ b/tests/custom/test_handshake_timeout.py
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import logging
+import socket
+import struct
+import time
+from concurrent.futures import ThreadPoolExecutor
+
+import pytest
+
+from pyignite import Client, AioClient
+from pyignite import monitoring
+from pyignite.exceptions import ReconnectError, ParameterError
+from pyignite.monitoring import HandshakeFailedEvent
+
+logger = logging.getLogger('fake_ignite')
+logger.setLevel(logging.DEBUG)
+
+DEFAULT_HOST = '127.0.0.1'
+DEFAULT_PORT = 10800
+
+
+class FakeIgniteProtocol(asyncio.Protocol):
+ def __init__(self, server):
+ self._transport = None
+ self._server = server
+ self._buf = bytearray()
+ self._done_handshake = False
+
+ def connection_made(self, transport):
+ sock = transport.get_extra_info('socket')
+ if sock is not None:
+ logger.debug('Connecting from %s', sock)
+ self._server.add_client(transport)
+ self._transport = transport
+
+ def _handshake_response(self, error=True):
+ if error:
+ return struct.pack('<lbhhhbll', 16, 0, 1, 3, 0, 9, 0, 0)
+ else:
+ return struct.pack('<lb', 1, 1)
+
+ def _parse_handshake_request(self, buf):
+ return struct.unpack('<lbhhhb', buf)
+
+ def data_received(self, data):
+ logger.debug(f'Data received: {data if data else b""}')
+
+ if self._server.do_handshake and not self._done_handshake:
+ self._buf += data
+
+ if len(self._buf) < 12:
+ return
+
+ req = self._parse_handshake_request(self._buf[0:12])
+
+ if req[1] == 1 and (req[2], req[3], req[4]) > (1, 3, 0):
+ response = self._handshake_response(True)
+ logger.debug(f'Writing handshake response {response}')
+ self._transport.write(response)
+ self._transport.close()
+ else:
+ response = self._handshake_response(False)
+ logger.debug(f'Writing handshake response {response}')
+ self._transport.write(response)
+ self._done_handshake = True
+ self._buf = bytearray()
+
+
+class FakeIgniteServer:
+ def __init__(self, do_handshake=False):
+ self.clients = []
+ self.server = None
+ self.do_handshake = do_handshake
+ self.loop = asyncio.get_event_loop()
+
+ async def start(self):
+ self.server = await self.loop.create_server(lambda: FakeIgniteProtocol(self), DEFAULT_HOST, DEFAULT_PORT)
+
+ def add_client(self, client):
+ self.clients.append(client)
+
+ async def close(self):
+ for client in self.clients:
+ client.close()
+
+ if self.server:
+ self.server.close()
+ await self.server.wait_closed()
+
+
+class HandshakeTimeoutListener(monitoring.ConnectionEventListener):
+ def __init__(self):
+ self.events = []
+
+ def on_handshake_fail(self, event: HandshakeFailedEvent):
+ self.events.append(event)
+
+
+@pytest.fixture
+async def server():
+ server = FakeIgniteServer()
+ try:
+ await server.start()
+ yield server
+ finally:
+ await server.close()
+
+
+@pytest.fixture
+async def server_with_handshake():
+ server = FakeIgniteServer(do_handshake=True)
+ try:
+ await server.start()
+ yield server
+ finally:
+ await server.close()
+
+
+@pytest.mark.asyncio
+async def test_handshake_timeout(server, event_loop):
+ def sync_client_connect():
+ hs_to_listener = HandshakeTimeoutListener()
+ client = Client(handshake_timeout=3.0, event_listeners=[hs_to_listener])
+ start = time.monotonic()
+ try:
+ client.connect(DEFAULT_HOST, DEFAULT_PORT)
+ except Exception as e:
+ return time.monotonic() - start, hs_to_listener.events, e
+ return time.monotonic() - start, hs_to_listener.events, None
+
+ duration, events, err = await event_loop.run_in_executor(ThreadPoolExecutor(), sync_client_connect)
+
+ assert isinstance(err, ReconnectError)
+ assert 3.0 <= duration < 4.0
+ assert len(events) > 0
+ for ev in events:
+ assert isinstance(ev, HandshakeFailedEvent)
+ assert 'timed out' in ev.error_msg
+
+
+@pytest.mark.asyncio
+async def test_handshake_timeout_async(server):
+ hs_to_listener = HandshakeTimeoutListener()
+ client = AioClient(handshake_timeout=3.0, event_listeners=[hs_to_listener])
+ with pytest.raises(ReconnectError):
+ start = time.monotonic()
+ await client.connect(DEFAULT_HOST, DEFAULT_PORT)
+
+ assert 3.0 <= time.monotonic() - start < 4.0
+ assert len(hs_to_listener.events) > 0
+ for ev in hs_to_listener.events:
+ assert isinstance(ev, HandshakeFailedEvent)
+ assert 'timed out' in ev.error_msg
+
+
+@pytest.mark.asyncio
+async def test_socket_timeout_applied_sync(server_with_handshake, event_loop):
+ def sync_client_connect():
+ hs_to_listener = HandshakeTimeoutListener()
+ client = Client(timeout=5.0, handshake_timeout=3.0, event_listeners=[hs_to_listener])
+ start = time.monotonic()
+ try:
+ client.connect(DEFAULT_HOST, DEFAULT_PORT)
+ assert all(n.alive for n in client._nodes)
+ client.get_cache_names()
+ except Exception as e:
+ return time.monotonic() - start, hs_to_listener.events, e
+ return time.monotonic() - start, hs_to_listener.events, None
+
+ duration, events, err = await event_loop.run_in_executor(ThreadPoolExecutor(), sync_client_connect)
+
+ assert isinstance(err, socket.timeout)
+ assert 5.0 <= duration < 6.0
+ assert len(events) == 0
+
+
+@pytest.mark.asyncio
+async def test_handshake_timeout_not_affected_for_others_requests_async(server_with_handshake):
+ hs_to_listener = HandshakeTimeoutListener()
+ client = AioClient(handshake_timeout=3.0, event_listeners=[hs_to_listener])
+ with pytest.raises(asyncio.TimeoutError):
+ await client.connect(DEFAULT_HOST, DEFAULT_PORT)
+ assert all(n.alive for n in client._nodes)
+ await asyncio.wait_for(client.get_cache_names(), 5.0)
+
+
+@pytest.mark.parametrize(
+ 'handshake_timeout',
+ [0.0, -10.0, -0.01]
+)
+@pytest.mark.asyncio
+async def test_handshake_timeout_param_validation(handshake_timeout):
+ with pytest.raises(ParameterError):
+ await AioClient(handshake_timeout=handshake_timeout).connect(DEFAULT_HOST, DEFAULT_PORT)
+
+ with pytest.raises(ParameterError):
+ Client(handshake_timeout=handshake_timeout).connect(DEFAULT_HOST, DEFAULT_PORT)
diff --git a/tests/security/test_auth.py b/tests/security/test_auth.py
index 503cf88..83ac780 100644
--- a/tests/security/test_auth.py
+++ b/tests/security/test_auth.py
@@ -95,7 +95,8 @@ def __assert_successful_connect_events(conn, listener):
assert ev.host == conn.host
assert ev.port == conn.port
if isinstance(ev, (HandshakeSuccessEvent, ConnectionClosedEvent)):
- assert ev.node_uuid == str(conn.uuid if conn.uuid else '')
+ if ev.node_uuid:
+ assert ev.node_uuid == str(conn.uuid)
if isinstance(ev, HandshakeSuccessEvent):
assert ev.protocol_context
diff --git a/tox.ini b/tox.ini
index 90153da..964b748 100644
--- a/tox.ini
+++ b/tox.ini
@@ -17,6 +17,13 @@
skipsdist = True
envlist = codestyle,py{36,37,38,39}
+[pytest]
+log_format = %(asctime)s %(name)s %(levelname)s %(message)s
+log_date_format = %Y-%m-%d %H:%M:%S
+# Uncomment if you want verbose logging for all tests (for failed it will be printed anyway).
+# log_cli = True
+# log_cli_level = DEBUG
+
[flake8]
max-line-length=120
ignore = F401,F403,F405,F821