You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/21 20:40:16 UTC

(arrow-datafusion) branch main updated: improve file path validation when reading parquet (#8267)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 952e7c302b improve file path validation when reading parquet (#8267)
952e7c302b is described below

commit 952e7c302bcdc05d090fe334269f41705f28ceea
Author: Alex Huang <hu...@gmail.com>
AuthorDate: Tue Nov 21 21:40:09 2023 +0100

    improve file path validation when reading parquet (#8267)
    
    * improve file path validation
    
    * fix cli
    
    * update test
    
    * update test
---
 datafusion/core/src/execution/context/mod.rs     |  8 +--
 datafusion/core/src/execution/context/parquet.rs | 69 ++++++++++++++++++++++--
 2 files changed, 71 insertions(+), 6 deletions(-)

diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs
index b8e111d361..f829092570 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -858,10 +858,12 @@ impl SessionContext {
 
         // check if the file extension matches the expected extension
         for path in &table_paths {
-            let file_name = path.prefix().filename().unwrap_or_default();
-            if !path.as_str().ends_with(&option_extension) && file_name.contains('.') {
+            let file_path = path.as_str();
+            if !file_path.ends_with(option_extension.clone().as_str())
+                && !path.is_collection()
+            {
                 return exec_err!(
-                    "File '{file_name}' does not match the expected extension '{option_extension}'"
+                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
                 );
             }
         }
diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs
index 821b1ccf18..5d649d3e6d 100644
--- a/datafusion/core/src/execution/context/parquet.rs
+++ b/datafusion/core/src/execution/context/parquet.rs
@@ -138,10 +138,10 @@ mod tests {
         Ok(())
     }
 
-    #[cfg(not(target_family = "windows"))]
     #[tokio::test]
     async fn read_from_different_file_extension() -> Result<()> {
         let ctx = SessionContext::new();
+        let sep = std::path::MAIN_SEPARATOR.to_string();
 
         // Make up a new dataframe.
         let write_df = ctx.read_batch(RecordBatch::try_new(
@@ -175,6 +175,25 @@ mod tests {
             .unwrap()
             .to_string();
 
+        let path4 = temp_dir_path
+            .join("output4.parquet".to_owned() + &sep)
+            .to_str()
+            .unwrap()
+            .to_string();
+
+        let path5 = temp_dir_path
+            .join("bbb..bbb")
+            .join("filename.parquet")
+            .to_str()
+            .unwrap()
+            .to_string();
+        let dir = temp_dir_path
+            .join("bbb..bbb".to_owned() + &sep)
+            .to_str()
+            .unwrap()
+            .to_string();
+        std::fs::create_dir(dir).expect("create dir failed");
+
         // Write the dataframe to a parquet file named 'output1.parquet'
         write_df
             .clone()
@@ -205,6 +224,7 @@ mod tests {
 
         // Write the dataframe to a parquet file named 'output3.parquet.snappy.parquet'
         write_df
+            .clone()
             .write_parquet(
                 &path3,
                 DataFrameWriteOptions::new().with_single_file_output(true),
@@ -216,6 +236,19 @@ mod tests {
             )
             .await?;
 
+        // Write the dataframe to a parquet file named 'bbb..bbb/filename.parquet'
+        write_df
+            .write_parquet(
+                &path5,
+                DataFrameWriteOptions::new().with_single_file_output(true),
+                Some(
+                    WriterProperties::builder()
+                        .set_compression(Compression::SNAPPY)
+                        .build(),
+                ),
+            )
+            .await?;
+
         // Read the dataframe from 'output1.parquet' with the default file extension.
         let read_df = ctx
             .read_parquet(
@@ -253,10 +286,11 @@ mod tests {
                 },
             )
             .await;
-
+        let binding = DataFilePaths::to_urls(&path2).unwrap();
+        let expexted_path = binding[0].as_str();
         assert_eq!(
             read_df.unwrap_err().strip_backtrace(),
-            "Execution error: File 'output2.parquet.snappy' does not match the expected extension '.parquet'"
+            format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expexted_path)
         );
 
         // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension.
@@ -269,6 +303,35 @@ mod tests {
             )
             .await?;
 
+        let results = read_df.collect().await?;
+        let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
+        assert_eq!(total_rows, 5);
+
+        // Read the dataframe from 'output4/'
+        std::fs::create_dir(&path4)?;
+        let read_df = ctx
+            .read_parquet(
+                &path4,
+                ParquetReadOptions {
+                    ..Default::default()
+                },
+            )
+            .await?;
+
+        let results = read_df.collect().await?;
+        let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
+        assert_eq!(total_rows, 0);
+
+        // Read the datafram from doule dot folder;
+        let read_df = ctx
+            .read_parquet(
+                &path5,
+                ParquetReadOptions {
+                    ..Default::default()
+                },
+            )
+            .await?;
+
         let results = read_df.collect().await?;
         let total_rows: usize = results.iter().map(|rb| rb.num_rows()).sum();
         assert_eq!(total_rows, 5);