You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by fo...@apache.org on 2022/12/30 22:07:31 UTC
[iceberg] branch master updated: Python: Projection by Field ID (#6437)
This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new a1a196da85 Python: Projection by Field ID (#6437)
a1a196da85 is described below
commit a1a196da850a65050232ab10a929ac9eb66560b5
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Fri Dec 30 23:07:24 2022 +0100
Python: Projection by Field ID (#6437)
* Python: Projection by Field ID
instead of name
* Comments
* Add tests
* WIP
* Support nested structures
* Cleanup
* Add support for filtering on a column that's not projected
* Python: Add SchemaWithPartnerVisitor, update pyarrow reads. (#339)
* Comments
* Fix nested list projection.
* Update based on comments
* Remove duplicate fixtures
* Cleanup
Co-authored-by: Ryan Blue <bl...@apache.org>
---
python/pyiceberg/avro/resolver.py | 92 +----
python/pyiceberg/exceptions.py | 4 +
python/pyiceberg/expressions/visitors.py | 89 +++++
python/pyiceberg/io/pyarrow.py | 217 +++++++++++-
python/pyiceberg/schema.py | 216 +++++++++++-
python/pyiceberg/table/__init__.py | 47 +--
python/tests/avro/test_decoder.py | 5 +-
python/tests/avro/test_resolver.py | 58 ++--
python/tests/io/test_pyarrow.py | 570 ++++++++++++++++++++++++++++++-
9 files changed, 1138 insertions(+), 160 deletions(-)
diff --git a/python/pyiceberg/avro/resolver.py b/python/pyiceberg/avro/resolver.py
index ca559a2998..5542e8de3c 100644
--- a/python/pyiceberg/avro/resolver.py
+++ b/python/pyiceberg/avro/resolver.py
@@ -31,27 +31,19 @@ from pyiceberg.avro.reader import (
Reader,
StructReader,
)
-from pyiceberg.schema import Schema, visit
+from pyiceberg.exceptions import ResolveError
+from pyiceberg.schema import Schema, promote, visit
from pyiceberg.types import (
- BinaryType,
- DecimalType,
DoubleType,
FloatType,
IcebergType,
- IntegerType,
ListType,
- LongType,
MapType,
PrimitiveType,
- StringType,
StructType,
)
-class ResolveException(Exception):
- pass
-
-
@singledispatch
def resolve(file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType]) -> Reader:
"""This resolves the file and read schema
@@ -79,7 +71,7 @@ def _(file_struct: StructType, read_struct: IcebergType) -> Reader:
"""Iterates over the file schema, and checks if the field is in the read schema"""
if not isinstance(read_struct, StructType):
- raise ResolveException(f"File/read schema are not aligned for {file_struct}, got {read_struct}")
+ raise ResolveError(f"File/read schema are not aligned for {file_struct}, got {read_struct}")
results: List[Tuple[Optional[int], Reader]] = []
read_fields = {field.field_id: (pos, field) for pos, field in enumerate(read_struct.fields)}
@@ -98,7 +90,7 @@ def _(file_struct: StructType, read_struct: IcebergType) -> Reader:
for pos, read_field in enumerate(read_struct.fields):
if read_field.field_id not in file_fields:
if read_field.required:
- raise ResolveException(f"{read_field} is non-optional, and not part of the file schema")
+ raise ResolveError(f"{read_field} is non-optional, and not part of the file schema")
# Just set the new field to None
results.append((pos, NoneReader()))
@@ -108,7 +100,7 @@ def _(file_struct: StructType, read_struct: IcebergType) -> Reader:
@resolve.register(ListType)
def _(file_list: ListType, read_list: IcebergType) -> Reader:
if not isinstance(read_list, ListType):
- raise ResolveException(f"File/read schema are not aligned for {file_list}, got {read_list}")
+ raise ResolveError(f"File/read schema are not aligned for {file_list}, got {read_list}")
element_reader = resolve(file_list.element_type, read_list.element_type)
return ListReader(element_reader)
@@ -116,79 +108,29 @@ def _(file_list: ListType, read_list: IcebergType) -> Reader:
@resolve.register(MapType)
def _(file_map: MapType, read_map: IcebergType) -> Reader:
if not isinstance(read_map, MapType):
- raise ResolveException(f"File/read schema are not aligned for {file_map}, got {read_map}")
+ raise ResolveError(f"File/read schema are not aligned for {file_map}, got {read_map}")
key_reader = resolve(file_map.key_type, read_map.key_type)
value_reader = resolve(file_map.value_type, read_map.value_type)
return MapReader(key_reader, value_reader)
+@resolve.register(FloatType)
+def _(file_type: PrimitiveType, read_type: IcebergType) -> Reader:
+ """This is a special case, when we need to adhere to the bytes written"""
+ if isinstance(read_type, DoubleType):
+ return visit(file_type, ConstructReader())
+ else:
+ raise ResolveError(f"Cannot promote an float to {read_type}")
+
+
@resolve.register(PrimitiveType)
def _(file_type: PrimitiveType, read_type: IcebergType) -> Reader:
"""Converting the primitive type into an actual reader that will decode the physical data"""
if not isinstance(read_type, PrimitiveType):
- raise ResolveException(f"Cannot promote {file_type} to {read_type}")
+ raise ResolveError(f"Cannot promote {file_type} to {read_type}")
# In the case of a promotion, we want to check if it is valid
if file_type != read_type:
- return promote(file_type, read_type)
+ read_type = promote(file_type, read_type)
return visit(read_type, ConstructReader())
-
-
-@singledispatch
-def promote(file_type: IcebergType, read_type: IcebergType) -> Reader:
- """Promotes reading a file type to a read type
-
- Args:
- file_type (IcebergType): The type of the Avro file
- read_type (IcebergType): The requested read type
-
- Raises:
- ResolveException: If attempting to resolve an unrecognized object type
- """
- raise ResolveException(f"Cannot promote {file_type} to {read_type}")
-
-
-@promote.register(IntegerType)
-def _(file_type: IntegerType, read_type: IcebergType) -> Reader:
- if isinstance(read_type, LongType):
- # Ints/Longs are binary compatible in Avro, so this is okay
- return visit(read_type, ConstructReader())
- else:
- raise ResolveException(f"Cannot promote an int to {read_type}")
-
-
-@promote.register(FloatType)
-def _(file_type: FloatType, read_type: IcebergType) -> Reader:
- if isinstance(read_type, DoubleType):
- # We should just read the float, and return it, since it both returns a float
- return visit(file_type, ConstructReader())
- else:
- raise ResolveException(f"Cannot promote an float to {read_type}")
-
-
-@promote.register(StringType)
-def _(file_type: StringType, read_type: IcebergType) -> Reader:
- if isinstance(read_type, BinaryType):
- return visit(read_type, ConstructReader())
- else:
- raise ResolveException(f"Cannot promote an string to {read_type}")
-
-
-@promote.register(BinaryType)
-def _(file_type: BinaryType, read_type: IcebergType) -> Reader:
- if isinstance(read_type, StringType):
- return visit(read_type, ConstructReader())
- else:
- raise ResolveException(f"Cannot promote an binary to {read_type}")
-
-
-@promote.register(DecimalType)
-def _(file_type: DecimalType, read_type: IcebergType) -> Reader:
- if isinstance(read_type, DecimalType):
- if file_type.precision <= read_type.precision and file_type.scale == file_type.scale:
- return visit(read_type, ConstructReader())
- else:
- raise ResolveException(f"Cannot reduce precision from {file_type} to {read_type}")
- else:
- raise ResolveException(f"Cannot promote an decimal to {read_type}")
diff --git a/python/pyiceberg/exceptions.py b/python/pyiceberg/exceptions.py
index 0438a5322a..69e40159ce 100644
--- a/python/pyiceberg/exceptions.py
+++ b/python/pyiceberg/exceptions.py
@@ -86,3 +86,7 @@ class NotInstalledError(Exception):
class SignError(Exception):
"""Raises when unable to sign a S3 request"""
+
+
+class ResolveError(Exception):
+ pass
diff --git a/python/pyiceberg/expressions/visitors.py b/python/pyiceberg/expressions/visitors.py
index f5c11a61cf..de7a489ab3 100644
--- a/python/pyiceberg/expressions/visitors.py
+++ b/python/pyiceberg/expressions/visitors.py
@@ -39,12 +39,15 @@ from pyiceberg.expressions import (
BoundIsNull,
BoundLessThan,
BoundLessThanOrEqual,
+ BoundLiteralPredicate,
BoundNotEqualTo,
BoundNotIn,
BoundNotNaN,
BoundNotNull,
BoundPredicate,
+ BoundSetPredicate,
BoundTerm,
+ BoundUnaryPredicate,
L,
Not,
Or,
@@ -753,3 +756,89 @@ def inclusive_projection(
schema: Schema, spec: PartitionSpec, case_sensitive: bool = True
) -> Callable[[BooleanExpression], BooleanExpression]:
return InclusiveProjection(schema, spec, case_sensitive).project
+
+
+class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
+ """Converts the column names with the ones in the actual file
+
+ Args:
+ file_schema (Schema): The schema of the file
+ case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True
+
+ Raises:
+ TypeError: In the case of an UnboundPredicate
+ ValueError: When a column name cannot be found
+ """
+
+ file_schema: Schema
+ case_sensitive: bool
+
+ def __init__(self, file_schema: Schema, case_sensitive: bool) -> None:
+ self.file_schema = file_schema
+ self.case_sensitive = case_sensitive
+
+ def visit_true(self) -> BooleanExpression:
+ return AlwaysTrue()
+
+ def visit_false(self) -> BooleanExpression:
+ return AlwaysFalse()
+
+ def visit_not(self, child_result: BooleanExpression) -> BooleanExpression:
+ return Not(child=child_result)
+
+ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+ return And(left=left_result, right=right_result)
+
+ def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+ return Or(left=left_result, right=right_result)
+
+ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpression:
+ raise TypeError(f"Expected Bound Predicate, got: {predicate.term}")
+
+ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
+ file_column_name = self.file_schema.find_column_name(predicate.term.ref().field.field_id)
+
+ if not file_column_name:
+ raise ValueError(f"Not found in file schema: {file_column_name}")
+
+ if isinstance(predicate, BoundUnaryPredicate):
+ return predicate.as_unbound(file_column_name)
+ elif isinstance(predicate, BoundLiteralPredicate):
+ return predicate.as_unbound(file_column_name, predicate.literal)
+ elif isinstance(predicate, BoundSetPredicate):
+ return predicate.as_unbound(file_column_name, predicate.literals)
+ else:
+ raise ValueError(f"Unsupported predicate: {predicate}")
+
+
+def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_sensitive: bool) -> BooleanExpression:
+ return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive))
+
+
+class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]):
+ """Extracts the field IDs used in the BooleanExpression"""
+
+ def visit_true(self) -> Set[int]:
+ return set()
+
+ def visit_false(self) -> Set[int]:
+ return set()
+
+ def visit_not(self, child_result: Set[int]) -> Set[int]:
+ return child_result
+
+ def visit_and(self, left_result: Set[int], right_result: Set[int]) -> Set[int]:
+ return left_result.union(right_result)
+
+ def visit_or(self, left_result: Set[int], right_result: Set[int]) -> Set[int]:
+ return left_result.union(right_result)
+
+ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> Set[int]:
+ raise ValueError("Only works on bound records")
+
+ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> Set[int]:
+ return {predicate.term.ref().field.field_id}
+
+
+def extract_field_ids(expr: BooleanExpression) -> Set[int]:
+ return visit(expr, _ExpressionFieldIDs())
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index d5f7e80f62..b4c9024f2d 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=redefined-outer-name,arguments-renamed
"""FileIO implementation for reading and writing table files that uses pyarrow.fs
This file contains a FileIO implementation that relies on the filesystem interface provided
@@ -21,13 +22,17 @@ by PyArrow. It relies on PyArrow's `from_uri` method that infers the correct fil
type to use. Theoretically, this allows the supported storage types to grow naturally
with the pyarrow library.
"""
+from __future__ import annotations
import os
from functools import lru_cache
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
+ Iterable,
List,
+ Optional,
Set,
Tuple,
Union,
@@ -36,6 +41,8 @@ from urllib.parse import urlparse
import pyarrow as pa
import pyarrow.compute as pc
+import pyarrow.dataset as ds
+import pyarrow.parquet as pq
from pyarrow.fs import (
FileInfo,
FileSystem,
@@ -44,8 +51,19 @@ from pyarrow.fs import (
S3FileSystem,
)
-from pyiceberg.expressions import BooleanExpression, BoundTerm, Literal
-from pyiceberg.expressions.visitors import BoundBooleanExpressionVisitor
+from pyiceberg.avro.resolver import ResolveError, promote
+from pyiceberg.expressions import (
+ AlwaysTrue,
+ BooleanExpression,
+ BoundTerm,
+ Literal,
+)
+from pyiceberg.expressions.visitors import (
+ BoundBooleanExpressionVisitor,
+ bind,
+ extract_field_ids,
+ translate_column_names,
+)
from pyiceberg.expressions.visitors import visit as boolean_expression_visit
from pyiceberg.io import (
FileIO,
@@ -54,7 +72,15 @@ from pyiceberg.io import (
OutputFile,
OutputStream,
)
-from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType, visit
+from pyiceberg.schema import (
+ PartnerAccessor,
+ Schema,
+ SchemaVisitorPerPrimitiveType,
+ SchemaWithPartnerVisitor,
+ prune_columns,
+ visit,
+ visit_with_partner,
+)
from pyiceberg.typedef import EMPTY_DICT, Properties
from pyiceberg.types import (
BinaryType,
@@ -80,6 +106,11 @@ from pyiceberg.types import (
)
from pyiceberg.utils.singleton import Singleton
+if TYPE_CHECKING:
+ from pyiceberg.table import FileScanTask, Table
+
+ICEBERG_SCHEMA = b"iceberg.schema"
+
class PyArrowFile(InputFile, OutputFile):
"""A combined InputFile and OutputFile implementation that uses a pyarrow filesystem to generate pyarrow.lib.NativeFile instances
@@ -195,7 +226,7 @@ class PyArrowFile(InputFile, OutputFile):
raise # pragma: no cover - If some other kind of OSError, raise the raw error
return output_file
- def to_input_file(self) -> "PyArrowFile":
+ def to_input_file(self) -> PyArrowFile:
"""Returns a new PyArrowFile for the location of an existing PyArrowFile instance
This method is included to abide by the OutputFile abstract base class. Since this implementation uses a single
@@ -287,7 +318,7 @@ class PyArrowFileIO(FileIO):
raise # pragma: no cover - If some other kind of OSError, raise the raw error
-def schema_to_pyarrow(schema: Schema) -> pa.schema:
+def schema_to_pyarrow(schema: Union[Schema, IcebergType]) -> pa.schema:
return visit(schema, _ConvertToArrowSchema())
@@ -302,7 +333,7 @@ class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType], Singleto
return pa.field(
name=field.name,
type=field_result,
- nullable=not field.required,
+ nullable=field.optional,
metadata={"doc": field.doc, "id": str(field.field_id)} if field.doc else {},
)
@@ -423,3 +454,177 @@ class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())
+
+
+def project_table(
+ files: Iterable[FileScanTask], table: Table, row_filter: BooleanExpression, projected_schema: Schema, case_sensitive: bool
+) -> pa.Table:
+ """Resolves the right columns based on the identifier
+
+ Args:
+ files(Iterable[FileScanTask]): A URI or a path to a local file
+ table(Table): The table that's being queried
+ row_filter(BooleanExpression): The expression for filtering rows
+ projected_schema(Schema): The output schema
+ case_sensitive(bool): Case sensitivity when looking up column names
+
+ Raises:
+ ResolveError: When an incompatible query is done
+ """
+
+ if isinstance(table.io, PyArrowFileIO):
+ scheme, path = PyArrowFileIO.parse_location(table.location())
+ fs = table.io.get_fs(scheme)
+ else:
+ raise ValueError(f"Expected PyArrowFileIO, got: {table.io}")
+
+ bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive)
+
+ projected_field_ids = {
+ id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
+ }.union(extract_field_ids(bound_row_filter))
+
+ tables = []
+ for task in files:
+ _, path = PyArrowFileIO.parse_location(task.file.file_path)
+
+ # Get the schema
+ with fs.open_input_file(path) as fout:
+ parquet_schema = pq.read_schema(fout)
+ schema_raw = parquet_schema.metadata.get(ICEBERG_SCHEMA)
+ if schema_raw is None:
+ raise ValueError(
+ "Iceberg schema is not embedded into the Parquet file, see https://github.com/apache/iceberg/issues/6505"
+ )
+ file_schema = Schema.parse_raw(schema_raw)
+
+ pyarrow_filter = None
+ if row_filter is not AlwaysTrue():
+ translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
+ bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
+ pyarrow_filter = expression_to_pyarrow(bound_file_filter)
+
+ file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
+
+ if file_schema is None:
+ raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
+
+ # Prune the stuff that we don't need anyway
+ file_project_schema_arrow = schema_to_pyarrow(file_project_schema)
+
+ arrow_table = ds.dataset(
+ source=[path], schema=file_project_schema_arrow, format=ds.ParquetFileFormat(), filesystem=fs
+ ).to_table(filter=pyarrow_filter)
+
+ tables.append(to_requested_schema(projected_schema, file_project_schema, arrow_table))
+
+ if len(tables) > 1:
+ return pa.concat_tables(tables)
+ else:
+ return tables[0]
+
+
+def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: pa.Table) -> pa.Table:
+ struct_array = visit_with_partner(requested_schema, table, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
+
+ arrays = []
+ fields = []
+ for pos, field in enumerate(requested_schema.fields):
+ array = struct_array.field(pos)
+ arrays.append(array)
+ fields.append(pa.field(field.name, array.type, field.optional))
+ return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
+
+
+class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
+ file_schema: Schema
+
+ def __init__(self, file_schema: Schema):
+ self.file_schema = file_schema
+
+ def cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
+ file_field = self.file_schema.find_field(field.field_id)
+ if field.field_type.is_primitive and field.field_type != file_field.field_type:
+ return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type)))
+ return values
+
+ def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
+ return struct_result
+
+ def struct(
+ self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
+ ) -> Optional[pa.Array]:
+ if struct_array is None:
+ return None
+ field_arrays: List[pa.Array] = []
+ fields: List[pa.Field] = []
+ for field, field_array in zip(struct.fields, field_results):
+ if field_array is not None:
+ array = self.cast_if_needed(field, field_array)
+ field_arrays.append(array)
+ fields.append(pa.field(field.name, array.type, field.optional))
+ elif field.optional:
+ arrow_type = schema_to_pyarrow(field.field_type)
+ field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
+ fields.append(pa.field(field.name, arrow_type, field.optional))
+ else:
+ raise ResolveError(f"Field is required, and could not be found in the file: {field}")
+
+ return pa.StructArray.from_arrays(arrays=field_arrays, fields=pa.struct(fields))
+
+ def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
+ return field_array
+
+ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
+ return (
+ pa.ListArray.from_arrays(list_array.offsets, self.cast_if_needed(list_type.element_field, value_array))
+ if isinstance(list_array, pa.ListArray)
+ else None
+ )
+
+ def map(
+ self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
+ ) -> Optional[pa.Array]:
+ return (
+ pa.MapArray.from_arrays(
+ map_array.offsets,
+ self.cast_if_needed(map_type.key_field, key_result),
+ self.cast_if_needed(map_type.value_field, value_result),
+ )
+ if isinstance(map_array, pa.MapArray)
+ else None
+ )
+
+ def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]:
+ return array
+
+
+class ArrowAccessor(PartnerAccessor[pa.Array]):
+ file_schema: Schema
+
+ def __init__(self, file_schema: Schema):
+ self.file_schema = file_schema
+
+ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]:
+ if partner_struct:
+ # use the field name from the file schema
+ try:
+ name = self.file_schema.find_field(field_id).name
+ except ValueError:
+ return None
+
+ if isinstance(partner_struct, pa.StructArray):
+ return partner_struct.field(name)
+ elif isinstance(partner_struct, pa.Table):
+ return partner_struct.column(name).combine_chunks()
+
+ return None
+
+ def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]:
+ return partner_list.values if isinstance(partner_list, pa.ListArray) else None
+
+ def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
+ return partner_map.keys if isinstance(partner_map, pa.MapArray) else None
+
+ def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
+ return partner_map.items if isinstance(partner_map, pa.MapArray) else None
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index b265c444f1..33f5cf3dc0 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -35,7 +35,8 @@ from typing import (
from pydantic import Field, PrivateAttr
-from pyiceberg.typedef import StructProtocol
+from pyiceberg.exceptions import ResolveError
+from pyiceberg.typedef import EMPTY_DICT, StructProtocol
from pyiceberg.types import (
BinaryType,
BooleanType,
@@ -61,6 +62,7 @@ from pyiceberg.types import (
from pyiceberg.utils.iceberg_base_model import IcebergBaseModel
T = TypeVar("T")
+P = TypeVar("P")
INITIAL_SCHEMA_ID = 0
@@ -248,6 +250,11 @@ class Schema(IcebergBaseModel):
return prune_columns(self, ids)
+ @property
+ def field_ids(self) -> Set[int]:
+ """Returns the IDs of the current schema"""
+ return set(self._name_to_id.values())
+
class SchemaVisitor(Generic[T], ABC):
def before_field(self, field: NestedField) -> None:
@@ -331,6 +338,142 @@ class PreOrderSchemaVisitor(Generic[T], ABC):
"""Visit a PrimitiveType"""
+class SchemaWithPartnerVisitor(Generic[P, T], ABC):
+ def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately before visiting a field"""
+
+ def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately after visiting a field"""
+
+ def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately before visiting an element within a ListType"""
+ self.before_field(element, element_partner)
+
+ def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately after visiting an element within a ListType"""
+ self.after_field(element, element_partner)
+
+ def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately before visiting a key within a MapType"""
+ self.before_field(key, key_partner)
+
+ def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately after visiting a key within a MapType"""
+ self.after_field(key, key_partner)
+
+ def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately before visiting a value within a MapType"""
+ self.before_field(value, value_partner)
+
+ def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
+ """Override this method to perform an action immediately after visiting a value within a MapType"""
+ self.after_field(value, value_partner)
+
+ @abstractmethod
+ def schema(self, schema: Schema, schema_partner: Optional[P], struct_result: T) -> T:
+ """Visit a schema with a partner"""
+
+ @abstractmethod
+ def struct(self, struct: StructType, struct_partner: Optional[P], field_results: List[T]) -> T:
+ """Visit a struct type with a partner"""
+
+ @abstractmethod
+ def field(self, field: NestedField, field_partner: Optional[P], field_result: T) -> T:
+ """Visit a nested field with a partner"""
+
+ @abstractmethod
+ def list(self, list_type: ListType, list_partner: Optional[P], element_result: T) -> T:
+ """Visit a list type with a partner"""
+
+ @abstractmethod
+ def map(self, map_type: MapType, map_partner: Optional[P], key_result: T, value_result: T) -> T:
+ """Visit a map type with a partner"""
+
+ @abstractmethod
+ def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T:
+ """Visit a primitive type with a partner"""
+
+
+class PartnerAccessor(Generic[P], ABC):
+ @abstractmethod
+ def field_partner(self, partner_struct: Optional[P], field_id: int, field_name: str) -> Optional[P]:
+ """Returns the equivalent struct field by name or id in the partner struct"""
+
+ @abstractmethod
+ def list_element_partner(self, partner_list: Optional[P]) -> Optional[P]:
+ """Returns the equivalent list element in the partner list"""
+
+ @abstractmethod
+ def map_key_partner(self, partner_map: Optional[P]) -> Optional[P]:
+ """Returns the equivalent map key in the partner map"""
+
+ @abstractmethod
+ def map_value_partner(self, partner_map: Optional[P]) -> Optional[P]:
+ """Returns the equivalent map value in the partner map"""
+
+
+@singledispatch
+def visit_with_partner(
+ schema_or_type: Union[Schema, IcebergType], partner: P, visitor: SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P]
+) -> T:
+ raise ValueError(f"Unsupported type: {type}")
+
+
+@visit_with_partner.register(Schema)
+def _(schema: Schema, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
+ return visitor.schema(schema, partner, visit_with_partner(schema.as_struct(), partner, visitor, accessor)) # type: ignore
+
+
+@visit_with_partner.register(StructType)
+def _(struct: StructType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
+ field_results = []
+ for field in struct.fields:
+ field_partner = accessor.field_partner(partner, field.field_id, field.name)
+ visitor.before_field(field, field_partner)
+ try:
+ field_result = visit_with_partner(field.field_type, field_partner, visitor, accessor) # type: ignore
+ field_results.append(visitor.field(field, field_partner, field_result))
+ finally:
+ visitor.after_field(field, field_partner)
+
+ return visitor.struct(struct, partner, field_results)
+
+
+@visit_with_partner.register(ListType)
+def _(list_type: ListType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
+ element_partner = accessor.list_element_partner(partner)
+ visitor.before_list_element(list_type.element_field, element_partner)
+ try:
+ element_result = visit_with_partner(list_type.element_type, element_partner, visitor, accessor) # type: ignore
+ finally:
+ visitor.after_list_element(list_type.element_field, element_partner)
+
+ return visitor.list(list_type, partner, element_result)
+
+
+@visit_with_partner.register(MapType)
+def _(map_type: MapType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
+ key_partner = accessor.map_key_partner(partner)
+ visitor.before_map_key(map_type.key_field, key_partner)
+ try:
+ key_result = visit_with_partner(map_type.key_type, key_partner, visitor, accessor) # type: ignore
+ finally:
+ visitor.after_map_key(map_type.key_field, key_partner)
+
+ value_partner = accessor.map_value_partner(partner)
+ visitor.before_map_value(map_type.value_field, value_partner)
+ try:
+ value_result = visit_with_partner(map_type.value_type, value_partner, visitor, accessor) # type: ignore
+ finally:
+ visitor.after_map_value(map_type.value_field, value_partner)
+ return visitor.map(map_type, partner, key_result, value_result)
+
+
+@visit_with_partner.register(PrimitiveType)
+def _(primitive: PrimitiveType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], _: PartnerAccessor[P]) -> T:
+ return visitor.primitive(primitive, partner)
+
+
class SchemaVisitorPerPrimitiveType(SchemaVisitor[T], ABC):
def primitive(self, primitive: PrimitiveType) -> T:
"""Visit a PrimitiveType"""
@@ -727,9 +870,12 @@ def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]:
Returns:
Dict[str, int]: An index of field names to field IDs
"""
- indexer = _IndexByName()
- visit(schema_or_type, indexer)
- return indexer.by_name()
+ if len(schema_or_type.fields) > 0:
+ indexer = _IndexByName()
+ visit(schema_or_type, indexer)
+ return indexer.by_name()
+ else:
+ return EMPTY_DICT
def index_name_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, str]:
@@ -1046,3 +1192,65 @@ class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
value_type=value_result,
value_required=map_type.value_required,
)
+
+
+@singledispatch
+def promote(file_type: IcebergType, read_type: IcebergType) -> IcebergType:
+ """Promotes reading a file type to a read type
+
+ Args:
+ file_type (IcebergType): The type of the Avro file
+ read_type (IcebergType): The requested read type
+
+ Raises:
+ ResolveError: If attempting to resolve an unrecognized object type
+ """
+ if file_type == read_type:
+ return file_type
+ else:
+ raise ResolveError(f"Cannot promote {file_type} to {read_type}")
+
+
+@promote.register(IntegerType)
+def _(file_type: IntegerType, read_type: IcebergType) -> IcebergType:
+ if isinstance(read_type, LongType):
+ # Ints/Longs are binary compatible in Avro, so this is okay
+ return read_type
+ else:
+ raise ResolveError(f"Cannot promote an int to {read_type}")
+
+
+@promote.register(FloatType)
+def _(file_type: FloatType, read_type: IcebergType) -> IcebergType:
+ if isinstance(read_type, DoubleType):
+ # A double type is wider
+ return read_type
+ else:
+ raise ResolveError(f"Cannot promote an float to {read_type}")
+
+
+@promote.register(StringType)
+def _(file_type: StringType, read_type: IcebergType) -> IcebergType:
+ if isinstance(read_type, BinaryType):
+ return read_type
+ else:
+ raise ResolveError(f"Cannot promote an string to {read_type}")
+
+
+@promote.register(BinaryType)
+def _(file_type: BinaryType, read_type: IcebergType) -> IcebergType:
+ if isinstance(read_type, StringType):
+ return read_type
+ else:
+ raise ResolveError(f"Cannot promote an binary to {read_type}")
+
+
+@promote.register(DecimalType)
+def _(file_type: DecimalType, read_type: IcebergType) -> IcebergType:
+ if isinstance(read_type, DecimalType):
+ if file_type.precision <= read_type.precision and file_type.scale == file_type.scale:
+ return read_type
+ else:
+ raise ResolveError(f"Cannot reduce precision from {file_type} to {read_type}")
+ else:
+ raise ResolveError(f"Cannot promote an decimal to {read_type}")
diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py
index 5dbc2f22ae..3fb0702889 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property
@@ -41,8 +40,9 @@ from pyiceberg.expressions import (
BooleanExpression,
visitors,
)
-from pyiceberg.expressions.visitors import bind, inclusive_projection
+from pyiceberg.expressions.visitors import inclusive_projection
from pyiceberg.io import FileIO
+from pyiceberg.io.pyarrow import project_table
from pyiceberg.manifest import DataFile, ManifestFile, files
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
@@ -313,7 +313,7 @@ class DataScan(TableScan["DataScan"]):
return lambda data_file: evaluator(wrapper.wrap(data_file.partition))
- def plan_files(self) -> Iterator[ScanTask]:
+ def plan_files(self) -> Iterator[FileScanTask]:
snapshot = self.snapshot()
if not snapshot:
return
@@ -344,47 +344,10 @@ class DataScan(TableScan["DataScan"]):
yield from (FileScanTask(file) for file in matching_partition_files)
def to_arrow(self) -> pa.Table:
- from pyiceberg.io.pyarrow import PyArrowFileIO, expression_to_pyarrow, schema_to_pyarrow
-
- warnings.warn(
- "Projection is currently done by name instead of Field ID, this can lead to incorrect results in some cases."
+ return project_table(
+ self.plan_files(), self.table, self.row_filter, self.projection(), case_sensitive=self.case_sensitive
)
- fs = None
- if isinstance(self.table.io, PyArrowFileIO):
- scheme, path = PyArrowFileIO.parse_location(self.table.location())
- fs = self.table.io.get_fs(scheme)
-
- locations = []
- for task in self.plan_files():
- if isinstance(task, FileScanTask):
- _, path = PyArrowFileIO.parse_location(task.file.file_path)
- locations.append(path)
- else:
- raise ValueError(f"Cannot read unexpected task: {task}")
-
- columns = None
- if "*" not in self.selected_fields:
- columns = list(self.selected_fields)
-
- pyarrow_filter = None
- if self.row_filter is not AlwaysTrue():
- bound_row_filter = bind(self.table.schema(), self.row_filter, case_sensitive=self.case_sensitive)
- pyarrow_filter = expression_to_pyarrow(bound_row_filter)
-
- from pyarrow.dataset import dataset
-
- ds = dataset(
- source=locations,
- filesystem=fs,
- # Optionally provide the Schema for the Dataset,
- # in which case it will not be inferred from the source.
- # https://arrow.apache.org/docs/python/generated/pyarrow.dataset.dataset.html#pyarrow.dataset.dataset
- schema=schema_to_pyarrow(self.table.schema()),
- )
-
- return ds.to_table(filter=pyarrow_filter, columns=columns)
-
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)
diff --git a/python/tests/avro/test_decoder.py b/python/tests/avro/test_decoder.py
index e723270376..56b9cf7489 100644
--- a/python/tests/avro/test_decoder.py
+++ b/python/tests/avro/test_decoder.py
@@ -26,7 +26,7 @@ from uuid import UUID
import pytest
from pyiceberg.avro.decoder import BinaryDecoder
-from pyiceberg.avro.resolver import promote
+from pyiceberg.avro.resolver import resolve
from pyiceberg.io import InputStream
from pyiceberg.io.memory import MemoryInputStream
from pyiceberg.types import DoubleType, FloatType
@@ -225,6 +225,5 @@ def test_skip_utf8() -> None:
def test_read_int_as_float() -> None:
mis = MemoryInputStream(b"\x00\x00\x9A\x41")
decoder = BinaryDecoder(mis)
- reader = promote(FloatType(), DoubleType())
-
+ reader = resolve(FloatType(), DoubleType())
assert reader.read(decoder) == 19.25
diff --git a/python/tests/avro/test_resolver.py b/python/tests/avro/test_resolver.py
index 853be4e448..f051881e0d 100644
--- a/python/tests/avro/test_resolver.py
+++ b/python/tests/avro/test_resolver.py
@@ -25,7 +25,7 @@ from pyiceberg.avro.reader import (
StringReader,
StructReader,
)
-from pyiceberg.avro.resolver import ResolveException, promote, resolve
+from pyiceberg.avro.resolver import ResolveError, resolve
from pyiceberg.schema import Schema
from pyiceberg.types import (
BinaryType,
@@ -101,7 +101,7 @@ def test_resolver_new_required_field() -> None:
schema_id=1,
)
- with pytest.raises(ResolveException) as exc_info:
+ with pytest.raises(ResolveError) as exc_info:
resolve(write_schema, read_schema)
assert "2: data: required string is non-optional, and not part of the file schema" in str(exc_info.value)
@@ -117,7 +117,7 @@ def test_resolver_invalid_evolution() -> None:
schema_id=1,
)
- with pytest.raises(ResolveException) as exc_info:
+ with pytest.raises(ResolveError) as exc_info:
resolve(write_schema, read_schema)
assert "Cannot promote long to double" in str(exc_info.value)
@@ -157,69 +157,69 @@ def test_resolver_change_type() -> None:
schema_id=1,
)
- with pytest.raises(ResolveException) as exc_info:
+ with pytest.raises(ResolveError) as exc_info:
resolve(write_schema, read_schema)
assert "File/read schema are not aligned for list<string>, got map<string, string>" in str(exc_info.value)
-def test_promote_int_to_long() -> None:
- assert promote(IntegerType(), LongType()) == IntegerReader()
+def test_resolve_int_to_long() -> None:
+ assert resolve(IntegerType(), LongType()) == IntegerReader()
-def test_promote_float_to_double() -> None:
+def test_resolve_float_to_double() -> None:
# We should still read floats, because it is encoded in 4 bytes
- assert promote(FloatType(), DoubleType()) == FloatReader()
+ assert resolve(FloatType(), DoubleType()) == FloatReader()
-def test_promote_decimal_to_decimal() -> None:
+def test_resolve_decimal_to_decimal() -> None:
# DecimalType(P, S) to DecimalType(P2, S) where P2 > P
- assert promote(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(22, 25)
+ assert resolve(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(22, 25)
def test_struct_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(StructType(), StringType())
+ with pytest.raises(ResolveError):
+ assert resolve(StructType(), StringType())
def test_map_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(MapType(1, StringType(), 2, IntegerType()), StringType())
+ with pytest.raises(ResolveError):
+ assert resolve(MapType(1, StringType(), 2, IntegerType()), StringType())
def test_primitive_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(IntegerType(), MapType(1, StringType(), 2, IntegerType()))
+ with pytest.raises(ResolveError):
+ assert resolve(IntegerType(), MapType(1, StringType(), 2, IntegerType()))
def test_integer_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(IntegerType(), StringType())
+ with pytest.raises(ResolveError):
+ assert resolve(IntegerType(), StringType())
def test_float_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(FloatType(), StringType())
+ with pytest.raises(ResolveError):
+ assert resolve(FloatType(), StringType())
def test_string_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(StringType(), FloatType())
+ with pytest.raises(ResolveError):
+ assert resolve(StringType(), FloatType())
def test_binary_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(BinaryType(), FloatType())
+ with pytest.raises(ResolveError):
+ assert resolve(BinaryType(), FloatType())
def test_decimal_not_aligned() -> None:
- with pytest.raises(ResolveException):
- assert promote(DecimalType(22, 19), StringType())
+ with pytest.raises(ResolveError):
+ assert resolve(DecimalType(22, 19), StringType())
-def test_promote_decimal_to_decimal_reduce_precision() -> None:
+def test_resolve_decimal_to_decimal_reduce_precision() -> None:
# DecimalType(P, S) to DecimalType(P2, S) where P2 > P
- with pytest.raises(ResolveException) as exc_info:
- _ = promote(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25)
+ with pytest.raises(ResolveError) as exc_info:
+ _ = resolve(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25)
assert "Cannot reduce precision from decimal(19, 25) to decimal(10, 25)" in str(exc_info.value)
diff --git a/python/tests/io/test_pyarrow.py b/python/tests/io/test_pyarrow.py
index 674331bf1c..02b7d02729 100644
--- a/python/tests/io/test_pyarrow.py
+++ b/python/tests/io/test_pyarrow.py
@@ -18,17 +18,20 @@
import os
import tempfile
-from typing import Any
+from typing import Any, List, Optional
from unittest.mock import MagicMock, patch
import pyarrow as pa
+import pyarrow.parquet as pq
import pytest
from pyarrow.fs import FileType
+from pyiceberg.avro.resolver import ResolveError
from pyiceberg.expressions import (
AlwaysFalse,
AlwaysTrue,
And,
+ BooleanExpression,
BoundEqualTo,
BoundGreaterThan,
BoundGreaterThanOrEqual,
@@ -42,6 +45,7 @@ from pyiceberg.expressions import (
BoundNotNaN,
BoundNotNull,
BoundReference,
+ GreaterThan,
Not,
Or,
literal,
@@ -52,9 +56,14 @@ from pyiceberg.io.pyarrow import (
PyArrowFileIO,
_ConvertToArrowSchema,
expression_to_pyarrow,
+ project_table,
schema_to_pyarrow,
)
+from pyiceberg.manifest import DataFile, FileFormat
+from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema, visit
+from pyiceberg.table import FileScanTask, Table
+from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import (
BinaryType,
BooleanType,
@@ -69,6 +78,7 @@ from pyiceberg.types import (
MapType,
NestedField,
StringType,
+ StructType,
TimestampType,
TimestamptzType,
TimeType,
@@ -572,3 +582,561 @@ def test_always_true_to_pyarrow(bound_reference: BoundReference[str]) -> None:
def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None:
assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>"
+
+
+@pytest.fixture
+def schema_int() -> Schema:
+ return Schema(NestedField(1, "id", IntegerType(), required=False))
+
+
+@pytest.fixture
+def schema_int_str() -> Schema:
+ return Schema(NestedField(1, "id", IntegerType(), required=False), NestedField(2, "data", StringType(), required=False))
+
+
+@pytest.fixture
+def schema_str() -> Schema:
+ return Schema(NestedField(2, "data", StringType(), required=False))
+
+
+@pytest.fixture
+def schema_long() -> Schema:
+ return Schema(NestedField(3, "id", LongType(), required=False))
+
+
+@pytest.fixture
+def schema_struct() -> Schema:
+ return Schema(
+ NestedField(
+ 4,
+ "location",
+ StructType(
+ NestedField(41, "lat", DoubleType()),
+ NestedField(42, "long", DoubleType()),
+ ),
+ )
+ )
+
+
+@pytest.fixture
+def schema_list() -> Schema:
+ return Schema(
+ NestedField(5, "ids", ListType(51, IntegerType(), element_required=False), required=False),
+ )
+
+
+@pytest.fixture
+def schema_list_of_structs() -> Schema:
+ return Schema(
+ NestedField(
+ 5,
+ "locations",
+ ListType(
+ 51,
+ StructType(NestedField(511, "lat", DoubleType()), NestedField(512, "long", DoubleType())),
+ element_required=False,
+ ),
+ required=False,
+ ),
+ )
+
+
+@pytest.fixture
+def schema_map() -> Schema:
+ return Schema(
+ NestedField(
+ 5,
+ "properties",
+ MapType(
+ key_id=51,
+ key_type=StringType(),
+ value_id=52,
+ value_type=StringType(),
+ value_required=True,
+ ),
+ required=False,
+ ),
+ )
+
+
+def _write_table_to_file(filepath: str, schema: pa.Schema, table: pa.Table) -> str:
+ with pq.ParquetWriter(filepath, schema) as writer:
+ writer.write_table(table)
+ return filepath
+
+
+@pytest.fixture
+def file_int(schema_int: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_int), metadata={"iceberg.schema": schema_int.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/a.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema)
+ )
+
+
+@pytest.fixture
+def file_int_str(schema_int_str: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_int_str), metadata={"iceberg.schema": schema_int_str.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/a.parquet",
+ pyarrow_schema,
+ pa.Table.from_arrays([pa.array([0, 1, 2]), pa.array(["0", "1", "2"])], schema=pyarrow_schema),
+ )
+
+
+@pytest.fixture
+def file_string(schema_str: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_str), metadata={"iceberg.schema": schema_str.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/b.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array(["0", "1", "2"])], schema=pyarrow_schema)
+ )
+
+
+@pytest.fixture
+def file_long(schema_long: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_long), metadata={"iceberg.schema": schema_long.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/c.parquet", pyarrow_schema, pa.Table.from_arrays([pa.array([0, 1, 2])], schema=pyarrow_schema)
+ )
+
+
+@pytest.fixture
+def file_struct(schema_struct: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_struct), metadata={"iceberg.schema": schema_struct.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/d.parquet",
+ pyarrow_schema,
+ pa.Table.from_pylist(
+ [
+ {"location": {"lat": 52.371807, "long": 4.896029}},
+ {"location": {"lat": 52.387386, "long": 4.646219}},
+ {"location": {"lat": 52.078663, "long": 4.288788}},
+ ],
+ schema=pyarrow_schema,
+ ),
+ )
+
+
+@pytest.fixture
+def file_list(schema_list: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_list), metadata={"iceberg.schema": schema_list.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/e.parquet",
+ pyarrow_schema,
+ pa.Table.from_pylist(
+ [
+ {"ids": list(range(1, 10))},
+ {"ids": list(range(2, 20))},
+ {"ids": list(range(3, 30))},
+ ],
+ schema=pyarrow_schema,
+ ),
+ )
+
+
+@pytest.fixture
+def file_list_of_structs(schema_list_of_structs: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(
+ schema_to_pyarrow(schema_list_of_structs), metadata={"iceberg.schema": schema_list_of_structs.json()}
+ )
+ return _write_table_to_file(
+ f"file:{tmpdir}/e.parquet",
+ pyarrow_schema,
+ pa.Table.from_pylist(
+ [
+ {"locations": [{"lat": 52.371807, "long": 4.896029}, {"lat": 52.387386, "long": 4.646219}]},
+ {"locations": []},
+ {"locations": [{"lat": 52.078663, "long": 4.288788}, {"lat": 52.387386, "long": 4.646219}]},
+ ],
+ schema=pyarrow_schema,
+ ),
+ )
+
+
+@pytest.fixture
+def file_map(schema_map: Schema, tmpdir: str) -> str:
+ pyarrow_schema = pa.schema(schema_to_pyarrow(schema_map), metadata={"iceberg.schema": schema_map.json()})
+ return _write_table_to_file(
+ f"file:{tmpdir}/e.parquet",
+ pyarrow_schema,
+ pa.Table.from_pylist(
+ [
+ {"properties": [("a", "b")]},
+ {"properties": [("c", "d")]},
+ {"properties": [("e", "f"), ("g", "h")]},
+ ],
+ schema=pyarrow_schema,
+ ),
+ )
+
+
+def project(
+ schema: Schema, files: List[str], expr: Optional[BooleanExpression] = None, table_schema: Optional[Schema] = None
+) -> pa.Table:
+ return project_table(
+ [
+ FileScanTask(
+ DataFile(file_path=file, file_format=FileFormat.PARQUET, partition={}, record_count=3, file_size_in_bytes=3)
+ )
+ for file in files
+ ],
+ Table(
+ ("namespace", "table"),
+ metadata=TableMetadataV2(
+ location="file://a/b/",
+ last_column_id=1,
+ format_version=2,
+ schemas=[table_schema or schema],
+ partition_specs=[PartitionSpec()],
+ ),
+ metadata_location="file://a/b/c.json",
+ io=PyArrowFileIO(),
+ ),
+ expr or AlwaysTrue(),
+ schema,
+ case_sensitive=True,
+ )
+
+
+def test_projection_add_column(file_int: str) -> None:
+ schema = Schema(
+ # All new IDs
+ NestedField(10, "id", IntegerType(), required=False),
+ NestedField(20, "list", ListType(21, IntegerType(), element_required=False), required=False),
+ NestedField(
+ 30,
+ "map",
+ MapType(key_id=31, key_type=IntegerType(), value_id=32, value_type=StringType(), value_required=False),
+ required=False,
+ ),
+ NestedField(
+ 40,
+ "location",
+ StructType(
+ NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "lon", DoubleType(), required=False)
+ ),
+ required=False,
+ ),
+ )
+ result_table = project(schema, [file_int])
+
+ for col in result_table.columns:
+ assert len(col) == 3
+
+ for actual, expected in zip(result_table.columns[0], [None, None, None]):
+ assert actual.as_py() == expected
+
+ for actual, expected in zip(result_table.columns[1], [None, None, None]):
+ assert actual.as_py() == expected
+
+ for actual, expected in zip(result_table.columns[2], [None, None, None]):
+ assert actual.as_py() == expected
+
+ for actual, expected in zip(result_table.columns[3], [None, None, None]):
+ assert actual.as_py() == expected
+
+ assert (
+ repr(result_table.schema)
+ == """id: int32
+list: list<item: int32>
+ child 0, item: int32
+map: map<int32, string>
+ child 0, entries: struct<key: int32 not null, value: string> not null
+ child 0, key: int32 not null
+ child 1, value: string
+location: struct<lat: double, lon: double>
+ child 0, lat: double
+ child 1, lon: double"""
+ )
+
+
+def test_read_list(schema_list: Schema, file_list: str) -> None:
+ result_table = project(schema_list, [file_list])
+
+ assert len(result_table.columns[0]) == 3
+ for actual, expected in zip(result_table.columns[0], [list(range(1, 10)), list(range(2, 20)), list(range(3, 30))]):
+ assert actual.as_py() == expected
+
+ assert repr(result_table.schema) == "ids: list<item: int32>\n child 0, item: int32"
+
+
+def test_read_map(schema_map: Schema, file_map: str) -> None:
+ result_table = project(schema_map, [file_map])
+
+ assert len(result_table.columns[0]) == 3
+ for actual, expected in zip(result_table.columns[0], [[("a", "b")], [("c", "d")], [("e", "f"), ("g", "h")]]):
+ assert actual.as_py() == expected
+
+ assert (
+ repr(result_table.schema)
+ == """properties: map<string, string>
+ child 0, entries: struct<key: string not null, value: string> not null
+ child 0, key: string not null
+ child 1, value: string"""
+ )
+
+
+def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None:
+ schema = Schema(
+ # A new ID
+ NestedField(
+ 2,
+ "id",
+ MapType(key_id=3, key_type=IntegerType(), value_id=4, value_type=StringType(), value_required=False),
+ required=False,
+ )
+ )
+ result_table = project(schema, [file_int])
+ # Everything should be None
+ for r in result_table.columns[0]:
+ assert r.as_py() is None
+
+ assert (
+ repr(result_table.schema)
+ == """id: map<int32, string>
+ child 0, entries: struct<key: int32 not null, value: string> not null
+ child 0, key: int32 not null
+ child 1, value: string"""
+ )
+
+
+def test_projection_add_column_struct_required(file_int: str) -> None:
+ schema = Schema(
+ # A new ID
+ NestedField(
+ 2,
+ "other_id",
+ IntegerType(),
+ required=True,
+ )
+ )
+ with pytest.raises(ResolveError) as exc_info:
+ _ = project(schema, [file_int])
+ assert "Field is required, and could not be found in the file: 2: other_id: required int" in str(exc_info.value)
+
+
+def test_projection_rename_column(schema_int: Schema, file_int: str) -> None:
+ schema = Schema(
+ # Reuses the id 1
+ NestedField(1, "other_name", IntegerType())
+ )
+ result_table = project(schema, [file_int])
+ assert len(result_table.columns[0]) == 3
+ for actual, expected in zip(result_table.columns[0], [0, 1, 2]):
+ assert actual.as_py() == expected
+
+ assert repr(result_table.schema) == "other_name: int32 not null"
+
+
+def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
+ result_table = project(schema_int, [file_int, file_int])
+
+ for actual, expected in zip(result_table.columns[0], [0, 1, 2, 0, 1, 2]):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 6
+ assert repr(result_table.schema) == "id: int32"
+
+
+def test_projection_filter(schema_int: Schema, file_int: str) -> None:
+ result_table = project(schema_int, [file_int], GreaterThan("id", 4))
+ assert len(result_table.columns[0]) == 0
+ assert repr(result_table.schema) == "id: int32"
+
+
+def test_projection_filter_renamed_column(file_int: str) -> None:
+ schema = Schema(
+ # Reuses the id 1
+ NestedField(1, "other_id", IntegerType())
+ )
+ result_table = project(schema, [file_int], GreaterThan("other_id", 1))
+ assert len(result_table.columns[0]) == 1
+ assert repr(result_table.schema) == "other_id: int32 not null"
+
+
+def test_projection_filter_add_column(schema_int: Schema, file_int: str, file_string: str) -> None:
+ """We have one file that has the column, and the other one doesn't"""
+ result_table = project(schema_int, [file_int, file_string])
+
+ for actual, expected in zip(result_table.columns[0], [0, 1, 2, None, None, None]):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 6
+ assert repr(result_table.schema) == "id: int32"
+
+
+def test_projection_filter_add_column_promote(file_int: str) -> None:
+ schema_long = Schema(NestedField(1, "id", LongType()))
+ result_table = project(schema_long, [file_int])
+
+ for actual, expected in zip(result_table.columns[0], [0, 1, 2]):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 3
+ assert repr(result_table.schema) == "id: int64 not null"
+
+
+def test_projection_filter_add_column_demote(file_long: str) -> None:
+ schema_int = Schema(NestedField(3, "id", IntegerType()))
+ with pytest.raises(ResolveError) as exc_info:
+ _ = project(schema_int, [file_long])
+ assert "Cannot promote long to int" in str(exc_info.value)
+
+
+def test_projection_nested_struct_subset(file_struct: str) -> None:
+ schema = Schema(
+ NestedField(
+ 4,
+ "location",
+ StructType(
+ NestedField(41, "lat", DoubleType()),
+ # long is missing!
+ ),
+ )
+ )
+
+ result_table = project(schema, [file_struct])
+
+ for actual, expected in zip(result_table.columns[0], [52.371807, 52.387386, 52.078663]):
+ assert actual.as_py() == {"lat": expected}
+
+ assert len(result_table.columns[0]) == 3
+ assert repr(result_table.schema) == "location: struct<lat: double not null> not null\n child 0, lat: double not null"
+
+
+def test_projection_nested_new_field(file_struct: str) -> None:
+ schema = Schema(
+ NestedField(
+ 4,
+ "location",
+ StructType(
+ NestedField(43, "null", DoubleType(), required=False), # Whoa, this column doesn't exist in the file
+ ),
+ )
+ )
+
+ result_table = project(schema, [file_struct])
+
+ for actual, expected in zip(result_table.columns[0], [None, None, None]):
+ assert actual.as_py() == {"null": expected}
+ assert len(result_table.columns[0]) == 3
+ assert repr(result_table.schema) == "location: struct<null: double> not null\n child 0, null: double"
+
+
+def test_projection_nested_struct(schema_struct: Schema, file_struct: str) -> None:
+ schema = Schema(
+ NestedField(
+ 4,
+ "location",
+ StructType(
+ NestedField(41, "lat", DoubleType(), required=False),
+ NestedField(43, "null", DoubleType(), required=False),
+ NestedField(42, "long", DoubleType(), required=False),
+ ),
+ )
+ )
+
+ result_table = project(schema, [file_struct])
+ for actual, expected in zip(
+ result_table.columns[0],
+ [
+ {"lat": 52.371807, "long": 4.896029, "null": None},
+ {"lat": 52.387386, "long": 4.646219, "null": None},
+ {"lat": 52.078663, "long": 4.288788, "null": None},
+ ],
+ ):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 3
+ assert (
+ repr(result_table.schema)
+ == "location: struct<lat: double, null: double, long: double> not null\n child 0, lat: double\n child 1, null: double\n child 2, long: double"
+ )
+
+
+def test_projection_list_of_structs(schema_list_of_structs: Schema, file_list_of_structs: str) -> None:
+ schema = Schema(
+ NestedField(
+ 5,
+ "locations",
+ ListType(
+ 51,
+ StructType(
+ NestedField(511, "latitude", DoubleType()),
+ NestedField(512, "longitude", DoubleType()),
+ NestedField(513, "altitude", DoubleType(), required=False),
+ ),
+ element_required=False,
+ ),
+ required=False,
+ ),
+ )
+
+ result_table = project(schema, [file_list_of_structs])
+ assert len(result_table.columns) == 1
+ assert len(result_table.columns[0]) == 3
+ for actual, expected in zip(
+ result_table.columns[0],
+ [
+ [
+ {"latitude": 52.371807, "longitude": 4.896029, "altitude": None},
+ {"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
+ ],
+ [],
+ [
+ {"latitude": 52.078663, "longitude": 4.288788, "altitude": None},
+ {"latitude": 52.387386, "longitude": 4.646219, "altitude": None},
+ ],
+ ],
+ ):
+ assert actual.as_py() == expected
+ assert (
+ repr(result_table.schema)
+ == """locations: list<item: struct<latitude: double not null, longitude: double not null, altitude: double>>
+ child 0, item: struct<latitude: double not null, longitude: double not null, altitude: double>
+ child 0, latitude: double not null
+ child 1, longitude: double not null
+ child 2, altitude: double"""
+ )
+
+
+def test_projection_nested_struct_different_parent_id(file_struct: str) -> None:
+ schema = Schema(
+ NestedField(
+ 5, # 😱 this is 4 in the file, this will be fixed when projecting the file schema
+ "location",
+ StructType(
+ NestedField(41, "lat", DoubleType(), required=False), NestedField(42, "long", DoubleType(), required=False)
+ ),
+ required=False,
+ )
+ )
+
+ result_table = project(schema, [file_struct])
+ for actual, expected in zip(result_table.columns[0], [None, None, None]):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 3
+ assert (
+ repr(result_table.schema)
+ == """location: struct<lat: double, long: double>
+ child 0, lat: double
+ child 1, long: double"""
+ )
+
+
+def test_projection_filter_on_unprojected_field(schema_int_str: Schema, file_int_str: str) -> None:
+ schema = Schema(NestedField(1, "id", IntegerType()))
+
+ result_table = project(schema, [file_int_str], GreaterThan("data", "1"), schema_int_str)
+
+ for actual, expected in zip(
+ result_table.columns[0],
+ [2],
+ ):
+ assert actual.as_py() == expected
+ assert len(result_table.columns[0]) == 1
+ assert repr(result_table.schema) == "id: int32 not null"
+
+
+def test_projection_filter_on_unknown_field(schema_int_str: Schema, file_int_str: str) -> None:
+ schema = Schema(NestedField(1, "id", IntegerType()))
+
+ with pytest.raises(ValueError) as exc_info:
+ _ = project(schema, [file_int_str], GreaterThan("unknown_field", "1"), schema_int_str)
+
+ assert "Could not find field with name unknown_field, case_sensitive=True" in str(exc_info.value)