You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2022/08/22 05:52:58 UTC
[incubator-sedona] branch master updated: [SEDONA-94] GeoParquet Reader Writer (#652)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 547ff404 [SEDONA-94] GeoParquet Reader Writer (#652)
547ff404 is described below
commit 547ff40466b79fc9f54f771afa78306f85fefefd
Author: ashar <96...@users.noreply.github.com>
AuthorDate: Sun Aug 21 22:52:52 2022 -0700
[SEDONA-94] GeoParquet Reader Writer (#652)
Co-authored-by: Jia Yu <ji...@apache.org>
Co-authored-by: Jia Yu <ji...@gmail.com>
---
.github/workflows/java.yml | 1 +
.../src/test/resources/geoparquet/example1.parquet | Bin 0 -> 27797 bytes
.../src/test/resources/geoparquet/example2.parquet | Bin 0 -> 3202 bytes
.../src/test/resources/geoparquet/example3.parquet | Bin 0 -> 943485 bytes
docs/tutorial/sql.md | 45 ++
...org.apache.spark.sql.sources.DataSourceRegister | 3 +-
.../datasources/parquet/GeoDataSourceUtils.scala | 89 +++
.../datasources/parquet/GeoDateTimeUtils.scala | 38 ++
.../datasources/parquet/GeoParquetFileFormat.scala | 272 +++++++++
.../datasources/parquet/GeoParquetFilters.scala | 640 +++++++++++++++++++
.../datasources/parquet/GeoParquetOptions.scala | 27 +
.../parquet/GeoParquetReadSupport.scala | 392 ++++++++++++
.../parquet/GeoParquetRecordMaterializer.scala | 59 ++
.../parquet/GeoParquetRowConverter.scala | 675 +++++++++++++++++++++
.../parquet/GeoParquetSchemaConverter.scala | 576 ++++++++++++++++++
.../datasources/parquet/GeoParquetUtils.scala | 80 +++
.../datasources/parquet/GeoSchemaMergeUtils.scala | 102 ++++
.../datasources/parquet/GeometryField.scala | 28 +
.../org/apache/sedona/sql/geoparquetIOTests.scala | 66 ++
19 files changed, 3092 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index c8712a95..247d592b 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-18.04
strategy:
+ fail-fast: true
matrix:
include:
- spark: 3.3.0
diff --git a/core/src/test/resources/geoparquet/example1.parquet b/core/src/test/resources/geoparquet/example1.parquet
new file mode 100644
index 00000000..5e89f107
Binary files /dev/null and b/core/src/test/resources/geoparquet/example1.parquet differ
diff --git a/core/src/test/resources/geoparquet/example2.parquet b/core/src/test/resources/geoparquet/example2.parquet
new file mode 100644
index 00000000..a1f0105f
Binary files /dev/null and b/core/src/test/resources/geoparquet/example2.parquet differ
diff --git a/core/src/test/resources/geoparquet/example3.parquet b/core/src/test/resources/geoparquet/example3.parquet
new file mode 100644
index 00000000..ff3c4c3e
Binary files /dev/null and b/core/src/test/resources/geoparquet/example3.parquet differ
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 56924bfa..cb9de422 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -136,6 +136,43 @@ root
Shapefile and GeoJSON must be loaded by SpatialRDD and converted to DataFrame using Adapter. Please read [Load SpatialRDD](../rdd/#create-a-generic-spatialrdd) and [DataFrame <-> RDD](#convert-between-dataframe-and-spatialrdd).
+## Load GeoParquet
+
+GeoParquet must be loaded using DataFrame if default name is geometry.
+
+```Scala
+val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1)
+df.printSchema()
+```
+The output will be as follows:
+
+```
+root
+ |-- pop_est: long (nullable = true)
+ |-- continent: string (nullable = true)
+ |-- name: string (nullable = true)
+ |-- iso_a3: string (nullable = true)
+ |-- gdp_md_est: double (nullable = true)
+ |-- geometry: geometry (nullable = true)
+```
+If geometry column name is different
+
+```Scala
+var df = sparkSession.read.format("geoparquet").option("fieldGeometry", "new_geometry").load(geoparquetdatalocation1)
+```
+
+The output will be as follows:
+
+```
+root
+ |-- pop_est: long (nullable = true)
+ |-- continent: string (nullable = true)
+ |-- name: string (nullable = true)
+ |-- iso_a3: string (nullable = true)
+ |-- gdp_md_est: double (nullable = true)
+ |-- new_geometry: geometry (nullable = true)
+```
+
## Transform the Coordinate Reference System
@@ -240,6 +277,14 @@ var stringDf = sparkSession.sql(
ST_AsGeoJSON is also available. We would like to invite you to contribute more functions
+## Save GeoParquet
+
+GeoParquet can be saved as follows
+
+```Scala
+df.write.format("geoparquet").save(geoparquetoutputlocation + "/GeoParquet_File_Name.parquet")
+```
+
## Convert between DataFrame and SpatialRDD
### DataFrame to SpatialRDD
diff --git a/sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 6592fb0a..68ea723a 100644
--- a/sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1 +1,2 @@
-org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat
\ No newline at end of file
+org.apache.spark.sql.sedona_sql.io.GeotiffFileFormat
+org.apache.spark.sql.execution.datasources.parquet.GeoParquetFileFormat
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala
new file mode 100644
index 00000000..c6ee8e3c
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDataSourceUtils.scala
@@ -0,0 +1,89 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.catalyst.util.RebaseDateTime
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.util.Utils
+
+// Needed by Sedona to support Spark 3.0 - 3.3
+object GeoDataSourceUtils {
+ def datetimeRebaseMode(
+ lookupFileMeta: String => String,
+ modeByConfig: String): LegacyBehaviorPolicy.Value = {
+ if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") {
+ return LegacyBehaviorPolicy.CORRECTED
+ }
+ // If there is no version, we return the mode specified by the config.
+ Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version =>
+ // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to
+ // rebase the datetime values.
+ // Files written by Spark 3.0 and latter may also need the rebase if they were written with
+ // the "LEGACY" rebase mode.
+ if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) {
+ LegacyBehaviorPolicy.LEGACY
+ } else {
+ LegacyBehaviorPolicy.CORRECTED
+ }
+ }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig))
+ }
+
+ def int96RebaseMode(
+ lookupFileMeta: String => String,
+ modeByConfig: String): LegacyBehaviorPolicy.Value = {
+ if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") {
+ return LegacyBehaviorPolicy.CORRECTED
+ }
+ // If there is no version, we return the mode specified by the config.
+ Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version =>
+ // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to
+ // rebase the INT96 timestamp values.
+ // Files written by Spark 3.1 and latter may also need the rebase if they were written with
+ // the "LEGACY" rebase mode.
+ if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) {
+ LegacyBehaviorPolicy.LEGACY
+ } else {
+ LegacyBehaviorPolicy.CORRECTED
+ }
+ }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig))
+ }
+
+ def creteDateRebaseFuncInRead(
+ rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Int => Int = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => days: Int =>
+ if (days < RebaseDateTime.lastSwitchJulianDay) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ days
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays
+ case LegacyBehaviorPolicy.CORRECTED => identity[Int]
+ }
+
+ def creteTimestampRebaseFuncInRead(
+ rebaseMode: LegacyBehaviorPolicy.Value,
+ format: String): Long => Long = rebaseMode match {
+ case LegacyBehaviorPolicy.EXCEPTION => micros: Long =>
+ if (micros < RebaseDateTime.lastSwitchJulianTs) {
+ throw DataSourceUtils.newRebaseExceptionInRead(format)
+ }
+ micros
+ case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros
+ case LegacyBehaviorPolicy.CORRECTED => identity[Long]
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala
new file mode 100644
index 00000000..7a661542
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoDateTimeUtils.scala
@@ -0,0 +1,38 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
+
+// Needed by Sedona to support Spark 3.0 - 3.3
+object GeoDateTimeUtils {
+ /**
+ * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have microseconds
+ * precision, so this conversion is lossy.
+ */
+ def microsToMillis(micros: Long): Long = {
+ // When the timestamp is negative i.e before 1970, we need to adjust the millseconds portion.
+ // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision.
+ // In millis precision the above needs to be represented as (-157700927877).
+ Math.floorDiv(micros, MICROS_PER_MILLIS)
+ }
+
+ /**
+ * Converts milliseconds since the epoch to microseconds.
+ */
+ def millisToMicros(millis: Long): Long = {
+ Math.multiplyExact(millis, MICROS_PER_MILLIS)
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
new file mode 100644
index 00000000..93d0f5b1
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFileFormat.scala
@@ -0,0 +1,272 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+import org.apache.parquet.filter2.compat.FilterCompat
+import org.apache.parquet.filter2.predicate.FilterApi
+import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
+import org.apache.parquet.hadoop._
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.readParquetFootersInParallel
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.SerializableConfiguration
+
+import java.net.URI
+import java.util.NoSuchElementException
+
+class GeoParquetFileFormat extends ParquetFileFormat with FileFormat with DataSourceRegister with Logging with Serializable {
+
+ override def shortName(): String = "geoparquet"
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[GeoParquetFileFormat]
+
+ override def inferSchema(
+ sparkSession: SparkSession,
+ parameters: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ val fieldGeometry = new GeoParquetOptions(parameters).fieldGeometry
+ GeometryField.setFieldGeometry(fieldGeometry)
+ GeoParquetUtils.inferSchema(sparkSession, parameters, files)
+ }
+
+ override def buildReaderWithPartitionValues(
+ sparkSession: SparkSession,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String],
+ hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
+ hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
+ hadoopConf.set(
+ ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
+ requiredSchema.json)
+ hadoopConf.set(
+ ParquetWriteSupport.SPARK_ROW_SCHEMA,
+ requiredSchema.json)
+ hadoopConf.set(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
+ hadoopConf.setBoolean(
+ SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
+ sparkSession.sessionState.conf.nestedSchemaPruningEnabled)
+ hadoopConf.setBoolean(
+ SQLConf.CASE_SENSITIVE.key,
+ sparkSession.sessionState.conf.caseSensitiveAnalysis)
+
+ ParquetWriteSupport.setSchema(requiredSchema, hadoopConf)
+
+ // Sets flags for `ParquetToSparkSchemaConverter`
+ hadoopConf.setBoolean(
+ SQLConf.PARQUET_BINARY_AS_STRING.key,
+ sparkSession.sessionState.conf.isParquetBinaryAsString)
+ hadoopConf.setBoolean(
+ SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
+ sparkSession.sessionState.conf.isParquetINT96AsTimestamp)
+
+ val broadcastedHadoopConf =
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+
+ // TODO: if you move this into the closure it reverts to the default values.
+ // If true, enable using the custom RecordReader for parquet. This only works for
+ // a subset of the types (no complex types).
+ val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields)
+ val sqlConf = sparkSession.sessionState.conf
+ val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
+ val enableVectorizedReader: Boolean =
+ sqlConf.parquetVectorizedReaderEnabled &&
+ resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
+ val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled
+ val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion
+ val capacity = sqlConf.parquetVectorizedReaderBatchSize
+ val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown
+ // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
+ val returningBatch = supportBatch(sparkSession, resultSchema)
+ val pushDownDate = sqlConf.parquetFilterPushDownDate
+ val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp
+ val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal
+ val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith
+ val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold
+ val isCaseSensitive = sqlConf.caseSensitiveAnalysis
+
+ (file: PartitionedFile) => {
+ assert(file.partitionValues.numFields == partitionSchema.size)
+
+ val filePath = new Path(new URI(file.filePath))
+ val split =
+ new org.apache.parquet.hadoop.ParquetInputSplit(
+ filePath,
+ file.start,
+ file.start + file.length,
+ file.length,
+ Array.empty,
+ null)
+
+ val sharedConf = broadcastedHadoopConf.value.value
+
+ lazy val footerFileMetaData =
+ ParquetFileReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData
+ // Try to push down filters when filter push-down is enabled.
+ val pushed = if (enableParquetFilterPushDown) {
+ val parquetSchema = footerFileMetaData.getSchema
+ val parquetFilters = new GeoParquetFilters(parquetSchema, pushDownDate, pushDownTimestamp,
+ pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive)
+ filters
+ // Collects all converted Parquet filter predicates. Notice that not all predicates can be
+ // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
+ // is used here.
+ .flatMap(parquetFilters.createFilter(_))
+ .reduceOption(FilterApi.and)
+ } else {
+ None
+ }
+
+ // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps'
+ // *only* if the file was created by something other than "parquet-mr", so check the actual
+ // writer here for this file. We have to do this per-file, as each file in the table may
+ // have different writers.
+ // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads.
+ def isCreatedByParquetMr: Boolean =
+ footerFileMetaData.getCreatedBy().startsWith("parquet-mr")
+
+ val convertTz =
+ if (timestampConversion && !isCreatedByParquetMr) {
+ Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key)))
+ } else {
+ None
+ }
+ var dateTimeRebaseModeConf:String = ""
+ var int96RebaseModeConf:String = ""
+ try {
+ // Try Spark 3.2+ style
+ dateTimeRebaseModeConf = SQLConf.get.getConfString("spark.sql.parquet.datetimeRebaseModeInRead")
+ int96RebaseModeConf = SQLConf.get.getConfString("spark.sql.parquet.int96RebaseModeInRead")
+ } catch {
+ case e1:NoSuchElementException => {
+ // Try Spark 3.1 style
+ dateTimeRebaseModeConf = SQLConf.get.getConfString("spark.sql.legacy.parquet.datetimeRebaseModeInRead")
+ try {
+ int96RebaseModeConf = SQLConf.get.getConfString("spark.sql.legacy.parquet.int96RebaseModeInRead")
+ }
+ catch {
+ case e2: NoSuchElementException => {
+ // This is Spark 3.1 style. Assume int96 mode is same as dateTime
+ int96RebaseModeConf = dateTimeRebaseModeConf
+ }
+ }
+ }
+ }
+ val datetimeRebaseMode = GeoDataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ dateTimeRebaseModeConf)
+ val int96RebaseMode = GeoDataSourceUtils.int96RebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ int96RebaseModeConf)
+
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext =
+ new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId)
+
+ // Try to push down filters when filter push-down is enabled.
+ // Notice: This push-down is RowGroups level, not individual records.
+ if (pushed.isDefined) {
+ ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get)
+ }
+ val taskContext = Option(TaskContext.get())
+ if (enableVectorizedReader) {
+ logWarning(s"GeoParquet currently does not support vectorized reader. Falling back to parquet-mr")
+ }
+ logDebug(s"Falling back to parquet-mr")
+ // ParquetRecordReader returns InternalRow
+ val readSupport = new GeoParquetReadSupport(
+ convertTz,
+ enableVectorizedReader = false,
+ datetimeRebaseMode,
+ int96RebaseMode)
+ val reader = if (pushed.isDefined && enableRecordFilter) {
+ val parquetFilter = FilterCompat.get(pushed.get, null)
+ new ParquetRecordReader[InternalRow](readSupport, parquetFilter)
+ } else {
+ new ParquetRecordReader[InternalRow](readSupport)
+ }
+ val iter = new RecordReaderIterator[InternalRow](reader)
+ // SPARK-23457 Register a task completion listener before `initialization`.
+ taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
+ reader.initialize(split, hadoopAttemptContext)
+
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+
+ if (partitionSchema.length == 0) {
+ // There is no partition columns
+ iter.map(unsafeProjection)
+ } else {
+ val joinedRow = new JoinedRow()
+ iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues)))
+ }
+ }
+ }
+
+ override def supportDataType(dataType: DataType): Boolean = super.supportDataType(dataType)
+}
+
+object GeoParquetFileFormat extends Logging {
+
+ /**
+ * Figures out a merged Parquet schema with a distributed Spark job.
+ *
+ * Note that locality is not taken into consideration here because:
+ *
+ * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of
+ * that file. Thus we only need to retrieve the location of the last block. However, Hadoop
+ * `FileSystem` only provides API to retrieve locations of all blocks, which can be
+ * potentially expensive.
+ *
+ * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty
+ * slow. And basically locality is not available when using S3 (you can't run computation on
+ * S3 nodes).
+ */
+ def mergeSchemasInParallel(
+ parameters: Map[String, String],
+ filesToTouch: Seq[FileStatus],
+ sparkSession: SparkSession): Option[StructType] = {
+ val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString
+ val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp
+
+ val reader = (files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean) => {
+ // Converter used to convert Parquet `MessageType` to Spark SQL `StructType`
+ val converter = new GeoParquetToSparkSchemaConverter(
+ assumeBinaryIsString = assumeBinaryIsString,
+ assumeInt96IsTimestamp = assumeInt96IsTimestamp)
+ readParquetFootersInParallel(conf, files, ignoreCorruptFiles)
+ .map(ParquetFileFormat.readSchemaFromFooter(_, converter))
+ }
+
+ GeoSchemaMergeUtils.mergeSchemasInParallel(sparkSession, parameters, filesToTouch, reader)
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala
new file mode 100644
index 00000000..e8964a09
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetFilters.scala
@@ -0,0 +1,640 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong}
+import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time.{Instant, LocalDate}
+import java.util.Locale
+
+import scala.collection.JavaConverters.asScalaBufferConverter
+
+import org.apache.parquet.filter2.predicate._
+import org.apache.parquet.filter2.predicate.SparkFilterApi._
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{DecimalMetadata, GroupType, MessageType, OriginalType, PrimitiveComparator, PrimitiveType, Type}
+import org.apache.parquet.schema.OriginalType._
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.sources
+import org.apache.spark.unsafe.types.UTF8String
+
+// Needed by Sedona to support Spark 3.0 - 3.3
+/**
+ * Some utility function to convert Spark data source filters to Parquet filters.
+ */
+class GeoParquetFilters(
+ schema: MessageType,
+ pushDownDate: Boolean,
+ pushDownTimestamp: Boolean,
+ pushDownDecimal: Boolean,
+ pushDownStartWith: Boolean,
+ pushDownInFilterThreshold: Int,
+ caseSensitive: Boolean) {
+ // A map which contains parquet field name and data type, if predicate push down applies.
+ //
+ // Each key in `nameToParquetField` represents a column; `dots` are used as separators for
+ // nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
+ // See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
+ private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
+ // Recursively traverse the parquet schema to get primitive fields that can be pushed-down.
+ // `parentFieldNames` is used to keep track of the current nested level when traversing.
+ def getPrimitiveFields(
+ fields: Seq[Type],
+ parentFieldNames: Array[String] = Array.empty): Seq[ParquetPrimitiveField] = {
+ fields.flatMap {
+ case p: PrimitiveType =>
+ Some(ParquetPrimitiveField(fieldNames = parentFieldNames :+ p.getName,
+ fieldType = ParquetSchemaType(p.getOriginalType,
+ p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)))
+ // Note that when g is a `Struct`, `g.getOriginalType` is `null`.
+ // When g is a `Map`, `g.getOriginalType` is `MAP`.
+ // When g is a `List`, `g.getOriginalType` is `LIST`.
+ case g: GroupType if g.getOriginalType == null =>
+ getPrimitiveFields(g.getFields.asScala.toSeq, parentFieldNames :+ g.getName)
+ // Parquet only supports push-down for primitive types; as a result, Map and List types
+ // are removed.
+ case _ => None
+ }
+ }
+
+ val primitiveFields = getPrimitiveFields(schema.getFields.asScala.toSeq).map { field =>
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ (field.fieldNames.toSeq.quoted, field)
+ }
+ if (caseSensitive) {
+ primitiveFields.toMap
+ } else {
+ // Don't consider ambiguity here, i.e. more than one field is matched in case insensitive
+ // mode, just skip pushdown for these fields, they will trigger Exception when reading,
+ // See: SPARK-25132.
+ val dedupPrimitiveFields =
+ primitiveFields
+ .groupBy(_._1.toLowerCase(Locale.ROOT))
+ .filter(_._2.size == 1)
+ .mapValues(_.head._2)
+ CaseInsensitiveMap(dedupPrimitiveFields.toMap)
+ }
+ }
+
+ /**
+ * Holds a single primitive field information stored in the underlying parquet file.
+ *
+ * @param fieldNames a field name as an array of string multi-identifier in parquet file
+ * @param fieldType field type related info in parquet file
+ */
+ private case class ParquetPrimitiveField(
+ fieldNames: Array[String],
+ fieldType: ParquetSchemaType)
+
+ private case class ParquetSchemaType(
+ originalType: OriginalType,
+ primitiveTypeName: PrimitiveTypeName,
+ length: Int,
+ decimalMetadata: DecimalMetadata)
+
+ private val ParquetBooleanType = ParquetSchemaType(null, BOOLEAN, 0, null)
+ private val ParquetByteType = ParquetSchemaType(INT_8, INT32, 0, null)
+ private val ParquetShortType = ParquetSchemaType(INT_16, INT32, 0, null)
+ private val ParquetIntegerType = ParquetSchemaType(null, INT32, 0, null)
+ private val ParquetLongType = ParquetSchemaType(null, INT64, 0, null)
+ private val ParquetFloatType = ParquetSchemaType(null, FLOAT, 0, null)
+ private val ParquetDoubleType = ParquetSchemaType(null, DOUBLE, 0, null)
+ private val ParquetStringType = ParquetSchemaType(UTF8, BINARY, 0, null)
+ private val ParquetBinaryType = ParquetSchemaType(null, BINARY, 0, null)
+ private val ParquetDateType = ParquetSchemaType(DATE, INT32, 0, null)
+ private val ParquetTimestampMicrosType = ParquetSchemaType(TIMESTAMP_MICROS, INT64, 0, null)
+ private val ParquetTimestampMillisType = ParquetSchemaType(TIMESTAMP_MILLIS, INT64, 0, null)
+
+ private def dateToDays(date: Any): Int = date match {
+ case d: Date => DateTimeUtils.fromJavaDate(d)
+ case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
+ }
+
+ private def timestampToMicros(v: Any): JLong = v match {
+ case i: Instant => DateTimeUtils.instantToMicros(i)
+ case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
+ }
+
+ private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue()
+
+ private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue()
+
+ private def decimalToByteArray(decimal: JBigDecimal, numBytes: Int): Binary = {
+ val decimalBuffer = new Array[Byte](numBytes)
+ val bytes = decimal.unscaledValue().toByteArray
+
+ val fixedLengthBytes = if (bytes.length == numBytes) {
+ bytes
+ } else {
+ val signByte = if (bytes.head < 0) -1: Byte else 0: Byte
+ java.util.Arrays.fill(decimalBuffer, 0, numBytes - bytes.length, signByte)
+ System.arraycopy(bytes, 0, decimalBuffer, numBytes - bytes.length, bytes.length)
+ decimalBuffer
+ }
+ Binary.fromConstantByteArray(fixedLengthBytes, 0, numBytes)
+ }
+
+ private def timestampToMillis(v: Any): JLong = {
+ val micros = timestampToMicros(v)
+ val millis = GeoDateTimeUtils.microsToMillis(micros)
+ millis.asInstanceOf[JLong]
+ }
+
+ private val makeEq:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
+ (n: Array[String], v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[JBoolean])
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ // Binary.fromString and Binary.fromByteArray don't accept null values
+ case ParquetStringType =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ longColumn(n),
+ Option(v).map(timestampToMicros).orNull)
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ longColumn(n),
+ Option(v).map(timestampToMillis).orNull)
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ intColumn(n),
+ Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ longColumn(n),
+ Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.eq(
+ binaryColumn(n),
+ Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull)
+ }
+
+ private val makeNotEq:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetBooleanType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[JBoolean])
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(_.asInstanceOf[Number].intValue.asInstanceOf[Integer]).orNull)
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull)
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(b => Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]])).orNull)
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(date => dateToDays(date).asInstanceOf[Integer]).orNull)
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ longColumn(n),
+ Option(v).map(timestampToMicros).orNull)
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ longColumn(n),
+ Option(v).map(timestampToMillis).orNull)
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ intColumn(n),
+ Option(v).map(d => decimalToInt32(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ longColumn(n),
+ Option(v).map(d => decimalToInt64(d.asInstanceOf[JBigDecimal])).orNull)
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) => FilterApi.notEq(
+ binaryColumn(n),
+ Option(v).map(d => decimalToByteArray(d.asInstanceOf[JBigDecimal], length)).orNull)
+ }
+
+ private val makeLt:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.lt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeLtEq:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeGt:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gt(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ private val makeGtEq:
+ PartialFunction[ParquetSchemaType, (Array[String], Any) => FilterPredicate] = {
+ case ParquetByteType | ParquetShortType | ParquetIntegerType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), v.asInstanceOf[Number].intValue.asInstanceOf[Integer])
+ case ParquetLongType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[JLong])
+ case ParquetFloatType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[JFloat])
+ case ParquetDoubleType =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[JDouble])
+
+ case ParquetStringType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
+ case ParquetBinaryType =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromReusedByteArray(v.asInstanceOf[Array[Byte]]))
+ case ParquetDateType if pushDownDate =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
+ case ParquetTimestampMicrosType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v))
+ case ParquetTimestampMillisType if pushDownTimestamp =>
+ (n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v))
+
+ case ParquetSchemaType(DECIMAL, INT32, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(intColumn(n), decimalToInt32(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, INT64, _, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(longColumn(n), decimalToInt64(v.asInstanceOf[JBigDecimal]))
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, length, _) if pushDownDecimal =>
+ (n: Array[String], v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), decimalToByteArray(v.asInstanceOf[JBigDecimal], length))
+ }
+
+ // Returns filters that can be pushed down when reading Parquet files.
+ def convertibleFilters(filters: Seq[sources.Filter]): Seq[sources.Filter] = {
+ filters.flatMap(convertibleFiltersHelper(_, canPartialPushDown = true))
+ }
+
+ private def convertibleFiltersHelper(
+ predicate: sources.Filter,
+ canPartialPushDown: Boolean): Option[sources.Filter] = {
+ predicate match {
+ case sources.And(left, right) =>
+ val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown)
+ val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown)
+ (leftResultOptional, rightResultOptional) match {
+ case (Some(leftResult), Some(rightResult)) => Some(sources.And(leftResult, rightResult))
+ case (Some(leftResult), None) if canPartialPushDown => Some(leftResult)
+ case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult)
+ case _ => None
+ }
+
+ case sources.Or(left, right) =>
+ val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown)
+ val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown)
+ if (leftResultOptional.isEmpty || rightResultOptional.isEmpty) {
+ None
+ } else {
+ Some(sources.Or(leftResultOptional.get, rightResultOptional.get))
+ }
+ case sources.Not(pred) =>
+ val resultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false)
+ resultOptional.map(sources.Not)
+
+ case other =>
+ if (createFilter(other).isDefined) {
+ Some(other)
+ } else {
+ None
+ }
+ }
+ }
+
+ /**
+ * Converts data sources filters to Parquet filter predicates.
+ */
+ def createFilter(predicate: sources.Filter): Option[FilterPredicate] = {
+ createFilterHelper(predicate, canPartialPushDownConjuncts = true)
+ }
+
+ // Parquet's type in the given file should be matched to the value's type
+ // in the pushed filter in order to push down the filter to Parquet.
+ private def valueCanMakeFilterOn(name: String, value: Any): Boolean = {
+ value == null || (nameToParquetField(name).fieldType match {
+ case ParquetBooleanType => value.isInstanceOf[JBoolean]
+ case ParquetByteType | ParquetShortType | ParquetIntegerType => value.isInstanceOf[Number]
+ case ParquetLongType => value.isInstanceOf[JLong]
+ case ParquetFloatType => value.isInstanceOf[JFloat]
+ case ParquetDoubleType => value.isInstanceOf[JDouble]
+ case ParquetStringType => value.isInstanceOf[String]
+ case ParquetBinaryType => value.isInstanceOf[Array[Byte]]
+ case ParquetDateType =>
+ value.isInstanceOf[Date] || value.isInstanceOf[LocalDate]
+ case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
+ value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant]
+ case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) =>
+ isDecimalMatched(value, decimalMeta)
+ case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) =>
+ isDecimalMatched(value, decimalMeta)
+ case ParquetSchemaType(DECIMAL, FIXED_LEN_BYTE_ARRAY, _, decimalMeta) =>
+ isDecimalMatched(value, decimalMeta)
+ case _ => false
+ })
+ }
+
+ // Decimal type must make sure that filter value's scale matched the file.
+ // If doesn't matched, which would cause data corruption.
+ private def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match {
+ case decimal: JBigDecimal =>
+ decimal.scale == decimalMeta.getScale
+ case _ => false
+ }
+
+ private def canMakeFilterOn(name: String, value: Any): Boolean = {
+ nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
+ }
+
+ /**
+ * @param predicate the input filter predicates. Not all the predicates can be pushed down.
+ * @param canPartialPushDownConjuncts whether a subset of conjuncts of predicates can be pushed
+ * down safely. Pushing ONLY one side of AND down is safe to
+ * do at the top level or none of its ancestors is NOT and OR.
+ * @return the Parquet-native filter predicates that are eligible for pushdown.
+ */
+ private def createFilterHelper(
+ predicate: sources.Filter,
+ canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = {
+ // NOTE:
+ //
+ // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`,
+ // which can be casted to `false` implicitly. Please refer to the `eval` method of these
+ // operators and the `PruneFilters` rule for details.
+
+ // Hyukjin:
+ // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]].
+ // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]].
+ // The reason why I did this is, that the actual Parquet filter checks null-safe equality
+ // comparison.
+ // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because
+ // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc.
+ // Probably I missed something and obviously this should be changed.
+
+ predicate match {
+ case sources.IsNull(name) if canMakeFilterOn(name, null) =>
+ makeEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, null))
+ case sources.IsNotNull(name) if canMakeFilterOn(name, null) =>
+ makeNotEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, null))
+
+ case sources.EqualTo(name, value) if canMakeFilterOn(name, value) =>
+ makeEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name, value) =>
+ makeNotEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.EqualNullSafe(name, value) if canMakeFilterOn(name, value) =>
+ makeEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name, value) =>
+ makeNotEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.LessThan(name, value) if canMakeFilterOn(name, value) =>
+ makeLt.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ makeLtEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.GreaterThan(name, value) if canMakeFilterOn(name, value) =>
+ makeGt.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+ case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name, value) =>
+ makeGtEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, value))
+
+ case sources.And(lhs, rhs) =>
+ // At here, it is not safe to just convert one side and remove the other side
+ // if we do not understand what the parent filters are.
+ //
+ // Here is an example used to explain the reason.
+ // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to
+ // convert b in ('1'). If we only convert a = 2, we will end up with a filter
+ // NOT(a = 2), which will generate wrong results.
+ //
+ // Pushing one side of AND down is only safe to do at the top level or in the child
+ // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate
+ // can be safely removed.
+ val lhsFilterOption =
+ createFilterHelper(lhs, canPartialPushDownConjuncts)
+ val rhsFilterOption =
+ createFilterHelper(rhs, canPartialPushDownConjuncts)
+
+ (lhsFilterOption, rhsFilterOption) match {
+ case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter))
+ case (Some(lhsFilter), None) if canPartialPushDownConjuncts => Some(lhsFilter)
+ case (None, Some(rhsFilter)) if canPartialPushDownConjuncts => Some(rhsFilter)
+ case _ => None
+ }
+
+ case sources.Or(lhs, rhs) =>
+ // The Or predicate is convertible when both of its children can be pushed down.
+ // That is to say, if one/both of the children can be partially pushed down, the Or
+ // predicate can be partially pushed down as well.
+ //
+ // Here is an example used to explain the reason.
+ // Let's say we have
+ // (a1 AND a2) OR (b1 AND b2),
+ // a1 and b1 is convertible, while a2 and b2 is not.
+ // The predicate can be converted as
+ // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2)
+ // As per the logical in And predicate, we can push down (a1 OR b1).
+ for {
+ lhsFilter <- createFilterHelper(lhs, canPartialPushDownConjuncts)
+ rhsFilter <- createFilterHelper(rhs, canPartialPushDownConjuncts)
+ } yield FilterApi.or(lhsFilter, rhsFilter)
+
+ case sources.Not(pred) =>
+ createFilterHelper(pred, canPartialPushDownConjuncts = false)
+ .map(FilterApi.not)
+
+ case sources.In(name, values) if canMakeFilterOn(name, values.head)
+ && values.distinct.length <= pushDownInFilterThreshold =>
+ values.distinct.flatMap { v =>
+ makeEq.lift(nameToParquetField(name).fieldType)
+ .map(_(nameToParquetField(name).fieldNames, v))
+ }.reduceLeftOption(FilterApi.or)
+
+ case sources.StringStartsWith(name, prefix)
+ if pushDownStartWith && canMakeFilterOn(name, prefix) =>
+ Option(prefix).map { v =>
+ FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldNames),
+ new UserDefinedPredicate[Binary] with Serializable {
+ private val strToBinary = Binary.fromReusedByteArray(v.getBytes)
+ private val size = strToBinary.length
+
+ override def canDrop(statistics: Statistics[Binary]): Boolean = {
+ val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR
+ val max = statistics.getMax
+ val min = statistics.getMin
+ comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) < 0 ||
+ comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) > 0
+ }
+
+ override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = {
+ val comparator = PrimitiveComparator.UNSIGNED_LEXICOGRAPHICAL_BINARY_COMPARATOR
+ val max = statistics.getMax
+ val min = statistics.getMin
+ comparator.compare(max.slice(0, math.min(size, max.length)), strToBinary) == 0 &&
+ comparator.compare(min.slice(0, math.min(size, min.length)), strToBinary) == 0
+ }
+
+ override def keep(value: Binary): Boolean = {
+ value != null && UTF8String.fromBytes(value.getBytes).startsWith(
+ UTF8String.fromBytes(strToBinary.getBytes))
+ }
+ }
+ )
+ }
+
+ case _ => None
+ }
+ }
+}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetOptions.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetOptions.scala
new file mode 100644
index 00000000..9305023e
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetOptions.scala
@@ -0,0 +1,27 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+
+class GeoParquetOptions(@transient private val parameters: CaseInsensitiveMap[String]) extends Serializable {
+ def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
+ /**
+ * geometry field name. Default: geometry
+ */
+
+ val fieldGeometry = parameters.getOrElse("fieldGeometry", "geometry")
+
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala
new file mode 100644
index 00000000..8b8f1f35
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetReadSupport.scala
@@ -0,0 +1,392 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.hadoop.api.ReadSupport.ReadContext
+import org.apache.parquet.hadoop.api.{InitContext, ReadSupport}
+import org.apache.parquet.io.api.RecordMaterializer
+import org.apache.parquet.schema.Type.Repetition
+import org.apache.parquet.schema._
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+
+import java.time.ZoneId
+import java.util.{Locale, Map => JMap}
+import scala.collection.JavaConverters._
+
+/**
+ * A Parquet [[ReadSupport]] implementation for reading Parquet records as Catalyst
+ * [[InternalRow]]s.
+ *
+ * The API interface of [[ReadSupport]] is a little bit over complicated because of historical
+ * reasons. In older versions of parquet-mr (say 1.6.0rc3 and prior), [[ReadSupport]] need to be
+ * instantiated and initialized twice on both driver side and executor side. The [[init()]] method
+ * is for driver side initialization, while [[prepareForRead()]] is for executor side. However,
+ * starting from parquet-mr 1.6.0, it's no longer the case, and [[ReadSupport]] is only instantiated
+ * and initialized on executor side. So, theoretically, now it's totally fine to combine these two
+ * methods into a single initialization method. The only reason (I could think of) to still have
+ * them here is for parquet-mr API backwards-compatibility.
+ *
+ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]]
+ * to [[prepareForRead()]], but use a private `var` for simplicity.
+ */
+class GeoParquetReadSupport (
+ override val convertTz: Option[ZoneId],
+ enableVectorizedReader: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ int96RebaseMode: LegacyBehaviorPolicy.Value)
+ extends ParquetReadSupport with Logging {
+ private var catalystRequestedSchema: StructType = _
+
+ /**
+ * Called on executor side before [[prepareForRead()]] and instantiating actual Parquet record
+ * readers. Responsible for figuring out Parquet requested schema used for column pruning.
+ */
+ override def init(context: InitContext): ReadContext = {
+ val conf = context.getConfiguration
+ catalystRequestedSchema = {
+ val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA)
+ assert(schemaString != null, "Parquet requested schema not set.")
+ StructType.fromString(schemaString)
+ }
+
+ val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key,
+ SQLConf.CASE_SENSITIVE.defaultValue.get)
+ val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key,
+ SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get)
+ val parquetFileSchema = context.getFileSchema
+ val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema,
+ catalystRequestedSchema, caseSensitive)
+
+ // We pass two schema to ParquetRecordMaterializer:
+ // - parquetRequestedSchema: the schema of the file data we want to read
+ // - catalystRequestedSchema: the schema of the rows we want to return
+ // The reader is responsible for reconciling the differences between the two.
+ val parquetRequestedSchema = if (schemaPruningEnabled && !enableVectorizedReader) {
+ // Parquet-MR reader requires that parquetRequestedSchema include only those fields present
+ // in the underlying parquetFileSchema. Therefore, we intersect the parquetClippedSchema
+ // with the parquetFileSchema
+ GeoParquetReadSupport.intersectParquetGroups(parquetClippedSchema, parquetFileSchema)
+ .map(groupType => new MessageType(groupType.getName, groupType.getFields))
+ .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE)
+ } else {
+ // Spark's vectorized reader only support atomic types currently. It also skip fields
+ // in parquetRequestedSchema which are not present in the file.
+ parquetClippedSchema
+ }
+ logDebug(
+ s"""Going to read the following fields from the Parquet file with the following schema:
+ |Parquet file schema:
+ |$parquetFileSchema
+ |Parquet clipped schema:
+ |$parquetClippedSchema
+ |Parquet requested schema:
+ |$parquetRequestedSchema
+ |Catalyst requested schema:
+ |${catalystRequestedSchema.treeString}
+ """.stripMargin)
+ new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava)
+ }
+ /**
+ * Called on executor side after [[init()]], before instantiating actual Parquet record readers.
+ * Responsible for instantiating [[RecordMaterializer]], which is used for converting Parquet
+ * records to Catalyst [[InternalRow]]s.
+ */
+ override def prepareForRead(
+ conf: Configuration,
+ keyValueMetaData: JMap[String, String],
+ fileSchema: MessageType,
+ readContext: ReadContext): RecordMaterializer[InternalRow] = {
+ val parquetRequestedSchema = readContext.getRequestedSchema
+ new GeoParquetRecordMaterializer(
+ parquetRequestedSchema,
+ ParquetReadSupport.expandUDT(catalystRequestedSchema),
+ new GeoParquetToSparkSchemaConverter(conf),
+ convertTz,
+ datetimeRebaseMode,
+ int96RebaseMode)
+ }
+}
+
+object GeoParquetReadSupport extends Logging {
+
+ /**
+ * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist
+ * in `catalystSchema`, and adding those only exist in `catalystSchema`.
+ */
+ def clipParquetSchema(
+ parquetSchema: MessageType,
+ catalystSchema: StructType,
+ caseSensitive: Boolean = true): MessageType = {
+ val clippedParquetFields = clipParquetGroupFields(
+ parquetSchema.asGroupType(), catalystSchema, caseSensitive)
+ if (clippedParquetFields.isEmpty) {
+ ParquetSchemaConverter.EMPTY_MESSAGE
+ } else {
+ Types
+ .buildMessage()
+ .addFields(clippedParquetFields: _*)
+ .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME)
+ }
+ }
+
+ private def clipParquetType(
+ parquetType: Type, catalystType: DataType, caseSensitive: Boolean): Type = {
+ catalystType match {
+ case t: ArrayType if !isPrimitiveCatalystType(t.elementType) =>
+ // Only clips array types with nested type as element type.
+ clipParquetListType(parquetType.asGroupType(), t.elementType, caseSensitive)
+
+ case t: MapType
+ if !isPrimitiveCatalystType(t.keyType) ||
+ !isPrimitiveCatalystType(t.valueType) =>
+ // Only clips map types with nested key type or value type
+ clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive)
+
+ case t: StructType =>
+ clipParquetGroup(parquetType.asGroupType(), t, caseSensitive)
+
+ case _ =>
+ // UDTs and primitive types are not clipped. For UDTs, a clipped version might not be able
+ // to be mapped to desired user-space types. So UDTs shouldn't participate schema merging.
+ parquetType
+ }
+ }
+
+ /**
+ * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to
+ * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an
+ * [[AtomicType]].
+ */
+ private def isPrimitiveCatalystType(dataType: DataType): Boolean = {
+ dataType match {
+ case _: ArrayType | _: MapType | _: StructType => false
+ case _ => true
+ }
+ }
+
+ /**
+ * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type
+ * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a
+ * [[StructType]].
+ */
+ private def clipParquetListType(
+ parquetList: GroupType, elementType: DataType, caseSensitive: Boolean): Type = {
+ // Precondition of this method, should only be called for lists with nested element types.
+ assert(!isPrimitiveCatalystType(elementType))
+
+ // Unannotated repeated group should be interpreted as required list of required element, so
+ // list element type is just the group itself. Clip it.
+ if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) {
+ clipParquetType(parquetList, elementType, caseSensitive)
+ } else {
+ assert(
+ parquetList.getOriginalType == OriginalType.LIST,
+ "Invalid Parquet schema. " +
+ "Original type of annotated Parquet lists must be LIST: " +
+ parquetList.toString)
+
+ assert(
+ parquetList.getFieldCount == 1 && parquetList.getType(0).isRepetition(Repetition.REPEATED),
+ "Invalid Parquet schema. " +
+ "LIST-annotated group should only have exactly one repeated field: " +
+ parquetList)
+
+ // Precondition of this method, should only be called for lists with nested element types.
+ assert(!parquetList.getType(0).isPrimitive)
+
+ val repeatedGroup = parquetList.getType(0).asGroupType()
+
+ // If the repeated field is a group with multiple fields, or the repeated field is a group
+ // with one field and is named either "array" or uses the LIST-annotated group's name with
+ // "_tuple" appended then the repeated type is the element type and elements are required.
+ // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the
+ // only field.
+ if (
+ repeatedGroup.getFieldCount > 1 ||
+ repeatedGroup.getName == "array" ||
+ repeatedGroup.getName == parquetList.getName + "_tuple"
+ ) {
+ Types
+ .buildGroup(parquetList.getRepetition)
+ .as(OriginalType.LIST)
+ .addField(clipParquetType(repeatedGroup, elementType, caseSensitive))
+ .named(parquetList.getName)
+ } else {
+ // Otherwise, the repeated field's type is the element type with the repeated field's
+ // repetition.
+ Types
+ .buildGroup(parquetList.getRepetition)
+ .as(OriginalType.LIST)
+ .addField(
+ Types
+ .repeatedGroup()
+ .addField(clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive))
+ .named(repeatedGroup.getName))
+ .named(parquetList.getName)
+ }
+ }
+ }
+
+ /**
+ * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. Either key type or
+ * value type of the [[MapType]] must be a nested type, namely an [[ArrayType]], a [[MapType]], or
+ * a [[StructType]].
+ */
+ private def clipParquetMapType(
+ parquetMap: GroupType,
+ keyType: DataType,
+ valueType: DataType,
+ caseSensitive: Boolean): GroupType = {
+ // Precondition of this method, only handles maps with nested key types or value types.
+ assert(!isPrimitiveCatalystType(keyType) || !isPrimitiveCatalystType(valueType))
+
+ val repeatedGroup = parquetMap.getType(0).asGroupType()
+ val parquetKeyType = repeatedGroup.getType(0)
+ val parquetValueType = repeatedGroup.getType(1)
+
+ val clippedRepeatedGroup =
+ Types
+ .repeatedGroup()
+ .as(repeatedGroup.getOriginalType)
+ .addField(clipParquetType(parquetKeyType, keyType, caseSensitive))
+ .addField(clipParquetType(parquetValueType, valueType, caseSensitive))
+ .named(repeatedGroup.getName)
+
+ Types
+ .buildGroup(parquetMap.getRepetition)
+ .as(parquetMap.getOriginalType)
+ .addField(clippedRepeatedGroup)
+ .named(parquetMap.getName)
+ }
+
+ /**
+ * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]].
+ *
+ * @return A clipped [[GroupType]], which has at least one field.
+ * @note Parquet doesn't allow creating empty [[GroupType]] instances except for empty
+ * [[MessageType]]. Because it's legal to construct an empty requested schema for column
+ * pruning.
+ */
+ private def clipParquetGroup(
+ parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): GroupType = {
+ val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType, caseSensitive)
+ Types
+ .buildGroup(parquetRecord.getRepetition)
+ .as(parquetRecord.getOriginalType)
+ .addFields(clippedParquetFields: _*)
+ .named(parquetRecord.getName)
+ }
+
+ /**
+ * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]].
+ *
+ * @return A list of clipped [[GroupType]] fields, which can be empty.
+ */
+ private def clipParquetGroupFields(
+ parquetRecord: GroupType, structType: StructType, caseSensitive: Boolean): Seq[Type] = {
+ val toParquet = new SparkToGeoParquetSchemaConverter(writeLegacyParquetFormat = false)
+ if (caseSensitive) {
+ val caseSensitiveParquetFieldMap =
+ parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap
+ structType.map { f =>
+ caseSensitiveParquetFieldMap
+ .get(f.name)
+ .map(clipParquetType(_, f.dataType, caseSensitive))
+ .getOrElse(toParquet.convertField(f))
+ }
+ } else {
+ // Do case-insensitive resolution only if in case-insensitive mode
+ val caseInsensitiveParquetFieldMap =
+ parquetRecord.getFields.asScala.groupBy(_.getName.toLowerCase(Locale.ROOT))
+ structType.map { f =>
+ caseInsensitiveParquetFieldMap
+ .get(f.name.toLowerCase(Locale.ROOT))
+ .map { parquetTypes =>
+ if (parquetTypes.size > 1) {
+ // Need to fail if there is ambiguity, i.e. more than one field is matched
+ val parquetTypesString = parquetTypes.map(_.getName).mkString("[", ", ", "]")
+ throw new RuntimeException(s"""Found duplicate field(s) "${f.name}": """ +
+ s"$parquetTypesString in case-insensitive mode")
+ } else {
+ clipParquetType(parquetTypes.head, f.dataType, caseSensitive)
+ }
+ }.getOrElse(toParquet.convertField(f))
+ }
+ }
+ }
+
+ /**
+ * Computes the structural intersection between two Parquet group types.
+ * This is used to create a requestedSchema for ReadContext of Parquet-MR reader.
+ * Parquet-MR reader does not support the nested field access to non-existent field
+ * while parquet library does support to read the non-existent field by regular field access.
+ */
+ private def intersectParquetGroups(
+ groupType1: GroupType, groupType2: GroupType): Option[GroupType] = {
+ val fields =
+ groupType1.getFields.asScala
+ .filter(field => groupType2.containsField(field.getName))
+ .flatMap {
+ case field1: GroupType =>
+ val field2 = groupType2.getType(field1.getName)
+ if (field2.isPrimitive) {
+ None
+ } else {
+ intersectParquetGroups(field1, field2.asGroupType)
+ }
+ case field1 => Some(field1)
+ }
+
+ if (fields.nonEmpty) {
+ Some(groupType1.withNewFields(fields.asJava))
+ } else {
+ None
+ }
+ }
+
+ def expandUDT(schema: StructType): StructType = {
+ def expand(dataType: DataType): DataType = {
+ dataType match {
+ case t: ArrayType =>
+ t.copy(elementType = expand(t.elementType))
+
+ case t: MapType =>
+ t.copy(
+ keyType = expand(t.keyType),
+ valueType = expand(t.valueType))
+
+ case t: StructType =>
+ val expandedFields = t.fields.map(f => f.copy(dataType = expand(f.dataType)))
+ t.copy(fields = expandedFields)
+
+ case t: UserDefinedType[_] =>
+ t.sqlType
+
+ case t =>
+ t
+ }
+ }
+
+ expand(schema).asInstanceOf[StructType]
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala
new file mode 100644
index 00000000..575ef7ee
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRecordMaterializer.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.time.ZoneId
+import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer}
+import org.apache.parquet.schema.MessageType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[RecordMaterializer]] for Catalyst rows.
+ *
+ * @param parquetSchema Parquet schema of the records to be read
+ * @param catalystSchema Catalyst schema of the rows to be constructed
+ * @param schemaConverter A Parquet-Catalyst schema converter that helps initializing row converters
+ * @param convertTz the optional time zone to convert to int96 data
+ * @param datetimeRebaseSpec the specification of rebasing date/timestamp from Julian to Proleptic
+ * Gregorian calendar: mode + optional original time zone
+ * @param int96RebaseSpec the specification of rebasing INT96 timestamp from Julian to Proleptic
+ * Gregorian calendar
+ */
+class GeoParquetRecordMaterializer(
+ parquetSchema: MessageType,
+ catalystSchema: StructType,
+ schemaConverter: GeoParquetToSparkSchemaConverter,
+ convertTz: Option[ZoneId],
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ int96RebaseMode: LegacyBehaviorPolicy.Value)
+ extends RecordMaterializer[InternalRow] {
+ private val rootConverter = new GeoParquetRowConverter(
+ schemaConverter,
+ parquetSchema,
+ catalystSchema,
+ convertTz,
+ datetimeRebaseMode,
+ int96RebaseMode,
+ NoopUpdater)
+
+ override def getCurrentRecord: InternalRow = rootConverter.currentRecord
+
+ override def getRootConverter: GroupConverter = rootConverter
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
new file mode 100644
index 00000000..47e2a9d7
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetRowConverter.scala
@@ -0,0 +1,675 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.parquet.column.Dictionary
+import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
+import org.apache.parquet.schema.OriginalType.LIST
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+import org.apache.parquet.schema.{GroupType, OriginalType, Type}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+import java.math.{BigDecimal, BigInteger}
+import java.time.{ZoneId, ZoneOffset}
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A [[ParquetRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s.
+ * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root
+ * converter. Take the following Parquet type as an example:
+ * {{{
+ * message root {
+ * required int32 f1;
+ * optional group f2 {
+ * required double f21;
+ * optional binary f22 (utf8);
+ * }
+ * }
+ * }}}
+ * 5 converters will be created:
+ *
+ * - a root [[ParquetRowConverter]] for [[org.apache.parquet.schema.MessageType]] `root`,
+ * which contains:
+ * - a [[ParquetPrimitiveConverter]] for required
+ * [[org.apache.parquet.schema.OriginalType.INT_32]] field `f1`, and
+ * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains:
+ * - a [[ParquetPrimitiveConverter]] for required
+ * [[org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE]] field `f21`, and
+ * - a [[ParquetStringConverter]] for optional [[org.apache.parquet.schema.OriginalType.UTF8]]
+ * string field `f22`
+ *
+ * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have
+ * any "parent" container.
+ *
+ * @param schemaConverter A utility converter used to convert Parquet types to Catalyst types.
+ * @param parquetType Parquet schema of Parquet records
+ * @param catalystType Spark SQL schema that corresponds to the Parquet record type. User-defined
+ * types should have been expanded.
+ * @param convertTz the optional time zone to convert to int96 data
+ * @param datetimeRebaseMode the mode of rebasing date/timestamp from Julian to Proleptic Gregorian
+ * calendar
+ * @param int96RebaseMode the mode of rebasing INT96 timestamp from Julian to Proleptic Gregorian
+ * calendar
+ * @param updater An updater which propagates converted field values to the parent container
+ */
+private[parquet] class GeoParquetRowConverter(
+ schemaConverter: GeoParquetToSparkSchemaConverter,
+ parquetType: GroupType,
+ catalystType: StructType,
+ convertTz: Option[ZoneId],
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ int96RebaseMode: LegacyBehaviorPolicy.Value,
+ updater: ParentContainerUpdater)
+ extends ParquetGroupConverter(updater) with Logging {
+
+ assert(
+ parquetType.getFieldCount <= catalystType.length,
+ s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema:
+ |
+ |Parquet schema:
+ |$parquetType
+ |Catalyst schema:
+ |${catalystType.prettyJson}
+ """.stripMargin)
+
+ assert(
+ !catalystType.existsRecursively(_.isInstanceOf[UserDefinedType[_]]),
+ s"""User-defined types in Catalyst schema should have already been expanded:
+ |${catalystType.prettyJson}
+ """.stripMargin)
+
+ logDebug(
+ s"""Building row converter for the following schema:
+ |
+ |Parquet form:
+ |$parquetType
+ |Catalyst form:
+ |${catalystType.prettyJson}
+ """.stripMargin)
+
+ /**
+ * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates
+ * converted filed values to the `ordinal`-th cell in `currentRow`.
+ */
+ private final class RowUpdater(row: InternalRow, ordinal: Int) extends ParentContainerUpdater {
+ override def set(value: Any): Unit = row(ordinal) = value
+ override def setBoolean(value: Boolean): Unit = row.setBoolean(ordinal, value)
+ override def setByte(value: Byte): Unit = row.setByte(ordinal, value)
+ override def setShort(value: Short): Unit = row.setShort(ordinal, value)
+ override def setInt(value: Int): Unit = row.setInt(ordinal, value)
+ override def setLong(value: Long): Unit = row.setLong(ordinal, value)
+ override def setDouble(value: Double): Unit = row.setDouble(ordinal, value)
+ override def setFloat(value: Float): Unit = row.setFloat(ordinal, value)
+ }
+
+ private[this] val currentRow = new SpecificInternalRow(catalystType.map(_.dataType))
+
+ /**
+ * The [[InternalRow]] converted from an entire Parquet record.
+ */
+ def currentRecord: InternalRow = currentRow
+
+ private val dateRebaseFunc = GeoDataSourceUtils.creteDateRebaseFuncInRead(
+ datetimeRebaseMode, "Parquet")
+
+ private val timestampRebaseFunc = GeoDataSourceUtils.creteTimestampRebaseFuncInRead(
+ datetimeRebaseMode, "Parquet")
+
+ private val int96RebaseFunc = GeoDataSourceUtils.creteTimestampRebaseFuncInRead(
+ int96RebaseMode, "Parquet INT96")
+
+ // Converters for each field.
+ private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = {
+ // (SPARK-31116) Use case insensitive map if spark.sql.caseSensitive is false
+ // to prevent throwing IllegalArgumentException when searching catalyst type's field index
+ val catalystFieldNameToIndex = if (SQLConf.get.caseSensitiveAnalysis) {
+ catalystType.fieldNames.zipWithIndex.toMap
+ } else {
+ CaseInsensitiveMap(catalystType.fieldNames.zipWithIndex.toMap)
+ }
+ parquetType.getFields.asScala.map { parquetField =>
+ val fieldIndex = catalystFieldNameToIndex(parquetField.getName)
+ val catalystField = catalystType(fieldIndex)
+ // Converted field value should be set to the `fieldIndex`-th cell of `currentRow`
+ newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex))
+ }.toArray
+ }
+
+ // Updaters for each field.
+ private[this] val fieldUpdaters: Array[ParentContainerUpdater] = fieldConverters.map(_.updater)
+
+ override def getConverter(fieldIndex: Int): Converter = fieldConverters(fieldIndex)
+
+ override def end(): Unit = {
+ var i = 0
+ while (i < fieldUpdaters.length) {
+ fieldUpdaters(i).end()
+ i += 1
+ }
+ updater.set(currentRow)
+ }
+
+ override def start(): Unit = {
+ var i = 0
+ val numFields = currentRow.numFields
+ while (i < numFields) {
+ currentRow.setNullAt(i)
+ i += 1
+ }
+ i = 0
+ while (i < fieldUpdaters.length) {
+ fieldUpdaters(i).start()
+ i += 1
+ }
+ }
+
+ /**
+ * Creates a converter for the given Parquet type `parquetType` and Spark SQL data type
+ * `catalystType`. Converted values are handled by `updater`.
+ */
+ private def newConverter(
+ parquetType: Type,
+ catalystType: DataType,
+ updater: ParentContainerUpdater): Converter with HasParentContainerUpdater = {
+
+ catalystType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType =>
+ new ParquetPrimitiveConverter(updater)
+
+ case ByteType =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addInt(value: Int): Unit =
+ updater.setByte(value.asInstanceOf[ByteType#InternalType])
+
+ override def addBinary(value: Binary): Unit = {
+ val bytes = value.getBytes
+ for (b <- bytes) {
+ updater.set(b)
+ }
+ }
+ }
+
+ case ShortType =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addInt(value: Int): Unit =
+ updater.setShort(value.asInstanceOf[ShortType#InternalType])
+ }
+
+ // For INT32 backed decimals
+ case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 =>
+ new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
+ // For INT64 backed decimals
+ case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
+ new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
+ // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals
+ case t: DecimalType
+ if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY ||
+ parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY =>
+ new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
+ case t: DecimalType =>
+ throw new RuntimeException(
+ s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " +
+ s"$parquetType. Parquet DECIMAL type can only be backed by INT32, INT64, " +
+ "FIXED_LEN_BYTE_ARRAY, or BINARY.")
+
+ case StringType =>
+ new ParquetStringConverter(updater)
+
+ case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addLong(value: Long): Unit = {
+ updater.setLong(timestampRebaseFunc(value))
+ }
+ }
+
+ case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MILLIS =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addLong(value: Long): Unit = {
+ val micros = GeoDateTimeUtils.millisToMicros(value)
+ updater.setLong(timestampRebaseFunc(micros))
+ }
+ }
+
+ // INT96 timestamp doesn't have a logical type, here we check the physical type instead.
+ case TimestampType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT96 =>
+ new ParquetPrimitiveConverter(updater) {
+ // Converts nanosecond timestamps stored as INT96
+ override def addBinary(value: Binary): Unit = {
+ val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value)
+ val gregorianMicros = int96RebaseFunc(julianMicros)
+ val adjTime = convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC))
+ .getOrElse(gregorianMicros)
+ updater.setLong(adjTime)
+ }
+ }
+
+ case DateType =>
+ new ParquetPrimitiveConverter(updater) {
+ override def addInt(value: Int): Unit = {
+ updater.set(dateRebaseFunc(value))
+ }
+ }
+
+ // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
+ // annotated by `LIST` or `MAP` should be interpreted as a required list of required
+ // elements where the element type is the type of the field.
+ case t: ArrayType if parquetType.getOriginalType != LIST =>
+ if (parquetType.isPrimitive) {
+ new RepeatedPrimitiveConverter(parquetType, t.elementType, updater)
+ } else {
+ new RepeatedGroupConverter(parquetType, t.elementType, updater)
+ }
+
+ case t: ArrayType =>
+ new ParquetArrayConverter(parquetType.asGroupType(), t, updater)
+
+ case t: MapType =>
+ new ParquetMapConverter(parquetType.asGroupType(), t, updater)
+
+ case t: StructType =>
+ val wrappedUpdater = {
+ // SPARK-30338: avoid unnecessary InternalRow copying for nested structs:
+ // There are two cases to handle here:
+ //
+ // 1. Parent container is a map or array: we must make a deep copy of the mutable row
+ // because this converter may be invoked multiple times per Parquet input record
+ // (if the map or array contains multiple elements).
+ //
+ // 2. Parent container is a struct: we don't need to copy the row here because either:
+ //
+ // (a) all ancestors are structs and therefore no copying is required because this
+ // converter will only be invoked once per Parquet input record, or
+ // (b) some ancestor is struct that is nested in a map or array and that ancestor's
+ // converter will perform deep-copying (which will recursively copy this row).
+ if (updater.isInstanceOf[RowUpdater]) {
+ // `updater` is a RowUpdater, implying that the parent container is a struct.
+ updater
+ } else {
+ // `updater` is NOT a RowUpdater, implying that the parent container a map or array.
+ new ParentContainerUpdater {
+ override def set(value: Any): Unit = {
+ updater.set(value.asInstanceOf[SpecificInternalRow].copy()) // deep copy
+ }
+ }
+ }
+ }
+ new GeoParquetRowConverter(
+ schemaConverter,
+ parquetType.asGroupType(),
+ t,
+ convertTz,
+ datetimeRebaseMode,
+ int96RebaseMode,
+ wrappedUpdater)
+
+ case t =>
+ throw new RuntimeException(
+ s"Unable to create Parquet converter for data type ${t.json} " +
+ s"whose Parquet type is $parquetType")
+ }
+ }
+
+ /**
+ * Parquet converter for strings. A dictionary is used to minimize string decoding cost.
+ */
+ private final class ParquetStringConverter(updater: ParentContainerUpdater)
+ extends ParquetPrimitiveConverter(updater) {
+
+ private var expandedDictionary: Array[UTF8String] = null
+
+ override def hasDictionarySupport: Boolean = true
+
+ override def setDictionary(dictionary: Dictionary): Unit = {
+ this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { i =>
+ UTF8String.fromBytes(dictionary.decodeToBinary(i).getBytes)
+ }
+ }
+
+ override def addValueFromDictionary(dictionaryId: Int): Unit = {
+ updater.set(expandedDictionary(dictionaryId))
+ }
+
+ override def addBinary(value: Binary): Unit = {
+ // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here we
+ // are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying
+ // it.
+ val buffer = value.toByteBuffer
+ val offset = buffer.arrayOffset() + buffer.position()
+ val numBytes = buffer.remaining()
+ updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes))
+ }
+ }
+
+ /**
+ * Parquet converter for fixed-precision decimals.
+ */
+ private abstract class ParquetDecimalConverter(
+ precision: Int, scale: Int, updater: ParentContainerUpdater)
+ extends ParquetPrimitiveConverter(updater) {
+
+ protected var expandedDictionary: Array[Decimal] = _
+
+ override def hasDictionarySupport: Boolean = true
+
+ override def addValueFromDictionary(dictionaryId: Int): Unit = {
+ updater.set(expandedDictionary(dictionaryId))
+ }
+
+ // Converts decimals stored as INT32
+ override def addInt(value: Int): Unit = {
+ addLong(value: Long)
+ }
+
+ // Converts decimals stored as INT64
+ override def addLong(value: Long): Unit = {
+ updater.set(decimalFromLong(value))
+ }
+
+ // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY
+ override def addBinary(value: Binary): Unit = {
+ updater.set(decimalFromBinary(value))
+ }
+
+ protected def decimalFromLong(value: Long): Decimal = {
+ Decimal(value, precision, scale)
+ }
+
+ protected def decimalFromBinary(value: Binary): Decimal = {
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ // Constructs a `Decimal` with an unscaled `Long` value if possible.
+ val unscaled = ParquetRowConverter.binaryToUnscaledLong(value)
+ Decimal(unscaled, precision, scale)
+ } else {
+ // Otherwise, resorts to an unscaled `BigInteger` instead.
+ Decimal(new BigDecimal(new BigInteger(value.getBytes), scale), precision, scale)
+ }
+ }
+ }
+
+ private class ParquetIntDictionaryAwareDecimalConverter(
+ precision: Int, scale: Int, updater: ParentContainerUpdater)
+ extends ParquetDecimalConverter(precision, scale, updater) {
+
+ override def setDictionary(dictionary: Dictionary): Unit = {
+ this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+ decimalFromLong(dictionary.decodeToInt(id).toLong)
+ }
+ }
+ }
+
+ private class ParquetLongDictionaryAwareDecimalConverter(
+ precision: Int, scale: Int, updater: ParentContainerUpdater)
+ extends ParquetDecimalConverter(precision, scale, updater) {
+
+ override def setDictionary(dictionary: Dictionary): Unit = {
+ this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+ decimalFromLong(dictionary.decodeToLong(id))
+ }
+ }
+ }
+
+ private class ParquetBinaryDictionaryAwareDecimalConverter(
+ precision: Int, scale: Int, updater: ParentContainerUpdater)
+ extends ParquetDecimalConverter(precision, scale, updater) {
+
+ override def setDictionary(dictionary: Dictionary): Unit = {
+ this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+ decimalFromBinary(dictionary.decodeToBinary(id))
+ }
+ }
+ }
+
+ /**
+ * Parquet converter for arrays. Spark SQL arrays are represented as Parquet lists. Standard
+ * Parquet lists are represented as a 3-level group annotated by `LIST`:
+ * {{{
+ * <list-repetition> group <name> (LIST) { <-- parquetSchema points here
+ * repeated group list {
+ * <element-repetition> <element-type> element;
+ * }
+ * }
+ * }}}
+ * The `parquetSchema` constructor argument points to the outermost group.
+ *
+ * However, before this representation is standardized, some Parquet libraries/tools also use some
+ * non-standard formats to represent list-like structures. Backwards-compatibility rules for
+ * handling these cases are described in Parquet format spec.
+ *
+ * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists
+ */
+ private final class ParquetArrayConverter(
+ parquetSchema: GroupType,
+ catalystSchema: ArrayType,
+ updater: ParentContainerUpdater)
+ extends ParquetGroupConverter(updater) {
+
+ private[this] val currentArray = ArrayBuffer.empty[Any]
+
+ private[this] val elementConverter: Converter = {
+ val repeatedType = parquetSchema.getType(0)
+ val elementType = catalystSchema.elementType
+
+ // At this stage, we're not sure whether the repeated field maps to the element type or is
+ // just the syntactic repeated group of the 3-level standard LIST layout. Take the following
+ // Parquet LIST-annotated group type as an example:
+ //
+ // optional group f (LIST) {
+ // repeated group list {
+ // optional group element {
+ // optional int32 element;
+ // }
+ // }
+ // }
+ //
+ // This type is ambiguous:
+ //
+ // 1. When interpreted as a standard 3-level layout, the `list` field is just the syntactic
+ // group, and the entire type should be translated to:
+ //
+ // ARRAY<STRUCT<element: INT>>
+ //
+ // 2. On the other hand, when interpreted as a non-standard 2-level layout, the `list` field
+ // represents the element type, and the entire type should be translated to:
+ //
+ // ARRAY<STRUCT<element: STRUCT<element: INT>>>
+ //
+ // Here we try to convert field `list` into a Catalyst type to see whether the converted type
+ // matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise,
+ // it's case 2.
+ val guessedElementType = schemaConverter.convertFieldWithGeo(repeatedType)
+
+ if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) {
+ // If the repeated field corresponds to the element type, creates a new converter using the
+ // type of the repeated field.
+ newConverter(repeatedType, elementType, new ParentContainerUpdater {
+ override def set(value: Any): Unit = currentArray += value
+ })
+ } else {
+ // If the repeated field corresponds to the syntactic group in the standard 3-level Parquet
+ // LIST layout, creates a new converter using the only child field of the repeated field.
+ assert(!repeatedType.isPrimitive && repeatedType.asGroupType().getFieldCount == 1)
+ new ElementConverter(repeatedType.asGroupType().getType(0), elementType)
+ }
+ }
+
+ override def getConverter(fieldIndex: Int): Converter = elementConverter
+
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
+
+ override def start(): Unit = currentArray.clear()
+
+ /** Array element converter */
+ private final class ElementConverter(parquetType: Type, catalystType: DataType)
+ extends GroupConverter {
+
+ private var currentElement: Any = _
+
+ private[this] val converter =
+ newConverter(parquetType, catalystType, new ParentContainerUpdater {
+ override def set(value: Any): Unit = currentElement = value
+ })
+
+ override def getConverter(fieldIndex: Int): Converter = converter
+
+ override def end(): Unit = currentArray += currentElement
+
+ override def start(): Unit = currentElement = null
+ }
+ }
+
+ /** Parquet converter for maps */
+ private final class ParquetMapConverter(
+ parquetType: GroupType,
+ catalystType: MapType,
+ updater: ParentContainerUpdater)
+ extends ParquetGroupConverter(updater) {
+
+ private[this] val currentKeys = ArrayBuffer.empty[Any]
+ private[this] val currentValues = ArrayBuffer.empty[Any]
+
+ private[this] val keyValueConverter = {
+ val repeatedType = parquetType.getType(0).asGroupType()
+ new KeyValueConverter(
+ repeatedType.getType(0),
+ repeatedType.getType(1),
+ catalystType.keyType,
+ catalystType.valueType)
+ }
+
+ override def getConverter(fieldIndex: Int): Converter = keyValueConverter
+
+ override def end(): Unit = {
+ // The parquet map may contains null or duplicated map keys. When it happens, the behavior is
+ // undefined.
+ // TODO (SPARK-26174): disallow it with a config.
+ updater.set(
+ new ArrayBasedMapData(
+ new GenericArrayData(currentKeys.toArray),
+ new GenericArrayData(currentValues.toArray)))
+ }
+
+ override def start(): Unit = {
+ currentKeys.clear()
+ currentValues.clear()
+ }
+
+ /** Parquet converter for key-value pairs within the map. */
+ private final class KeyValueConverter(
+ parquetKeyType: Type,
+ parquetValueType: Type,
+ catalystKeyType: DataType,
+ catalystValueType: DataType)
+ extends GroupConverter {
+
+ private var currentKey: Any = _
+
+ private var currentValue: Any = _
+
+ private[this] val converters = Array(
+ // Converter for keys
+ newConverter(parquetKeyType, catalystKeyType, new ParentContainerUpdater {
+ override def set(value: Any): Unit = currentKey = value
+ }),
+
+ // Converter for values
+ newConverter(parquetValueType, catalystValueType, new ParentContainerUpdater {
+ override def set(value: Any): Unit = currentValue = value
+ }))
+
+ override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
+
+ override def end(): Unit = {
+ currentKeys += currentKey
+ currentValues += currentValue
+ }
+
+ override def start(): Unit = {
+ currentKey = null
+ currentValue = null
+ }
+ }
+ }
+
+ private trait RepeatedConverter {
+ private[this] val currentArray = ArrayBuffer.empty[Any]
+
+ protected def newArrayUpdater(updater: ParentContainerUpdater) = new ParentContainerUpdater {
+ override def start(): Unit = currentArray.clear()
+ override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray))
+ override def set(value: Any): Unit = currentArray += value
+ }
+ }
+
+ /**
+ * A primitive converter for converting unannotated repeated primitive values to required arrays
+ * of required primitives values.
+ */
+ private final class RepeatedPrimitiveConverter(
+ parquetType: Type,
+ catalystType: DataType,
+ parentUpdater: ParentContainerUpdater)
+ extends PrimitiveConverter with RepeatedConverter with HasParentContainerUpdater {
+
+ val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater)
+
+ private[this] val elementConverter: PrimitiveConverter =
+ newConverter(parquetType, catalystType, updater).asPrimitiveConverter()
+
+ override def addBoolean(value: Boolean): Unit = elementConverter.addBoolean(value)
+ override def addInt(value: Int): Unit = elementConverter.addInt(value)
+ override def addLong(value: Long): Unit = elementConverter.addLong(value)
+ override def addFloat(value: Float): Unit = elementConverter.addFloat(value)
+ override def addDouble(value: Double): Unit = elementConverter.addDouble(value)
+ override def addBinary(value: Binary): Unit = elementConverter.addBinary(value)
+
+ override def setDictionary(dict: Dictionary): Unit = elementConverter.setDictionary(dict)
+ override def hasDictionarySupport: Boolean = elementConverter.hasDictionarySupport
+ override def addValueFromDictionary(id: Int): Unit = elementConverter.addValueFromDictionary(id)
+ }
+
+ /**
+ * A group converter for converting unannotated repeated group values to required arrays of
+ * required struct values.
+ */
+ private final class RepeatedGroupConverter(
+ parquetType: Type,
+ catalystType: DataType,
+ parentUpdater: ParentContainerUpdater)
+ extends GroupConverter with HasParentContainerUpdater with RepeatedConverter {
+
+ val updater: ParentContainerUpdater = newArrayUpdater(parentUpdater)
+
+ private[this] val elementConverter: GroupConverter =
+ newConverter(parquetType, catalystType, updater).asGroupConverter()
+
+ override def getConverter(field: Int): Converter = elementConverter.getConverter(field)
+ override def end(): Unit = elementConverter.end()
+ override def start(): Unit = elementConverter.start()
+ }
+}
\ No newline at end of file
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
new file mode 100644
index 00000000..5e20bcb1
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetSchemaConverter.scala
@@ -0,0 +1,576 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import scala.collection.JavaConverters._
+import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.schema._
+import org.apache.parquet.schema.OriginalType._
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
+import org.apache.parquet.schema.Type.Repetition._
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.checkConversionRequirement
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types._
+
+/**
+ * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]].
+ *
+ * Parquet format backwards-compatibility rules are respected when converting Parquet
+ * [[MessageType]] schemas.
+ *
+ * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md
+ *
+ * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL
+ * [[StringType]] fields.
+ * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL
+ * [[TimestampType]] fields.
+ */
+class GeoParquetToSparkSchemaConverter(
+ assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
+ assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get)
+extends ParquetToSparkSchemaConverter(assumeBinaryIsString, assumeInt96IsTimestamp) {
+
+ def this(conf: SQLConf) = this(
+ assumeBinaryIsString = conf.isParquetBinaryAsString,
+ assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp)
+
+ def this(conf: Configuration) = this(
+ assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
+ assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean)
+
+
+ /**
+ * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
+ */
+ override def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType())
+
+ private def convert(parquetSchema: GroupType): StructType = {
+ val fields = parquetSchema.getFields.asScala.map { field =>
+ field.getRepetition match {
+ case OPTIONAL =>
+ StructField(field.getName, convertFieldWithGeo(field), nullable = true)
+
+ case REQUIRED =>
+ StructField(field.getName, convertFieldWithGeo(field), nullable = false)
+
+ case REPEATED =>
+ // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
+ // annotated by `LIST` or `MAP` should be interpreted as a required list of required
+ // elements where the element type is the type of the field.
+ val arrayType = ArrayType(convertFieldWithGeo(field), containsNull = false)
+ StructField(field.getName, arrayType, nullable = false)
+ }
+ }
+
+ StructType(fields.toSeq)
+ }
+
+ /**
+ * Converts a Parquet [[Type]] to a Spark SQL [[DataType]].
+ */
+ def convertFieldWithGeo(parquetType: Type): DataType = parquetType match {
+ case t: PrimitiveType => convertPrimitiveField(t)
+ case t: GroupType => convertGroupField(t.asGroupType())
+ }
+
+ private def convertPrimitiveField(field: PrimitiveType): DataType = {
+ val typeName = field.getPrimitiveTypeName
+ val originalType = field.getOriginalType
+
+ def typeString =
+ if (originalType == null) s"$typeName" else s"$typeName ($originalType)"
+
+ def typeNotSupported() =
+ throw new AnalysisException(s"Parquet type not supported: $typeString")
+
+ def typeNotImplemented() =
+ throw new AnalysisException(s"Parquet type not yet supported: $typeString")
+
+ def illegalType() =
+ throw new AnalysisException(s"Illegal Parquet type: $typeString")
+
+ // When maxPrecision = -1, we skip precision range check, and always respect the precision
+ // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored
+ // as binaries with variable lengths.
+ def makeDecimalType(maxPrecision: Int = -1): DecimalType = {
+ val precision = field.getDecimalMetadata.getPrecision
+ val scale = field.getDecimalMetadata.getScale
+
+ ParquetSchemaConverter.checkConversionRequirement(
+ maxPrecision == -1 || 1 <= precision && precision <= maxPrecision,
+ s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)")
+
+ DecimalType(precision, scale)
+ }
+
+ typeName match {
+ case BOOLEAN => BooleanType
+
+ case FLOAT => FloatType
+
+ case DOUBLE => DoubleType
+
+ case INT32 =>
+ originalType match {
+ case INT_8 => ByteType
+ case INT_16 => ShortType
+ case INT_32 | null => IntegerType
+ case DATE => DateType
+ case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS)
+ case UINT_8 => typeNotSupported()
+ case UINT_16 => typeNotSupported()
+ case UINT_32 => typeNotSupported()
+ case TIME_MILLIS => typeNotImplemented()
+ case _ => illegalType()
+ }
+
+ case INT64 =>
+ originalType match {
+ case INT_64 | null => LongType
+ case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS)
+ case UINT_64 => typeNotSupported()
+ case TIMESTAMP_MICROS => TimestampType
+ case TIMESTAMP_MILLIS => TimestampType
+ case _ => illegalType()
+ }
+
+ case INT96 =>
+ ParquetSchemaConverter.checkConversionRequirement(
+ assumeInt96IsTimestamp,
+ "INT96 is not supported unless it's interpreted as timestamp. " +
+ s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
+ TimestampType
+
+ case BINARY =>
+ originalType match {
+ case UTF8 | ENUM | JSON => StringType
+ case null if GeoParquetSchemaConverter.checkGeomFieldName(field.getName) => GeometryUDT
+ case null if assumeBinaryIsString => StringType
+ case null => BinaryType
+ case BSON => BinaryType
+ case DECIMAL => makeDecimalType()
+ case _ => illegalType()
+ }
+
+ case FIXED_LEN_BYTE_ARRAY =>
+ originalType match {
+ case DECIMAL => makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength))
+ case INTERVAL => typeNotImplemented()
+ case _ => illegalType()
+ }
+
+ case _ => illegalType()
+ }
+ }
+
+ private def convertGroupField(field: GroupType): DataType = {
+ Option(field.getOriginalType).fold(convert(field): DataType) {
+ // A Parquet list is represented as a 3-level structure:
+ //
+ // <list-repetition> group <name> (LIST) {
+ // repeated group list {
+ // <element-repetition> <element-type> element;
+ // }
+ // }
+ //
+ // However, according to the most recent Parquet format spec (not released yet up until
+ // writing), some 2-level structures are also recognized for backwards-compatibility. Thus,
+ // we need to check whether the 2nd level or the 3rd level refers to list element type.
+ //
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists
+ case LIST =>
+ ParquetSchemaConverter.checkConversionRequirement(
+ field.getFieldCount == 1, s"Invalid list type $field")
+
+ val repeatedType = field.getType(0)
+ ParquetSchemaConverter.checkConversionRequirement(
+ repeatedType.isRepetition(REPEATED), s"Invalid list type $field")
+
+ if (isElementTypeWithGeo(repeatedType, field.getName)) {
+ ArrayType(convertFieldWithGeo(repeatedType), containsNull = false)
+ } else {
+ val elementType = repeatedType.asGroupType().getType(0)
+ val optional = elementType.isRepetition(OPTIONAL)
+ ArrayType(convertFieldWithGeo(elementType), containsNull = optional)
+ }
+
+ // scalastyle:off
+ // `MAP_KEY_VALUE` is for backwards-compatibility
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1
+ // scalastyle:on
+ case MAP | MAP_KEY_VALUE =>
+ ParquetSchemaConverter.checkConversionRequirement(
+ field.getFieldCount == 1 && !field.getType(0).isPrimitive,
+ s"Invalid map type: $field")
+
+ val keyValueType = field.getType(0).asGroupType()
+ ParquetSchemaConverter.checkConversionRequirement(
+ keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2,
+ s"Invalid map type: $field")
+
+ val keyType = keyValueType.getType(0)
+ val valueType = keyValueType.getType(1)
+ val valueOptional = valueType.isRepetition(OPTIONAL)
+ MapType(
+ convertFieldWithGeo(keyType),
+ convertFieldWithGeo(valueType),
+ valueContainsNull = valueOptional)
+
+ case _ =>
+ throw new AnalysisException(s"Unrecognized Parquet type: $field")
+ }
+ }
+
+ // scalastyle:off
+ // Here we implement Parquet LIST backwards-compatibility rules.
+ // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules
+ // scalastyle:on
+ def isElementTypeWithGeo(repeatedType: Type, parentName: String): Boolean = {
+ {
+ // For legacy 2-level list types with primitive element type, e.g.:
+ //
+ // // ARRAY<INT> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated int32 element;
+ // }
+ //
+ repeatedType.isPrimitive
+ } || {
+ // For legacy 2-level list types whose element type is a group type with 2 or more fields,
+ // e.g.:
+ //
+ // // ARRAY<STRUCT<str: STRING, num: INT>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group element {
+ // required binary str (UTF8);
+ // required int32 num;
+ // };
+ // }
+ //
+ repeatedType.asGroupType().getFieldCount > 1
+ } || {
+ // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.:
+ //
+ // // ARRAY<STRUCT<str: STRING>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group array {
+ // required binary str (UTF8);
+ // };
+ // }
+ //
+ repeatedType.getName == "array"
+ } || {
+ // For Parquet data generated by parquet-thrift, e.g.:
+ //
+ // // ARRAY<STRUCT<str: STRING>> (nullable list, non-null elements)
+ // optional group my_list (LIST) {
+ // repeated group my_list_tuple {
+ // required binary str (UTF8);
+ // };
+ // }
+ //
+ repeatedType.getName == s"${parentName}_tuple"
+ }
+ }
+}
+
+/**
+ * This converter class is used to convert Spark SQL [[StructType]] to Parquet [[MessageType]].
+ *
+ * @param writeLegacyParquetFormat Whether to use legacy Parquet format compatible with Spark 1.4
+ * and prior versions when converting a Catalyst [[StructType]] to a Parquet [[MessageType]].
+ * When set to false, use standard format defined in parquet-format spec. This argument only
+ * affects Parquet write path.
+ * @param outputTimestampType which parquet timestamp type to use when writing.
+ */
+class SparkToGeoParquetSchemaConverter(
+ writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get,
+ outputTimestampType: SQLConf.ParquetOutputTimestampType.Value =
+ SQLConf.ParquetOutputTimestampType.INT96)
+extends SparkToParquetSchemaConverter(writeLegacyParquetFormat, outputTimestampType) {
+
+ def this(conf: SQLConf) = this(
+ writeLegacyParquetFormat = conf.writeLegacyParquetFormat,
+ outputTimestampType = conf.parquetOutputTimestampType)
+
+ def this(conf: Configuration) = this(
+ writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean,
+ outputTimestampType = SQLConf.ParquetOutputTimestampType.withName(
+ conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)))
+
+ /**
+ * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]].
+ */
+ override def convert(catalystSchema: StructType): MessageType = {
+ Types
+ .buildMessage()
+ .addFields(catalystSchema.map(convertField): _*)
+ .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME)
+ }
+
+ /**
+ * Converts a Spark SQL [[StructField]] to a Parquet [[Type]].
+ */
+ override def convertField(field: StructField): Type = {
+ convertField(field, if (field.nullable) OPTIONAL else REQUIRED)
+ }
+
+ private def convertField(field: StructField, repetition: Type.Repetition): Type = {
+ GeoParquetSchemaConverter.checkFieldName(field.name)
+
+ field.dataType match {
+ // ===================
+ // Simple atomic types
+ // ===================
+
+ case BooleanType =>
+ Types.primitive(BOOLEAN, repetition).named(field.name)
+
+ case ByteType =>
+ Types.primitive(INT32, repetition).as(INT_8).named(field.name)
+
+ case ShortType =>
+ Types.primitive(INT32, repetition).as(INT_16).named(field.name)
+
+ case IntegerType =>
+ Types.primitive(INT32, repetition).named(field.name)
+
+ case LongType =>
+ Types.primitive(INT64, repetition).named(field.name)
+
+ case FloatType =>
+ Types.primitive(FLOAT, repetition).named(field.name)
+
+ case DoubleType =>
+ Types.primitive(DOUBLE, repetition).named(field.name)
+
+ case StringType =>
+ Types.primitive(BINARY, repetition).as(UTF8).named(field.name)
+
+ case DateType =>
+ Types.primitive(INT32, repetition).as(DATE).named(field.name)
+
+ // NOTE: Spark SQL can write timestamp values to Parquet using INT96, TIMESTAMP_MICROS or
+ // TIMESTAMP_MILLIS. TIMESTAMP_MICROS is recommended but INT96 is the default to keep the
+ // behavior same as before.
+ //
+ // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond
+ // timestamp in Impala for some historical reasons. It's not recommended to be used for any
+ // other types and will probably be deprecated in some future version of parquet-format spec.
+ // That's the reason why parquet-format spec only defines `TIMESTAMP_MILLIS` and
+ // `TIMESTAMP_MICROS` which are both logical types annotating `INT64`.
+ //
+ // Originally, Spark SQL uses the same nanosecond timestamp type as Impala and Hive. Starting
+ // from Spark 1.5.0, we resort to a timestamp type with microsecond precision so that we can
+ // store a timestamp into a `Long`. This design decision is subject to change though, for
+ // example, we may resort to nanosecond precision in the future.
+ case TimestampType =>
+ outputTimestampType match {
+ case SQLConf.ParquetOutputTimestampType.INT96 =>
+ Types.primitive(INT96, repetition).named(field.name)
+ case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS =>
+ Types.primitive(INT64, repetition).as(TIMESTAMP_MICROS).named(field.name)
+ case SQLConf.ParquetOutputTimestampType.TIMESTAMP_MILLIS =>
+ Types.primitive(INT64, repetition).as(TIMESTAMP_MILLIS).named(field.name)
+ }
+
+ case BinaryType =>
+ Types.primitive(BINARY, repetition).named(field.name)
+
+ // ======================
+ // Decimals (legacy mode)
+ // ======================
+
+ // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
+ // always store decimals in fixed-length byte arrays. To keep compatibility with these older
+ // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated
+ // by `DECIMAL`.
+ case DecimalType.Fixed(precision, scale) if writeLegacyParquetFormat =>
+ Types
+ .primitive(FIXED_LEN_BYTE_ARRAY, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .length(Decimal.minBytesForPrecision(precision))
+ .named(field.name)
+
+ // ========================
+ // Decimals (standard mode)
+ // ========================
+
+ // Uses INT32 for 1 <= precision <= 9
+ case DecimalType.Fixed(precision, scale)
+ if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat =>
+ Types
+ .primitive(INT32, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .named(field.name)
+
+ // Uses INT64 for 1 <= precision <= 18
+ case DecimalType.Fixed(precision, scale)
+ if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat =>
+ Types
+ .primitive(INT64, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .named(field.name)
+
+ // Uses FIXED_LEN_BYTE_ARRAY for all other precisions
+ case DecimalType.Fixed(precision, scale) if !writeLegacyParquetFormat =>
+ Types
+ .primitive(FIXED_LEN_BYTE_ARRAY, repetition)
+ .as(DECIMAL)
+ .precision(precision)
+ .scale(scale)
+ .length(Decimal.minBytesForPrecision(precision))
+ .named(field.name)
+
+ // ===================================
+ // ArrayType and MapType (legacy mode)
+ // ===================================
+
+ // Spark 1.4.x and prior versions convert `ArrayType` with nullable elements into a 3-level
+ // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro
+ // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element
+ // field name "array" is borrowed from parquet-avro.
+ case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat =>
+ // <list-repetition> group <name> (LIST) {
+ // optional group bag {
+ // repeated <element-type> array;
+ // }
+ // }
+
+ // This should not use `listOfElements` here because this new method checks if the
+ // element name is `element` in the `GroupType` and throws an exception if not.
+ // As mentioned above, Spark prior to 1.4.x writes `ArrayType` as `LIST` but with
+ // `array` as its element name as below. Therefore, we build manually
+ // the correct group type here via the builder. (See SPARK-16777)
+ Types
+ .buildGroup(repetition).as(LIST)
+ .addField(Types
+ .buildGroup(REPEATED)
+ // "array" is the name chosen by parquet-hive (1.7.0 and prior version)
+ .addField(convertField(StructField("array", elementType, nullable)))
+ .named("bag"))
+ .named(field.name)
+
+ // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level
+ // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is
+ // covered by the backwards-compatibility rules implemented in `isElementType()`.
+ case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat =>
+ // <list-repetition> group <name> (LIST) {
+ // repeated <element-type> element;
+ // }
+
+ // Here too, we should not use `listOfElements`. (See SPARK-16777)
+ Types
+ .buildGroup(repetition).as(LIST)
+ // "array" is the name chosen by parquet-avro (1.7.0 and prior version)
+ .addField(convertField(StructField("array", elementType, nullable), REPEATED))
+ .named(field.name)
+
+ // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by
+ // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`.
+ case MapType(keyType, valueType, valueContainsNull) if writeLegacyParquetFormat =>
+ // <map-repetition> group <name> (MAP) {
+ // repeated group map (MAP_KEY_VALUE) {
+ // required <key-type> key;
+ // <value-repetition> <value-type> value;
+ // }
+ // }
+ ConversionPatterns.mapType(
+ repetition,
+ field.name,
+ convertField(StructField("key", keyType, nullable = false)),
+ convertField(StructField("value", valueType, valueContainsNull)))
+
+ // =====================================
+ // ArrayType and MapType (standard mode)
+ // =====================================
+
+ case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat =>
+ // <list-repetition> group <name> (LIST) {
+ // repeated group list {
+ // <element-repetition> <element-type> element;
+ // }
+ // }
+ Types
+ .buildGroup(repetition).as(LIST)
+ .addField(
+ Types.repeatedGroup()
+ .addField(convertField(StructField("element", elementType, containsNull)))
+ .named("list"))
+ .named(field.name)
+
+ case MapType(keyType, valueType, valueContainsNull) =>
+ // <map-repetition> group <name> (MAP) {
+ // repeated group key_value {
+ // required <key-type> key;
+ // <value-repetition> <value-type> value;
+ // }
+ // }
+ Types
+ .buildGroup(repetition).as(MAP)
+ .addField(
+ Types
+ .repeatedGroup()
+ .addField(convertField(StructField("key", keyType, nullable = false)))
+ .addField(convertField(StructField("value", valueType, valueContainsNull)))
+ .named("key_value"))
+ .named(field.name)
+
+ // ===========
+ // Other types
+ // ===========
+
+ case StructType(fields) =>
+ fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) =>
+ builder.addField(convertField(field))
+ }.named(field.name)
+
+ case udt: UserDefinedType[_] =>
+ convertField(field.copy(dataType = udt.sqlType))
+
+ case _ =>
+ throw new AnalysisException(s"Unsupported data type ${field.dataType.catalogString}")
+ }
+ }
+}
+
+private[sql] object GeoParquetSchemaConverter {
+ def checkGeomFieldName(name: String): Boolean = {
+ if (name.equals(GeometryField.getFieldGeometry())) {
+ true
+ } else false
+ }
+
+ def checkFieldName(name: String): Unit = {
+ // ,;{}()\n\t= and space are special characters in Parquet schema
+ checkConversionRequirement(
+ !name.matches(".*[ ,;{}()\n\t=].*"),
+ s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=".
+ |Please use alias to rename it.
+ """.stripMargin.split("\n").mkString(" ").trim)
+ }
+}
+
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala
new file mode 100644
index 00000000..5e8c9c98
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoParquetUtils.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.parquet.hadoop.ParquetFileWriter
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.types.StructType
+
+import scala.language.existentials
+
+object GeoParquetUtils {
+ def inferSchema(
+ sparkSession: SparkSession,
+ parameters: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] = {
+ val parquetOptions = new ParquetOptions(parameters, sparkSession.sessionState.conf)
+ val shouldMergeSchemas = parquetOptions.mergeSchema
+ val mergeRespectSummaries = sparkSession.sessionState.conf.isParquetSchemaRespectSummaries
+ val filesByType = splitFiles(files)
+ val filesToTouch =
+ if (shouldMergeSchemas) {
+ val needMerged: Seq[FileStatus] =
+ if (mergeRespectSummaries) {
+ Seq.empty
+ } else {
+ filesByType.data
+ }
+ needMerged ++ filesByType.metadata ++ filesByType.commonMetadata
+ } else {
+ // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet
+ // don't have this.
+ filesByType.commonMetadata.headOption
+ // Falls back to "_metadata"
+ .orElse(filesByType.metadata.headOption)
+ // Summary file(s) not found, the Parquet file is either corrupted, or different part-
+ // files contain conflicting user defined metadata (two or more values are associated
+ // with a same key in different files). In either case, we fall back to any of the
+ // first part-file, and just assume all schemas are consistent.
+ .orElse(filesByType.data.headOption)
+ .toSeq
+ }
+ GeoParquetFileFormat.mergeSchemasInParallel(parameters, filesToTouch, sparkSession)
+ }
+
+ case class FileTypes(
+ data: Seq[FileStatus],
+ metadata: Seq[FileStatus],
+ commonMetadata: Seq[FileStatus])
+
+ private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = {
+ val leaves = allFiles.toArray.sortBy(_.getPath.toString)
+
+ FileTypes(
+ data = leaves.filterNot(f => isSummaryFile(f.getPath)),
+ metadata =
+ leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE),
+ commonMetadata =
+ leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))
+ }
+
+ private def isSummaryFile(file: Path): Boolean = {
+ file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
+ file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala
new file mode 100644
index 00000000..94232487
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeoSchemaMergeUtils.scala
@@ -0,0 +1,102 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
+
+// Needed by Sedona to support Spark 3.0 - 3.3
+object GeoSchemaMergeUtils {
+
+ def mergeSchemasInParallel(
+ sparkSession: SparkSession,
+ parameters: Map[String, String],
+ files: Seq[FileStatus],
+ schemaReader: (Seq[FileStatus], Configuration, Boolean) => Seq[StructType])
+ : Option[StructType] = {
+ val serializedConf = new SerializableConfiguration(
+ sparkSession.sessionState.newHadoopConfWithOptions(parameters))
+
+ // !! HACK ALERT !!
+ // Here is a hack for Parquet, but it can be used by Orc as well.
+ //
+ // Parquet requires `FileStatus`es to read footers.
+ // Here we try to send cached `FileStatus`es to executor side to avoid fetching them again.
+ // However, `FileStatus` is not `Serializable`
+ // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well
+ // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These
+ // facts virtually prevents us to serialize `FileStatus`es.
+ //
+ // Since Parquet only relies on path and length information of those `FileStatus`es to read
+ // footers, here we just extract them (which can be easily serialized), send them to executor
+ // side, and resemble fake `FileStatus`es there.
+ val partialFileStatusInfo = files.map(f => (f.getPath.toString, f.getLen))
+
+ // Set the number of partitions to prevent following schema reads from generating many tasks
+ // in case of a small number of orc files.
+ val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1),
+ sparkSession.sparkContext.defaultParallelism)
+
+ val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
+
+ // Issues a Spark job to read Parquet/ORC schema in parallel.
+ val partiallyMergedSchemas =
+ sparkSession
+ .sparkContext
+ .parallelize(partialFileStatusInfo, numParallelism)
+ .mapPartitions { iterator =>
+ // Resembles fake `FileStatus`es with serialized path and length information.
+ val fakeFileStatuses = iterator.map { case (path, length) =>
+ new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path))
+ }.toSeq
+
+ val schemas = schemaReader(fakeFileStatuses, serializedConf.value, ignoreCorruptFiles)
+
+ if (schemas.isEmpty) {
+ Iterator.empty
+ } else {
+ var mergedSchema = schemas.head
+ schemas.tail.foreach { schema =>
+ try {
+ mergedSchema = mergedSchema.merge(schema)
+ } catch { case cause: SparkException =>
+ throw new SparkException(
+ s"Failed merging schema:\n${schema.treeString}", cause)
+ }
+ }
+ Iterator.single(mergedSchema)
+ }
+ }.collect()
+
+ if (partiallyMergedSchemas.isEmpty) {
+ None
+ } else {
+ var finalSchema = partiallyMergedSchemas.head
+ partiallyMergedSchemas.tail.foreach { schema =>
+ try {
+ finalSchema = finalSchema.merge(schema)
+ } catch { case cause: SparkException =>
+ throw new SparkException(
+ s"Failed merging schema:\n${schema.treeString}", cause)
+ }
+ }
+ Some(finalSchema)
+ }
+ }
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeometryField.scala b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeometryField.scala
new file mode 100644
index 00000000..65f52c05
--- /dev/null
+++ b/sql/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/GeometryField.scala
@@ -0,0 +1,28 @@
+/**
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.parquet
+
+object GeometryField {
+
+ var fieldGeometry: String = ""
+
+ def getFieldGeometry(): String ={
+ fieldGeometry
+ }
+
+ def setFieldGeometry(fieldName: String): Unit ={
+ fieldGeometry = fieldName
+ }
+}
\ No newline at end of file
diff --git a/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala b/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
new file mode 100644
index 00000000..21590ab1
--- /dev/null
+++ b/sql/src/test/scala/org/apache/sedona/sql/geoparquetIOTests.scala
@@ -0,0 +1,66 @@
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.SaveMode
+import org.locationtech.jts.geom.Geometry
+import org.scalatest.BeforeAndAfter
+
+class geoparquetIOTests extends TestBaseScala with BeforeAndAfter{
+ var geoparquetdatalocation1: String = resourceFolder + "geoparquet/example1.parquet"
+ var geoparquetdatalocation2: String = resourceFolder + "geoparquet/example2.parquet"
+ var geoparquetdatalocation3: String = resourceFolder + "geoparquet/example3.parquet"
+ var geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/"
+
+ describe("GeoParquet IO tests"){
+ it("GEOPARQUET Test example1 i.e. naturalearth_lowers dataset's Read and Write"){
+ val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation1)
+ val rows = df.collect()(0)
+ assert(rows.getAs[Long]("pop_est") == 920938)
+ assert(rows.getAs[String]("continent") == "Oceania")
+ assert(rows.getAs[String]("name") == "Fiji")
+ assert(rows.getAs[String]("iso_a3") == "FJI")
+ assert(rows.getAs[Double]("gdp_md_est") == 8374.0)
+ assert(rows.getAs[Geometry]("geometry").toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177.285 [...]
+ df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoparquetoutputlocation + "/gp_sample1.parquet")
+ val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation + "/gp_sample1.parquet")
+ val newrows = df2.collect()(0)
+ assert(newrows.getAs[Geometry]("geometry").toString == "MULTIPOLYGON (((180 -16.067132663642447, 180 -16.555216566639196, 179.36414266196414 -16.801354076946883, 178.72505936299711 -17.01204167436804, 178.59683859511713 -16.639150000000004, 179.0966093629971 -16.433984277547403, 179.4135093629971 -16.379054277547404, 180 -16.067132663642447)), ((178.12557 -17.50481, 178.3736 -17.33992, 178.71806 -17.62846, 178.55271 -18.15059, 177.93266000000003 -18.28799, 177.38146 -18.16432, 177. [...]
+ }
+ it("GEOPARQUET Test example2 i.e. naturalearth_citie dataset's Read and Write"){
+ val df = sparkSession.read.format("geoparquet").load(geoparquetdatalocation2)
+ val rows = df.collect()(0)
+ assert(rows.getAs[String]("name") == "Vatican City")
+ assert(rows.getAs[Geometry]("geometry").toString == "POINT (12.453386544971766 41.903282179960115)")
+ df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoparquetoutputlocation + "/gp_sample2.parquet")
+ val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation + "/gp_sample2.parquet")
+ val newrows = df2.collect()(0)
+ assert(newrows.getAs[String]("name") == "Vatican City")
+ assert(newrows.getAs[Geometry]("geometry").toString == "POINT (12.453386544971766 41.903282179960115)")
+ }
+ it("GEOPARQUET Test example3 i.e. nybb dataset's Read and Write"){
+ val df=sparkSession.read.format("geoparquet").load(geoparquetdatalocation3)
+ val rows = df.collect()(0)
+ assert(rows.getAs[Long]("BoroCode") == 5)
+ assert(rows.getAs[String]("BoroName") == "Staten Island")
+ assert(rows.getAs[Double]("Shape_Leng") == 330470.010332)
+ assert(rows.getAs[Double]("Shape_Area") == 1.62381982381E9)
+ assert(rows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022"))
+ df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoparquetoutputlocation + "/gp_sample3.parquet")
+ val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation + "/gp_sample3.parquet")
+ val newrows = df2.collect()(0)
+ assert(newrows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022"))
+ }
+ it("GEOPARQUET Test example3 i.e. nybb dataset's Read and Write Options"){
+ val df=sparkSession.read.format("geoparquet").option("fieldGeometry", "geometry").load(geoparquetdatalocation3)
+ val rows = df.collect()(0)
+ assert(rows.getAs[Long]("BoroCode") == 5)
+ assert(rows.getAs[String]("BoroName") == "Staten Island")
+ assert(rows.getAs[Double]("Shape_Leng") == 330470.010332)
+ assert(rows.getAs[Double]("Shape_Area") == 1.62381982381E9)
+ assert(rows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022"))
+ df.write.format("geoparquet").mode(SaveMode.Overwrite).save(geoparquetoutputlocation + "/gp_sample3o.parquet")
+ val df2 = sparkSession.read.format("geoparquet").load(geoparquetoutputlocation + "/gp_sample3.parquet")
+ val newrows = df2.collect()(0)
+ assert(newrows.getAs[Geometry]("geometry").toString.startsWith("MULTIPOLYGON (((970217.022"))
+ }
+ }
+}