You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/09 20:19:02 UTC
[iceberg] branch master updated: Python: Ability to Prune Columns (#5931)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new ef8ef491e3 Python: Ability to Prune Columns (#5931)
ef8ef491e3 is described below
commit ef8ef491e3d6748dbcf14e28acec800dab50a4b6
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Sun Oct 9 22:18:54 2022 +0200
Python: Ability to Prune Columns (#5931)
This is required if we want to implement `.select(*columns: str)`
---
python/pyiceberg/schema.py | 140 ++++++++++++++++++++++++
python/pyiceberg/types.py | 4 +
python/tests/test_schema.py | 258 +++++++++++++++++++++++++++++++++++++++++++-
3 files changed, 401 insertions(+), 1 deletion(-)
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 7f2c3166ff..7f7d096ace 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -27,6 +27,7 @@ from typing import (
List,
Literal,
Optional,
+ Set,
Tuple,
TypeVar,
Union,
@@ -795,3 +796,142 @@ class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
def primitive(self, primitive: PrimitiveType) -> PrimitiveType:
return primitive
+
+
+def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
+ result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
+ return Schema(*(result or StructType()).fields, schema_id=schema.schema_id, identifier_field_ids=schema.identifier_field_ids)
+
+
+class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
+ selected: Set[int]
+ select_full_types: bool
+
+ def __init__(self, selected: Set[int], select_full_types: bool):
+ self.selected = selected
+ self.select_full_types = select_full_types
+
+ def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
+ return struct_result
+
+ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
+ fields = struct.fields
+ selected_fields = []
+ same_type = True
+
+ for idx, projected_type in enumerate(field_results):
+ field = fields[idx]
+ if field.field_type == projected_type:
+ selected_fields.append(field)
+ elif projected_type is not None:
+ same_type = False
+ # Type has changed, create a new field with the projected type
+ selected_fields.append(
+ NestedField(
+ field_id=field.field_id,
+ name=field.name,
+ field_type=projected_type,
+ doc=field.doc,
+ required=field.required,
+ )
+ )
+
+ if selected_fields:
+ if len(selected_fields) == len(fields) and same_type is True:
+ # Nothing has changed, and we can return the original struct
+ return struct
+ else:
+ return StructType(*selected_fields)
+ return None
+
+ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
+ if field.field_id in self.selected:
+ if self.select_full_types:
+ return field.field_type
+ elif field.field_type.is_struct:
+ return self._project_selected_struct(field_result)
+ else:
+ if not field.field_type.is_primitive:
+ raise ValueError(
+ f"Cannot explicitly project List or Map types, {field.field_id}:{field.name} of type {field.field_type} was selected"
+ )
+ # Selected non-struct field
+ return field.field_type
+ elif field_result is not None:
+ # This field wasn't selected but a subfield was so include that
+ return field_result
+ else:
+ return None
+
+ def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
+ if list_type.element_id in self.selected:
+ if self.select_full_types:
+ return list_type
+ elif list_type.element_type and list_type.element_type.is_struct:
+ projected_struct = self._project_selected_struct(element_result)
+ return self._project_list(list_type, projected_struct)
+ else:
+ if not list_type.element_type.is_primitive:
+ raise ValueError(
+ f"Cannot explicitly project List or Map types, {list_type.element_id} of type {list_type.element_type} was selected"
+ )
+ return list_type
+ elif element_result is not None:
+ return self._project_list(list_type, element_result)
+ else:
+ return None
+
+ def map(
+ self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
+ ) -> Optional[IcebergType]:
+ if map_type.value_id in self.selected:
+ if self.select_full_types:
+ return map_type
+ elif map_type.value_type and map_type.value_type.is_struct:
+ projected_struct = self._project_selected_struct(value_result)
+ return self._project_map(map_type, projected_struct)
+ if not map_type.value_type.is_primitive:
+ raise ValueError(
+ f"Cannot explicitly project List or Map types, Map value {map_type.value_id} of type {map_type.value_type} was selected"
+ )
+ return map_type
+ elif value_result is not None:
+ return self._project_map(map_type, value_result)
+ elif map_type.key_id in self.selected:
+ return map_type
+ return None
+
+ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
+ return None
+
+ @staticmethod
+ def _project_selected_struct(projected_field: Optional[IcebergType]) -> StructType:
+ if projected_field and not isinstance(projected_field, StructType):
+ raise ValueError("Expected a struct")
+
+ if projected_field is None:
+ return StructType()
+ else:
+ return projected_field
+
+ @staticmethod
+ def _project_list(list_type: ListType, element_result: IcebergType):
+ if list_type.element_type == element_result:
+ return list_type
+ else:
+ return ListType(
+ element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required
+ )
+
+ @staticmethod
+ def _project_map(map_type: MapType, value_result: IcebergType):
+ if map_type.value_type == value_result:
+ return map_type
+ else:
+ return MapType(
+ key_id=map_type.key_id,
+ value_id=map_type.value_id,
+ key_type=map_type.key_type,
+ value_type=value_result,
+ value_required=map_type.value_required,
+ )
diff --git a/python/pyiceberg/types.py b/python/pyiceberg/types.py
index 2f4da429b7..86a234a4e5 100644
--- a/python/pyiceberg/types.py
+++ b/python/pyiceberg/types.py
@@ -97,6 +97,10 @@ class IcebergType(IcebergBaseModel, Singleton):
def is_primitive(self) -> bool:
return isinstance(self, PrimitiveType)
+ @property
+ def is_struct(self) -> bool:
+ return isinstance(self, StructType)
+
class PrimitiveType(IcebergType):
"""Base class for all Iceberg Primitive Types"""
diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py
index 9fbebb0902..0017760ed2 100644
--- a/python/tests/test_schema.py
+++ b/python/tests/test_schema.py
@@ -23,7 +23,7 @@ import pytest
from pyiceberg import schema
from pyiceberg.expressions.base import Accessor
from pyiceberg.files import StructProtocol
-from pyiceberg.schema import Schema, build_position_accessors
+from pyiceberg.schema import Schema, build_position_accessors, prune_columns
from pyiceberg.typedef import EMPTY_DICT
from pyiceberg.types import (
BooleanType,
@@ -415,3 +415,259 @@ def test_deserialize_schema(table_schema_simple: Schema):
)
expected = table_schema_simple
assert actual == expected
+
+
+def test_prune_columns_string(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {1}, False) == Schema(
+ NestedField(field_id=1, name="foo", field_type=StringType(), required=False), schema_id=1, identifier_field_ids=[1]
+ )
+
+
+def test_prune_columns_string_full(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {1}, True) == Schema(
+ NestedField(field_id=1, name="foo", field_type=StringType(), required=False), schema_id=1, identifier_field_ids=[1]
+ )
+
+
+def test_prune_columns_list(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {5}, False) == Schema(
+ NestedField(
+ field_id=4,
+ name="qux",
+ field_type=ListType(type="list", element_id=5, element_type=StringType(), element_required=True),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_list_itself(table_schema_nested: Schema):
+ with pytest.raises(ValueError) as exc_info:
+ assert prune_columns(table_schema_nested, {4}, False)
+ assert "Cannot explicitly project List or Map types, 4:qux of type list<string> was selected" in str(exc_info.value)
+
+
+def test_prune_columns_list_full(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {5}, True) == Schema(
+ NestedField(
+ field_id=4,
+ name="qux",
+ field_type=ListType(type="list", element_id=5, element_type=StringType(), element_required=True),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_map(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {9}, False) == Schema(
+ NestedField(
+ field_id=6,
+ name="quux",
+ field_type=MapType(
+ type="map",
+ key_id=7,
+ key_type=StringType(),
+ value_id=8,
+ value_type=MapType(
+ type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_map_itself(table_schema_nested: Schema):
+ with pytest.raises(ValueError) as exc_info:
+ assert prune_columns(table_schema_nested, {6}, False)
+ assert "Cannot explicitly project List or Map types, 6:quux of type map<string, map<string, int>> was selected" in str(
+ exc_info.value
+ )
+
+
+def test_prune_columns_map_full(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {9}, True) == Schema(
+ NestedField(
+ field_id=6,
+ name="quux",
+ field_type=MapType(
+ type="map",
+ key_id=7,
+ key_type=StringType(),
+ value_id=8,
+ value_type=MapType(
+ type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_map_key(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {10}, False) == Schema(
+ NestedField(
+ field_id=6,
+ name="quux",
+ field_type=MapType(
+ type="map",
+ key_id=7,
+ key_type=StringType(),
+ value_id=8,
+ value_type=MapType(
+ type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_struct(table_schema_nested: Schema):
+ assert prune_columns(table_schema_nested, {16}, False) == Schema(
+ NestedField(
+ field_id=15,
+ name="person",
+ field_type=StructType(NestedField(field_id=16, name="name", field_type=StringType(), required=False)),
+ required=False,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_struct_full(table_schema_nested: Schema):
+ actual = prune_columns(table_schema_nested, {16}, True)
+ assert actual == Schema(
+ NestedField(
+ field_id=15,
+ name="person",
+ field_type=StructType(NestedField(field_id=16, name="name", field_type=StringType(), required=False)),
+ required=False,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_empty_struct():
+ schema_empty_struct = Schema(
+ NestedField(
+ field_id=15,
+ name="person",
+ field_type=StructType(),
+ required=False,
+ )
+ )
+ assert prune_columns(schema_empty_struct, {15}, False) == Schema(
+ NestedField(field_id=15, name="person", field_type=StructType(), required=False), schema_id=0, identifier_field_ids=[]
+ )
+
+
+def test_prune_columns_empty_struct_full():
+ schema_empty_struct = Schema(
+ NestedField(
+ field_id=15,
+ name="person",
+ field_type=StructType(),
+ required=False,
+ )
+ )
+ assert prune_columns(schema_empty_struct, {15}, True) == Schema(
+ NestedField(field_id=15, name="person", field_type=StructType(), required=False), schema_id=0, identifier_field_ids=[]
+ )
+
+
+def test_prune_columns_struct_in_map():
+ table_schema_nested = Schema(
+ NestedField(
+ field_id=6,
+ name="id_to_person",
+ field_type=MapType(
+ key_id=7,
+ key_type=IntegerType(),
+ value_id=8,
+ value_type=StructType(
+ NestedField(field_id=10, name="name", field_type=StringType(), required=False),
+ NestedField(field_id=11, name="age", field_type=IntegerType(), required=True),
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+ assert prune_columns(table_schema_nested, {11}, False) == Schema(
+ NestedField(
+ field_id=6,
+ name="id_to_person",
+ field_type=MapType(
+ type="map",
+ key_id=7,
+ key_type=IntegerType(),
+ value_id=8,
+ value_type=StructType(NestedField(field_id=11, name="age", field_type=IntegerType(), required=True)),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_struct_in_map_full():
+ table_schema_nested = Schema(
+ NestedField(
+ field_id=6,
+ name="id_to_person",
+ field_type=MapType(
+ key_id=7,
+ key_type=IntegerType(),
+ value_id=8,
+ value_type=StructType(
+ NestedField(field_id=10, name="name", field_type=StringType(), required=False),
+ NestedField(field_id=11, name="age", field_type=IntegerType(), required=True),
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+ assert prune_columns(table_schema_nested, {11}, True) == Schema(
+ NestedField(
+ field_id=6,
+ name="id_to_person",
+ field_type=MapType(
+ type="map",
+ key_id=7,
+ key_type=IntegerType(),
+ value_id=8,
+ value_type=StructType(NestedField(field_id=11, name="age", field_type=IntegerType(), required=True)),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
+def test_prune_columns_select_original_schema(table_schema_nested: Schema):
+ ids = set(range(table_schema_nested.highest_field_id))
+ assert prune_columns(table_schema_nested, ids, True) == table_schema_nested