You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sa...@apache.org on 2023/07/04 06:46:13 UTC

[tvm] branch unity updated: [VM] Add repetition penalty functions to Relax VM (#15219)

This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 05278ea77b [VM] Add repetition penalty functions to Relax VM (#15219)
05278ea77b is described below

commit 05278ea77b24e6d6d0984589200af68af6c46002
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Jul 3 23:46:07 2023 -0700

    [VM] Add repetition penalty functions to Relax VM (#15219)
    
    This PR brings the repetition penalty functions to the lm-support
    part of Relax VM. The repetition penalty function is invoked at
    LM sampling time and is inplace.
    
    Co-authored-by: Zihao Ye <ex...@outlook.com>
---
 src/runtime/relax_vm/lm_support.cc |  48 ++++++++++
 web/src/runtime.ts                 | 173 +++++++++++++++++++++----------------
 2 files changed, 147 insertions(+), 74 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index bdee444608..39948ae522 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -424,6 +424,54 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
 
 TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
 
+// This is an inplace operation.
+void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) {
+  ICHECK(logits.IsContiguous());
+  ICHECK(token_ids.IsContiguous());
+  ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+  ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
+  ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+  ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+  float* logits_raw_data = static_cast<float*>(logits->data);
+  int* token_ids_data = static_cast<int*>(token_ids->data);
+  size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
+  for (size_t i = 0; i < num_token_ids; ++i) {
+    int token_id = token_ids_data[i];
+    if (logits_raw_data[token_id] <= 0) {
+      logits_raw_data[token_id] *= penalty;
+    } else {  // logits > 0
+      logits_raw_data[token_id] /= penalty;
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty);
+
+// This is an inplace operation.
+void ApplySoftmaxWithTemperature(NDArray logits, double temperature) {
+  ICHECK(logits.IsContiguous());
+  ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+  ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+  int vocab_size = logits->shape[logits->ndim - 1];
+  float* logits_raw_data = static_cast<float*>(logits->data);
+  float inv_temp = 1.0f / temperature;
+  float m = std::numeric_limits<float>::min();
+  double d = 0.0f;
+  for (int i = 0; i < vocab_size; ++i) {
+    float x = logits_raw_data[i] * inv_temp;
+    float m_prev = m;
+    m = std::max(m, x);
+    d = d * std::exp(m_prev - m) + std::exp(x - m);
+  }
+  for (int i = 0; i < vocab_size; ++i) {
+    float x = logits_raw_data[i] * inv_temp;
+    logits_raw_data[i] = std::exp(x - m) / d;
+  }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature")
+    .set_body_typed(ApplySoftmaxWithTemperature);
+
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 0c615fd506..021bfceb1e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -128,9 +128,9 @@ class FFILibrary implements Disposable {
 
     throw new Error(
       "Cannt detect wasm memory from imports " +
-        imports +
-        " or exports" +
-        instance.exports
+      imports +
+      " or exports" +
+      instance.exports
     );
   }
 }
@@ -140,21 +140,23 @@ class FFILibrary implements Disposable {
  * Manages extra runtime context for the runtime.
  */
 class RuntimeContext implements Disposable {
-  arrayGetItem : PackedFunc;
-  arrayGetSize : PackedFunc;
-  arrayMake : PackedFunc;
-  stringMake : PackedFunc;
-  getFFIString : PackedFunc;
-  getSysLib : PackedFunc;
-  arrayCacheGet : PackedFunc;
-  arrayCacheUpdate : PackedFunc;
-  arrayCacheRemove : PackedFunc;
-  arrayCacheClear : PackedFunc;
-  arrayDecodeStorage : PackedFunc;
-  paramModuleFromCache : PackedFunc;
-  makeShapeTuple : PackedFunc;
-  ndarrayCreateView : PackedFunc;
-  sampleTopPFromLogits : PackedFunc;
+  arrayGetItem: PackedFunc;
+  arrayGetSize: PackedFunc;
+  arrayMake: PackedFunc;
+  stringMake: PackedFunc;
+  getFFIString: PackedFunc;
+  getSysLib: PackedFunc;
+  arrayCacheGet: PackedFunc;
+  arrayCacheUpdate: PackedFunc;
+  arrayCacheRemove: PackedFunc;
+  arrayCacheClear: PackedFunc;
+  arrayDecodeStorage: PackedFunc;
+  paramModuleFromCache: PackedFunc;
+  makeShapeTuple: PackedFunc;
+  ndarrayCreateView: PackedFunc;
+  sampleTopPFromLogits: PackedFunc;
+  applyRepetitionPenalty: PackedFunc;
+  applySoftmaxWithTemperature: PackedFunc;
 
   private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
 
@@ -174,6 +176,8 @@ class RuntimeContext implements Disposable {
     this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
     this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
     this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
+    this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
+    this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
   }
 
   dispose(): void {
@@ -193,13 +197,15 @@ class RuntimeContext implements Disposable {
     this.makeShapeTuple.dispose();
     this.ndarrayCreateView.dispose();
     this.sampleTopPFromLogits.dispose();
+    this.applyRepetitionPenalty.dispose();
+    this.applySoftmaxWithTemperature.dispose();
   }
 
-  beginScope() : void {
+  beginScope(): void {
     this.autoDisposeScope.push([]);
   }
 
-  endScope() : void {
+  endScope(): void {
     if (this.autoDisposeScope.length == 0) {
       throw Error("tvm.endScope called when the stack is empty.");
     }
@@ -296,7 +302,7 @@ class PackedFuncCell implements Disposable {
     }
   }
 
-  getHandle(requireNotNull : boolean = true): Pointer {
+  getHandle(requireNotNull: boolean = true): Pointer {
     if (requireNotNull && this.handle == 0) {
       throw Error("PackedFunc has already been disposed");
     }
@@ -491,29 +497,29 @@ export class NDArray implements Disposable {
    * @param shape The shape of the view.
    * @returns The new sliced ndarray.
    */
-  view(shape: Array<number>) : NDArray {
+  view(shape: Array<number>): NDArray {
     const shapeArray = shape.map((value) => new Scalar(value, "int"));
     return this.ctx.ndarrayCreateView(this, this.ctx.makeShapeTuple(...shapeArray));
   }
 
- /**
-  * Get handle of ndarray, check it is not null.
-  *
-  * @param requireNotNull require handle is not null.
-  * @returns The handle.
-  */
-  getHandle(requireNotNull : boolean = true): Pointer {
+  /**
+   * Get handle of ndarray, check it is not null.
+   *
+   * @param requireNotNull require handle is not null.
+   * @returns The handle.
+   */
+  getHandle(requireNotNull: boolean = true): Pointer {
     if (requireNotNull && this.handle == 0) {
       throw Error("NDArray has already been disposed");
     }
     return this.handle;
   }
 
- /**
-  * Get dataPtr of NDarray
-  *
-  * @returns The handle.
-  */
+  /**
+   * Get dataPtr of NDarray
+   *
+   * @returns The handle.
+   */
   getDataPtr(): Pointer {
     if (this.handle == 0) {
       throw Error("NDArray has already been disposed");
@@ -553,9 +559,9 @@ export class NDArray implements Disposable {
       if (data.length != size) {
         throw new Error(
           "data size and shape mismatch data.length" +
-            data.length +
-            " vs " +
-            size
+          data.length +
+          " vs " +
+          size
         );
       }
       let buffer: ArrayBuffer;
@@ -703,7 +709,7 @@ export class Module implements Disposable {
    * @param requireNotNull require handle is not null.
    * @returns The handle.
    */
-  getHandle(requireNotNull : boolean = true): Pointer {
+  getHandle(requireNotNull: boolean = true): Pointer {
     if (requireNotNull && this.handle == 0) {
       throw Error("Module has already been disposed");
     }
@@ -733,7 +739,7 @@ export class Module implements Disposable {
       (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)(
         this.getHandle(),
         stack.ptrFromOffset(nameOffset),
-        queryImports? 1 : 0,
+        queryImports ? 1 : 0,
         outPtr
       )
     );
@@ -763,7 +769,7 @@ export class Module implements Disposable {
 /**
  * Generic object base
  */
- export class TVMObject implements Disposable {
+export class TVMObject implements Disposable {
   private handle: Pointer;
   private lib: FFILibrary;
   protected ctx: RuntimeContext;
@@ -832,7 +838,7 @@ export class Module implements Disposable {
         outPtr
       )
     );
-    const result =this.lib.memory.loadCString(
+    const result = this.lib.memory.loadCString(
       this.lib.memory.loadPointer(outPtr)
     );
     this.lib.recycleCallStack(stack);
@@ -859,7 +865,7 @@ export class TVMArray extends TVMObject {
   /**
    * @returns the size of the array.
    */
-  size() : number {
+  size(): number {
     return this.ctx.arrayGetSize(this) as number;
   }
   /**
@@ -867,7 +873,7 @@ export class TVMArray extends TVMObject {
    * @param index the array index.
    * @returns The element.
    */
-  get(index : number) : TVMObjectBase {
+  get(index: number): TVMObjectBase {
     return this.ctx.arrayGetItem(this, new Scalar(index, "int32")) as TVMObjectBase;
   }
 }
@@ -885,7 +891,7 @@ export class TVMString extends TVMObject {
   /**
    * @returns the size of the array.
    */
-  toString() : string {
+  toString(): string {
     return this.ctx.getFFIString(this) as string;
   }
 }
@@ -1086,7 +1092,7 @@ export class Instance implements Disposable {
    * @number The number of times to compute the average.
    * @repeat The number of times to repeat the run.
    */
-  async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=1): Promise<number[]> {
+  async benchmark(run: () => void, dev: DLDevice, number = 10, repeat = 1): Promise<number[]> {
     // Skip first run as it can involve GPU warmup and module loading time.
     const perf = compact.getPerformance();
     const results = [];
@@ -1152,7 +1158,7 @@ export class Instance implements Disposable {
    *       we will need to call {@link moveToParentScope}
    *       for the objects that are created in the scope.
    */
-  withNewScope<T>(action: ()=>T): T {
+  withNewScope<T>(action: () => T): T {
     this.beginScope();
     const val = action();
     this.endScope();
@@ -1168,7 +1174,7 @@ export class Instance implements Disposable {
    *       the current scope. You only need to do so when you call
    *       {@link detachFromCurrentScope} to create a detached object.
    */
-  attachToCurrentScope<T extends Disposable>(obj: T) : T {
+  attachToCurrentScope<T extends Disposable>(obj: T): T {
     return this.ctx.attachToCurrentScope(obj);
   }
 
@@ -1181,7 +1187,7 @@ export class Instance implements Disposable {
    * @param obj The object to be moved.
    * @returns The input obj.
    */
-  moveToParentScope<T extends Disposable>(obj: T) : T {
+  moveToParentScope<T extends Disposable>(obj: T): T {
     return this.ctx.moveToParentScope(obj);
   }
 
@@ -1340,12 +1346,12 @@ export class Instance implements Disposable {
     return ret;
   }
 
-   /**
-   * Setup a virtual machine module with given device.
-   *
-   * @param dev DLDevice the device.
-   * @returns The created virtual machime.
-   */
+  /**
+  * Setup a virtual machine module with given device.
+  *
+  * @param dev DLDevice the device.
+  * @returns The created virtual machime.
+  */
   createVirtualMachine(dev: DLDevice): VirtualMachine {
     const mod = this.ctx.detachFromCurrentScope(
       this.systemLib().getFunction("vm_load_executable")()
@@ -1374,7 +1380,7 @@ export class Instance implements Disposable {
    * @param numParams  Number of parameters.
    * @returns
    */
-  getParamsFromCache(prefix: string, numParams: number) : TVMObject {
+  getParamsFromCache(prefix: string, numParams: number): TVMObject {
     return (this.ctx.paramModuleFromCache(
       prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")();
   }
@@ -1384,7 +1390,7 @@ export class Instance implements Disposable {
    * @param name  The name of array.
    * @returns  The result.
    */
-  ndarrayCacheGet(name: string) : NDArray | undefined {
+  ndarrayCacheGet(name: string): NDArray | undefined {
     return this.ctx.arrayCacheGet(name);
   }
 
@@ -1393,7 +1399,7 @@ export class Instance implements Disposable {
    * @param name  The name of array.
    * @returns  The result.
    */
-  ndarrayCacheRemove(name: string) : NDArray | undefined {
+  ndarrayCacheRemove(name: string): NDArray | undefined {
     return this.ctx.arrayCacheRemove(name);
   }
 
@@ -1466,10 +1472,10 @@ export class Instance implements Disposable {
     let fetchedBytes = 0;
     let timeElapsed = 0;
 
-    const reportCallback = (iter: number)=> {
+    const reportCallback = (iter: number) => {
       // report
       for (let j = 0; j < this.initProgressCallback.length; ++j) {
-        let text = "Fetching param cache[" + iter + "/" + list.length+ "]: ";
+        let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
         text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
         text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
         text += timeElapsed + " secs elapsed.";
@@ -1702,6 +1708,25 @@ export class Instance implements Disposable {
     return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
   }
 
+  /**
+   * Apply repetition penalty to the logits.
+   * @param logits The input logits before penalty.
+   * @param token_ids The appeared token ids.
+   * @param penalty The penalty factor.
+   */
+  applyRepetitionPenalty(logits: NDArray, token_ids: NDArray, penalty: number) {
+    return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty);
+  }
+
+  /**
+   * Apply softmax with temperature to the logits.
+   * @param logits The input logits before softmax w/ temperature.
+   * @param temperature The temperature factor.
+   */
+  applySoftmaxWithTemperature(logits: NDArray, temperature: number) {
+    return this.ctx.applySoftmaxWithTemperature(logits, temperature);
+  }
+
   /**
    * Bind canvas to the current WebGPU context
    * @param canvas The canvas.
@@ -1750,7 +1775,7 @@ export class Instance implements Disposable {
    * @param inputs The input array
    * @returns The result array.
    */
-   makeTVMArray(
+  makeTVMArray(
     inputs: Array<TVMObjectBase>
   ): TVMArray {
     return this.ctx.arrayMake(...inputs) as TVMArray;
@@ -1771,7 +1796,7 @@ export class Instance implements Disposable {
    * @param shape The shape .
    * @returns The created shape tuple.
    */
-  makeShapeTuple(shape: Array<number>) : TVMObject {
+  makeShapeTuple(shape: Array<number>): TVMObject {
     const shapeArray = shape.map((value) => new Scalar(value, "int"));
     return this.ctx.makeShapeTuple(...shapeArray);
   }
@@ -1782,7 +1807,7 @@ export class Instance implements Disposable {
    */
   typeKey2Index(
     typeKey: string
-  ) : number {
+  ): number {
     const stack = this.lib.getOrAllocCallStack();
     const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1);
     stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey));
@@ -1895,7 +1920,7 @@ export class Instance implements Disposable {
         // report
         for (let j = 0; j < this.initProgressCallback.length; ++j) {
           const progress = finishCounter / fmapEntries.length;
-          let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length+ "]: ";
+          let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length + "]: ";
           text += Math.floor(progress * 100).toString() + "% completed, "
           text += timeElapsed + " secs elapsed.";
           this.initProgressCallback[j]({
@@ -1905,7 +1930,7 @@ export class Instance implements Disposable {
           });
         }
       });
-      allEvents = Promise.all([allEvents, event]).then(()=>{});
+      allEvents = Promise.all([allEvents, event]).then(() => { });
     }
     await allEvents;
     assert(finishCounter == fmapEntries.length);
@@ -1937,11 +1962,11 @@ export class Instance implements Disposable {
     this.registerObjectConstructor("Array",
       (handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
         return new TVMArray(handle, lib, ctx);
-    });
+      });
     this.registerObjectConstructor("runtime.String",
       (handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
         return new TVMString(handle, lib, ctx);
-    });
+      });
   }
 
   /** Register global packed functions needed by the backend to the env. */
@@ -1992,7 +2017,7 @@ export class Instance implements Disposable {
         } while (durationMs < minRepeatMs && absoluteZeroTimes < limitZeroTimeIterations);
         const speed = durationMs / setupNumber / 1000;
         result.push(speed);
-        if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0 ) {
+        if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0) {
           await new Promise(r => setTimeout(r, cooldownIntervalMs));
         }
       }
@@ -2030,9 +2055,9 @@ export class Instance implements Disposable {
     this.lib.checkCall(
       (this.exports
         .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)(
-        findex,
-        outPtr
-      )
+          findex,
+          outPtr
+        )
     );
     const ret = this.makePackedFunc(this.memory.loadPointer(outPtr));
     this.lib.recycleCallStack(stack);
@@ -2236,9 +2261,9 @@ export class Instance implements Disposable {
         return this.memory.loadI64(rvaluePtr);
       case ArgTypeCode.Float:
         return this.memory.loadF64(rvaluePtr);
-        case ArgTypeCode.TVMOpaqueHandle: {
-          return this.memory.loadPointer(rvaluePtr);
-        }
+      case ArgTypeCode.TVMOpaqueHandle: {
+        return this.memory.loadPointer(rvaluePtr);
+      }
       case ArgTypeCode.TVMNDArrayHandle: {
         return this.ctx.attachToCurrentScope(
           new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx)
@@ -2256,7 +2281,7 @@ export class Instance implements Disposable {
       }
       case ArgTypeCode.TVMModuleHandle: {
         return this.ctx.attachToCurrentScope(
-            new Module(
+          new Module(
             this.memory.loadPointer(rvaluePtr),
             this.lib,
             (ptr: Pointer) => {