You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by fo...@apache.org on 2023/05/07 19:34:55 UTC
[iceberg] branch master updated: Python: Add REST support for SigV4 (#7519)
This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 6a0732f7c4 Python: Add REST support for SigV4 (#7519)
6a0732f7c4 is described below
commit 6a0732f7c48f81cac8b382b47edaec98b58b8d30
Author: Daniel Weeks <dw...@apache.org>
AuthorDate: Sun May 7 12:34:50 2023 -0700
Python: Add REST support for SigV4 (#7519)
* Add REST support for SigV4 signed requests
* Add comment for second session create
* lint
* Address comments
* Fix typing issues
* Add docks
* Lint
* remove lint ignore
---
python/mkdocs/docs/configuration.md | 9 +++
python/pyiceberg/catalog/rest.py | 123 ++++++++++++++++++++++++++----------
python/tests/catalog/test_rest.py | 18 +++++-
3 files changed, 117 insertions(+), 33 deletions(-)
diff --git a/python/mkdocs/docs/configuration.md b/python/mkdocs/docs/configuration.md
index 65a3fdad43..947b701e1b 100644
--- a/python/mkdocs/docs/configuration.md
+++ b/python/mkdocs/docs/configuration.md
@@ -97,6 +97,15 @@ catalog:
cabundle: /absolute/path/to/cabundle.pem
```
+| Key | Example | Description |
+| ------------------- | ----------------------- | -------------------------------------------------------------------------- |
+| uri | https://rest-catalog/ws | URI identifying the REST Server |
+| credential | t-1234:secret | Credential to use for OAuth2 credential flow when initializing the catalog |
+| token | FEW23.DFSDF.FSDF | Bearer token value to use for `Authorization` header |
+| rest.sigv4-enabled | true | Sign requests to the REST Server using AWS SigV4 protocol |
+| rest.signing-region | us-east-1 | The region to use when SigV4 signing a request |
+| rest.signing-name | execute-api | The service signing name to use when SigV4 signing a request |
+
## Hive Catalog
```yaml
diff --git a/python/pyiceberg/catalog/rest.py b/python/pyiceberg/catalog/rest.py
index 2c0cf634a9..e24e23910b 100644
--- a/python/pyiceberg/catalog/rest.py
+++ b/python/pyiceberg/catalog/rest.py
@@ -96,6 +96,9 @@ CERT = "cert"
CLIENT = "client"
CA_BUNDLE = "cabundle"
SSL = "ssl"
+SIGV4 = "rest.sigv4-enabled"
+SIGV4_REGION = "rest.signing-region"
+SIGV4_SERVICE = "rest.signing-name"
NAMESPACE_SEPARATOR = b"\x1F".decode("UTF-8")
@@ -172,8 +175,7 @@ class OAuthErrorResponse(IcebergBaseModel):
class RestCatalog(Catalog):
uri: str
- session: Session
- properties: Properties
+ _session: Session
def __init__(self, name: str, **properties: str):
"""Rest Catalog
@@ -184,37 +186,43 @@ class RestCatalog(Catalog):
name: Name to identify the catalog
properties: Properties that are passed along to the configuration
"""
- self.properties = properties
+ super().__init__(name, **properties)
self.uri = properties[URI]
- self._create_session()
- super().__init__(name, **self._fetch_config(properties))
+ self._fetch_config()
+ self._session = self._create_session()
- def _create_session(self) -> None:
+ def _create_session(self) -> Session:
"""Creates a request session with provided catalog configuration"""
+ session = Session()
- self.session = Session()
# Sets the client side and server side SSL cert verification, if provided as properties.
if ssl_config := self.properties.get(SSL):
if ssl_ca_bundle := ssl_config.get(CA_BUNDLE): # type: ignore
- self.session.verify = ssl_ca_bundle
+ session.verify = ssl_ca_bundle
if ssl_client := ssl_config.get(CLIENT): # type: ignore
if all(k in ssl_client for k in (CERT, KEY)):
- self.session.cert = (ssl_client[CERT], ssl_client[KEY])
+ session.cert = (ssl_client[CERT], ssl_client[KEY])
elif ssl_client_cert := ssl_client.get(CERT):
- self.session.cert = ssl_client_cert
+ session.cert = ssl_client_cert
# If we have credentials, but not a token, we want to fetch a token
if TOKEN not in self.properties and CREDENTIAL in self.properties:
- self.properties[TOKEN] = self._fetch_access_token(self.properties[CREDENTIAL])
+ self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
# Set Auth token for subsequent calls in the session
if token := self.properties.get(TOKEN):
- self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
+ session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
# Set HTTP headers
- self.session.headers["Content-type"] = "application/json"
- self.session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION
- self.session.headers["User-Agent"] = f"PyIceberg/{__version__}"
+ session.headers["Content-type"] = "application/json"
+ session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION
+ session.headers["User-Agent"] = f"PyIceberg/{__version__}"
+
+ # Configure SigV4 Request Signing
+ if str(self.properties.get(SIGV4, False)).lower() == "true":
+ self._init_sigv4(session)
+
+ return session
def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
"""The identifier should have at least one element"""
@@ -243,7 +251,7 @@ class RestCatalog(Catalog):
return url + endpoint.format(**kwargs)
- def _fetch_access_token(self, credential: str) -> str:
+ def _fetch_access_token(self, session: Session, credential: str) -> str:
if SEMICOLON in credential:
client_id, client_secret = credential.split(SEMICOLON)
else:
@@ -251,7 +259,7 @@ class RestCatalog(Catalog):
data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE}
url = self.url(Endpoints.get_token, prefixed=False)
# Uses application/x-www-form-urlencoded by default
- response = self.session.post(url=url, data=data)
+ response = session.post(url=url, data=data)
try:
response.raise_for_status()
except HTTPError as exc:
@@ -259,21 +267,26 @@ class RestCatalog(Catalog):
return TokenResponse(**response.json()).access_token
- def _fetch_config(self, properties: Properties) -> Properties:
+ def _fetch_config(self) -> None:
params = {}
- if warehouse_location := properties.get(WAREHOUSE_LOCATION):
+ if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
params[WAREHOUSE_LOCATION] = warehouse_location
- response = self.session.get(self.url(Endpoints.get_config, prefixed=False), params=params)
+ with self._create_session() as session:
+ response = session.get(self.url(Endpoints.get_config, prefixed=False), params=params)
try:
response.raise_for_status()
except HTTPError as exc:
self._handle_non_200_response(exc, {})
config_response = ConfigResponse(**response.json())
+
config = config_response.defaults
- config.update(properties)
+ config.update(self.properties)
config.update(config_response.overrides)
- return config
+ self.properties = config
+
+ # Update URI based on overrides
+ self.uri = config[URI]
def _split_identifier_for_path(self, identifier: Union[str, Identifier]) -> Properties:
identifier_tuple = self.identifier_to_tuple(identifier)
@@ -335,6 +348,52 @@ class RestCatalog(Catalog):
raise exception(response) from exc
+ def _init_sigv4(self, session: Session) -> None:
+ from urllib import parse
+
+ import boto3
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from requests import PreparedRequest
+ from requests.adapters import HTTPAdapter
+
+ class SigV4Adapter(HTTPAdapter):
+ def __init__(self, **properties: str):
+ super().__init__()
+ self._properties = properties
+
+ def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None: # pylint: disable=W0613
+ boto_session = boto3.Session()
+ credentials = boto_session.get_credentials().get_frozen_credentials()
+ region = self._properties.get(SIGV4_REGION, boto_session.region_name)
+ service = self._properties.get(SIGV4_SERVICE, "execute-api")
+
+ url = str(request.url).split("?")[0]
+ query = str(parse.urlsplit(request.url).query)
+ params = dict(parse.parse_qsl(query))
+
+ # remove the connection header as it will be updated after signing
+ del request.headers["connection"]
+
+ aws_request = AWSRequest(
+ method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers)
+ )
+
+ SigV4Auth(credentials, service, region).add_auth(aws_request)
+ original_header = request.headers
+ signed_headers = aws_request.headers
+ relocated_headers = {}
+
+ # relocate headers if there is a conflict with signed headers
+ for header, value in original_header.items():
+ if header in signed_headers and signed_headers[header] != value:
+ relocated_headers[f"Original-{header}"] = value
+
+ request.headers.update(relocated_headers)
+ request.headers.update(signed_headers)
+
+ session.mount(self.uri, SigV4Adapter(**self.properties))
+
def create_table(
self,
identifier: Union[str, Identifier],
@@ -354,7 +413,7 @@ class RestCatalog(Catalog):
properties=properties,
)
serialized_json = request.json()
- response = self.session.post(
+ response = self._session.post(
self.url(Endpoints.create_table, namespace=namespace_and_table["namespace"]),
data=serialized_json,
)
@@ -377,7 +436,7 @@ class RestCatalog(Catalog):
def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
- response = self.session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
+ response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
try:
response.raise_for_status()
except HTTPError as exc:
@@ -390,7 +449,7 @@ class RestCatalog(Catalog):
if len(identifier_tuple) <= 1:
raise NoSuchTableError(f"Missing namespace or invalid identifier: {identifier}")
- response = self.session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)))
+ response = self._session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)))
try:
response.raise_for_status()
except HTTPError as exc:
@@ -407,7 +466,7 @@ class RestCatalog(Catalog):
)
def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
- response = self.session.delete(
+ response = self._session.delete(
self.url(Endpoints.drop_table, prefixed=True, purge=purge_requested, **self._split_identifier_for_path(identifier)),
)
try:
@@ -423,7 +482,7 @@ class RestCatalog(Catalog):
"source": self._split_identifier_for_json(from_identifier),
"destination": self._split_identifier_for_json(to_identifier),
}
- response = self.session.post(self.url(Endpoints.rename_table), json=payload)
+ response = self._session.post(self.url(Endpoints.rename_table), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
@@ -434,7 +493,7 @@ class RestCatalog(Catalog):
def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
payload = {"namespace": namespace_tuple, "properties": properties}
- response = self.session.post(self.url(Endpoints.create_namespace), json=payload)
+ response = self._session.post(self.url(Endpoints.create_namespace), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
@@ -443,7 +502,7 @@ class RestCatalog(Catalog):
def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
- response = self.session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
+ response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
try:
response.raise_for_status()
except HTTPError as exc:
@@ -451,7 +510,7 @@ class RestCatalog(Catalog):
def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
namespace_tuple = self.identifier_to_tuple(namespace)
- response = self.session.get(
+ response = self._session.get(
self.url(
f"{Endpoints.list_namespaces}?parent={NAMESPACE_SEPARATOR.join(namespace_tuple)}"
if namespace_tuple
@@ -469,7 +528,7 @@ class RestCatalog(Catalog):
def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
- response = self.session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
+ response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
try:
response.raise_for_status()
except HTTPError as exc:
@@ -483,7 +542,7 @@ class RestCatalog(Catalog):
namespace_tuple = self._check_valid_namespace_identifier(namespace)
namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
payload = {"removals": list(removals or []), "updates": updates}
- response = self.session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload)
+ response = self._session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload)
try:
response.raise_for_status()
except HTTPError as exc:
diff --git a/python/tests/catalog/test_rest.py b/python/tests/catalog/test_rest.py
index 9c1700d97f..6f3cafffc1 100644
--- a/python/tests/catalog/test_rest.py
+++ b/python/tests/catalog/test_rest.py
@@ -91,7 +91,8 @@ def test_token_200(rest_mock: Mocker) -> None:
request_headers=OAUTH_TEST_HEADERS,
)
assert (
- RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS).session.headers["Authorization"] == f"Bearer {TEST_TOKEN}"
+ RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS)._session.headers["Authorization"] # pylint: disable=W0212
+ == f"Bearer {TEST_TOKEN}"
)
@@ -161,6 +162,21 @@ def test_list_tables_200(rest_mock: Mocker) -> None:
assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_tables(namespace) == [("examples", "fooshare")]
+def test_list_tables_200_sigv4(rest_mock: Mocker) -> None:
+ namespace = "examples"
+ rest_mock.get(
+ f"{TEST_URI}v1/namespaces/{namespace}/tables",
+ json={"identifiers": [{"namespace": ["examples"], "name": "fooshare"}]},
+ status_code=200,
+ request_headers=TEST_HEADERS,
+ )
+
+ assert RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}).list_tables(namespace) == [
+ ("examples", "fooshare")
+ ]
+ assert rest_mock.called
+
+
def test_list_tables_404(rest_mock: Mocker) -> None:
namespace = "examples"
rest_mock.get(