You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/09 20:19:02 UTC

[iceberg] branch master updated: Python: Ability to Prune Columns (#5931)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new ef8ef491e3 Python: Ability to Prune Columns (#5931)
ef8ef491e3 is described below

commit ef8ef491e3d6748dbcf14e28acec800dab50a4b6
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Sun Oct 9 22:18:54 2022 +0200

    Python: Ability to Prune Columns (#5931)
    
    This is required if we want to implement `.select(*columns: str)`
---
 python/pyiceberg/schema.py  | 140 ++++++++++++++++++++++++
 python/pyiceberg/types.py   |   4 +
 python/tests/test_schema.py | 258 +++++++++++++++++++++++++++++++++++++++++++-
 3 files changed, 401 insertions(+), 1 deletion(-)

diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 7f2c3166ff..7f7d096ace 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -27,6 +27,7 @@ from typing import (
     List,
     Literal,
     Optional,
+    Set,
     Tuple,
     TypeVar,
     Union,
@@ -795,3 +796,142 @@ class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
 
     def primitive(self, primitive: PrimitiveType) -> PrimitiveType:
         return primitive
+
+
+def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
+    result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
+    return Schema(*(result or StructType()).fields, schema_id=schema.schema_id, identifier_field_ids=schema.identifier_field_ids)
+
+
+class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
+    selected: Set[int]
+    select_full_types: bool
+
+    def __init__(self, selected: Set[int], select_full_types: bool):
+        self.selected = selected
+        self.select_full_types = select_full_types
+
+    def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
+        return struct_result
+
+    def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
+        fields = struct.fields
+        selected_fields = []
+        same_type = True
+
+        for idx, projected_type in enumerate(field_results):
+            field = fields[idx]
+            if field.field_type == projected_type:
+                selected_fields.append(field)
+            elif projected_type is not None:
+                same_type = False
+                # Type has changed, create a new field with the projected type
+                selected_fields.append(
+                    NestedField(
+                        field_id=field.field_id,
+                        name=field.name,
+                        field_type=projected_type,
+                        doc=field.doc,
+                        required=field.required,
+                    )
+                )
+
+        if selected_fields:
+            if len(selected_fields) == len(fields) and same_type is True:
+                # Nothing has changed, and we can return the original struct
+                return struct
+            else:
+                return StructType(*selected_fields)
+        return None
+
+    def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
+        if field.field_id in self.selected:
+            if self.select_full_types:
+                return field.field_type
+            elif field.field_type.is_struct:
+                return self._project_selected_struct(field_result)
+            else:
+                if not field.field_type.is_primitive:
+                    raise ValueError(
+                        f"Cannot explicitly project List or Map types, {field.field_id}:{field.name} of type {field.field_type} was selected"
+                    )
+                # Selected non-struct field
+                return field.field_type
+        elif field_result is not None:
+            # This field wasn't selected but a subfield was so include that
+            return field_result
+        else:
+            return None
+
+    def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
+        if list_type.element_id in self.selected:
+            if self.select_full_types:
+                return list_type
+            elif list_type.element_type and list_type.element_type.is_struct:
+                projected_struct = self._project_selected_struct(element_result)
+                return self._project_list(list_type, projected_struct)
+            else:
+                if not list_type.element_type.is_primitive:
+                    raise ValueError(
+                        f"Cannot explicitly project List or Map types, {list_type.element_id} of type {list_type.element_type} was selected"
+                    )
+                return list_type
+        elif element_result is not None:
+            return self._project_list(list_type, element_result)
+        else:
+            return None
+
+    def map(
+        self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
+    ) -> Optional[IcebergType]:
+        if map_type.value_id in self.selected:
+            if self.select_full_types:
+                return map_type
+            elif map_type.value_type and map_type.value_type.is_struct:
+                projected_struct = self._project_selected_struct(value_result)
+                return self._project_map(map_type, projected_struct)
+            if not map_type.value_type.is_primitive:
+                raise ValueError(
+                    f"Cannot explicitly project List or Map types, Map value {map_type.value_id} of type {map_type.value_type} was selected"
+                )
+            return map_type
+        elif value_result is not None:
+            return self._project_map(map_type, value_result)
+        elif map_type.key_id in self.selected:
+            return map_type
+        return None
+
+    def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
+        return None
+
+    @staticmethod
+    def _project_selected_struct(projected_field: Optional[IcebergType]) -> StructType:
+        if projected_field and not isinstance(projected_field, StructType):
+            raise ValueError("Expected a struct")
+
+        if projected_field is None:
+            return StructType()
+        else:
+            return projected_field
+
+    @staticmethod
+    def _project_list(list_type: ListType, element_result: IcebergType):
+        if list_type.element_type == element_result:
+            return list_type
+        else:
+            return ListType(
+                element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required
+            )
+
+    @staticmethod
+    def _project_map(map_type: MapType, value_result: IcebergType):
+        if map_type.value_type == value_result:
+            return map_type
+        else:
+            return MapType(
+                key_id=map_type.key_id,
+                value_id=map_type.value_id,
+                key_type=map_type.key_type,
+                value_type=value_result,
+                value_required=map_type.value_required,
+            )
diff --git a/python/pyiceberg/types.py b/python/pyiceberg/types.py
index 2f4da429b7..86a234a4e5 100644
--- a/python/pyiceberg/types.py
+++ b/python/pyiceberg/types.py
@@ -97,6 +97,10 @@ class IcebergType(IcebergBaseModel, Singleton):
     def is_primitive(self) -> bool:
         return isinstance(self, PrimitiveType)
 
+    @property
+    def is_struct(self) -> bool:
+        return isinstance(self, StructType)
+
 
 class PrimitiveType(IcebergType):
     """Base class for all Iceberg Primitive Types"""
diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py
index 9fbebb0902..0017760ed2 100644
--- a/python/tests/test_schema.py
+++ b/python/tests/test_schema.py
@@ -23,7 +23,7 @@ import pytest
 from pyiceberg import schema
 from pyiceberg.expressions.base import Accessor
 from pyiceberg.files import StructProtocol
-from pyiceberg.schema import Schema, build_position_accessors
+from pyiceberg.schema import Schema, build_position_accessors, prune_columns
 from pyiceberg.typedef import EMPTY_DICT
 from pyiceberg.types import (
     BooleanType,
@@ -415,3 +415,259 @@ def test_deserialize_schema(table_schema_simple: Schema):
     )
     expected = table_schema_simple
     assert actual == expected
+
+
+def test_prune_columns_string(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {1}, False) == Schema(
+        NestedField(field_id=1, name="foo", field_type=StringType(), required=False), schema_id=1, identifier_field_ids=[1]
+    )
+
+
+def test_prune_columns_string_full(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {1}, True) == Schema(
+        NestedField(field_id=1, name="foo", field_type=StringType(), required=False), schema_id=1, identifier_field_ids=[1]
+    )
+
+
+def test_prune_columns_list(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {5}, False) == Schema(
+        NestedField(
+            field_id=4,
+            name="qux",
+            field_type=ListType(type="list", element_id=5, element_type=StringType(), element_required=True),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_list_itself(table_schema_nested: Schema):
+    with pytest.raises(ValueError) as exc_info:
+        assert prune_columns(table_schema_nested, {4}, False)
+    assert "Cannot explicitly project List or Map types, 4:qux of type list<string> was selected" in str(exc_info.value)
+
+
+def test_prune_columns_list_full(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {5}, True) == Schema(
+        NestedField(
+            field_id=4,
+            name="qux",
+            field_type=ListType(type="list", element_id=5, element_type=StringType(), element_required=True),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_map(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {9}, False) == Schema(
+        NestedField(
+            field_id=6,
+            name="quux",
+            field_type=MapType(
+                type="map",
+                key_id=7,
+                key_type=StringType(),
+                value_id=8,
+                value_type=MapType(
+                    type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+                ),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_map_itself(table_schema_nested: Schema):
+    with pytest.raises(ValueError) as exc_info:
+        assert prune_columns(table_schema_nested, {6}, False)
+    assert "Cannot explicitly project List or Map types, 6:quux of type map<string, map<string, int>> was selected" in str(
+        exc_info.value
+    )
+
+
+def test_prune_columns_map_full(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {9}, True) == Schema(
+        NestedField(
+            field_id=6,
+            name="quux",
+            field_type=MapType(
+                type="map",
+                key_id=7,
+                key_type=StringType(),
+                value_id=8,
+                value_type=MapType(
+                    type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+                ),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_map_key(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {10}, False) == Schema(
+        NestedField(
+            field_id=6,
+            name="quux",
+            field_type=MapType(
+                type="map",
+                key_id=7,
+                key_type=StringType(),
+                value_id=8,
+                value_type=MapType(
+                    type="map", key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True
+                ),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_struct(table_schema_nested: Schema):
+    assert prune_columns(table_schema_nested, {16}, False) == Schema(
+        NestedField(
+            field_id=15,
+            name="person",
+            field_type=StructType(NestedField(field_id=16, name="name", field_type=StringType(), required=False)),
+            required=False,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_struct_full(table_schema_nested: Schema):
+    actual = prune_columns(table_schema_nested, {16}, True)
+    assert actual == Schema(
+        NestedField(
+            field_id=15,
+            name="person",
+            field_type=StructType(NestedField(field_id=16, name="name", field_type=StringType(), required=False)),
+            required=False,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_empty_struct():
+    schema_empty_struct = Schema(
+        NestedField(
+            field_id=15,
+            name="person",
+            field_type=StructType(),
+            required=False,
+        )
+    )
+    assert prune_columns(schema_empty_struct, {15}, False) == Schema(
+        NestedField(field_id=15, name="person", field_type=StructType(), required=False), schema_id=0, identifier_field_ids=[]
+    )
+
+
+def test_prune_columns_empty_struct_full():
+    schema_empty_struct = Schema(
+        NestedField(
+            field_id=15,
+            name="person",
+            field_type=StructType(),
+            required=False,
+        )
+    )
+    assert prune_columns(schema_empty_struct, {15}, True) == Schema(
+        NestedField(field_id=15, name="person", field_type=StructType(), required=False), schema_id=0, identifier_field_ids=[]
+    )
+
+
+def test_prune_columns_struct_in_map():
+    table_schema_nested = Schema(
+        NestedField(
+            field_id=6,
+            name="id_to_person",
+            field_type=MapType(
+                key_id=7,
+                key_type=IntegerType(),
+                value_id=8,
+                value_type=StructType(
+                    NestedField(field_id=10, name="name", field_type=StringType(), required=False),
+                    NestedField(field_id=11, name="age", field_type=IntegerType(), required=True),
+                ),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+    assert prune_columns(table_schema_nested, {11}, False) == Schema(
+        NestedField(
+            field_id=6,
+            name="id_to_person",
+            field_type=MapType(
+                type="map",
+                key_id=7,
+                key_type=IntegerType(),
+                value_id=8,
+                value_type=StructType(NestedField(field_id=11, name="age", field_type=IntegerType(), required=True)),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_struct_in_map_full():
+    table_schema_nested = Schema(
+        NestedField(
+            field_id=6,
+            name="id_to_person",
+            field_type=MapType(
+                key_id=7,
+                key_type=IntegerType(),
+                value_id=8,
+                value_type=StructType(
+                    NestedField(field_id=10, name="name", field_type=StringType(), required=False),
+                    NestedField(field_id=11, name="age", field_type=IntegerType(), required=True),
+                ),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+    assert prune_columns(table_schema_nested, {11}, True) == Schema(
+        NestedField(
+            field_id=6,
+            name="id_to_person",
+            field_type=MapType(
+                type="map",
+                key_id=7,
+                key_type=IntegerType(),
+                value_id=8,
+                value_type=StructType(NestedField(field_id=11, name="age", field_type=IntegerType(), required=True)),
+                value_required=True,
+            ),
+            required=True,
+        ),
+        schema_id=1,
+        identifier_field_ids=[1],
+    )
+
+
+def test_prune_columns_select_original_schema(table_schema_nested: Schema):
+    ids = set(range(table_schema_nested.highest_field_id))
+    assert prune_columns(table_schema_nested, ids, True) == table_schema_nested