You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by fo...@apache.org on 2022/10/17 11:11:59 UTC

[iceberg] branch master updated: Python: Cleanup inconsistencies around OAuth responses (#5957)

This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c2fcccf4a Python: Cleanup inconsistencies around OAuth responses (#5957)
7c2fcccf4a is described below

commit 7c2fcccf4a033ce620a254fae4ddc92cdc10bf85
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Mon Oct 17 13:11:51 2022 +0200

    Python: Cleanup inconsistencies around OAuth responses (#5957)
    
    * Python: Cleanup inconsistency in OAuth responses
    
    401 also returns oauth token reponse
    https://github.com/apache/iceberg/blob/master/open-api/rest-catalog-open-api.yaml#L170-L178
    
    * Comments
    
    * Switch them around
---
 python/pyiceberg/catalog/__init__.py | 35 +++++++++++++++++++++++++++++------
 python/pyiceberg/catalog/rest.py     | 18 ++++++++++++------
 python/pyiceberg/cli/console.py      |  9 +--------
 python/pyiceberg/exceptions.py       |  4 ----
 python/tests/catalog/test_rest.py    | 28 +++-------------------------
 5 files changed, 45 insertions(+), 49 deletions(-)

diff --git a/python/pyiceberg/catalog/__init__.py b/python/pyiceberg/catalog/__init__.py
index ee8372cdc9..d4e3163c29 100644
--- a/python/pyiceberg/catalog/__init__.py
+++ b/python/pyiceberg/catalog/__init__.py
@@ -41,6 +41,7 @@ logger = logging.getLogger(__name__)
 _ENV_CONFIG = Config()
 
 TYPE = "type"
+URI = "uri"
 
 
 class CatalogType(Enum):
@@ -69,7 +70,7 @@ AVAILABLE_CATALOGS: dict[CatalogType, Callable[[str, Properties], Catalog]] = {
 }
 
 
-def infer_catalog_type(catalog_properties: RecursiveDict) -> CatalogType | None:
+def infer_catalog_type(name: str, catalog_properties: RecursiveDict) -> CatalogType | None:
     """Tries to infer the type based on the dict
 
     Args:
@@ -77,6 +78,9 @@ def infer_catalog_type(catalog_properties: RecursiveDict) -> CatalogType | None:
 
     Returns:
         The inferred type based on the provided properties
+
+    Raises:
+        ValueError: Raises a ValueError in case properties are missing, or the wrong type
     """
     if uri := catalog_properties.get("uri"):
         if isinstance(uri, str):
@@ -84,20 +88,39 @@ def infer_catalog_type(catalog_properties: RecursiveDict) -> CatalogType | None:
                 return CatalogType.REST
             elif uri.startswith("thrift"):
                 return CatalogType.HIVE
-    return None
+            else:
+                raise ValueError(f"Could not infer the catalog type from the uri: {uri}")
+        else:
+            raise ValueError(f"Expects the URI to be a string, got: {type(uri)}")
+    raise ValueError(
+        f"URI missing, please provide using --uri, the config or environment variable PYICEBERG_CATALOG__{name.upper()}__URI"
+    )
 
 
 def load_catalog(name: str, **properties: str | None) -> Catalog:
+    """Load the catalog based on the properties
+
+    Will look up the properties from the config, based on the name
+
+    Args:
+        name: The name of the catalog
+        properties: The properties that are used next to the configuration
+
+    Returns:
+        An initialized Catalog
+
+    Raises:
+        ValueError: Raises a ValueError in case properties are missing or malformed,
+            or if it could not determine the catalog based on the properties
+    """
     env = _ENV_CONFIG.get_catalog_config(name)
     conf = merge_config(env or {}, properties)
 
+    catalog_type: CatalogType | None
     if provided_catalog_type := conf.get(TYPE):
         catalog_type = CatalogType[provided_catalog_type.upper()]
     else:
-        if inferred_catalog_type := infer_catalog_type(conf):
-            catalog_type = inferred_catalog_type
-        else:
-            raise ValueError(f"Invalid configuration. Could not determine the catalog type: {properties}")
+        catalog_type = infer_catalog_type(name, conf)
 
     if catalog_type:
         return AVAILABLE_CATALOGS[catalog_type](name, conf)
diff --git a/python/pyiceberg/catalog/rest.py b/python/pyiceberg/catalog/rest.py
index 09f5793f3c..90db6691f9 100644
--- a/python/pyiceberg/catalog/rest.py
+++ b/python/pyiceberg/catalog/rest.py
@@ -31,6 +31,7 @@ from requests import HTTPError
 
 from pyiceberg import __version__
 from pyiceberg.catalog import (
+    URI,
     Catalog,
     Identifier,
     Properties,
@@ -38,7 +39,6 @@ from pyiceberg.catalog import (
 )
 from pyiceberg.exceptions import (
     AuthorizationExpiredError,
-    BadCredentialsError,
     BadRequestError,
     ForbiddenError,
     NamespaceAlreadyExistsError,
@@ -88,7 +88,9 @@ CLIENT_CREDENTIALS = "client_credentials"
 CREDENTIAL = "credential"
 GRANT_TYPE = "grant_type"
 SCOPE = "scope"
+TOKEN = "token"
 TOKEN_EXCHANGE = "urn:ietf:params:oauth:grant-type:token-exchange"
+SEMICOLON = ":"
 
 NAMESPACE_SEPARATOR = b"\x1F".decode("UTF-8")
 
@@ -181,9 +183,10 @@ class RestCatalog(Catalog):
             properties: Properties that are passed along to the configuration
         """
         self.properties = properties
-        self.uri = properties["uri"]
-        if credential := properties.get("credential"):
-            properties["token"] = self._fetch_access_token(credential)
+        self.uri = properties[URI]
+
+        if credential := properties.get(CREDENTIAL):
+            properties[TOKEN] = self._fetch_access_token(credential)
         super().__init__(name, **self._fetch_config(properties))
 
     def _check_valid_namespace_identifier(self, identifier: Union[str, Identifier]) -> Identifier:
@@ -225,7 +228,10 @@ class RestCatalog(Catalog):
         return url + endpoint.format(**kwargs)
 
     def _fetch_access_token(self, credential: str) -> str:
-        client_id, client_secret = credential.split(":")
+        if SEMICOLON in credential:
+            client_id, client_secret = credential.split(SEMICOLON)
+        else:
+            client_id, client_secret = None, credential
         data = {GRANT_TYPE: CLIENT_CREDENTIALS, CLIENT_ID: client_id, CLIENT_SECRET: client_secret, SCOPE: CATALOG_SCOPE}
         url = self.url(Endpoints.get_token, prefixed=False)
         # Uses application/x-www-form-urlencoded by default
@@ -233,7 +239,7 @@ class RestCatalog(Catalog):
         try:
             response.raise_for_status()
         except HTTPError as exc:
-            self._handle_non_200_response(exc, {400: OAuthError, 401: BadCredentialsError})
+            self._handle_non_200_response(exc, {400: OAuthError, 401: OAuthError})
 
         return TokenResponse(**response.json()).access_token
 
diff --git a/python/pyiceberg/cli/console.py b/python/pyiceberg/cli/console.py
index bac4ae6bc2..e783a4312f 100644
--- a/python/pyiceberg/cli/console.py
+++ b/python/pyiceberg/cli/console.py
@@ -65,14 +65,7 @@ def run(ctx: Context, catalog: str, verbose: bool, output: str, uri: Optional[st
         ctx.obj["output"] = JsonOutput(verbose=verbose)
 
     try:
-        try:
-            ctx.obj["catalog"] = load_catalog(catalog, **properties)
-        except ValueError as exc:
-            if not uri:
-                raise ValueError(
-                    f"URI missing, please provide using --uri, the config or environment variable PYICEBERG_CATALOG__{catalog.upper()}__URI"
-                ) from exc
-            raise exc
+        ctx.obj["catalog"] = load_catalog(catalog, **properties)
     except Exception as e:
         ctx.obj["output"].exception(e)
         ctx.exit(1)
diff --git a/python/pyiceberg/exceptions.py b/python/pyiceberg/exceptions.py
index 86e8a8102f..e44125cd9c 100644
--- a/python/pyiceberg/exceptions.py
+++ b/python/pyiceberg/exceptions.py
@@ -44,10 +44,6 @@ class RESTError(Exception):
     """Raises when there is an unknown response from the REST Catalog"""
 
 
-class BadCredentialsError(RESTError):
-    """Raises when providing invalid credentials"""
-
-
 class BadRequestError(RESTError):
     """Raises when an invalid request is being made"""
 
diff --git a/python/tests/catalog/test_rest.py b/python/tests/catalog/test_rest.py
index 270bad85f1..e91d8e3674 100644
--- a/python/tests/catalog/test_rest.py
+++ b/python/tests/catalog/test_rest.py
@@ -23,12 +23,10 @@ from requests_mock import Mocker
 from pyiceberg.catalog import PropertiesUpdateSummary, Table
 from pyiceberg.catalog.rest import RestCatalog
 from pyiceberg.exceptions import (
-    BadCredentialsError,
     NamespaceAlreadyExistsError,
     NoSuchNamespaceError,
     NoSuchTableError,
     OAuthError,
-    RESTError,
     TableAlreadyExistsError,
 )
 from pyiceberg.schema import Schema
@@ -96,34 +94,14 @@ def test_token_400(rest_mock: Mocker):
 
 
 def test_token_401(rest_mock: Mocker):
-    message = "Invalid client ID: abc"
+    message = "invalid_client"
     rest_mock.post(
         f"{TEST_URI}v1/oauth/tokens",
-        json={
-            "error": {
-                "message": message,
-                "type": "BadCredentialsException",
-                "code": 401,
-            }
-        },
-        status_code=401,
-    )
-
-    with pytest.raises(BadCredentialsError) as e:
-        RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS)
-    assert message in str(e.value)
-
-
-def test_token_401_oauth_error(rest_mock: Mocker):
-    """This test returns a OAuth error instead of an OpenAPI error"""
-    message = """RESTError 401: Received unexpected JSON Payload: {"error": "invalid_client", "error_description": "Invalid credentials"}, errors: value is not a valid dict"""
-    rest_mock.post(
-        f"{TEST_URI}v1/oauth/tokens",
-        json={"error": "invalid_client", "error_description": "Invalid credentials"},
+        json={"error": "invalid_client", "error_description": "Unknown or invalid client"},
         status_code=401,
     )
 
-    with pytest.raises(RESTError) as e:
+    with pytest.raises(OAuthError) as e:
         RestCatalog("rest", uri=TEST_URI, credential=TEST_CREDENTIALS)
     assert message in str(e.value)