You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@sedona.apache.org by ji...@apache.org on 2020/12/17 01:32:23 UTC
[incubator-sedona] branch master updated: [SEDONA-11] Add faster
Python conversion from spatial rdd to df. (#496)
This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-sedona.git
The following commit(s) were added to refs/heads/master by this push:
new f94e5c5 [SEDONA-11] Add faster Python conversion from spatial rdd to df. (#496)
f94e5c5 is described below
commit f94e5c5a593de16ee698fea3f505f2ca6e5fb36d
Author: Paweł Kociński <pa...@gmail.com>
AuthorDate: Thu Dec 17 02:32:14 2020 +0100
[SEDONA-11] Add faster Python conversion from spatial rdd to df. (#496)
* [SEDONA-3] Add faster Python conversion from spatial rdd to df.
* [SEDONA-3] Add missing license clauses.
* Merge the toGeometryDf to toDf
* [SEDONA-3] Change Python api to handle one function toDf.
Co-authored-by: Jia Yu <ji...@gmail.com>
---
jts | 2 +-
.../utils/PythonAdapterWrapper.scala | 6 +-
python/sedona/core/spatialOperator/join_query.py | 80 ++-------
.../sedona/core/spatialOperator/join_query_raw.py | 99 ++++++++++++
python/sedona/core/spatialOperator/range_query.py | 24 +--
.../{range_query.py => range_query_raw.py} | 12 +-
python/sedona/core/spatialOperator/rdd.py | 90 +++++++++++
python/sedona/utils/adapter.py | 33 ++--
.../core/test_avoiding_python_jvm_serde_df.py | 180 +++++++++++++++++++++
.../core/test_avoiding_python_jvm_serde_to_rdd.py | 93 +++++++++++
python/tests/core/test_core_spatial_relations.py | 60 -------
python/tests/resources/small/areas.csv | 4 +
python/tests/resources/small/points.csv | 10 ++
python/tests/sql/test_adapter.py | 9 +-
.../org/apache/sedona/sql/utils/Adapter.scala | 143 +++++-----------
.../spark/sql/sedona_sql/UDT/GeometryUDT.scala | 6 +-
.../sql/sedona_sql/expressions/Constructors.scala | 18 +--
.../sql/sedona_sql/expressions/Functions.scala | 38 ++---
.../expressions_udaf/AggregateFunctions.scala | 18 +--
sql/src/test/resources/small/areas.csv | 4 +
sql/src/test/resources/small/points.csv | 10 ++
.../org/apache/sedona/sql/TestBaseScala.scala | 2 +
.../org/apache/sedona/sql/adapterTestScala.scala | 52 +++---
.../apache/sedona/sql/constructorTestScala.scala | 2 -
24 files changed, 644 insertions(+), 351 deletions(-)
diff --git a/jts b/jts
index e4d7d6f..30fba3d 160000
--- a/jts
+++ b/jts
@@ -1 +1 @@
-Subproject commit e4d7d6f451c908b1bfeef6a10230a5220d7e28d1
+Subproject commit 30fba3dc16e0ea48595e461b46b8ae393430fddc
diff --git a/python-adapter/src/main/scala/org.apache.sedona.python.wrapper/utils/PythonAdapterWrapper.scala b/python-adapter/src/main/scala/org.apache.sedona.python.wrapper/utils/PythonAdapterWrapper.scala
index 9a7d512..f3a6b0b 100644
--- a/python-adapter/src/main/scala/org.apache.sedona.python.wrapper/utils/PythonAdapterWrapper.scala
+++ b/python-adapter/src/main/scala/org.apache.sedona.python.wrapper/utils/PythonAdapterWrapper.scala
@@ -24,17 +24,17 @@ import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.locationtech.jts.geom.Geometry
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
object PythonAdapterWrapper {
def toDf[T <: Geometry](spatialRDD: SpatialRDD[T], fieldNames: java.util.ArrayList[String], sparkSession: SparkSession): DataFrame = {
- Adapter.toDf(spatialRDD, fieldNames.toList, sparkSession)
+ Adapter.toDf(spatialRDD, fieldNames.asScala, sparkSession)
}
def toDf(spatialPairRDD: JavaPairRDD[Geometry, Geometry],
leftFieldnames: java.util.ArrayList[String],
rightFieldNames: java.util.ArrayList[String],
sparkSession: SparkSession): DataFrame = {
- Adapter.toDf(spatialPairRDD, leftFieldnames.toList, rightFieldNames.toList, sparkSession)
+ Adapter.toDf(spatialPairRDD, leftFieldnames.asScala, rightFieldNames.asScala, sparkSession)
}
}
diff --git a/python/sedona/core/spatialOperator/join_query.py b/python/sedona/core/spatialOperator/join_query.py
index 26ce888..d4483e7 100644
--- a/python/sedona/core/spatialOperator/join_query.py
+++ b/python/sedona/core/spatialOperator/join_query.py
@@ -18,10 +18,9 @@
from pyspark import RDD
from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD
-from sedona.core.jvm.translate import JvmSedonaPythonConverter
from sedona.core.spatialOperator.join_params import JoinParams
+from sedona.core.spatialOperator.join_query_raw import JoinQueryRaw
from sedona.utils.decorators import require
-from sedona.utils.spatial_rdd_parser import SedonaPickler
class JoinQuery:
@@ -39,19 +38,8 @@ class JoinQuery:
:return:
"""
- jvm = spatialRDD._jvm
- sc = spatialRDD._sc
-
- srdd = jvm.JoinQuery.SpatialJoinQuery(
- spatialRDD._srdd,
- queryRDD._srdd,
- useIndex,
- considerBoundaryIntersection
- )
- serialized = JvmSedonaPythonConverter(jvm) \
- .translate_spatial_pair_rdd_with_list_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ pair_rdd = JoinQueryRaw.SpatialJoinQuery(spatialRDD, queryRDD, useIndex, considerBoundaryIntersection)
+ return pair_rdd.to_rdd()
@classmethod
@require(["JoinQuery"])
@@ -66,18 +54,8 @@ class JoinQuery:
:return:
"""
- jvm = spatialRDD._jvm
- sc = spatialRDD._sc
- srdd = jvm.JoinQuery.DistanceJoinQuery(
- spatialRDD._srdd,
- queryRDD._srdd,
- useIndex,
- considerBoundaryIntersection
- )
- serialized = JvmSedonaPythonConverter(jvm). \
- translate_spatial_pair_rdd_with_list_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ pair_rdd = JoinQueryRaw.DistanceJoinQuery(spatialRDD, queryRDD, useIndex, considerBoundaryIntersection)
+ return pair_rdd.to_rdd()
@classmethod
@require(["JoinQuery"])
@@ -90,16 +68,8 @@ class JoinQuery:
:return:
"""
- jvm = queryWindowRDD._jvm
- sc = queryWindowRDD._sc
-
- jvm_join_params = joinParams.jvm_instance(jvm)
-
- srdd = jvm.JoinQuery.spatialJoin(queryWindowRDD._srdd, objectRDD._srdd, jvm_join_params)
- serialized = JvmSedonaPythonConverter(jvm). \
- translate_spatial_pair_rdd_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ pair_rdd = JoinQueryRaw.spatialJoin(queryWindowRDD, objectRDD, joinParams)
+ return pair_rdd.to_rdd()
@classmethod
@require(["JoinQuery"])
@@ -120,21 +90,9 @@ class JoinQuery:
:return:
"""
- jvm = spatialRDD._jvm
- sc = spatialRDD._sc
-
- spatial_join = jvm.JoinQuery.DistanceJoinQueryFlat
- srdd = spatial_join(
- spatialRDD._srdd,
- queryRDD._srdd,
- useIndex,
- considerBoundaryIntersection
- )
-
- serialized = JvmSedonaPythonConverter(jvm). \
- translate_spatial_pair_rdd_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ pair_rdd = JoinQueryRaw.DistanceJoinQueryFlat(spatialRDD, queryRDD, useIndex,
+ considerBoundaryIntersection)
+ return pair_rdd.to_rdd()
@classmethod
@require(["JoinQuery"])
@@ -159,18 +117,6 @@ class JoinQuery:
[[GeoData(Polygon, ), GeoData()], [GeoData(), GeoData()], [GeoData(), GeoData()]]
"""
- jvm = spatialRDD._jvm
- sc = spatialRDD._sc
-
- spatial_join = jvm.JoinQuery.SpatialJoinQueryFlat
- srdd = spatial_join(
- spatialRDD._srdd,
- queryRDD._srdd,
- useIndex,
- considerBoundaryIntersection
- )
-
- serialized = JvmSedonaPythonConverter(jvm). \
- translate_spatial_pair_rdd_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ pair_rdd = JoinQueryRaw.SpatialJoinQueryFlat(spatialRDD, queryRDD, useIndex,
+ considerBoundaryIntersection)
+ return pair_rdd.to_rdd()
diff --git a/python/sedona/core/spatialOperator/join_query_raw.py b/python/sedona/core/spatialOperator/join_query_raw.py
new file mode 100644
index 0000000..4e1b54e
--- /dev/null
+++ b/python/sedona/core/spatialOperator/join_query_raw.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from sedona.core.SpatialRDD import SpatialRDD
+from sedona.core.spatialOperator.join_params import JoinParams
+from sedona.core.spatialOperator.rdd import SedonaPairRDDList, SedonaPairRDD
+from sedona.utils.decorators import require
+
+
+class JoinQueryRaw:
+
+ @classmethod
+ @require(["JoinQuery"])
+ def SpatialJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDDList:
+ jvm = spatialRDD._jvm
+ sc = spatialRDD._sc
+
+ srdd = jvm.JoinQuery.SpatialJoinQuery(
+ spatialRDD._srdd,
+ queryRDD._srdd,
+ useIndex,
+ considerBoundaryIntersection
+ )
+
+ return SedonaPairRDDList(srdd, sc)
+
+ @classmethod
+ @require(["JoinQuery"])
+ def DistanceJoinQuery(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDDList:
+
+ jvm = spatialRDD._jvm
+ sc = spatialRDD._sc
+ srdd = jvm.JoinQuery.DistanceJoinQuery(
+ spatialRDD._srdd,
+ queryRDD._srdd,
+ useIndex,
+ considerBoundaryIntersection
+ )
+ return SedonaPairRDDList(srdd, sc)
+
+ @classmethod
+ @require(["JoinQuery"])
+ def spatialJoin(cls, queryWindowRDD: SpatialRDD, objectRDD: SpatialRDD, joinParams: JoinParams) -> SedonaPairRDD:
+
+ jvm = queryWindowRDD._jvm
+ sc = queryWindowRDD._sc
+
+ jvm_join_params = joinParams.jvm_instance(jvm)
+
+ srdd = jvm.JoinQuery.spatialJoin(queryWindowRDD._srdd, objectRDD._srdd, jvm_join_params)
+
+ return SedonaPairRDD(srdd, sc)
+
+ @classmethod
+ @require(["JoinQuery"])
+ def DistanceJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool, considerBoundaryIntersection: bool) -> SedonaPairRDD:
+
+ jvm = spatialRDD._jvm
+ sc = spatialRDD._sc
+
+ spatial_join = jvm.JoinQuery.DistanceJoinQueryFlat
+ srdd = spatial_join(
+ spatialRDD._srdd,
+ queryRDD._srdd,
+ useIndex,
+ considerBoundaryIntersection
+ )
+ return SedonaPairRDD(srdd, sc)
+
+ @classmethod
+ @require(["JoinQuery"])
+ def SpatialJoinQueryFlat(cls, spatialRDD: SpatialRDD, queryRDD: SpatialRDD, useIndex: bool,
+ considerBoundaryIntersection: bool) -> SedonaPairRDD:
+
+ jvm = spatialRDD._jvm
+ sc = spatialRDD._sc
+
+ spatial_join = jvm.JoinQuery.SpatialJoinQueryFlat
+ srdd = spatial_join(
+ spatialRDD._srdd,
+ queryRDD._srdd,
+ useIndex,
+ considerBoundaryIntersection
+ )
+ return SedonaPairRDD(srdd, sc)
diff --git a/python/sedona/core/spatialOperator/range_query.py b/python/sedona/core/spatialOperator/range_query.py
index c8ad140..df44ee3 100644
--- a/python/sedona/core/spatialOperator/range_query.py
+++ b/python/sedona/core/spatialOperator/range_query.py
@@ -15,14 +15,11 @@
# specific language governing permissions and limitations
# under the License.
-from pyspark import RDD
from shapely.geometry.base import BaseGeometry
from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD
-from sedona.core.jvm.translate import JvmSedonaPythonConverter
+from sedona.core.spatialOperator.range_query_raw import RangeQueryRaw
from sedona.utils.decorators import require
-from sedona.utils.geometry_adapter import GeometryAdapter
-from sedona.utils.spatial_rdd_parser import SedonaPickler
class RangeQuery:
@@ -39,20 +36,5 @@ class RangeQuery:
:param usingIndex:
:return:
"""
-
- jvm = spatialRDD._jvm
- sc = spatialRDD._sc
-
- jvm_geom = GeometryAdapter.create_jvm_geometry_from_base_geometry(jvm, rangeQueryWindow)
-
- srdd = jvm. \
- RangeQuery.SpatialRangeQuery(
- spatialRDD._srdd,
- jvm_geom,
- considerBoundaryIntersection,
- usingIndex
- )
-
- serialized = JvmSedonaPythonConverter(jvm).translate_spatial_rdd_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ j_srdd = RangeQueryRaw.SpatialRangeQuery(spatialRDD, rangeQueryWindow, considerBoundaryIntersection, usingIndex)
+ return j_srdd.to_rdd()
diff --git a/python/sedona/core/spatialOperator/range_query.py b/python/sedona/core/spatialOperator/range_query_raw.py
similarity index 84%
copy from python/sedona/core/spatialOperator/range_query.py
copy to python/sedona/core/spatialOperator/range_query_raw.py
index c8ad140..dced4cd 100644
--- a/python/sedona/core/spatialOperator/range_query.py
+++ b/python/sedona/core/spatialOperator/range_query_raw.py
@@ -15,22 +15,20 @@
# specific language governing permissions and limitations
# under the License.
-from pyspark import RDD
from shapely.geometry.base import BaseGeometry
from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD
-from sedona.core.jvm.translate import JvmSedonaPythonConverter
+from sedona.core.spatialOperator.rdd import SedonaPairRDD, SedonaRDD
from sedona.utils.decorators import require
from sedona.utils.geometry_adapter import GeometryAdapter
-from sedona.utils.spatial_rdd_parser import SedonaPickler
-class RangeQuery:
+class RangeQueryRaw:
@classmethod
@require(["RangeQuery", "GeometryAdapter", "GeoSerializerData"])
def SpatialRangeQuery(self, spatialRDD: SpatialRDD, rangeQueryWindow: BaseGeometry,
- considerBoundaryIntersection: bool, usingIndex: bool):
+ considerBoundaryIntersection: bool, usingIndex: bool) -> SedonaRDD:
"""
:param spatialRDD:
@@ -53,6 +51,4 @@ class RangeQuery:
usingIndex
)
- serialized = JvmSedonaPythonConverter(jvm).translate_spatial_rdd_to_python(srdd)
-
- return RDD(serialized, sc, SedonaPickler())
+ return SedonaRDD(srdd, sc)
diff --git a/python/sedona/core/spatialOperator/rdd.py b/python/sedona/core/spatialOperator/rdd.py
new file mode 100644
index 0000000..cf7ba83
--- /dev/null
+++ b/python/sedona/core/spatialOperator/rdd.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Optional
+
+from pyspark import SparkContext, RDD
+from pyspark.sql import SparkSession, DataFrame
+
+from sedona.core.SpatialRDD import SpatialRDD
+from sedona.core.jvm.translate import JvmSedonaPythonConverter
+from sedona.utils.spatial_rdd_parser import SedonaPickler
+
+
+class SedonaRDD:
+
+ def __init__(self, jsrdd, sc: SparkContext):
+ self.jsrdd = jsrdd
+ self.sc = sc
+
+ def to_df(self, spark: SparkSession, field_names: List[str] = None) -> DataFrame:
+ from sedona.utils.adapter import Adapter
+ srdd = SpatialRDD(self.sc)
+ srdd.setRawSpatialRDD(self.jsrdd)
+ if field_names:
+ return Adapter.toDf(srdd, field_names, spark)
+ else:
+ return Adapter.toDf(srdd, spark)
+
+ def to_rdd(self) -> RDD:
+ jvm = self.sc._jvm
+ serialized = JvmSedonaPythonConverter(jvm). \
+ translate_spatial_rdd_to_python(self.jsrdd)
+
+ return RDD(serialized, self.sc, SedonaPickler())
+
+
+class SedonaPairRDD:
+
+ def __init__(self, jsrdd, sc: SparkContext):
+ self.jsrdd = jsrdd
+ self.sc = sc
+
+ def to_df(self, spark: SparkSession,
+ left_field_names: Optional[List] = None,
+ right_field_names: Optional[List] = None) -> DataFrame:
+ from sedona.utils.adapter import Adapter
+ if left_field_names is not None and right_field_names is not None:
+ df = Adapter.toDf(self, left_field_names, right_field_names, spark)
+ return df
+
+ elif left_field_names is None and right_field_names is None:
+ df = Adapter.toDf(self, spark)
+ return df
+ else:
+ raise AttributeError("when passing left_field_names you have also pass right_field_names and reverse")
+
+ def to_rdd(self) -> RDD:
+ jvm = self.sc._jvm
+ serialized = JvmSedonaPythonConverter(jvm). \
+ translate_spatial_pair_rdd_to_python(self.jsrdd)
+
+ return RDD(serialized, self.sc, SedonaPickler())
+
+
+class SedonaPairRDDList:
+
+ def __init__(self, jsrdd, sc: SparkContext):
+ self.jsrdd = jsrdd
+ self.sc = sc
+
+ def to_rdd(self):
+ jvm = self.sc._jvm
+ serialized = JvmSedonaPythonConverter(jvm). \
+ translate_spatial_pair_rdd_with_list_to_python(self.jsrdd)
+
+ return RDD(serialized, self.sc, SedonaPickler())
diff --git a/python/sedona/utils/adapter.py b/python/sedona/utils/adapter.py
index b32e091..260e036 100644
--- a/python/sedona/utils/adapter.py
+++ b/python/sedona/utils/adapter.py
@@ -22,6 +22,7 @@ from pyspark.sql import DataFrame, SparkSession
from sedona.core.SpatialRDD.spatial_rdd import SpatialRDD
from sedona.core.enums.spatial import SpatialType
+from sedona.core.spatialOperator.rdd import SedonaPairRDD
from sedona.utils.meta import MultipleMeta
@@ -59,23 +60,6 @@ class Adapter(metaclass=MultipleMeta):
return spatial_rdd
@classmethod
- def toSpatialRdd(cls, dataFrame: DataFrame):
- """
-
- :param dataFrame:
- :return:
- """
- sc = dataFrame._sc
- jvm = sc._jvm
-
- srdd = jvm.Adapter.toSpatialRdd(dataFrame._jdf)
-
- spatial_rdd = SpatialRDD(sc)
- spatial_rdd.set_srdd(srdd)
-
- return spatial_rdd
-
- @classmethod
def toSpatialRdd(cls, dataFrame: DataFrame, fieldNames: List) -> SpatialRDD:
"""
@@ -161,3 +145,18 @@ class Adapter(metaclass=MultipleMeta):
return df.toDF(*combined_columns)
else:
raise TypeError("Column length does not match")
+
+ @classmethod
+ def toDf(cls, rawPairRDD: SedonaPairRDD, sparkSession: SparkSession):
+ jvm = sparkSession._jvm
+ jdf = jvm.Adapter.toDf(rawPairRDD.jsrdd, sparkSession._jsparkSession)
+ df = DataFrame(jdf, sparkSession._wrapped)
+ return df
+
+ @classmethod
+ def toDf(cls, rawPairRDD: SedonaPairRDD, leftFieldnames: List, rightFieldNames: List, sparkSession: SparkSession):
+ jvm = sparkSession._jvm
+ jdf = jvm.PythonAdapterWrapper.toDf(
+ rawPairRDD.jsrdd, leftFieldnames, rightFieldNames, sparkSession._jsparkSession)
+ df = DataFrame(jdf, sparkSession._wrapped)
+ return df
diff --git a/python/tests/core/test_avoiding_python_jvm_serde_df.py b/python/tests/core/test_avoiding_python_jvm_serde_df.py
new file mode 100644
index 0000000..3b78bce
--- /dev/null
+++ b/python/tests/core/test_avoiding_python_jvm_serde_df.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pyspark.sql.types import StructField, StructType
+
+from sedona.core.SpatialRDD import CircleRDD
+from sedona.core.enums import GridType, IndexType
+from sedona.core.formatMapper import WktReader
+from sedona.core.spatialOperator.join_params import JoinParams
+from sedona.core.spatialOperator.join_query_raw import JoinQueryRaw
+from sedona.core.spatialOperator.range_query_raw import RangeQueryRaw
+from sedona.sql.types import GeometryType
+from tests.test_base import TestBase
+
+import os
+
+from tests.tools import tests_path
+from shapely.wkt import loads
+
+bank_csv_path = os.path.join(tests_path, "resources/small/points.csv")
+areas_csv_path = os.path.join(tests_path, "resources/small/areas.csv")
+
+
+class TestOmitPythonJvmSerdeToDf(TestBase):
+ expected_pois_within_areas_ids = [['4', '4'], ['1', '6'], ['2', '1'], ['3', '3'], ['3', '7']]
+
+ def test_spatial_join_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False)
+ poi_point_rdd.analyze()
+ areas_polygon_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.spatialJoin(poi_point_rdd, areas_polygon_rdd, JoinParams())
+ sedona_df = jvm_sedona_rdd.to_df(spark=self.spark,
+ left_field_names=["area_id", "area_name"],
+ right_field_names=["poi_id", "poi_name"])
+
+ assert sedona_df.count() == 5
+ assert sedona_df.columns == ["leftgeometry", "area_id", "area_name", "rightgeometry",
+ "poi_id", "poi_name"]
+
+ def test_distance_join_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ circle_rdd = CircleRDD(poi_point_rdd, 2.0)
+
+ circle_rdd.analyze()
+ poi_point_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ circle_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat(poi_point_rdd, circle_rdd, False, True)
+
+ df_sedona_rdd = jvm_sedona_rdd.to_df(
+ self.spark,
+ left_field_names=["poi_from_id", "poi_from_name"],
+ right_field_names=["poi_to_id", "poi_to_name"]
+ )
+
+ assert df_sedona_rdd.count() == 10
+ assert df_sedona_rdd.columns == [
+ "leftgeometry",
+ "poi_from_id",
+ "poi_from_name",
+ "rightgeometry",
+ "poi_to_id",
+ "poi_to_name"
+ ]
+
+ def test_spatial_join_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False)
+ poi_point_rdd.analyze()
+ areas_polygon_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.SpatialJoinQueryFlat(
+ poi_point_rdd, areas_polygon_rdd, False, True)
+
+ pois_within_areas_with_default_column_names = jvm_sedona_rdd.to_df(self.spark)
+
+ assert pois_within_areas_with_default_column_names.count() == 5
+
+ pois_within_areas_with_passed_column_names = jvm_sedona_rdd.to_df(
+ spark=self.spark,
+ left_field_names=["area_id", "area_name"],
+ right_field_names=["poi_id", "poi_name"]
+ )
+
+ assert pois_within_areas_with_passed_column_names.count() == 5
+
+ assert pois_within_areas_with_passed_column_names.columns == ["leftgeometry", "area_id", "area_name",
+ "rightgeometry",
+ "poi_id", "poi_name"]
+
+ assert pois_within_areas_with_default_column_names.schema == StructType(
+ [
+ StructField("leftgeometry", GeometryType()),
+ StructField("rightgeometry", GeometryType()),
+ ]
+ )
+
+ left_geometries_raw = pois_within_areas_with_default_column_names. \
+ selectExpr("ST_AsText(leftgeometry)"). \
+ collect()
+
+ left_geometries = self.__row_to_list(left_geometries_raw)
+
+ right_geometries_raw = pois_within_areas_with_default_column_names. \
+ selectExpr("ST_AsText(rightgeometry)"). \
+ collect()
+
+ right_geometries = self.__row_to_list(right_geometries_raw)
+
+ assert left_geometries == [
+ ['POLYGON ((0 4, -3 3, -8 6, -6 8, -2 9, 0 4))'],
+ ['POLYGON ((2 2, 2 4, 3 5, 7 5, 9 3, 8 1, 4 1, 2 2))'],
+ ['POLYGON ((10 3, 10 6, 14 6, 14 3, 10 3))'],
+ ['POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))'],
+ ['POLYGON ((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1))']
+ ]
+ assert right_geometries == [['POINT (-3 5)'],
+ ['POINT (4 3)'],
+ ['POINT (11 5)'],
+ ['POINT (-1 -1)'],
+ ['POINT (-4 -5)']]
+
+ def test_range_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+
+ poi_point_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ poi_point_rdd.buildIndex(IndexType.QUADTREE, False)
+
+ result = RangeQueryRaw.SpatialRangeQuery(
+ poi_point_rdd, loads("POLYGON((0 0, 0 20, 20 20, 20 0, 0 0))"), True, True
+ )
+
+ rdd = result.to_rdd()
+
+ assert rdd.collect().__len__() == 4
+
+ df_without_column_names = result.to_df(self.spark)
+
+ raw_geometries = self.__row_to_list(
+ df_without_column_names.collect()
+ )
+
+ assert [point[0].wkt for point in raw_geometries] == [
+ 'POINT (9 8)', 'POINT (4 3)', 'POINT (12 1)', 'POINT (11 5)'
+ ]
+ assert df_without_column_names.count() == 4
+ assert df_without_column_names.schema == StructType([StructField("geometry", GeometryType())])
+
+ df = result.to_df(self.spark, field_names=["poi_id", "poi_name"])
+
+ assert df.count() == 4
+ assert df.columns == ["geometry", "poi_id", "poi_name"]
+
+ def __row_to_list(self, row_list):
+ return [[*element] for element in row_list]
diff --git a/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py b/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py
new file mode 100644
index 0000000..48f6012
--- /dev/null
+++ b/python/tests/core/test_avoiding_python_jvm_serde_to_rdd.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from sedona.core.SpatialRDD import CircleRDD
+from sedona.core.enums import GridType, IndexType
+from sedona.core.formatMapper import WktReader
+from sedona.core.spatialOperator.join_params import JoinParams
+from sedona.core.spatialOperator.join_query_raw import JoinQueryRaw
+from sedona.core.spatialOperator.range_query_raw import RangeQueryRaw
+from tests.test_base import TestBase
+
+import os
+
+from tests.tools import tests_path
+from shapely.wkt import loads
+
+bank_csv_path = os.path.join(tests_path, "resources/small/points.csv")
+areas_csv_path = os.path.join(tests_path, "resources/small/areas.csv")
+
+
+class TestOmitPythonJvmSerdeToRDD(TestBase):
+ expected_pois_within_areas_ids = [['4', '4'], ['1', '6'], ['2', '1'], ['3', '3'], ['3', '7']]
+
+ def test_spatial_join_to_spatial_rdd(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False)
+ poi_point_rdd.analyze()
+ areas_polygon_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.spatialJoin(poi_point_rdd, areas_polygon_rdd, JoinParams())
+ sedona_rdd = jvm_sedona_rdd.to_rdd().collect()
+ assert sedona_rdd.__len__() == 5
+
+ def test_distance_join_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ circle_rdd = CircleRDD(poi_point_rdd, 2.0)
+
+ circle_rdd.analyze()
+ poi_point_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ circle_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.DistanceJoinQueryFlat(poi_point_rdd, circle_rdd, False, True)
+
+ assert jvm_sedona_rdd.to_rdd().collect().__len__() == 10
+
+ def test_spatial_join_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+ areas_polygon_rdd = WktReader.readToGeometryRDD(self.sc, areas_csv_path, 1, False, False)
+ poi_point_rdd.analyze()
+ areas_polygon_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ areas_polygon_rdd.spatialPartitioning(poi_point_rdd.getPartitioner())
+
+ jvm_sedona_rdd = JoinQueryRaw.SpatialJoinQueryFlat(
+ poi_point_rdd, areas_polygon_rdd, False, True)
+
+ assert jvm_sedona_rdd.to_rdd().collect().__len__() == 5
+
+ def test_range_query_flat_to_df(self):
+ poi_point_rdd = WktReader.readToGeometryRDD(self.sc, bank_csv_path, 1, False, False)
+
+ poi_point_rdd.analyze()
+
+ poi_point_rdd.spatialPartitioning(GridType.QUADTREE)
+ poi_point_rdd.buildIndex(IndexType.QUADTREE, False)
+
+ result = RangeQueryRaw.SpatialRangeQuery(
+ poi_point_rdd, loads("POLYGON((0 0, 0 20, 20 20, 20 0, 0 0))"), True, True
+ )
+
+ rdd = result.to_rdd()
+
+ assert rdd.collect().__len__() == 4
diff --git a/python/tests/core/test_core_spatial_relations.py b/python/tests/core/test_core_spatial_relations.py
deleted file mode 100644
index ebe978e..0000000
--- a/python/tests/core/test_core_spatial_relations.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-
-from sedona.core.SpatialRDD import PointRDD, PolygonRDD
-from sedona.core.enums import FileDataSplitter, GridType
-from sedona.core.spatialOperator import JoinQuery
-from tests.test_base import TestBase
-from tests.tools import tests_path
-
-point_path = os.path.join(tests_path, "resources/points.csv")
-counties_path = os.path.join(tests_path, "resources/counties_tsv.csv")
-
-
-class TestJoinQuery(TestBase):
-
- def test_spatial_join_query(self):
- point_rdd = PointRDD(
- self.sc,
- point_path,
- 4,
- FileDataSplitter.WKT,
- True
- )
-
- polygon_rdd = PolygonRDD(
- self.sc,
- counties_path,
- 2,
- 3,
- FileDataSplitter.WKT,
- True
- )
-
- point_rdd.analyze()
- point_rdd.spatialPartitioning(GridType.KDBTREE, num_partitions=10)
- polygon_rdd.spatialPartitioning(point_rdd.getPartitioner())
- result = JoinQuery.SpatialJoinQuery(
- point_rdd,
- polygon_rdd,
- True,
- False
- )
-
- assert result.count() == 26
diff --git a/python/tests/resources/small/areas.csv b/python/tests/resources/small/areas.csv
new file mode 100644
index 0000000..0977b3d
--- /dev/null
+++ b/python/tests/resources/small/areas.csv
@@ -0,0 +1,4 @@
+1 POLYGON((2 2, 2 4, 3 5, 7 5, 9 3, 8 1, 4 1, 2 2)) area1
+2 POLYGON((10 3, 10 6, 14 6, 14 3, 10 3)) area2
+3 POLYGON((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1)) area3
+4 POLYGON((0 4, -3 3, -8 6, -6 8, -2 9, 0 4)) area4
\ No newline at end of file
diff --git a/python/tests/resources/small/points.csv b/python/tests/resources/small/points.csv
new file mode 100644
index 0000000..7ac95ec
--- /dev/null
+++ b/python/tests/resources/small/points.csv
@@ -0,0 +1,10 @@
+1 POINT(11 5) bank1
+2 POINT(12 1) bank2
+3 POINT(-1 -1) bank3
+4 POINT(-3 5) bank4
+5 POINT(9 8) bank5
+6 POINT(4 3) bank6
+7 POINT(-4 -5) bank7
+8 POINT(4 -2) bank8
+9 POINT(-3 1) bank9
+10 POINT(-7 3) bank10
\ No newline at end of file
diff --git a/python/tests/sql/test_adapter.py b/python/tests/sql/test_adapter.py
index 0f06117..dc1aab6 100644
--- a/python/tests/sql/test_adapter.py
+++ b/python/tests/sql/test_adapter.py
@@ -99,7 +99,7 @@ class TestAdapter(TestBase):
spatial_df = self.spark.sql("select ST_GeomFromWKT(inputtable._c0) as usacounty from inputtable")
spatial_df.show()
spatial_df.printSchema()
- spatial_rdd = Adapter.toSpatialRdd(spatial_df)
+ spatial_rdd = Adapter.toSpatialRdd(spatial_df, "usacounty")
spatial_rdd.analyze()
Adapter.toDf(spatial_rdd, self.spark).show()
assert (Adapter.toDf(spatial_rdd, self.spark).columns.__len__() == 1)
@@ -148,10 +148,9 @@ class TestAdapter(TestBase):
)
spatial_rdd.analyze()
+ Adapter.toDf(spatial_rdd, self.spark).show()
+ df = Adapter.toDf(spatial_rdd, self.spark)
- df = Adapter.toDf(spatial_rdd, self.spark).\
- withColumn("geometry", expr("ST_GeomFromWKT(geometry)"))
- df.show()
assert (df.columns[1] == "STATEFP")
def test_convert_spatial_join_result_to_dataframe(self):
@@ -268,7 +267,7 @@ class TestAdapter(TestBase):
def test_to_spatial_rdd_df(self):
spatial_df = self._create_spatial_point_table()
- spatial_rdd = Adapter.toSpatialRdd(spatial_df)
+ spatial_rdd = Adapter.toSpatialRdd(spatial_df, "geometry")
spatial_rdd.analyze()
diff --git a/sql/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala b/sql/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala
index cfaf077..8611e15 100644
--- a/sql/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala
+++ b/sql/src/main/scala/org/apache/sedona/sql/utils/Adapter.scala
@@ -19,21 +19,15 @@
package org.apache.sedona.sql.utils
import org.apache.sedona.core.spatialRDD.SpatialRDD
-import org.apache.sedona.core.utils.GeomUtils
-import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.rdd.RDD
-import org.locationtech.jts.geom.Geometry
-//import org.apache.spark.sql.sedona_sql.GeometryWrapper
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.locationtech.jts.geom.Geometry
object Adapter {
- @deprecated("use toSpatialRdd and append geometry column's name", "1.2")
- def toSpatialRdd(dataFrame: DataFrame): SpatialRDD[Geometry] = {
- toSpatialRdd(dataFrame, "geometry")
- }
-
/**
* Convert a Spatial DF to a Spatial RDD. The geometry column can be at any place in the DF
*
@@ -47,7 +41,7 @@ object Adapter {
toSpatialRdd(dataFrame, 0, List[String]())
}
else {
- val fieldList = dataFrame.schema.toList.map(f => f.name.toString)
+ val fieldList = dataFrame.schema.toList.map(f => f.name)
toSpatialRdd(dataFrame, geometryFieldName, fieldList.filter(p => !p.equalsIgnoreCase(geometryFieldName)))
}
}
@@ -60,17 +54,17 @@ object Adapter {
* @param fieldNames
* @return
*/
- def toSpatialRdd(dataFrame: DataFrame, geometryFieldName: String, fieldNames: List[String]): SpatialRDD[Geometry] = {
+ def toSpatialRdd(dataFrame: DataFrame, geometryFieldName: String, fieldNames: Seq[String]): SpatialRDD[Geometry] = {
var spatialRDD = new SpatialRDD[Geometry]
spatialRDD.rawSpatialRDD = toRdd(dataFrame, geometryFieldName).toJavaRDD()
import scala.collection.JavaConversions._
- if (fieldNames.nonEmpty) spatialRDD.fieldNames = fieldNames
+ if (fieldNames != null && fieldNames.nonEmpty) spatialRDD.fieldNames = fieldNames
else spatialRDD.fieldNames = null
spatialRDD
}
private def toRdd(dataFrame: DataFrame, geometryFieldName: String): RDD[Geometry] = {
- val fieldList = dataFrame.schema.toList.map(f => f.name.toString)
+ val fieldList = dataFrame.schema.toList.map(f => f.name)
val geomColId = fieldList.indexOf(geometryFieldName)
assert(geomColId >= 0)
toRdd(dataFrame, geomColId)
@@ -84,7 +78,7 @@ object Adapter {
* @param fieldNames
* @return
*/
- def toSpatialRdd(dataFrame: DataFrame, geometryColId: Int, fieldNames: List[String]): SpatialRDD[Geometry] = {
+ def toSpatialRdd(dataFrame: DataFrame, geometryColId: Int, fieldNames: Seq[String]): SpatialRDD[Geometry] = {
var spatialRDD = new SpatialRDD[Geometry]
spatialRDD.rawSpatialRDD = toRdd(dataFrame, geometryColId).toJavaRDD()
import scala.collection.JavaConversions._
@@ -106,7 +100,7 @@ object Adapter {
toSpatialRdd(dataFrame, 0, List[String]())
}
else {
- val fieldList = dataFrame.schema.toList.map(f => f.name.toString)
+ val fieldList = dataFrame.schema.toList.map(f => f.name)
val geometryFieldName = fieldList(geometryColId)
toSpatialRdd(dataFrame, geometryColId, fieldList.filter(p => !p.equalsIgnoreCase(geometryFieldName)))
}
@@ -118,64 +112,44 @@ object Adapter {
toDf(spatialRDD, null, sparkSession);
}
- def toDf[T <: Geometry](spatialRDD: SpatialRDD[T], fieldNames: List[String], sparkSession: SparkSession): DataFrame = {
- val rowRdd = spatialRDD.rawSpatialRDD.rdd.map[Row](f => Row.fromSeq(GeomUtils.printGeom(f).split("\t", -1).toSeq))
+ def toDf[T <: Geometry](spatialRDD: SpatialRDD[T], fieldNames: Seq[String], sparkSession: SparkSession): DataFrame = {
+ val rowRdd = spatialRDD.rawSpatialRDD.rdd.map[Row](f => {
+ var userData = f.getUserData
+ f.setUserData(null)
+ if (userData != null) Row.fromSeq(f +: userData.asInstanceOf[String].split("\t", -1))
+ else Row.fromSeq(Seq(f))
+ })
+ var cols:Seq[StructField] = Seq(StructField("geometry", GeometryUDT))
if (fieldNames != null && fieldNames.nonEmpty) {
- var fieldArray = new Array[StructField](fieldNames.size + 1)
- fieldArray(0) = StructField("geometry", StringType)
- for (i <- 1 until fieldArray.length) fieldArray(i) = StructField(fieldNames(i - 1), StringType)
- val schema = StructType(fieldArray)
- sparkSession.createDataFrame(rowRdd, schema)
- }
- else {
- var fieldArray = new Array[StructField](rowRdd.take(1)(0).size)
- fieldArray(0) = StructField("geometry", StringType)
- for (i <- 1 until fieldArray.length) fieldArray(i) = StructField("_c" + i, StringType)
- val schema = StructType(fieldArray)
- sparkSession.createDataFrame(rowRdd, schema)
+ cols = cols ++ fieldNames.map(f => StructField(f, StringType))
}
+ val schema = StructType(cols)
+ sparkSession.createDataFrame(rowRdd, schema)
}
def toDf(spatialPairRDD: JavaPairRDD[Geometry, Geometry], sparkSession: SparkSession): DataFrame = {
- val rowRdd = spatialPairRDD.rdd.map[Row](f => {
- val seq1 = GeomUtils.printGeom(f._1).split("\t").toSeq
- val seq2 = GeomUtils.printGeom(f._2).split("\t").toSeq
- val result = seq1 ++ seq2
- Row.fromSeq(result)
- })
- val leftgeomlength = spatialPairRDD.rdd.take(1)(0)._1.toString.split("\t").length
-
- var fieldArray = new Array[StructField](rowRdd.take(1)(0).size)
- for (i <- fieldArray.indices) fieldArray(i) = StructField("_c" + i, StringType)
- fieldArray(0) = StructField("leftgeometry", StringType)
- fieldArray(leftgeomlength) = StructField("rightgeometry", StringType)
- val schema = StructType(fieldArray)
- sparkSession.createDataFrame(rowRdd, schema)
+ toDf(spatialPairRDD, null, null, sparkSession)
}
- def toDf(spatialPairRDD: JavaPairRDD[Geometry, Geometry], leftFieldnames: List[String], rightFieldNames: List[String], sparkSession: SparkSession): DataFrame = {
- val rowRdd = spatialPairRDD.rdd.map[Row](f => {
- val seq1 = GeomUtils.printGeom(f._1).split("\t").toSeq
- val seq2 = GeomUtils.printGeom(f._2).split("\t").toSeq
- val result = seq1 ++ seq2
- Row.fromSeq(result)
+ def toDf(spatialPairRDD: JavaPairRDD[Geometry, Geometry], leftFieldnames: Seq[String], rightFieldNames: Seq[String], sparkSession: SparkSession): DataFrame = {
+ val rowRdd = spatialPairRDD.rdd.map(f => {
+ val left = getGeomAndFields(f._1, leftFieldnames)
+ val right = getGeomAndFields(f._2, rightFieldNames)
+ Row.fromSeq(left._1 ++ left._2 ++ right._1 ++ right._2)
})
- val leftgeometryName = List("leftgeometry")
- val rightgeometryName = List("rightgeometry")
- val fullFieldNames = leftgeometryName ++ leftFieldnames ++ rightgeometryName ++ rightFieldNames
- val schema = StructType(fullFieldNames.map(fieldName => StructField(fieldName, StringType)))
+ var cols:Seq[StructField] = Seq(StructField("leftgeometry", GeometryUDT))
+ if (leftFieldnames != null && leftFieldnames.nonEmpty) cols = cols ++ leftFieldnames.map(fName => StructField(fName, StringType))
+ cols = cols ++ Seq(StructField("rightgeometry", GeometryUDT))
+ if (rightFieldNames != null && rightFieldNames.nonEmpty) cols = cols ++ rightFieldNames.map(fName => StructField(fName, StringType))
+ val schema = StructType(cols)
sparkSession.createDataFrame(rowRdd, schema)
}
- private def toJavaRdd(dataFrame: DataFrame, geometryColId: Int): JavaRDD[Geometry] = {
- toRdd(dataFrame, geometryColId).toJavaRDD()
- }
-
private def toRdd(dataFrame: DataFrame, geometryColId: Int): RDD[Geometry] = {
dataFrame.rdd.map[Geometry](f => {
var geometry = f.get(geometryColId).asInstanceOf[Geometry]
var fieldSize = f.size
- var userData:String = null
+ var userData: String = null
if (fieldSize > 1) {
userData = ""
// Add all attributes into geometry user data
@@ -188,46 +162,13 @@ object Adapter {
})
}
- private def toJavaRdd(dataFrame: DataFrame): JavaRDD[Geometry] = {
- toRdd(dataFrame, 0).toJavaRDD()
- }
-
- private def toRdd(dataFrame: DataFrame): RDD[Geometry] = {
- dataFrame.rdd.map[Geometry](f => {
- var geometry = f.get(0).asInstanceOf[Geometry]
- var fieldSize = f.size
- var userData:String = null
- if (fieldSize > 1) {
- userData = ""
- // Add all attributes into geometry user data
- for (i <- 1 until f.size) userData += f.get(i) + "\t"
- userData = userData.dropRight(1)
- }
- geometry.setUserData(userData)
- geometry
- })
- }
-
- /*
- * Since UserDefinedType is hidden from users. We cannot directly return spatialRDD to spatialDf.
- * Let's wait for Spark side's change
- */
- /*
- def toSpatialDf(spatialRDD: SpatialRDD[Geometry], sparkSession: SparkSession): DataFrame =
- {
- val rowRdd = spatialRDD.rawSpatialRDD.rdd.map[Row](f =>
- {
- var seq = Seq(new GeometryWrapper(f))
- var otherFields = f.getUserData.asInstanceOf[String].split("\t").toSeq
- seq :+ otherFields
- Row.fromSeq(seq)
- }
- )
- var fieldArray = new Array[StructField](rowRdd.take(1)(0).size)
- fieldArray(0) = StructField("rddshape", ArrayType(ByteType, false))
- for (i <- 1 to fieldArray.length-1) fieldArray(i) = StructField("_c"+i, StringType)
- val schema = StructType(fieldArray)
- return sparkSession.createDataFrame(rowRdd, schema)
+ private def getGeomAndFields(geom: Geometry, fieldNames: Seq[String]): (Seq[Geometry], Seq[String]) = {
+ if (fieldNames != null && fieldNames.nonEmpty) {
+ val userData = "" + geom.getUserData.asInstanceOf[String]
+ val fields = userData.split("\t")
+// geom.setUserData(null) // Set to null will lead to the null pointer exception of the previous line. Not sure why.
+ (Seq(geom), fields)
+ }
+ else (Seq(geom), Seq())
}
- */
-}
\ No newline at end of file
+}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
index 9dd5c2c..b7c79d6 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/UDT/GeometryUDT.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Geometry
-private[sql] class GeometryUDT extends UserDefinedType[Geometry] {
+class GeometryUDT extends UserDefinedType[Geometry] {
override def sqlType: DataType = ArrayType(ByteType, containsNull = false)
override def pyUDT: String = "sedona.sql.types.GeometryType"
@@ -40,7 +40,7 @@ private[sql] class GeometryUDT extends UserDefinedType[Geometry] {
GeometrySerializer.deserialize(values)
}
}
+}
- case object GeometryUDT extends GeometryUDT
+case object GeometryUDT extends org.apache.spark.sql.sedona_sql.UDT.GeometryUDT with scala.Serializable
-}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
index c94af99..ae05583 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala
@@ -51,7 +51,7 @@ case class ST_PointFromText(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -77,7 +77,7 @@ case class ST_PolygonFromText(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -105,7 +105,7 @@ case class ST_LineStringFromText(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -130,7 +130,7 @@ case class ST_GeomFromWKT(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -155,7 +155,7 @@ case class ST_GeomFromText(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -180,7 +180,7 @@ case class ST_GeomFromWKB(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -211,7 +211,7 @@ case class ST_GeomFromGeoJSON(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -241,7 +241,7 @@ case class ST_Point(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(geometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -289,7 +289,7 @@ case class ST_PolygonFromEnvelope(inputExpressions: Seq[Expression]) extends Exp
new GenericArrayData(GeometrySerializer.serialize(polygon))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
index 89dade2..83af40a 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala
@@ -94,7 +94,7 @@ case class ST_ConvexHull(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(geometry.convexHull()))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -143,7 +143,7 @@ case class ST_Buffer(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(geometry.buffer(buffer)))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -164,7 +164,7 @@ case class ST_Envelope(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(geometry.getEnvelope()))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -224,7 +224,7 @@ case class ST_Centroid(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(geometry.getCentroid()))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -274,7 +274,7 @@ case class ST_Transform(inputExpressions: Seq[Expression])
targetFactory.createCoordinateReferenceSystem(codeString)
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -316,7 +316,7 @@ case class ST_Intersection(inputExpressions: Seq[Expression])
return new GenericArrayData(GeometrySerializer.serialize(leftgeometry.intersection(rightgeometry)))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -444,7 +444,7 @@ case class ST_SimplifyPreserveTopology(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(simplifiedGeometry))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -467,7 +467,7 @@ case class ST_PrecisionReduce(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(precisionReduce.reduce(geometry)))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -561,7 +561,7 @@ case class ST_LineMerge(inputExpressions: Seq[Expression])
new GenericArrayData(GeometrySerializer.serialize(output))
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -649,7 +649,7 @@ case class ST_StartPoint(inputExpressions: Seq[Expression])
}
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -668,7 +668,7 @@ case class ST_Boundary(inputExpressions: Seq[Expression])
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -688,7 +688,7 @@ case class ST_EndPoint(inputExpressions: Seq[Expression])
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -708,7 +708,7 @@ case class ST_ExteriorRing(inputExpressions: Seq[Expression])
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -739,7 +739,7 @@ case class ST_GeometryN(inputExpressions: Seq[Expression])
}
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -760,7 +760,7 @@ case class ST_InteriorRingN(inputExpressions: Seq[Expression])
}
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
}
@@ -784,7 +784,7 @@ case class ST_Dump(inputExpressions: Seq[Expression])
ArrayData.toArrayData(geometryCollection)
}
- override def dataType: DataType = ArrayType(new GeometryUDT())
+ override def dataType: DataType = ArrayType(GeometryUDT)
override def children: Seq[Expression] = inputExpressions
}
@@ -799,7 +799,7 @@ case class ST_DumpPoints(inputExpressions: Seq[Expression])
ArrayData.toArrayData(geometry.getPoints.map(geom => geom.toGenericArrayData))
}
- override def dataType: DataType = ArrayType(new GeometryUDT())
+ override def dataType: DataType = ArrayType(GeometryUDT)
override def children: Seq[Expression] = inputExpressions
}
@@ -891,7 +891,7 @@ case class ST_AddPoint(inputExpressions: Seq[Expression])
private def lineStringFromCoordinates(coordinates: Array[Coordinate]): LineString =
geometryFactory.createLineString(coordinates)
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
@@ -920,7 +920,7 @@ case class ST_RemovePoint(inputExpressions: Seq[Expression])
}
}
- override def dataType: DataType = new GeometryUDT()
+ override def dataType: DataType = GeometryUDT
override def children: Seq[Expression] = inputExpressions
diff --git a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions_udaf/AggregateFunctions.scala b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions_udaf/AggregateFunctions.scala
index 5b3c0f5..f527d0b 100644
--- a/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions_udaf/AggregateFunctions.scala
+++ b/sql/src/main/scala/org/apache/spark/sql/sedona_sql/expressions_udaf/AggregateFunctions.scala
@@ -25,13 +25,13 @@ import org.locationtech.jts.geom.{Coordinate, Geometry, GeometryFactory}
*/
class ST_Union_Aggr extends UserDefinedAggregateFunction {
- override def inputSchema: StructType = StructType(StructField("Union", new GeometryUDT) :: Nil)
+ override def inputSchema: StructType = StructType(StructField("Union", GeometryUDT) :: Nil)
override def bufferSchema: StructType = StructType(
- StructField("Union", new GeometryUDT) :: Nil
+ StructField("Union", GeometryUDT) :: Nil
)
- override def dataType: DataType = new GeometryUDT
+ override def dataType: DataType = GeometryUDT
override def deterministic: Boolean = true
@@ -72,15 +72,15 @@ class ST_Union_Aggr extends UserDefinedAggregateFunction {
class ST_Envelope_Aggr extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
- StructType(StructField("Envelope", new GeometryUDT) :: Nil)
+ StructType(StructField("Envelope", GeometryUDT) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
- StructField("Envelope", new GeometryUDT) :: Nil
+ StructField("Envelope", GeometryUDT) :: Nil
)
// This is the output type of your aggregatation function.
- override def dataType: DataType = new GeometryUDT
+ override def dataType: DataType = GeometryUDT
override def deterministic: Boolean = true
@@ -176,13 +176,13 @@ class ST_Envelope_Aggr extends UserDefinedAggregateFunction {
* Return the polygon intersection of all Polygon in the given column
*/
class ST_Intersection_Aggr extends UserDefinedAggregateFunction {
- override def inputSchema: StructType = StructType(StructField("Intersection", new GeometryUDT) :: Nil)
+ override def inputSchema: StructType = StructType(StructField("Intersection", GeometryUDT) :: Nil)
override def bufferSchema: StructType = StructType(
- StructField("Intersection", new GeometryUDT) :: Nil
+ StructField("Intersection", GeometryUDT) :: Nil
)
- override def dataType: DataType = new GeometryUDT
+ override def dataType: DataType = GeometryUDT
override def deterministic: Boolean = true
diff --git a/sql/src/test/resources/small/areas.csv b/sql/src/test/resources/small/areas.csv
new file mode 100644
index 0000000..0977b3d
--- /dev/null
+++ b/sql/src/test/resources/small/areas.csv
@@ -0,0 +1,4 @@
+1 POLYGON((2 2, 2 4, 3 5, 7 5, 9 3, 8 1, 4 1, 2 2)) area1
+2 POLYGON((10 3, 10 6, 14 6, 14 3, 10 3)) area2
+3 POLYGON((-1 -1, -1 -3, -2 -5, -6 -8, -5 -2, -3 -2, -1 -1)) area3
+4 POLYGON((0 4, -3 3, -8 6, -6 8, -2 9, 0 4)) area4
\ No newline at end of file
diff --git a/sql/src/test/resources/small/points.csv b/sql/src/test/resources/small/points.csv
new file mode 100644
index 0000000..7ac95ec
--- /dev/null
+++ b/sql/src/test/resources/small/points.csv
@@ -0,0 +1,10 @@
+1 POINT(11 5) bank1
+2 POINT(12 1) bank2
+3 POINT(-1 -1) bank3
+4 POINT(-3 5) bank4
+5 POINT(9 8) bank5
+6 POINT(4 3) bank6
+7 POINT(-4 -5) bank7
+8 POINT(4 -2) bank8
+9 POINT(-3 1) bank9
+10 POINT(-7 3) bank10
\ No newline at end of file
diff --git a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
index 28274bb..f3405fd 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala
@@ -59,6 +59,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll {
val csvPoint1InputLocation = resourceFolder + "equalitycheckfiles/testequals_point1.csv"
val csvPoint2InputLocation = resourceFolder + "equalitycheckfiles/testequals_point2.csv"
val geojsonIdInputLocation = resourceFolder + "testContainsId.json"
+ val smallAreasLocation: String = resourceFolder + "small/areas.csv"
+ val smallPointsLocation: String = resourceFolder + "small/points.csv"
override def beforeAll(): Unit = {
SedonaSQLRegistrator.registerAll(sparkSession)
diff --git a/sql/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala
index dc75217..b4cc427 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/adapterTestScala.scala
@@ -24,8 +24,11 @@ import org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader
import org.apache.sedona.core.spatialOperator.JoinQuery
import org.apache.sedona.core.spatialRDD.{CircleRDD, PolygonRDD}
import org.apache.sedona.sql.utils.Adapter
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.scalatest.GivenWhenThen
+import org.scalatest.matchers.should.Matchers
-class adapterTestScala extends TestBaseScala {
+class adapterTestScala extends TestBaseScala with Matchers with GivenWhenThen{
describe("Sedona-SQL Scala Adapter Test") {
@@ -35,7 +38,8 @@ class adapterTestScala extends TestBaseScala {
var spatialDf = sparkSession.sql("select ST_PointFromText(inputtable._c0,\",\") as arealandmark from inputtable")
var spatialRDD = Adapter.toSpatialRdd(spatialDf, "arealandmark")
spatialRDD.analyze()
- Adapter.toDf(spatialRDD, sparkSession).show(1)
+ val resultDf = Adapter.toDf(spatialRDD, sparkSession)
+ assert(resultDf.schema(0).dataType == GeometryUDT)
}
it("Read CSV point at a different column id into a SpatialRDD") {
@@ -64,21 +68,8 @@ class adapterTestScala extends TestBaseScala {
var spatialDf = sparkSession.sql("select ST_Point(cast(inputtable._c0 as Decimal(24,20)),cast(inputtable._c1 as Decimal(24,20))) as arealandmark from inputtable")
var spatialRDD = Adapter.toSpatialRdd(spatialDf, "arealandmark")
assert(Adapter.toDf(spatialRDD, sparkSession).columns.length == 1)
- // Adapter.toDf(spatialRDD, sparkSession).show(1)
}
- it("Read CSV point into a SpatialRDD with unique Id by passing coordinates") {
- var df = sparkSession.read.format("csv").option("delimiter", ",").option("header", "false").load(arealmPointInputLocation)
- df.createOrReplaceTempView("inputtable")
- // Use Column _c0 as the unique Id but the id can be anything in the same row
- var spatialDf = sparkSession.sql("select ST_Point(cast(inputtable._c0 as Decimal(24,20)),cast(inputtable._c1 as Decimal(24,20))) as arealandmark from inputtable")
- var spatialRDD = Adapter.toSpatialRdd(spatialDf, "arealandmark")
- spatialRDD.analyze()
- assert(Adapter.toDf(spatialRDD, sparkSession).columns.length == 1)
- // Adapter.toDf(spatialRDD, sparkSession).show(1)
- }
-
-
it("Read mixed WKT geometries into a SpatialRDD") {
var df = sparkSession.read.format("csv").option("delimiter", "\t").option("header", "false").load(mixedWktGeometryInputLocation)
df.createOrReplaceTempView("inputtable")
@@ -101,21 +92,21 @@ class adapterTestScala extends TestBaseScala {
var spatialRDD = ShapefileReader.readToGeometryRDD(sparkSession.sparkContext, shapefileInputLocation)
spatialRDD.analyze()
var df = Adapter.toDf(spatialRDD, sparkSession)
- df.show(1)
+ assert(df.schema.toList.map(f => f.name).mkString("\t").equals("geometry\tSTATEFP\tCOUNTYFP\tCOUNTYNS\tAFFGEOID\tGEOID\tNAME\tLSAD\tALAND\tAWATER"))
+ assert(df.count() == 3220)
}
it("Read shapefileWithMissing -> DataFrame") {
var spatialRDD = ShapefileReader.readToGeometryRDD(sparkSession.sparkContext, shapefileWithMissingsTrailingInputLocation)
spatialRDD.analyze()
var df = Adapter.toDf(spatialRDD, sparkSession)
- df.show(1)
+ assert(df.count() == 3)
}
it("Read GeoJSON to DataFrame") {
- import org.apache.spark.sql.functions.{callUDF, col}
var spatialRDD = new PolygonRDD(sparkSession.sparkContext, geojsonInputLocation, FileDataSplitter.GEOJSON, true)
spatialRDD.analyze()
- var df = Adapter.toDf(spatialRDD, sparkSession).withColumn("geometry", callUDF("ST_GeomFromWKT", col("geometry")))
+ var df = Adapter.toDf(spatialRDD, sparkSession)//.withColumn("geometry", callUDF("ST_GeomFromWKT", col("geometry")))
assert(df.columns(1) == "STATEFP")
}
@@ -138,12 +129,19 @@ class adapterTestScala extends TestBaseScala {
pointRDD.buildIndex(IndexType.QUADTREE, true)
val joinResultPairRDD = JoinQuery.SpatialJoinQueryFlat(pointRDD, polygonRDD, true, true)
-
val joinResultDf = Adapter.toDf(joinResultPairRDD, sparkSession)
- joinResultDf.show(1)
-
- val joinResultDf2 = Adapter.toDf(joinResultPairRDD, List("abc", "def"), List(), sparkSession)
- joinResultDf2.show(1)
+ assert(joinResultDf.schema(0).dataType == GeometryUDT)
+ assert(joinResultDf.schema(1).dataType == GeometryUDT)
+ assert(joinResultDf.schema(0).name == "leftgeometry")
+ assert(joinResultDf.schema(1).name == "rightgeometry")
+ import scala.collection.JavaConversions._
+ val joinResultDf2 = Adapter.toDf(joinResultPairRDD, polygonRDD.fieldNames, List(), sparkSession)
+ assert(joinResultDf2.schema(0).dataType == GeometryUDT)
+ assert(joinResultDf2.schema(0).name == "leftgeometry")
+ assert(joinResultDf2.schema(1).name == "abc")
+ assert(joinResultDf2.schema(2).name == "def")
+ assert(joinResultDf2.schema(3).dataType == GeometryUDT)
+ assert(joinResultDf2.schema(3).name == "rightgeometry")
}
it("Convert distance join result to DataFrame") {
@@ -168,7 +166,10 @@ class adapterTestScala extends TestBaseScala {
var joinResultPairRDD = JoinQuery.DistanceJoinQueryFlat(pointRDD, circleRDD, true, true)
var joinResultDf = Adapter.toDf(joinResultPairRDD, sparkSession)
- joinResultDf.show(1)
+ assert(joinResultDf.schema(0).dataType == GeometryUDT)
+ assert(joinResultDf.schema(1).dataType == GeometryUDT)
+ assert(joinResultDf.schema(0).name == "leftgeometry")
+ assert(joinResultDf.schema(1).name == "rightgeometry")
}
it("load id column Data check") {
@@ -179,6 +180,5 @@ class adapterTestScala extends TestBaseScala {
assert(df.count() == 1)
}
-
}
}
diff --git a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
index 6aa225b..0b6c7bd 100644
--- a/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
+++ b/sql/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala
@@ -107,8 +107,6 @@ class constructorTestScala extends TestBaseScala {
spatialRDD.analyze()
var df = Adapter.toDf(spatialRDD, sparkSession)
assert(df.columns(1) == "STATEFP")
- import org.apache.spark.sql.functions.{callUDF, col}
- df = df.withColumn("geometry", callUDF("ST_GeomFromWKT", col("geometry")))
var spatialRDD2 = Adapter.toSpatialRdd(df, "geometry")
Adapter.toDf(spatialRDD2, sparkSession).show(1)
}