You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/08/03 16:56:48 UTC

[iceberg] branch master updated: Python: Handle OAuthErrorResponse properly (#5416)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new be20b6f13a Python: Handle OAuthErrorResponse properly (#5416)
be20b6f13a is described below

commit be20b6f13a2da4c9e47a934688d1a90e57c5f70f
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Wed Aug 3 18:56:42 2022 +0200

    Python: Handle OAuthErrorResponse properly (#5416)
---
 python/pyiceberg/catalog/rest.py  | 63 +++++++++++++++++++++++++--------------
 python/pyiceberg/exceptions.py    |  4 +++
 python/tests/catalog/test_rest.py | 13 ++++++++
 3 files changed, 57 insertions(+), 23 deletions(-)

diff --git a/python/pyiceberg/catalog/rest.py b/python/pyiceberg/catalog/rest.py
index bcedbc5cd4..afac2dd6c7 100644
--- a/python/pyiceberg/catalog/rest.py
+++ b/python/pyiceberg/catalog/rest.py
@@ -18,6 +18,7 @@ from json import JSONDecodeError
 from typing import (
     Dict,
     List,
+    Literal,
     Optional,
     Set,
     Tuple,
@@ -40,6 +41,7 @@ from pyiceberg.exceptions import (
     NamespaceAlreadyExistsError,
     NoSuchNamespaceError,
     NoSuchTableError,
+    OAuthError,
     RESTError,
     ServerError,
     ServiceUnavailableError,
@@ -148,6 +150,14 @@ class ErrorResponse(IcebergBaseModel):
     error: ErrorResponseMessage = Field()
 
 
+class OAuthErrorResponse(IcebergBaseModel):
+    error: Literal[
+        "invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"
+    ]
+    error_description: Optional[str]
+    error_uri: Optional[str]
+
+
 class RestCatalog(Catalog):
     token: str
     config: Properties
@@ -233,7 +243,7 @@ class RestCatalog(Catalog):
         try:
             response.raise_for_status()
         except HTTPError as exc:
-            self._handle_non_200_response(exc, {401: BadCredentialsError})
+            self._handle_non_200_response(exc, {400: OAuthError, 401: BadCredentialsError})
 
         return TokenResponse(**response.json()).access_token
 
@@ -255,39 +265,46 @@ class RestCatalog(Catalog):
         return {"namespace": identifier[:-1], "name": identifier[-1]}
 
     def _handle_non_200_response(self, exc: HTTPError, error_handler: Dict[int, Type[Exception]]):
-        try:
-            response = ErrorResponse(**exc.response.json())
-        except JSONDecodeError:
-            # In the case we don't have a proper response
-            response = ErrorResponse(
-                error=ErrorResponseMessage(
-                    message=f"Could not decode json payload: {exc.response.text}",
-                    type="RESTError",
-                    code=exc.response.status_code,
-                )
-            )
-
+        exception: Type[Exception]
         code = exc.response.status_code
         if code in error_handler:
-            raise error_handler[code](response.error.message) from exc
+            exception = error_handler[code]
         elif code == 400:
-            raise BadRequestError(response.error.message) from exc
+            exception = BadRequestError
         elif code == 401:
-            raise UnauthorizedError(response.error.message) from exc
+            exception = UnauthorizedError
         elif code == 403:
-            raise ForbiddenError(response.error.message) from exc
+            exception = ForbiddenError
         elif code == 422:
-            raise RESTError(response.error.message) from exc
+            exception = RESTError
         elif code == 419:
-            raise AuthorizationExpiredError(response.error.message)
+            exception = AuthorizationExpiredError
         elif code == 501:
-            raise NotImplementedError(response.error.message)
+            exception = NotImplementedError
         elif code == 503:
-            raise ServiceUnavailableError(response.error.message) from exc
+            exception = ServiceUnavailableError
         elif 500 <= code < 600:
-            raise ServerError(response.error.message) from exc
+            exception = ServerError
         else:
-            raise RESTError(response.error.message) from exc
+            exception = RESTError
+
+        try:
+            if exception == OAuthError:
+                # The OAuthErrorResponse has a different format
+                error = OAuthErrorResponse(**exc.response.json())
+                response = str(error.error)
+                if description := error.error_description:
+                    response += f": {description}"
+                if uri := error.error_uri:
+                    response += f" ({uri})"
+            else:
+                error = ErrorResponse(**exc.response.json()).error
+                response = f"{error.type}: {error.message}"
+        except JSONDecodeError:
+            # In the case we don't have a proper response
+            response = f"RESTError: Could not decode json payload: {exc.response.text}"
+
+        raise exception(response) from exc
 
     def create_table(
         self,
diff --git a/python/pyiceberg/exceptions.py b/python/pyiceberg/exceptions.py
index 432c25675f..b5db6a0e1b 100644
--- a/python/pyiceberg/exceptions.py
+++ b/python/pyiceberg/exceptions.py
@@ -70,3 +70,7 @@ class ForbiddenError(RESTError):
 
 class AuthorizationExpiredError(RESTError):
     """When the credentials are expired when performing an action on the REST catalog"""
+
+
+class OAuthError(RESTError):
+    """Raises when there is an error with the OAuth call"""
diff --git a/python/tests/catalog/test_rest.py b/python/tests/catalog/test_rest.py
index b3c7b661f2..11b1885d54 100644
--- a/python/tests/catalog/test_rest.py
+++ b/python/tests/catalog/test_rest.py
@@ -27,6 +27,7 @@ from pyiceberg.exceptions import (
     NamespaceAlreadyExistsError,
     NoSuchNamespaceError,
     NoSuchTableError,
+    OAuthError,
     TableAlreadyExistsError,
 )
 from pyiceberg.schema import Schema
@@ -76,6 +77,18 @@ def test_token_200(rest_mock: Mocker):
     assert RestCatalog("rest", {}, TEST_URI, TEST_CREDENTIALS).token == TEST_TOKEN
 
 
+def test_token_400(rest_mock: Mocker):
+    rest_mock.post(
+        f"{TEST_URI}v1/oauth/tokens",
+        json={"error": "invalid_client", "error_description": "Credentials for key invalid_key do not match"},
+        status_code=400,
+    )
+
+    with pytest.raises(OAuthError) as e:
+        RestCatalog("rest", {}, TEST_URI, credentials=TEST_CREDENTIALS)
+    assert str(e.value) == "invalid_client: Credentials for key invalid_key do not match"
+
+
 def test_token_401(rest_mock: Mocker):
     message = "Invalid client ID: abc"
     rest_mock.post(