You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/21 20:15:37 UTC

[iceberg] branch master updated: Python: Implement Schema.select (#5966)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new bac84bcc58 Python: Implement Schema.select (#5966)
bac84bcc58 is described below

commit bac84bcc580a392955eb520b9e3c316a16af62b7
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Fri Oct 21 22:15:32 2022 +0200

    Python: Implement Schema.select (#5966)
---
 python/pyiceberg/schema.py  | 30 +++++++++++++++++-------------
 python/tests/test_schema.py | 41 +++++++++++++++++++++++++++++++----------
 2 files changed, 48 insertions(+), 23 deletions(-)

diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 7f7d096ace..82a9533392 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -210,7 +210,7 @@ class Schema(IcebergBaseModel):
 
         return self._lazy_id_to_accessor[field_id]
 
-    def select(self, names: List[str], case_sensitive: bool = True) -> "Schema":
+    def select(self, *names: str, case_sensitive: bool = True) -> "Schema":
         """Return a new schema instance pruned to a subset of columns
 
         Args:
@@ -219,20 +219,20 @@ class Schema(IcebergBaseModel):
 
         Returns:
             Schema: A new schema with pruned columns
+
+        Raises:
+            ValueError: If a column is selected that doesn't exist
         """
-        if case_sensitive:
-            return self._case_sensitive_select(schema=self, names=names)
-        return self._case_insensitive_select(schema=self, names=names)
 
-    @classmethod
-    def _case_sensitive_select(cls, schema: "Schema", names: List[str]):
-        # TODO: Add a PruneColumns schema visitor and use it here
-        raise NotImplementedError()
+        try:
+            if case_sensitive:
+                ids = {self._name_to_id[name] for name in names}
+            else:
+                ids = {self._lazy_name_to_id_lower[name.lower()] for name in names}
+        except KeyError as e:
+            raise ValueError(f"Could not find column: {e}") from e
 
-    @classmethod
-    def _case_insensitive_select(cls, schema: "Schema", names: List[str]):
-        # TODO: Add a PruneColumns schema visitor and use it here
-        raise NotImplementedError()
+        return prune_columns(self, ids)
 
 
 class SchemaVisitor(Generic[T], ABC):
@@ -800,7 +800,11 @@ class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
 
 def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
     result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
-    return Schema(*(result or StructType()).fields, schema_id=schema.schema_id, identifier_field_ids=schema.identifier_field_ids)
+    return Schema(
+        *(result or StructType()).fields,
+        schema_id=schema.schema_id,
+        identifier_field_ids=list(selected.intersection(schema.identifier_field_ids)),
+    )
 
 
 class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py
index 5b64d95d6b..b2a2790438 100644
--- a/python/tests/test_schema.py
+++ b/python/tests/test_schema.py
@@ -438,7 +438,7 @@ def test_prune_columns_list(table_schema_nested: Schema):
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -457,7 +457,7 @@ def test_prune_columns_list_full(table_schema_nested: Schema):
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -479,7 +479,7 @@ def test_prune_columns_map(table_schema_nested: Schema):
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -509,7 +509,7 @@ def test_prune_columns_map_full(table_schema_nested: Schema):
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -531,7 +531,7 @@ def test_prune_columns_map_key(table_schema_nested: Schema):
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -544,7 +544,7 @@ def test_prune_columns_struct(table_schema_nested: Schema):
             required=False,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -558,7 +558,7 @@ def test_prune_columns_struct_full(table_schema_nested: Schema):
             required=False,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -625,7 +625,7 @@ def test_prune_columns_struct_in_map():
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
@@ -647,7 +647,7 @@ def test_prune_columns_struct_in_map_full():
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
     assert prune_columns(table_schema_nested, {11}, True) == Schema(
         NestedField(
@@ -664,10 +664,31 @@ def test_prune_columns_struct_in_map_full():
             required=True,
         ),
         schema_id=1,
-        identifier_field_ids=[1],
+        identifier_field_ids=[],
     )
 
 
 def test_prune_columns_select_original_schema(table_schema_nested: Schema):
     ids = set(range(table_schema_nested.highest_field_id))
     assert prune_columns(table_schema_nested, ids, True) == table_schema_nested
+
+
+def test_schema_select(table_schema_nested: Schema):
+    assert table_schema_nested.select("bar", "baz") == Schema(
+        NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
+        NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
+        schema_id=1,
+        identifier_field_ids=[],
+    )
+
+
+def test_schema_select_case_insensitive(table_schema_nested: Schema):
+    assert table_schema_nested.select("BAZ", case_sensitive=False) == Schema(
+        NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[]
+    )
+
+
+def test_schema_select_cant_be_found(table_schema_nested: Schema):
+    with pytest.raises(ValueError) as exc_info:
+        table_schema_nested.select("BAZ", case_sensitive=True)
+    assert "Could not find column: 'BAZ'" in str(exc_info.value)