You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "dtenedor (via GitHub)" <gi...@apache.org> on 2023/02/01 20:55:29 UTC

[GitHub] [spark] dtenedor commented on a diff in pull request #38419: [SPARK-40945][SQL] Support built-in function to truncate numbers

dtenedor commented on code in PR #38419:
URL: https://github.com/apache/spark/pull/38419#discussion_r1093715384


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+      // places to truncate is +ve, then we can return that input value without any
+      // modification as there is no +ve decimal place to be truncated from an integral number
+      // Truncate the input only if the value of decimal places to truncate is < 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>

Review Comment:
   no need for the parentheses around `scaleValue < 0` (here and below)



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query

Review Comment:
   no need to repeat these comments from the base class if they appear here unchanged.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+      // places to truncate is +ve, then we can return that input value without any
+      // modification as there is no +ve decimal place to be truncated from an integral number
+      // Truncate the input only if the value of decimal places to truncate is < 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = "org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input number.
+        // If 0, it removes all the decimal values and returns only the integer.
+        // If negative, the number is truncated to the left side of the decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+        // places to truncate is +ve, then we can return that input value without any
+        // modification as there is no +ve decimal place to be truncated from an integral number
+        // Truncate the input only if the value of decimal places to truncate is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>

Review Comment:
   same here



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+      // places to truncate is +ve, then we can return that input value without any
+      // modification as there is no +ve decimal place to be truncated from an integral number
+      // Truncate the input only if the value of decimal places to truncate is < 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = "org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input number.
+        // If 0, it removes all the decimal values and returns only the integer.
+        // If negative, the number is truncated to the left side of the decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+        // places to truncate is +ve, then we can return that input value without any
+        // modification as there is no +ve decimal place to be truncated from an integral number
+        // Truncate the input only if the value of decimal places to truncate is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>
+            s"""(byte)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case ShortType if (scaleValue < 0) =>
+            s"""(short)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case IntegerType if (scaleValue < 0) =>
+            s"""(int)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case LongType if (scaleValue < 0) =>
+            s"""($methodName(
+               |$input, $scaleValue))""".stripMargin
+          // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+          // will accept both -ve and +ve values
+          case FloatType =>
+            s"""$methodName(
+               |$input, $scaleValue).floatValue()""".stripMargin
+          case DoubleType =>
+            s"""$methodName(
+               |$input, $scaleValue).doubleValue()""".stripMargin
+          case DecimalType.Fixed(p, s) =>
+            s"""Decimal.apply(
+               |$methodName(
+               |${input}.toJavaBigDecimal(), $scaleValue), $p, $s)""".stripMargin
+          case _ => s"$input"
+        }
+      })
+}
+
+object TruncNumber {
+
+  /**
+   * To truncate whole numbers; byte, short, int, and long types.
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // Here we truncate the number by the absolute value of the position.
+      // For example, if the input is 123 and the scale is -2, then the result is 100.
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type.
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type.
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) input
+    else {
+      val wholePart = input.toBigInteger
+      position match {
+        case pos if pos >= 0 =>
+          // Here we truncate only the decimal part by the value of the position.
+          val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+          // If the position is zero OR Decimal part is zero,
+          // we extract the whole part and return it.
+          // For example,
+          // if the input is 123.456 and the scale is 0, the result will be 123.
+          // if the input is 123.000 and the scale is > 0, the result will be 123.
+          val wholePartBD = new jm.BigDecimal(wholePart)
+          if (pos == 0 || jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+            wholePartBD
+          } else {
+            // To avoid overflow during multiplication, we extract the decimal part from the input,
+            // truncate it and then append it to the whole part.
+            // For example, if the input is 123.456 and the scale is 2, the result will be 123.45.
+            val pow = jm.BigDecimal.valueOf(Math.pow(10, pos).toLong)
+            val truncated = new jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+            wholePartBD.add(truncated)
+          }
+        case pos if pos < 0 =>
+          // Here we truncate the whole part by the absolute value of the position.
+          // For example, if the input is 123.456 and the scale is -2, the result will be 100.
+          val pow = jm.BigInteger.valueOf(Math.pow(10, Math.abs(pos)).toLong)

Review Comment:
   same here



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+      // places to truncate is +ve, then we can return that input value without any
+      // modification as there is no +ve decimal place to be truncated from an integral number
+      // Truncate the input only if the value of decimal places to truncate is < 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = "org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input number.
+        // If 0, it removes all the decimal values and returns only the integer.
+        // If negative, the number is truncated to the left side of the decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+        // places to truncate is +ve, then we can return that input value without any
+        // modification as there is no +ve decimal place to be truncated from an integral number
+        // Truncate the input only if the value of decimal places to truncate is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>
+            s"""(byte)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case ShortType if (scaleValue < 0) =>
+            s"""(short)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case IntegerType if (scaleValue < 0) =>
+            s"""(int)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case LongType if (scaleValue < 0) =>
+            s"""($methodName(
+               |$input, $scaleValue))""".stripMargin
+          // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+          // will accept both -ve and +ve values
+          case FloatType =>
+            s"""$methodName(
+               |$input, $scaleValue).floatValue()""".stripMargin
+          case DoubleType =>
+            s"""$methodName(
+               |$input, $scaleValue).doubleValue()""".stripMargin
+          case DecimalType.Fixed(p, s) =>
+            s"""Decimal.apply(
+               |$methodName(
+               |${input}.toJavaBigDecimal(), $scaleValue), $p, $s)""".stripMargin
+          case _ => s"$input"
+        }
+      })
+}
+
+object TruncNumber {
+
+  /**
+   * To truncate whole numbers; byte, short, int, and long types.
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // Here we truncate the number by the absolute value of the position.
+      // For example, if the input is 123 and the scale is -2, then the result is 100.
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type.
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type.
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) input
+    else {
+      val wholePart = input.toBigInteger
+      position match {
+        case pos if pos >= 0 =>
+          // Here we truncate only the decimal part by the value of the position.
+          val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+          // If the position is zero OR Decimal part is zero,
+          // we extract the whole part and return it.
+          // For example,
+          // if the input is 123.456 and the scale is 0, the result will be 123.
+          // if the input is 123.000 and the scale is > 0, the result will be 123.
+          val wholePartBD = new jm.BigDecimal(wholePart)
+          if (pos == 0 || jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+            wholePartBD
+          } else {
+            // To avoid overflow during multiplication, we extract the decimal part from the input,
+            // truncate it and then append it to the whole part.
+            // For example, if the input is 123.456 and the scale is 2, the result will be 123.45.
+            val pow = jm.BigDecimal.valueOf(Math.pow(10, pos).toLong)

Review Comment:
   looks like this logic is duplicated with L457-458, can we dedup it to one place?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,275 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+
+/**
+ * Truncates a number to the specified number of digits.
+ * @param child
+ *   expression to get the number to be truncated.
+ * @param scale
+ *   expression to get the number of decimal places to truncate to.
+ */
+case class TruncNumber(child: Expression, scale: Expression)
+    extends BaseBinaryExpression
+    with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = child.dataType
+
+  /**
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    (dataType, input1) match {
+      // Trunc function accepts a second parameter to truncate the input number.
+      // If 0, it removes all the decimal values and returns only the integer.
+      // If negative, the number is truncated to the left side of the decimal point.
+      // Value  of decimal places to truncate can range from -ve to +ve
+      // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+      // places to truncate is +ve, then we can return that input value without any
+      // modification as there is no +ve decimal place to be truncated from an integral number
+      // Truncate the input only if the value of decimal places to truncate is < 0
+      case (ByteType, input: Byte) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).toByte
+      case (ShortType, input: Short) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).shortValue
+      case (IntegerType, input: Int) if (scaleValue < 0) =>
+        TruncNumber.trunc(input.toLong, scaleValue).intValue
+      case (LongType, input: Long) if (scaleValue < 0) =>
+        TruncNumber.trunc(input, scaleValue).longValue
+      // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+      // will accept both -ve and +ve values
+      case (FloatType, input: Float) =>
+        TruncNumber.trunc(input, scaleValue).floatValue
+      case (DoubleType, input: Double) =>
+        TruncNumber.trunc(input, scaleValue).doubleValue
+      case (DecimalType.Fixed(p, s), input: Decimal) =>
+        Decimal(TruncNumber.trunc(input.toJavaBigDecimal, scaleValue), p, s)
+      case _ => input1
+    }
+  }
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression.
+   * This overridden implementation delegates the overloaded TruncNumber.trunc methods based on
+   * data type of input values
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        val methodName = "org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc"
+        // Trunc function accepts a second parameter to truncate the input number.
+        // If 0, it removes all the decimal values and returns only the integer.
+        // If negative, the number is truncated to the left side of the decimal point.
+        // Value  of decimal places to truncate can range from -ve to +ve
+        // 1) In the case of integral numbers, as there is no decimal part if the value of decimal
+        // places to truncate is +ve, then we can return that input value without any
+        // modification as there is no +ve decimal place to be truncated from an integral number
+        // Truncate the input only if the value of decimal places to truncate is < 0
+        dataType match {
+          case ByteType if (scaleValue < 0) =>
+            s"""(byte)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case ShortType if (scaleValue < 0) =>
+            s"""(short)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case IntegerType if (scaleValue < 0) =>
+            s"""(int)($methodName(
+               |(long)$input, $scaleValue))""".stripMargin
+          case LongType if (scaleValue < 0) =>
+            s"""($methodName(
+               |$input, $scaleValue))""".stripMargin
+          // 2) In the case of Float, Double, and Decimal , TruncNumber.trunc
+          // will accept both -ve and +ve values
+          case FloatType =>
+            s"""$methodName(
+               |$input, $scaleValue).floatValue()""".stripMargin
+          case DoubleType =>
+            s"""$methodName(
+               |$input, $scaleValue).doubleValue()""".stripMargin
+          case DecimalType.Fixed(p, s) =>
+            s"""Decimal.apply(
+               |$methodName(
+               |${input}.toJavaBigDecimal(), $scaleValue), $p, $s)""".stripMargin
+          case _ => s"$input"
+        }
+      })
+}
+
+object TruncNumber {
+
+  /**
+   * To truncate whole numbers; byte, short, int, and long types.
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // Here we truncate the number by the absolute value of the position.
+      // For example, if the input is 123 and the scale is -2, then the result is 100.
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type.
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type.
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) input
+    else {
+      val wholePart = input.toBigInteger
+      position match {
+        case pos if pos >= 0 =>
+          // Here we truncate only the decimal part by the value of the position.
+          val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+          // If the position is zero OR Decimal part is zero,
+          // we extract the whole part and return it.
+          // For example,
+          // if the input is 123.456 and the scale is 0, the result will be 123.
+          // if the input is 123.000 and the scale is > 0, the result will be 123.
+          val wholePartBD = new jm.BigDecimal(wholePart)
+          if (pos == 0 || jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+            wholePartBD
+          } else {
+            // To avoid overflow during multiplication, we extract the decimal part from the input,
+            // truncate it and then append it to the whole part.
+            // For example, if the input is 123.456 and the scale is 2, the result will be 123.45.
+            val pow = jm.BigDecimal.valueOf(Math.pow(10, pos).toLong)
+            val truncated = new jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+            wholePartBD.add(truncated)
+          }
+        case pos if pos < 0 =>
+          // Here we truncate the whole part by the absolute value of the position.
+          // For example, if the input is 123.456 and the scale is -2, the result will be 100.
+          val pow = jm.BigInteger.valueOf(Math.pow(10, Math.abs(pos)).toLong)
+          new jm.BigDecimal(wholePart.divide(pow).multiply(pow), 0)
+      }
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """_FUNC_(number[, position]) - Returns the number after truncating to the specified number of digits.
+    An optional `position` parameter can be specified to truncate digits to the right of the decimal point.
+    If 0, it removes all the decimal values and returns only the integer.
+    If negative, the number is truncated to the left side of the decimal point.
+    Note that there is an overloaded version of this function to truncate date values.
+    _FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
+  """,
+  arguments = """
+    Arguments:
+      * number - number to be truncated
+      * position - number of decimal places up to which the given number is to be truncated
+    Arguments: To truncate date values:
+      * date - date value or valid date string
+      * fmt - the format representing the unit to be truncated to
+          - "YEAR", "YYYY", "YY" - truncate to the first date of the year that the `date` falls in
+          - "QUARTER" - truncate to the first date of the quarter that the `date` falls in
+          - "MONTH", "MM", "MON" - truncate to the first date of the month that the `date` falls in
+          - "WEEK" - truncate to the Monday of the week that the `date` falls in
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('2019-08-04', 'week');
+       2019-07-29
+      > SELECT _FUNC_('2019-08-04', 'quarter');
+       2019-07-01
+      > SELECT _FUNC_('2009-02-12', 'MM');
+       2009-02-01
+      > SELECT _FUNC_('2015-10-27', 'YEAR');
+       2015-01-01
+      > SELECT _FUNC_(-10.11, 0);
+       -10.00
+      > SELECT _FUNC_(10.11, -1);
+       10.00
+      > SELECT _FUNC_(100.61, 0);
+       100.00
+      > SELECT _FUNC_(-19087.1560, -3);
+       -19000.0000
+      > SELECT _FUNC_(10876.5489, -1);
+       10870.0000
+      > SELECT _FUNC_(-7767.1160, 2);
+       -7767.1100
+      > SELECT _FUNC_(17646.6019, 3);
+       17646.6010
+  """,
+  since = "3.5.0",
+  group = "math_funcs")
+// scalastyle:on line.size.limit
+object TruncExpressionBuilder extends ExpressionBuilder {
+  override def build(funcName: String, expressions: Seq[Expression]): Expression = {
+    val numArgs = expressions.length
+    if (numArgs < 1) {
+      throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)
+    }
+    expressions(0).dataType match {
+      case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
+          DecimalType.Fixed(_, _) =>
+        buildTruncNumber(funcName, expressions)
+      case _ => buildTruncDate(funcName, expressions)
+    }
+  }
+
+  private def buildTruncDate(funcName: String, expressions: Seq[Expression]) = {
+    val numArgs = expressions.length
+    if (numArgs == 2) {
+      TruncDate(expressions(0), expressions(1))
+    } else {
+      throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)
+    }
+  }
+
+  private def buildTruncNumber(funcName: String, expressions: Seq[Expression]) = {
+    val numArgs = expressions.length
+    if (numArgs < 1) {
+      throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)
+    }
+    val position = if (numArgs == 2) {
+      val positionExpr = expressions(1)
+      val scale_value = positionExpr.eval()
+      if (!(positionExpr.foldable && positionExpr.dataType == IntegerType) ||

Review Comment:
   ```suggestion
         if (!positionExpr.foldable || positionExpr.dataType != IntegerType ||
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala:
##########
@@ -331,6 +332,268 @@ case class RoundCeil(child: Expression, scale: Expression)
     copy(child = newLeft, scale = newRight)
 }
 
+case class TruncNumber(child: Expression, scale: Expression)
+  extends BaseBinaryExpression with NullIntolerant {
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression,
+      newRight: Expression): TruncNumber = copy(child = newLeft, scale = newRight)
+
+  /**
+   * Returns Java source code that can be compiled to evaluate this expression. The default
+   * behavior is to call the eval method of the expression. Concrete expression implementations
+   * should override this to do actual code generation.
+   *
+   * @param ctx
+   *   a [[CodegenContext]]
+   * @param ev
+   *   an [[ExprCode]] with unique terms.
+   * @return
+   *   an [[ExprCode]] containing the Java source code to generate the given expression
+   */
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+    defineCodeGen(
+      ctx,
+      ev,
+      (input, _) => {
+        dataType match {
+          case ByteType if (_scale <= 0) =>
+            s"""(byte)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case ShortType if (_scale <= 0) =>
+            s"""(short)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case IntegerType if (_scale <= 0) =>
+            s"""(int)(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |(long)$input, ${_scale}))""".stripMargin
+          case LongType if (_scale <= 0) =>
+            s"""(org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}))""".stripMargin
+          case FloatType if (_scale <= 0) =>
+            s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}).floatValue()""".stripMargin
+          case DoubleType if (_scale <= 0) =>
+            s"""org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |$input, ${_scale}).doubleValue()""".stripMargin
+          case DecimalType.Fixed(_, _) =>
+            s"""Decimal.apply(
+             |org.apache.spark.sql.catalyst.expressions.TruncNumber.trunc(
+             |${input}.toJavaBigDecimal(), ${_scale}))""".stripMargin
+          case _ => s"$input"
+        }
+      })
+
+  /**
+   * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query
+   * the dataType of an unresolved expression (i.e., when `resolved` == false).
+   */
+  override lazy val dataType: DataType = {
+    child.dataType match {
+      case DecimalType.Fixed(p, s) =>
+        val newPosition =
+          if (_scale > 0) {
+            if (_scale >= s) {
+              s
+            } else {
+              _scale
+            }
+          } else {
+            0
+          }
+        DecimalType(p - s + newPosition, newPosition)
+      case t => t
+    }
+  }
+
+  /**
+   * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
+   * nullability, they can override this method to save null-check code. If we need full control
+   * of evaluation process, we should override [[eval]].
+   */
+  override protected def nullSafeEval(input1: Any, input2: Any): Any = {
+    dataType match {
+      case ByteType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Byte].toLong, _scale).toByte
+      case ShortType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Short].toLong, _scale).shortValue
+      case IntegerType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Int].toLong, _scale).intValue
+      case LongType if (_scale <= 0) =>
+        TruncNumber.trunc(input1.asInstanceOf[Long], _scale).longValue
+      case FloatType =>
+        TruncNumber.trunc(input1.asInstanceOf[Float], _scale).floatValue
+      case DoubleType =>
+        TruncNumber.trunc(input1.asInstanceOf[Double], _scale).doubleValue
+      case DecimalType.Fixed(p, s) =>
+        Decimal(TruncNumber.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, _scale))
+      case _ => input1
+    }
+  }
+}
+
+object TruncNumber {
+  /**
+   * To truncate whole numbers ; byte, short, int, long types
+   */
+  def trunc(input: Long, position: Int): Long = {
+    if (position >= 0) {
+      input
+    } else {
+      // position is -ve, truncate the number by absolute value of position
+      // eg: input 123 , scale -2 , result 100
+      val pow = Math.pow(10, Math.abs(position)).toLong
+      (input / pow) * pow
+    }
+  }
+
+  /**
+   * To truncate double and float type
+   */
+  def trunc(input: Double, position: Int): BigDecimal = {
+    trunc(jm.BigDecimal.valueOf(input), position)
+  }
+
+  /**
+   * To truncate decimal type
+   */
+  def trunc(input: jm.BigDecimal, position: Int): jm.BigDecimal = {
+    if (input.scale < position) {
+      input
+    } else {
+      val wholePart = input.toBigInteger
+      if (position > 0) {
+        // position is +ve , truncate only the decimal part by value of position
+        val pow = jm.BigDecimal.valueOf(Math.pow(10, position).toLong)
+        val decimalPart = input.remainder(java.math.BigDecimal.ONE)
+        // To avoid overflow while multiplication, extract decimal part first,
+        // truncate it and then add it to whole part
+        // eg: input 123.456, scale 2, result 123.45
+        if (jm.BigDecimal.ZERO.compareTo(decimalPart) == 0) {
+          new jm.BigDecimal(wholePart)
+        } else {
+          val newRemainder = new jm.BigDecimal(decimalPart.multiply(pow).toBigInteger).divide(pow)
+          new jm.BigDecimal(wholePart).add(newRemainder)
+        }
+      } else if (position == 0) {
+        // position is 0, extract whole part
+        // eg: input 123.456, scale 0, result 123
+        new jm.BigDecimal(wholePart)
+      } else {
+        // position is -ve, truncate the whole part by absolute value of position
+        // eg: input 123.456, scale -2, result 100
+        if (jm.BigInteger.ZERO.compareTo(wholePart) == 0) {
+          new jm.BigDecimal(wholePart)
+        } else {
+          val pow = jm.BigInteger.valueOf(Math.pow(10, Math.abs(position)).toLong)
+          new jm.BigDecimal(wholePart.divide(pow).multiply(pow), 0)
+        }
+      }
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """_FUNC_(number[, position]) - Returns the number after truncated to the specified places.
+    An optional `position` parameter can be specified to truncate digits to the right of the decimal point.
+    If 0, it removes all the decimal values and returns only the integer.
+    If negative, the number is truncated to the left side of the decimal point.
+    There is an overloaded version of this function to truncate date values
+    _FUNC_(date, fmt) - Returns `date` with the time portion of the day truncated to the unit specified by the format model `fmt`.
+  """,
+  arguments = """
+    Arguments:
+      * number - number to be truncated
+      * position - number of decimal places up to which the given number is to be truncated
+    Arguments: To truncate date value:
+      * date - date value or valid date string
+      * fmt - the format representing the unit to be truncated to
+          - "YEAR", "YYYY", "YY" - truncate to the first date of the year that the `date` falls in
+          - "QUARTER" - truncate to the first date of the quarter that the `date` falls in
+          - "MONTH", "MM", "MON" - truncate to the first date of the month that the `date` falls in
+          - "WEEK" - truncate to the Monday of the week that the `date` falls in
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_('2019-08-04', 'week');
+       2019-07-29
+      > SELECT _FUNC_('2019-08-04', 'quarter');
+       2019-07-01
+      > SELECT _FUNC_('2009-02-12', 'MM');
+       2009-02-01
+      > SELECT _FUNC_('2015-10-27', 'YEAR');
+       2015-01-01
+      > SELECT _FUNC_(-10.11, 0);
+       -10
+      > SELECT _FUNC_(10.11, -1);
+       10
+      > SELECT _FUNC_(100.61, 0);
+       100
+      > SELECT _FUNC_(-19087.1560, -3);
+       -19000
+      > SELECT _FUNC_(10876.5489, -1);
+       10870
+      > SELECT _FUNC_(-7767.1160, 2);
+       -7767.11
+      > SELECT _FUNC_(17646.6019, 3);
+       17646.601
+  """,
+  since = "3.4.0",
+  group = "math_funcs")
+// scalastyle:on line.size.limit
+object TruncExpressionBuilder extends ExpressionBuilder {
+  override def build(funcName: String, expressions: Seq[Expression]): Expression = {
+    val numArgs = expressions.length
+    if (numArgs >= 1) {
+      expressions(0).dataType match {
+        case ByteType |  ShortType | IntegerType | LongType | FloatType | DoubleType
+             | DecimalType.Fixed(_, _) => buildTruncNumber(funcName, expressions)
+        case _ => buildTruncDate(funcName, expressions)
+      }
+    } else {
+      throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)
+    }
+  }
+
+  private def buildTruncDate(funcName: String, expressions: Seq[Expression]) = {
+    val numArgs = expressions.length
+    if (numArgs == 2) {
+      TruncDate(expressions(0), expressions(1))
+    } else {
+      throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)

Review Comment:
   This change does not seem to be reflected on GitHub; did you forget to push a commit?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org