You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2024/02/15 13:11:50 UTC

(arrow) branch main updated: GH-40055: [Java][Docs] Simplify use of Filter and Expression into Dataset Substrait (#40056)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new a03d957b5b GH-40055: [Java][Docs] Simplify use of Filter and Expression into Dataset Substrait (#40056)
a03d957b5b is described below

commit a03d957b5b8d0425f9d5b6c98b6ee1efa56a1248
Author: david dali susanibar arce <da...@gmail.com>
AuthorDate: Thu Feb 15 08:11:44 2024 -0500

    GH-40055: [Java][Docs] Simplify use of Filter and Expression into Dataset Substrait (#40056)
    
    ### Rationale for this change
    
    Simplify creation of SQL Expression Filter and Projections into Arrow Java Dataset module using new [Substrait Feature for SQL Expressions](https://github.com/substrait-io/substrait-java/releases/tag/v0.26.0).
    
    ### What changes are included in this PR?
    
    Update Apache Arrow Java Dataset Substrait documentation
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    No
    * Closes: #40055
    
    Authored-by: david dali susanibar arce <da...@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 docs/source/java/substrait.rst | 333 ++++++-----------------------------------
 1 file changed, 42 insertions(+), 291 deletions(-)

diff --git a/docs/source/java/substrait.rst b/docs/source/java/substrait.rst
index d8d49a96e8..c5857dcc23 100644
--- a/docs/source/java/substrait.rst
+++ b/docs/source/java/substrait.rst
@@ -113,31 +113,19 @@ This requires the substrait-java library.
 This Java program:
 
 - Loads a Parquet file containing the "nation" table from the TPC-H benchmark.
+- Applies a filter:
+    - ``N_NATIONKEY > 18``
 - Projects two new columns:
-    - ``N_NAME || ' - ' || N_COMMENT``
     - ``N_REGIONKEY + 10``
-- Applies a filter: ``N_NATIONKEY > 18``
+    - ``N_NAME || ' - ' || N_COMMENT``
+
+
 
 .. code-block:: Java
 
-    import io.substrait.extension.ExtensionCollector;
-    import io.substrait.proto.Expression;
-    import io.substrait.proto.ExpressionReference;
+    import com.google.common.collect.ImmutableList;
+    import io.substrait.isthmus.SqlExpressionToSubstrait;
     import io.substrait.proto.ExtendedExpression;
-    import io.substrait.proto.FunctionArgument;
-    import io.substrait.proto.SimpleExtensionDeclaration;
-    import io.substrait.proto.SimpleExtensionURI;
-    import io.substrait.type.NamedStruct;
-    import io.substrait.type.Type;
-    import io.substrait.type.TypeCreator;
-    import io.substrait.type.proto.TypeProtoConverter;
-    import java.nio.ByteBuffer;
-    import java.util.ArrayList;
-    import java.util.Arrays;
-    import java.util.Base64;
-    import java.util.HashMap;
-    import java.util.List;
-    import java.util.Optional;
     import org.apache.arrow.dataset.file.FileFormat;
     import org.apache.arrow.dataset.file.FileSystemDatasetFactory;
     import org.apache.arrow.dataset.jni.NativeMemoryPool;
@@ -148,297 +136,60 @@ This Java program:
     import org.apache.arrow.memory.BufferAllocator;
     import org.apache.arrow.memory.RootAllocator;
     import org.apache.arrow.vector.ipc.ArrowReader;
+    import org.apache.calcite.sql.parser.SqlParseException;
+
+    import java.nio.ByteBuffer;
+    import java.util.Base64;
+    import java.util.Optional;
 
     public class ClientSubstraitExtendedExpressionsCookbook {
 
-      public static void main(String[] args) throws Exception {
-        // project and filter dataset using extended expression definition - 03 Expressions:
-        // Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' - ' || col 3
-        // Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10
-        // Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18
+      public static void main(String[] args) throws SqlParseException {
         projectAndFilterDataset();
       }
 
-      public static void projectAndFilterDataset() {
+      private static void projectAndFilterDataset() throws SqlParseException {
         String uri = "file:///Users/data/tpch_parquet/nation.parquet";
-        ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768)
-            .columns(Optional.empty())
-            .substraitFilter(getSubstraitExpressionFilter())
-            .substraitProjection(getSubstraitExpressionProjection())
-            .build();
-        try (
-            BufferAllocator allocator = new RootAllocator();
-            DatasetFactory datasetFactory = new FileSystemDatasetFactory(
-                allocator, NativeMemoryPool.getDefault(),
-                FileFormat.PARQUET, uri);
-            Dataset dataset = datasetFactory.finish();
-            Scanner scanner = dataset.newScan(options);
-            ArrowReader reader = scanner.scanBatches()
-        ) {
+        ScanOptions options =
+            new ScanOptions.Builder(/*batchSize*/ 32768)
+                .columns(Optional.empty())
+                .substraitFilter(getByteBuffer(new String[]{"N_NATIONKEY > 18"}))
+                .substraitProjection(getByteBuffer(new String[]{"N_REGIONKEY + 10",
+                    "N_NAME || CAST(' - ' as VARCHAR) || N_COMMENT"}))
+                .build();
+        try (BufferAllocator allocator = new RootAllocator();
+             DatasetFactory datasetFactory =
+                 new FileSystemDatasetFactory(
+                     allocator, NativeMemoryPool.getDefault(), FileFormat.PARQUET, uri);
+             Dataset dataset = datasetFactory.finish();
+             Scanner scanner = dataset.newScan(options);
+             ArrowReader reader = scanner.scanBatches()) {
           while (reader.loadNextBatch()) {
-            System.out.println(
-                reader.getVectorSchemaRoot().contentToTSVString());
+            System.out.println(reader.getVectorSchemaRoot().contentToTSVString());
           }
         } catch (Exception e) {
           throw new RuntimeException(e);
         }
       }
 
-      private static ByteBuffer getSubstraitExpressionProjection() {
-        // Expression: N_REGIONKEY + 10 = col 3 + 10
-        Expression.Builder selectionBuilderProjectOne = Expression.newBuilder().
-            setSelection(
-                Expression.FieldReference.newBuilder().
-                    setDirectReference(
-                        Expression.ReferenceSegment.newBuilder().
-                            setStructField(
-                                Expression.ReferenceSegment.StructField.newBuilder().setField(
-                                    2)
-                            )
-                    )
-            );
-        Expression.Builder literalBuilderProjectOne = Expression.newBuilder()
-            .setLiteral(
-                Expression.Literal.newBuilder().setI32(10)
-            );
-        io.substrait.proto.Type outputProjectOne = TypeCreator.NULLABLE.I32.accept(
-            new TypeProtoConverter(new ExtensionCollector()));
-        Expression.Builder expressionBuilderProjectOne = Expression.
-            newBuilder().
-            setScalarFunction(
-                Expression.
-                    ScalarFunction.
-                    newBuilder().
-                    setFunctionReference(0).
-                    setOutputType(outputProjectOne).
-                    addArguments(
-                        0,
-                        FunctionArgument.newBuilder().setValue(
-                            selectionBuilderProjectOne)
-                    ).
-                    addArguments(
-                        1,
-                        FunctionArgument.newBuilder().setValue(
-                            literalBuilderProjectOne)
-                    )
-            );
-        ExpressionReference.Builder expressionReferenceBuilderProjectOne = ExpressionReference.newBuilder().
-            setExpression(expressionBuilderProjectOne)
-            .addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY");
-
-        // Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 || col 3
-        Expression.Builder selectionBuilderProjectTwo = Expression.newBuilder().
-            setSelection(
-                Expression.FieldReference.newBuilder().
-                    setDirectReference(
-                        Expression.ReferenceSegment.newBuilder().
-                            setStructField(
-                                Expression.ReferenceSegment.StructField.newBuilder().setField(
-                                    1)
-                            )
-                    )
-            );
-        Expression.Builder selectionBuilderProjectTwoConcatLiteral = Expression.newBuilder()
-            .setLiteral(
-                Expression.Literal.newBuilder().setString(" - ")
-            );
-        Expression.Builder selectionBuilderProjectOneToConcat = Expression.newBuilder().
-            setSelection(
-                Expression.FieldReference.newBuilder().
-                    setDirectReference(
-                        Expression.ReferenceSegment.newBuilder().
-                            setStructField(
-                                Expression.ReferenceSegment.StructField.newBuilder().setField(
-                                    3)
-                            )
-                    )
-            );
-        io.substrait.proto.Type outputProjectTwo = TypeCreator.NULLABLE.STRING.accept(
-            new TypeProtoConverter(new ExtensionCollector()));
-        Expression.Builder expressionBuilderProjectTwo = Expression.
-            newBuilder().
-            setScalarFunction(
-                Expression.
-                    ScalarFunction.
-                    newBuilder().
-                    setFunctionReference(1).
-                    setOutputType(outputProjectTwo).
-                    addArguments(
-                        0,
-                        FunctionArgument.newBuilder().setValue(
-                            selectionBuilderProjectTwo)
-                    ).
-                    addArguments(
-                        1,
-                        FunctionArgument.newBuilder().setValue(
-                            selectionBuilderProjectTwoConcatLiteral)
-                    ).
-                    addArguments(
-                        2,
-                        FunctionArgument.newBuilder().setValue(
-                            selectionBuilderProjectOneToConcat)
-                    )
-            );
-        ExpressionReference.Builder expressionReferenceBuilderProjectTwo = ExpressionReference.newBuilder().
-            setExpression(expressionBuilderProjectTwo)
-            .addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT");
-
-        List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
-            "N_REGIONKEY", "N_COMMENT");
-        List<Type> dataTypes = Arrays.asList(
-            TypeCreator.NULLABLE.I32,
-            TypeCreator.NULLABLE.STRING,
-            TypeCreator.NULLABLE.I32,
-            TypeCreator.NULLABLE.STRING
-        );
-        NamedStruct of = NamedStruct.of(
-            columnNames,
-            Type.Struct.builder().fields(dataTypes).nullable(false).build()
-        );
-        // Extensions URI
-        HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
-        extensionUris.put(
-            "key-001",
-            SimpleExtensionURI.newBuilder()
-                .setExtensionUriAnchor(1)
-                .setUri("/functions_arithmetic.yaml")
-                .build()
-        );
-        // Extensions
-        ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
-        SimpleExtensionDeclaration extensionFunctionAdd = SimpleExtensionDeclaration.newBuilder()
-            .setExtensionFunction(
-                SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
-                    .setFunctionAnchor(0)
-                    .setName("add:i32_i32")
-                    .setExtensionUriReference(1))
-            .build();
-        SimpleExtensionDeclaration extensionFunctionGreaterThan = SimpleExtensionDeclaration.newBuilder()
-            .setExtensionFunction(
-                SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
-                    .setFunctionAnchor(1)
-                    .setName("concat:vchar")
-                    .setExtensionUriReference(2))
-            .build();
-        extensions.add(extensionFunctionAdd);
-        extensions.add(extensionFunctionGreaterThan);
-        // Extended Expression
-        ExtendedExpression.Builder extendedExpressionBuilder =
-            ExtendedExpression.newBuilder().
-                addReferredExpr(0,
-                    expressionReferenceBuilderProjectOne).
-                addReferredExpr(1,
-                    expressionReferenceBuilderProjectTwo).
-                setBaseSchema(of.toProto(new TypeProtoConverter(
-                    new ExtensionCollector())));
-        extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
-        extendedExpressionBuilder.addAllExtensions(extensions);
-        ExtendedExpression extendedExpression = extendedExpressionBuilder.build();
-        byte[] extendedExpressions = Base64.getDecoder().decode(
-            Base64.getEncoder().encodeToString(
-                extendedExpression.toByteArray()));
-        ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect(
-            extendedExpressions.length);
-        substraitExpressionProjection.put(extendedExpressions);
-        return substraitExpressionProjection;
-      }
-
-      private static ByteBuffer getSubstraitExpressionFilter() {
-        // Expression: Filter: N_NATIONKEY > 18 = col 1 > 18
-        Expression.Builder selectionBuilderFilterOne = Expression.newBuilder().
-            setSelection(
-                Expression.FieldReference.newBuilder().
-                    setDirectReference(
-                        Expression.ReferenceSegment.newBuilder().
-                            setStructField(
-                                Expression.ReferenceSegment.StructField.newBuilder().setField(
-                                    0)
-                            )
-                    )
-            );
-        Expression.Builder literalBuilderFilterOne = Expression.newBuilder()
-            .setLiteral(
-                Expression.Literal.newBuilder().setI32(18)
-            );
-        io.substrait.proto.Type outputFilterOne = TypeCreator.NULLABLE.BOOLEAN.accept(
-            new TypeProtoConverter(new ExtensionCollector()));
-        Expression.Builder expressionBuilderFilterOne = Expression.
-            newBuilder().
-            setScalarFunction(
-                Expression.
-                    ScalarFunction.
-                    newBuilder().
-                    setFunctionReference(1).
-                    setOutputType(outputFilterOne).
-                    addArguments(
-                        0,
-                        FunctionArgument.newBuilder().setValue(
-                            selectionBuilderFilterOne)
-                    ).
-                    addArguments(
-                        1,
-                        FunctionArgument.newBuilder().setValue(
-                            literalBuilderFilterOne)
-                    )
-            );
-        ExpressionReference.Builder expressionReferenceBuilderFilterOne = ExpressionReference.newBuilder().
-            setExpression(expressionBuilderFilterOne)
-            .addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18");
-
-        List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
-            "N_REGIONKEY", "N_COMMENT");
-        List<Type> dataTypes = Arrays.asList(
-            TypeCreator.NULLABLE.I32,
-            TypeCreator.NULLABLE.STRING,
-            TypeCreator.NULLABLE.I32,
-            TypeCreator.NULLABLE.STRING
-        );
-        NamedStruct of = NamedStruct.of(
-            columnNames,
-            Type.Struct.builder().fields(dataTypes).nullable(false).build()
-        );
-        // Extensions URI
-        HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
-        extensionUris.put(
-            "key-001",
-            SimpleExtensionURI.newBuilder()
-                .setExtensionUriAnchor(1)
-                .setUri("/functions_comparison.yaml")
-                .build()
-        );
-        // Extensions
-        ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
-        SimpleExtensionDeclaration extensionFunctionLowerThan = SimpleExtensionDeclaration.newBuilder()
-            .setExtensionFunction(
-                SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
-                    .setFunctionAnchor(1)
-                    .setName("gt:any_any")
-                    .setExtensionUriReference(1))
-            .build();
-        extensions.add(extensionFunctionLowerThan);
-        // Extended Expression
-        ExtendedExpression.Builder extendedExpressionBuilder =
-            ExtendedExpression.newBuilder().
-                addReferredExpr(0,
-                    expressionReferenceBuilderFilterOne).
-                setBaseSchema(of.toProto(new TypeProtoConverter(
-                    new ExtensionCollector())));
-        extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
-        extendedExpressionBuilder.addAllExtensions(extensions);
-        ExtendedExpression extendedExpression = extendedExpressionBuilder.build();
-        byte[] extendedExpressions = Base64.getDecoder().decode(
-            Base64.getEncoder().encodeToString(
-                extendedExpression.toByteArray()));
-        ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(
-            extendedExpressions.length);
-        substraitExpressionFilter.put(extendedExpressions);
-        return substraitExpressionFilter;
+      private static ByteBuffer getByteBuffer(String[] sqlExpression) throws SqlParseException {
+        String schema =
+            "CREATE TABLE NATION (N_NATIONKEY INT NOT NULL, N_NAME VARCHAR, "
+                + "N_REGIONKEY INT NOT NULL, N_COMMENT VARCHAR)";
+        SqlExpressionToSubstrait expressionToSubstrait = new SqlExpressionToSubstrait();
+        ExtendedExpression expression =
+            expressionToSubstrait.convert(sqlExpression, ImmutableList.of(schema));
+        byte[] expressionToByte =
+            Base64.getDecoder().decode(Base64.getEncoder().encodeToString(expression.toByteArray()));
+        ByteBuffer byteBuffer = ByteBuffer.allocateDirect(expressionToByte.length);
+        byteBuffer.put(expressionToByte);
+        return byteBuffer;
       }
     }
 
 .. code-block:: text
 
-    ADD_TEN_TO_COLUMN_N_REGIONKEY	CONCAT_COLUMNS_N_NAME_AND_N_COMMENT
+    column-1	column-2
     13	ROMANIA - ular asymptotes are about the furious multipliers. express dependencies nag above the ironically ironic account
     14	SAUDI ARABIA - ts. silent requests haggle. closely express packages sleep across the blithely
     12	VIETNAM - hely enticingly express accounts. even, final