You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/26 06:53:13 UTC

[GitHub] [spark] zhengruifeng commented on a change in pull request #32199: [SPARK-35100][ML] Refactor AFT - support virtual centering

zhengruifeng commented on a change in pull request #32199:
URL: https://github.com/apache/spark/pull/32199#discussion_r619753068



##########
File path: mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/AFTBlockAggregator.scala
##########
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.feature._
+import org.apache.spark.ml.linalg._
+
+/**
+ * AFTBlockAggregator computes the gradient and loss as used in AFT survival regression
+ * for blocks in sparse or dense matrix in an online fashion.
+ *
+ * Two AFTBlockAggregator can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * NOTE: The feature values are expected to already have be scaled (multiplied by bcInverseStd,
+ * but NOT centered) before computation.
+ *
+ * @param bcCoefficients The coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term. When true, will perform data centering
+ *                     in a virtual way. Then we MUST adjust the intercept of both initial
+ *                     coefficients and final solution in the caller.
+ */
+private[ml] class AFTBlockAggregator (
+    bcScaledMean: Broadcast[Array[Double]],
+    fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
+  extends DifferentiableLossAggregator[InstanceBlock,
+    AFTBlockAggregator] {
+
+  protected override val dim: Int = bcCoefficients.value.size
+  private val numFeatures = dim - 2
+
+  @transient private lazy val coefficientsArray = bcCoefficients.value match {
+    case DenseVector(values) => values
+    case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
+      s" but got type ${bcCoefficients.value.getClass}.")
+  }
+
+  @transient private lazy val linear = Vectors.dense(coefficientsArray.take(numFeatures))
+
+  // pre-computed margin of an empty vector.
+  // with this variable as an offset, for a sparse vector, we only need to
+  // deal with non-zero values in prediction.
+  private val marginOffset = if (fitIntercept) {
+    coefficientsArray(dim - 2) -
+      BLAS.getBLAS(numFeatures).ddot(numFeatures, coefficientsArray, 1, bcScaledMean.value, 1)
+  } else {
+    Double.NaN
+  }
+
+  /**
+   * Add a new training instance block to this BlockAFTAggregator, and update the loss and
+   * gradient of the objective function.
+   *
+   * @return This BlockAFTAggregator object.
+   */
+  def add(block: InstanceBlock): this.type = {
+    require(block.matrix.isTransposed)
+    require(numFeatures == block.numFeatures, s"Dimensions mismatch when adding new " +
+      s"instance. Expecting $numFeatures but got ${block.numFeatures}.")
+    require(block.labels.forall(_ > 0.0), "The lifetime or label should be greater than 0.")
+
+    val size = block.size
+    val intercept = coefficientsArray(dim - 2)
+    // sigma is the scale parameter of the AFT model
+    val sigma = math.exp(coefficientsArray(dim - 1))
+
+    // vec/arr here represents margins
+    val vec = new DenseVector(Array.ofDim[Double](size))
+    val arr = vec.values
+    if (fitIntercept) {
+      val offset = if (fitIntercept) marginOffset else intercept
+      java.util.Arrays.fill(arr, offset)
+    }
+    BLAS.gemv(1.0, block.matrix, linear, 1.0, vec)
+
+    // in-place convert margins to gradient scales
+    // then, vec represents gradient scales
+    var localLossSum = 0.0
+    var i = 0
+    var sigmaGradSum = 0.0
+    var multiplierSum = 0.0
+    while (i < size) {
+      val ti = block.getLabel(i)
+      // here use Instance.weight to store censor for convenience
+      val delta = block.getWeight(i)
+      val margin = arr(i)
+      val epsilon = (math.log(ti) - margin) / sigma
+      val expEpsilon = math.exp(epsilon)
+      localLossSum += delta * math.log(sigma) - delta * epsilon + expEpsilon
+      val multiplier = (delta - expEpsilon) / sigma
+      arr(i) = multiplier
+      multiplierSum += multiplier
+      sigmaGradSum += delta + multiplier * sigma * epsilon
+      i += 1
+    }
+    lossSum += localLossSum
+    weightSum += size
+
+    block.matrix match {
+      case dm: DenseMatrix =>
+        BLAS.nativeBLAS.dgemv("N", dm.numCols, dm.numRows, 1.0, dm.values, dm.numCols,
+          arr, 1, 1.0, gradientSumArray, 1)
+
+      case sm: SparseMatrix =>
+        val linearGradSumVec = new DenseVector(Array.ofDim[Double](numFeatures))
+        BLAS.gemv(1.0, sm.transpose, vec, 0.0, linearGradSumVec)
+        BLAS.javaBLAS.daxpy(numFeatures, 1.0, linearGradSumVec.values, 1,
+          gradientSumArray, 1)
+    }

Review comment:
       > any opportunities to refactor some common block aggregator code now that there are many of them?
   
   I think it is possible to simplify here by adding some BLAS util fuctions in the future. Then we do not need to create some temporary vector/matrix/array for just matching the shape. I just have a idea now, and will check it in the future. If possible, there will be a seperate PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org