You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "icexelloss (via GitHub)" <gi...@apache.org> on 2023/02/27 19:48:10 UTC

[GitHub] [arrow] icexelloss commented on a diff in pull request #34373: GH-34366: [Python] Test run_query with a registered scalar UDF

icexelloss commented on code in PR #34373:
URL: https://github.com/apache/arrow/pull/34373#discussion_r1119235870


##########
python/pyarrow/tests/test_udf.py:
##########
@@ -613,3 +614,180 @@ def test_udt_datasource1_generator():
 def test_udt_datasource1_exception():
     with pytest.raises(RuntimeError, match='datasource1_exception'):
         _test_datasource1_udt(datasource1_exception)
+
+
+@pytest.mark.parametrize("use_threads", [True, False])
+def test_udf_via_substrait(unary_func_fixture, use_threads):
+    test_table_1 = pa.Table.from_pydict({"t": [1, 2, 3], "p": [4, 5, 6]})
+
+    def table_provider(names, _):
+        if not names:
+            raise Exception("No names provided")
+        elif names[0] == "t1":
+            return test_table_1
+        else:
+            raise Exception("Unrecognized table name")
+
+    substrait_query = """
+    {
+  "extensionUris": [
+    {
+      "extensionUriAnchor": 1
+    },
+    {
+      "extensionUriAnchor": 2,
+      "uri": "urn:arrow:substrait_simple_extension_function"
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "extensionUriReference": 2,
+        "functionAnchor": 1,
+        "name": "y=x+1"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "project": {
+            "common": {
+              "emit": {
+                "outputMapping": [
+                  2,
+                  3,
+                  4
+                ]
+              }
+            },
+            "input": {
+              "project": {
+                "common": {
+                  "emit": {
+                    "outputMapping": [
+                      2,
+                      3
+                    ]
+                  }
+                },
+                "input": {
+                  "read": {
+                    "baseSchema": {
+                      "names": [
+                        "t",
+                        "p"
+                      ],
+                      "struct": {
+                        "types": [
+                          {
+                            "i64": {
+                              "nullability": "NULLABILITY_REQUIRED"
+                            }
+                          },
+                          {
+                            "i64": {
+                              "nullability": "NULLABILITY_NULLABLE"
+                            }
+                          }
+                        ],
+                        "nullability": "NULLABILITY_REQUIRED"
+                      }
+                    },
+                    "namedTable": {
+                      "names": [
+                        "t1"
+                      ]
+                    }
+                  }
+                },
+                "expressions": [
+                  {
+                    "selection": {
+                      "directReference": {
+                        "structField": {}
+                      },
+                      "rootReference": {}
+                    }
+                  },
+                  {
+                    "selection": {
+                      "directReference": {
+                        "structField": {
+                          "field": 1
+                        }
+                      },
+                      "rootReference": {}
+                    }
+                  }
+                ]
+              }
+            },
+            "expressions": [
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {
+                      "field": 1
+                    }
+                  },
+                  "rootReference": {}
+                }
+              },
+              {
+                "scalarFunction": {
+                  "functionReference": 1,
+                  "outputType": {
+                    "i64": {
+                      "nullability": "NULLABILITY_NULLABLE"
+                    }
+                  },
+                  "arguments": [
+                    {
+                      "value": {
+                        "selection": {
+                          "directReference": {
+                            "structField": {
+                              "field": 1
+                            }
+                          },
+                          "rootReference": {}
+                        }
+                      }
+                    }
+                  ]
+                }
+              }
+            ]
+          }
+        },
+        "names": [
+          "t",
+          "p",
+          "p2"
+        ]
+      }
+    }
+  ]
+}
+    """
+
+    buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
+    reader = pa.substrait.run_query(
+        buf, table_provider=table_provider, use_threads=use_threads)
+    res_tb = reader.read_all()
+
+    function, name = unary_func_fixture
+    expected_tb = test_table_1.add_column(2, 'p2', function(
+        mock_scalar_udf_context(10), test_table_1['p']))
+    res_tb = res_tb.rename_columns(['t', 'p', 'p2'])

Review Comment:
   This is because a known bug that Acero doesn't name the result columns correctly



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org