You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/03/23 02:53:07 UTC

[GitHub] [spark] sadhen edited a comment on pull request #31735: [SPARK-34799][PYTHON][SQL] Return User-defined types from Pandas UDF

sadhen edited a comment on pull request #31735:
URL: https://github.com/apache/spark/pull/31735#issuecomment-804539589


   @eddyxu I wrote a UDT with Timestamp, but failed to make it work. See the demo pr: https://github.com/eddyxu/spark/pull/4
   
   For ExampleBox, serialize to list works fine. But for ExamplePointWithTimeUDT, to make `pa.StructArray.from_pandas` work, we need to serialize it to dict. For the following snippets, the python part works fine. But I failed to deserialize the ExamplePointWithTime properly in the Scala part.
   
   ``` python
   class ExamplePointWithTimeUDT(UserDefinedType):
       """
       User-defined type (UDT) for ExamplePointWithTime.
       """
   
       @classmethod
       def sqlType(self):
           return StructType([
               StructField("x", DoubleType(), False),
               StructField("y", DoubleType(), True),
               StructField("ts", TimestampType(), False),
           ])
   
       @classmethod
       def module(cls):
           return 'pyspark.sql.tests'
   
       @classmethod
       def scalaUDT(cls):
           return 'org.apache.spark.sql.test.ExamplePointWithTimeUDT'
   
       def serialize(self, obj):
           return {'x': obj.x, 'y': obj.y, 'ts': obj.ts}
   
       def deserialize(self, datum):
           return ExamplePointWithTime(datum['x'], datum['y'], datum['ts'])
   
   
   class ExamplePointWithTime:
       """
       An example class to demonstrate UDT in Scala, Java, and Python.
       """
   
       __UDT__ = ExamplePointWithTimeUDT()
   
       def __init__(self, x, y, ts):
           self.x = x
           self.y = y
           self.ts = ts
   
       def __repr__(self):
           return "ExamplePointWithTime(%s,%s,%s)" % (self.x, self.y, self.ts)
   
       def __str__(self):
           return "(%s,%s,%s)" % (self.x, self.y, self.ts)
   
       def __eq__(self, other):
           return isinstance(other, self.__class__) \
               and other.x == self.x and other.y == self.y \
               and other.ts == self.ts
   ```
   
   ``` scala
   package org.apache.spark.sql.test
   
   import java.sql.Timestamp
   
   import org.apache.spark.sql.catalyst.InternalRow
   import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
   import org.apache.spark.sql.types.{DataType, DoubleType, SQLUserDefinedType, StructField, StructType, TimestampType, UserDefinedType}
   
   
   /**
    * An example class to demonstrate UDT in Scala, Java, and Python.
    * @param x x coordinate
    * @param y y coordinate
    * @param ts timestamp
    */
   @SQLUserDefinedType(udt = classOf[ExamplePointUDT])
   private[sql] class ExamplePointWithTime(val x: Double, val y: Double, val ts: Timestamp)
     extends Serializable {
   
     override def hashCode(): Int = {
       var hash = 13
       hash = hash * 31 + x.hashCode()
       hash = hash * 31 + y.hashCode()
       hash = hash * 31 + ts.hashCode()
       hash
     }
   
     override def equals(other: Any): Boolean = other match {
       case that: ExamplePointWithTime =>
         this.x == that.x && this.y == that.y && this.ts == that.ts
       case _ => false
     }
   
     override def toString(): String = s"($x, $y, ${ts.toString})"
   }
   
   /**
    * User-defined type for [[ExamplePoint]].
    */
   private[sql] class ExamplePointWithTimeUDT extends UserDefinedType[ExamplePointWithTime] {
   
     override def sqlType: DataType = StructType(Array(
       StructField("x", DoubleType, nullable = false),
       StructField("y", DoubleType, nullable = true),
       StructField("ts", TimestampType, nullable = false)
     ))
   
     override def pyUDT: String = "pyspark.testing.sqlutils.ExamplePointWithTimeUDT"
   
     override def serialize(p: ExamplePointWithTime): ArrayBasedMapData = {
       ArrayBasedMapData(
         Array("x", "y", "ts"),
         Array(p.x, p.y, p.ts)
       )
     }
   
     override def deserialize(datum: Any): ExamplePointWithTime = {
       datum match {
         case row: InternalRow =>
           new ExamplePointWithTime(
             row.getDouble(0),
             row.getDouble(1),
             row.get(2, TimestampType)      // .asInstanceOf[Timestamp]   it is Long, cannot be casted to Timestamp
           )
       }
     }
   
     override def userClass: Class[ExamplePointWithTime] = classOf[ExamplePointWithTime]
   
     private[spark] override def asNullable: ExamplePointWithTimeUDT = this
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org