You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by fo...@apache.org on 2023/05/07 19:34:55 UTC

[iceberg] branch master updated: Python: Add REST support for SigV4 (#7519)

This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a0732f7c4 Python: Add REST support for SigV4 (#7519)
6a0732f7c4 is described below

commit 6a0732f7c48f81cac8b382b47edaec98b58b8d30
Author: Daniel Weeks <dw...@apache.org>
AuthorDate: Sun May 7 12:34:50 2023 -0700

    Python: Add REST support for SigV4 (#7519)
    
    * Add REST support for SigV4 signed requests
    
    * Add comment for second session create
    
    * lint
    
    * Address comments
    
    * Fix typing issues
    
    * Add docks
    
    * Lint
    
    * remove lint ignore
---
 python/mkdocs/docs/configuration.md |   9 +++
 python/pyiceberg/catalog/rest.py    | 123 ++++++++++++++++++++++++++----------
 python/tests/catalog/test_rest.py   |  18 +++++-
 3 files changed, 117 insertions(+), 33 deletions(-)

diff --git a/python/mkdocs/docs/configuration.md b/python/mkdocs/docs/configuration.md
index 65a3fdad43..947b701e1b 100644
--- a/python/mkdocs/docs/configuration.md
+++ b/python/mkdocs/docs/configuration.md
@@ -97,6 +97,15 @@ catalog:
       cabundle: /absolute/path/to/cabundle.pem
 ```
 
+| Key                 | Example                 | Description                                                                |
+| ------------------- | ----------------------- | -------------------------------------------------------------------------- |
+| uri                 | https://rest-catalog/ws | URI identifying the REST Server                                            |
+| credential          | t-1234:secret           | Credential to use for OAuth2 credential flow when initializing the catalog |
+| token               | FEW23.DFSDF.FSDF        | Bearer token value to use for `Authorization` header                       |
+| rest.sigv4-enabled  | true                    | Sign requests to the REST Server using AWS SigV4 protocol                  |
+| rest.signing-region | us-east-1               | The region to use when SigV4 signing a request                             |
+| rest.signing-name   | execute-api             | The service signing name to use when SigV4 signing a request               |
+
 ## Hive Catalog
 
 ```yaml
diff --git a/python/pyiceberg/catalog/rest.py b/python/pyiceberg/catalog/rest.py
index 2c0cf634a9..e24e23910b 100644
--- a/python/pyiceberg/catalog/rest.py
+++ b/python/pyiceberg/catalog/rest.py
@@ -96,6 +96,9 @@ CERT = "cert"
 CLIENT = "client"
 CA_BUNDLE = "cabundle"
 SSL = "ssl"
+SIGV4 = "rest.sigv4-enabled"
+SIGV4_REGION = "rest.signing-region"
+SIGV4_SERVICE = "rest.signing-name"
 
 NAMESPACE_SEPARATOR = b"\x1F".decode("UTF-8")
 
@@ -172,8 +175,7 @@ class OAuthErrorResponse(IcebergBaseModel):
 
 class RestCatalog(Catalog):
     uri: str
-    session: Session
-    properties: Properties
+    _session: Session
 
     def __init__(self, name: str, **properties: str):
         """Rest Catalog
@@ -184,37 +186,43 @@ class RestCatalog(Catalog):
             name: Name to identify the catalog
             properties: Properties that are passed along to the configuration
         """
-        self.properties = properties
+        super().__init__(name, **properties)
         self.uri = properties[URI]
-        self._create_session()
-        super().__init__(name, **self._fetch_config(properties))
+        self._fetch_config()
+        self._session = self._create_session()
 
-    def _create_session(self) -> None:
+    def _create_session(self) -> Session:
         """Creates a request session with provided catalog configuration"""
+        session = Session()
 
-        self.session = Session()
         # Sets the client side and server side SSL cert verification, if provided as properties.
         if ssl_config := self.properties.get(SSL):
             if ssl_ca_bundle := ssl_config.get(CA_BUNDLE):  # type: ignore
-                self.session.verify = ssl_ca_bundle
+                session.verify = ssl_ca_bundle
             if ssl_client := ssl_config.get(CLIENT):  # type: ignore
                 if all(k in ssl_client for k in (CERT, KEY)):
-                    self.session.cert = (ssl_client[CERT], ssl_client[KEY])
+                    session.cert = (ssl_client[CERT], ssl_client[KEY])
                 elif ssl_client_cert := ssl_client.get(CERT):
-                    self.session.cert = ssl_client_cert
+                    session.cert = ssl_client_cert
 
         # If we have credentials, but not a token, we want to fetch a token
         if TOKEN not in self.properties and CREDENTIAL in self.properties:
-            self.properties[TOKEN] = self._fetch_access_token(self.properties[CREDENTIAL])
+            self.properties[TOKEN] = self._fetch_access_token(session, self.properties[CREDENTIAL])
 
         # Set Auth token for subsequent calls in the session
         if token := self.properties.get(TOKEN):
-            self.session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
+            session.headers[AUTHORIZATION_HEADER] = f"{BEARER_PREFIX} {token}"
 
         # Set HTTP headers
-        self.session.headers["Content-type"] = "application/json"
-        self.session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION
-        self.session.headers["User-Agent"] = f"PyIceberg/{__version__}"
+        session.headers["Content-type"] = "application/json"
+        session.headers["X-Client-Version"] = ICEBERG_REST_SPEC_VERSION
+        session.headers["User-Agent"] = f"PyIceberg/{__version__}"
+
+        # Configure SigV4 Request Signing
+        if str(self.properties.get(SIGV4, False)).lower() == "true":
+            self._init_sigv4(session)
+
+        return session
 
     def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
         """The identifier should have at least one element"""
@@ -243,7 +251,7 @@ class RestCatalog(Catalog):
 
         return url + endpoint.format(**kwargs)
 
-    def _fetch_access_token(self, credential: str) -> str:
+    def _fetch_access_token(self, session: Session, credential: str) -> str:
         if SEMICOLON in credential:
             client_id, client_secret = credential.split(SEMICOLON)
         else:
@@ -251,7 +259,7 @@ class RestCatalog(Catalog):
         data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE}
         url = self.url(Endpoints.get_token, prefixed=False)
         # Uses application/x-www-form-urlencoded by default
-        response = self.session.post(url=url, data=data)
+        response = session.post(url=url, data=data)
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -259,21 +267,26 @@ class RestCatalog(Catalog):
 
         return TokenResponse(**response.json()).access_token
 
-    def _fetch_config(self, properties: Properties) -> Properties:
+    def _fetch_config(self) -> None:
         params = {}
-        if warehouse_location := properties.get(WAREHOUSE_LOCATION):
+        if warehouse_location := self.properties.get(WAREHOUSE_LOCATION):
             params[WAREHOUSE_LOCATION] = warehouse_location
 
-        response = self.session.get(self.url(Endpoints.get_config, prefixed=False), params=params)
+        with self._create_session() as session:
+            response = session.get(self.url(Endpoints.get_config, prefixed=False), params=params)
         try:
             response.raise_for_status()
         except HTTPError as exc:
             self._handle_non_200_response(exc, {})
         config_response = ConfigResponse(**response.json())
+
         config = config_response.defaults
-        config.update(properties)
+        config.update(self.properties)
         config.update(config_response.overrides)
-        return config
+        self.properties = config
+
+        # Update URI based on overrides
+        self.uri = config[URI]
 
     def _split_identifier_for_path(self, identifier: Union[str, Identifier]) -> Properties:
         identifier_tuple = self.identifier_to_tuple(identifier)
@@ -335,6 +348,52 @@ class RestCatalog(Catalog):
 
         raise exception(response) from exc
 
+    def _init_sigv4(self, session: Session) -> None:
+        from urllib import parse
+
+        import boto3
+        from botocore.auth import SigV4Auth
+        from botocore.awsrequest import AWSRequest
+        from requests import PreparedRequest
+        from requests.adapters import HTTPAdapter
+
+        class SigV4Adapter(HTTPAdapter):
+            def __init__(self, **properties: str):
+                super().__init__()
+                self._properties = properties
+
+            def add_headers(self, request: PreparedRequest, **kwargs: Any) -> None:  # pylint: disable=W0613
+                boto_session = boto3.Session()
+                credentials = boto_session.get_credentials().get_frozen_credentials()
+                region = self._properties.get(SIGV4_REGION, boto_session.region_name)
+                service = self._properties.get(SIGV4_SERVICE, "execute-api")
+
+                url = str(request.url).split("?")[0]
+                query = str(parse.urlsplit(request.url).query)
+                params = dict(parse.parse_qsl(query))
+
+                # remove the connection header as it will be updated after signing
+                del request.headers["connection"]
+
+                aws_request = AWSRequest(
+                    method=request.method, url=url, params=params, data=request.body, headers=dict(request.headers)
+                )
+
+                SigV4Auth(credentials, service, region).add_auth(aws_request)
+                original_header = request.headers
+                signed_headers = aws_request.headers
+                relocated_headers = {}
+
+                # relocate headers if there is a conflict with signed headers
+                for header, value in original_header.items():
+                    if header in signed_headers and signed_headers[header] != value:
+                        relocated_headers[f"Original-{header}"] = value
+
+                request.headers.update(relocated_headers)
+                request.headers.update(signed_headers)
+
+        session.mount(self.uri, SigV4Adapter(**self.properties))
+
     def create_table(
         self,
         identifier: Union[str, Identifier],
@@ -354,7 +413,7 @@ class RestCatalog(Catalog):
             properties=properties,
         )
         serialized_json = request.json()
-        response = self.session.post(
+        response = self._session.post(
             self.url(Endpoints.create_table, namespace=namespace_and_table["namespace"]),
             data=serialized_json,
         )
@@ -377,7 +436,7 @@ class RestCatalog(Catalog):
     def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
         namespace_tuple = self._check_valid_namespace_identifier(namespace)
         namespace_concat = NAMESPACE_SEPARATOR.join(namespace_tuple)
-        response = self.session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
+        response = self._session.get(self.url(Endpoints.list_tables, namespace=namespace_concat))
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -390,7 +449,7 @@ class RestCatalog(Catalog):
         if len(identifier_tuple) <= 1:
             raise NoSuchTableError(f"Missing namespace or invalid identifier: {identifier}")
 
-        response = self.session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)))
+        response = self._session.get(self.url(Endpoints.load_table, prefixed=True, **self._split_identifier_for_path(identifier)))
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -407,7 +466,7 @@ class RestCatalog(Catalog):
         )
 
     def drop_table(self, identifier: Union[str, Identifier], purge_requested: bool = False) -> None:
-        response = self.session.delete(
+        response = self._session.delete(
             self.url(Endpoints.drop_table, prefixed=True, purge=purge_requested, **self._split_identifier_for_path(identifier)),
         )
         try:
@@ -423,7 +482,7 @@ class RestCatalog(Catalog):
             "source": self._split_identifier_for_json(from_identifier),
             "destination": self._split_identifier_for_json(to_identifier),
         }
-        response = self.session.post(self.url(Endpoints.rename_table), json=payload)
+        response = self._session.post(self.url(Endpoints.rename_table), json=payload)
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -434,7 +493,7 @@ class RestCatalog(Catalog):
     def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None:
         namespace_tuple = self._check_valid_namespace_identifier(namespace)
         payload = {"namespace": namespace_tuple, "properties": properties}
-        response = self.session.post(self.url(Endpoints.create_namespace), json=payload)
+        response = self._session.post(self.url(Endpoints.create_namespace), json=payload)
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -443,7 +502,7 @@ class RestCatalog(Catalog):
     def drop_namespace(self, namespace: Union[str, Identifier]) -> None:
         namespace_tuple = self._check_valid_namespace_identifier(namespace)
         namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
-        response = self.session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
+        response = self._session.delete(self.url(Endpoints.drop_namespace, namespace=namespace))
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -451,7 +510,7 @@ class RestCatalog(Catalog):
 
     def list_namespaces(self, namespace: Union[str, Identifier] = ()) -> List[Identifier]:
         namespace_tuple = self.identifier_to_tuple(namespace)
-        response = self.session.get(
+        response = self._session.get(
             self.url(
                 f"{Endpoints.list_namespaces}?parent={NAMESPACE_SEPARATOR.join(namespace_tuple)}"
                 if namespace_tuple
@@ -469,7 +528,7 @@ class RestCatalog(Catalog):
     def load_namespace_properties(self, namespace: Union[str, Identifier]) -> Properties:
         namespace_tuple = self._check_valid_namespace_identifier(namespace)
         namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
-        response = self.session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
+        response = self._session.get(self.url(Endpoints.load_namespace_metadata, namespace=namespace))
         try:
             response.raise_for_status()
         except HTTPError as exc:
@@ -483,7 +542,7 @@ class RestCatalog(Catalog):
         namespace_tuple = self._check_valid_namespace_identifier(namespace)
         namespace = NAMESPACE_SEPARATOR.join(namespace_tuple)
         payload = {"removals": list(removals or []), "updates": updates}
-        response = self.session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload)
+        response = self._session.post(self.url(Endpoints.update_properties, namespace=namespace), json=payload)
         try:
             response.raise_for_status()
         except HTTPError as exc:
diff --git a/python/tests/catalog/test_rest.py b/python/tests/catalog/test_rest.py
index 9c1700d97f..6f3cafffc1 100644
--- a/python/tests/catalog/test_rest.py
+++ b/python/tests/catalog/test_rest.py
@@ -91,7 +91,8 @@ def test_token_200(rest_mock: Mocker) -> None:
         request_headers=OAUTH_TEST_HEADERS,
     )
     assert (
-        RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS).session.headers["Authorization"] == f"Bearer {TEST_TOKEN}"
+        RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS)._session.headers["Authorization"]  # pylint: disable=W0212
+        == f"Bearer {TEST_TOKEN}"
     )
 
 
@@ -161,6 +162,21 @@ def test_list_tables_200(rest_mock: Mocker) -> None:
     assert RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN).list_tables(namespace) == [("examples", "fooshare")]
 
 
+def test_list_tables_200_sigv4(rest_mock: Mocker) -> None:
+    namespace = "examples"
+    rest_mock.get(
+        f"{TEST_URI}v1/namespaces/{namespace}/tables",
+        json={"identifiers": [{"namespace": ["examples"], "name": "fooshare"}]},
+        status_code=200,
+        request_headers=TEST_HEADERS,
+    )
+
+    assert RestCatalog("rest", **{"uri": TEST_URI, "token": TEST_TOKEN, "rest.sigv4-enabled": "true"}).list_tables(namespace) == [
+        ("examples", "fooshare")
+    ]
+    assert rest_mock.called
+
+
 def test_list_tables_404(rest_mock: Mocker) -> None:
     namespace = "examples"
     rest_mock.get(