You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/21 20:15:37 UTC
[iceberg] branch master updated: Python: Implement Schema.select (#5966)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new bac84bcc58 Python: Implement Schema.select (#5966)
bac84bcc58 is described below
commit bac84bcc580a392955eb520b9e3c316a16af62b7
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Fri Oct 21 22:15:32 2022 +0200
Python: Implement Schema.select (#5966)
---
python/pyiceberg/schema.py | 30 +++++++++++++++++-------------
python/tests/test_schema.py | 41 +++++++++++++++++++++++++++++++----------
2 files changed, 48 insertions(+), 23 deletions(-)
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 7f7d096ace..82a9533392 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -210,7 +210,7 @@ class Schema(IcebergBaseModel):
return self._lazy_id_to_accessor[field_id]
- def select(self, names: List[str], case_sensitive: bool = True) -> "Schema":
+ def select(self, *names: str, case_sensitive: bool = True) -> "Schema":
"""Return a new schema instance pruned to a subset of columns
Args:
@@ -219,20 +219,20 @@ class Schema(IcebergBaseModel):
Returns:
Schema: A new schema with pruned columns
+
+ Raises:
+ ValueError: If a column is selected that doesn't exist
"""
- if case_sensitive:
- return self._case_sensitive_select(schema=self, names=names)
- return self._case_insensitive_select(schema=self, names=names)
- @classmethod
- def _case_sensitive_select(cls, schema: "Schema", names: List[str]):
- # TODO: Add a PruneColumns schema visitor and use it here
- raise NotImplementedError()
+ try:
+ if case_sensitive:
+ ids = {self._name_to_id[name] for name in names}
+ else:
+ ids = {self._lazy_name_to_id_lower[name.lower()] for name in names}
+ except KeyError as e:
+ raise ValueError(f"Could not find column: {e}") from e
- @classmethod
- def _case_insensitive_select(cls, schema: "Schema", names: List[str]):
- # TODO: Add a PruneColumns schema visitor and use it here
- raise NotImplementedError()
+ return prune_columns(self, ids)
class SchemaVisitor(Generic[T], ABC):
@@ -800,7 +800,11 @@ class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
- return Schema(*(result or StructType()).fields, schema_id=schema.schema_id, identifier_field_ids=schema.identifier_field_ids)
+ return Schema(
+ *(result or StructType()).fields,
+ schema_id=schema.schema_id,
+ identifier_field_ids=list(selected.intersection(schema.identifier_field_ids)),
+ )
class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py
index 5b64d95d6b..b2a2790438 100644
--- a/python/tests/test_schema.py
+++ b/python/tests/test_schema.py
@@ -438,7 +438,7 @@ def test_prune_columns_list(table_schema_nested: Schema):
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -457,7 +457,7 @@ def test_prune_columns_list_full(table_schema_nested: Schema):
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -479,7 +479,7 @@ def test_prune_columns_map(table_schema_nested: Schema):
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -509,7 +509,7 @@ def test_prune_columns_map_full(table_schema_nested: Schema):
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -531,7 +531,7 @@ def test_prune_columns_map_key(table_schema_nested: Schema):
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -544,7 +544,7 @@ def test_prune_columns_struct(table_schema_nested: Schema):
required=False,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -558,7 +558,7 @@ def test_prune_columns_struct_full(table_schema_nested: Schema):
required=False,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -625,7 +625,7 @@ def test_prune_columns_struct_in_map():
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
@@ -647,7 +647,7 @@ def test_prune_columns_struct_in_map_full():
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
assert prune_columns(table_schema_nested, {11}, True) == Schema(
NestedField(
@@ -664,10 +664,31 @@ def test_prune_columns_struct_in_map_full():
required=True,
),
schema_id=1,
- identifier_field_ids=[1],
+ identifier_field_ids=[],
)
def test_prune_columns_select_original_schema(table_schema_nested: Schema):
ids = set(range(table_schema_nested.highest_field_id))
assert prune_columns(table_schema_nested, ids, True) == table_schema_nested
+
+
+def test_schema_select(table_schema_nested: Schema):
+ assert table_schema_nested.select("bar", "baz") == Schema(
+ NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
+ NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
+ schema_id=1,
+ identifier_field_ids=[],
+ )
+
+
+def test_schema_select_case_insensitive(table_schema_nested: Schema):
+ assert table_schema_nested.select("BAZ", case_sensitive=False) == Schema(
+ NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[]
+ )
+
+
+def test_schema_select_cant_be_found(table_schema_nested: Schema):
+ with pytest.raises(ValueError) as exc_info:
+ table_schema_nested.select("BAZ", case_sensitive=True)
+ assert "Could not find column: 'BAZ'" in str(exc_info.value)