You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sa...@apache.org on 2023/07/04 06:46:13 UTC
[tvm] branch unity updated: [VM] Add repetition penalty functions to Relax VM (#15219)
This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 05278ea77b [VM] Add repetition penalty functions to Relax VM (#15219)
05278ea77b is described below
commit 05278ea77b24e6d6d0984589200af68af6c46002
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Mon Jul 3 23:46:07 2023 -0700
[VM] Add repetition penalty functions to Relax VM (#15219)
This PR brings the repetition penalty functions to the lm-support
part of Relax VM. The repetition penalty function is invoked at
LM sampling time and is inplace.
Co-authored-by: Zihao Ye <ex...@outlook.com>
---
src/runtime/relax_vm/lm_support.cc | 48 ++++++++++
web/src/runtime.ts | 173 +++++++++++++++++++++----------------
2 files changed, 147 insertions(+), 74 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc
index bdee444608..39948ae522 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -424,6 +424,54 @@ int SampleTopPFromProb(NDArray prob, double top_p, double uniform_sample) {
TVM_REGISTER_GLOBAL("vm.builtin.sample_top_p_from_prob").set_body_typed(SampleTopPFromProb);
+// This is an inplace operation.
+void ApplyRepetitionPenalty(NDArray logits, NDArray token_ids, double penalty) {
+ ICHECK(logits.IsContiguous());
+ ICHECK(token_ids.IsContiguous());
+ ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+ ICHECK(token_ids.DataType() == DataType::Int(32)) << "token ids must be int32!";
+ ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+ ICHECK(token_ids->device.device_type == kDLCPU) << "token_ids device must be CPU!";
+ float* logits_raw_data = static_cast<float*>(logits->data);
+ int* token_ids_data = static_cast<int*>(token_ids->data);
+ size_t num_token_ids = token_ids->shape[token_ids->ndim - 1];
+ for (size_t i = 0; i < num_token_ids; ++i) {
+ int token_id = token_ids_data[i];
+ if (logits_raw_data[token_id] <= 0) {
+ logits_raw_data[token_id] *= penalty;
+ } else { // logits > 0
+ logits_raw_data[token_id] /= penalty;
+ }
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_repetition_penalty").set_body_typed(ApplyRepetitionPenalty);
+
+// This is an inplace operation.
+void ApplySoftmaxWithTemperature(NDArray logits, double temperature) {
+ ICHECK(logits.IsContiguous());
+ ICHECK(logits.DataType() == DataType::Float(32)) << "Logits data type is not float32!";
+ ICHECK(logits->device.device_type == kDLCPU) << "logits device must be CPU!";
+ int vocab_size = logits->shape[logits->ndim - 1];
+ float* logits_raw_data = static_cast<float*>(logits->data);
+ float inv_temp = 1.0f / temperature;
+ float m = std::numeric_limits<float>::min();
+ double d = 0.0f;
+ for (int i = 0; i < vocab_size; ++i) {
+ float x = logits_raw_data[i] * inv_temp;
+ float m_prev = m;
+ m = std::max(m, x);
+ d = d * std::exp(m_prev - m) + std::exp(x - m);
+ }
+ for (int i = 0; i < vocab_size; ++i) {
+ float x = logits_raw_data[i] * inv_temp;
+ logits_raw_data[i] = std::exp(x - m) / d;
+ }
+}
+
+TVM_REGISTER_GLOBAL("vm.builtin.apply_softmax_with_temperature")
+ .set_body_typed(ApplySoftmaxWithTemperature);
+
} // namespace relax_vm
} // namespace runtime
} // namespace tvm
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 0c615fd506..021bfceb1e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -128,9 +128,9 @@ class FFILibrary implements Disposable {
throw new Error(
"Cannt detect wasm memory from imports " +
- imports +
- " or exports" +
- instance.exports
+ imports +
+ " or exports" +
+ instance.exports
);
}
}
@@ -140,21 +140,23 @@ class FFILibrary implements Disposable {
* Manages extra runtime context for the runtime.
*/
class RuntimeContext implements Disposable {
- arrayGetItem : PackedFunc;
- arrayGetSize : PackedFunc;
- arrayMake : PackedFunc;
- stringMake : PackedFunc;
- getFFIString : PackedFunc;
- getSysLib : PackedFunc;
- arrayCacheGet : PackedFunc;
- arrayCacheUpdate : PackedFunc;
- arrayCacheRemove : PackedFunc;
- arrayCacheClear : PackedFunc;
- arrayDecodeStorage : PackedFunc;
- paramModuleFromCache : PackedFunc;
- makeShapeTuple : PackedFunc;
- ndarrayCreateView : PackedFunc;
- sampleTopPFromLogits : PackedFunc;
+ arrayGetItem: PackedFunc;
+ arrayGetSize: PackedFunc;
+ arrayMake: PackedFunc;
+ stringMake: PackedFunc;
+ getFFIString: PackedFunc;
+ getSysLib: PackedFunc;
+ arrayCacheGet: PackedFunc;
+ arrayCacheUpdate: PackedFunc;
+ arrayCacheRemove: PackedFunc;
+ arrayCacheClear: PackedFunc;
+ arrayDecodeStorage: PackedFunc;
+ paramModuleFromCache: PackedFunc;
+ makeShapeTuple: PackedFunc;
+ ndarrayCreateView: PackedFunc;
+ sampleTopPFromLogits: PackedFunc;
+ applyRepetitionPenalty: PackedFunc;
+ applySoftmaxWithTemperature: PackedFunc;
private autoDisposeScope: Array<Array<Disposable | undefined>> = [];
@@ -174,6 +176,8 @@ class RuntimeContext implements Disposable {
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
+ this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
+ this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
}
dispose(): void {
@@ -193,13 +197,15 @@ class RuntimeContext implements Disposable {
this.makeShapeTuple.dispose();
this.ndarrayCreateView.dispose();
this.sampleTopPFromLogits.dispose();
+ this.applyRepetitionPenalty.dispose();
+ this.applySoftmaxWithTemperature.dispose();
}
- beginScope() : void {
+ beginScope(): void {
this.autoDisposeScope.push([]);
}
- endScope() : void {
+ endScope(): void {
if (this.autoDisposeScope.length == 0) {
throw Error("tvm.endScope called when the stack is empty.");
}
@@ -296,7 +302,7 @@ class PackedFuncCell implements Disposable {
}
}
- getHandle(requireNotNull : boolean = true): Pointer {
+ getHandle(requireNotNull: boolean = true): Pointer {
if (requireNotNull && this.handle == 0) {
throw Error("PackedFunc has already been disposed");
}
@@ -491,29 +497,29 @@ export class NDArray implements Disposable {
* @param shape The shape of the view.
* @returns The new sliced ndarray.
*/
- view(shape: Array<number>) : NDArray {
+ view(shape: Array<number>): NDArray {
const shapeArray = shape.map((value) => new Scalar(value, "int"));
return this.ctx.ndarrayCreateView(this, this.ctx.makeShapeTuple(...shapeArray));
}
- /**
- * Get handle of ndarray, check it is not null.
- *
- * @param requireNotNull require handle is not null.
- * @returns The handle.
- */
- getHandle(requireNotNull : boolean = true): Pointer {
+ /**
+ * Get handle of ndarray, check it is not null.
+ *
+ * @param requireNotNull require handle is not null.
+ * @returns The handle.
+ */
+ getHandle(requireNotNull: boolean = true): Pointer {
if (requireNotNull && this.handle == 0) {
throw Error("NDArray has already been disposed");
}
return this.handle;
}
- /**
- * Get dataPtr of NDarray
- *
- * @returns The handle.
- */
+ /**
+ * Get dataPtr of NDarray
+ *
+ * @returns The handle.
+ */
getDataPtr(): Pointer {
if (this.handle == 0) {
throw Error("NDArray has already been disposed");
@@ -553,9 +559,9 @@ export class NDArray implements Disposable {
if (data.length != size) {
throw new Error(
"data size and shape mismatch data.length" +
- data.length +
- " vs " +
- size
+ data.length +
+ " vs " +
+ size
);
}
let buffer: ArrayBuffer;
@@ -703,7 +709,7 @@ export class Module implements Disposable {
* @param requireNotNull require handle is not null.
* @returns The handle.
*/
- getHandle(requireNotNull : boolean = true): Pointer {
+ getHandle(requireNotNull: boolean = true): Pointer {
if (requireNotNull && this.handle == 0) {
throw Error("Module has already been disposed");
}
@@ -733,7 +739,7 @@ export class Module implements Disposable {
(this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)(
this.getHandle(),
stack.ptrFromOffset(nameOffset),
- queryImports? 1 : 0,
+ queryImports ? 1 : 0,
outPtr
)
);
@@ -763,7 +769,7 @@ export class Module implements Disposable {
/**
* Generic object base
*/
- export class TVMObject implements Disposable {
+export class TVMObject implements Disposable {
private handle: Pointer;
private lib: FFILibrary;
protected ctx: RuntimeContext;
@@ -832,7 +838,7 @@ export class Module implements Disposable {
outPtr
)
);
- const result =this.lib.memory.loadCString(
+ const result = this.lib.memory.loadCString(
this.lib.memory.loadPointer(outPtr)
);
this.lib.recycleCallStack(stack);
@@ -859,7 +865,7 @@ export class TVMArray extends TVMObject {
/**
* @returns the size of the array.
*/
- size() : number {
+ size(): number {
return this.ctx.arrayGetSize(this) as number;
}
/**
@@ -867,7 +873,7 @@ export class TVMArray extends TVMObject {
* @param index the array index.
* @returns The element.
*/
- get(index : number) : TVMObjectBase {
+ get(index: number): TVMObjectBase {
return this.ctx.arrayGetItem(this, new Scalar(index, "int32")) as TVMObjectBase;
}
}
@@ -885,7 +891,7 @@ export class TVMString extends TVMObject {
/**
* @returns the size of the array.
*/
- toString() : string {
+ toString(): string {
return this.ctx.getFFIString(this) as string;
}
}
@@ -1086,7 +1092,7 @@ export class Instance implements Disposable {
* @number The number of times to compute the average.
* @repeat The number of times to repeat the run.
*/
- async benchmark(run: ()=>void, dev: DLDevice, number=10, repeat=1): Promise<number[]> {
+ async benchmark(run: () => void, dev: DLDevice, number = 10, repeat = 1): Promise<number[]> {
// Skip first run as it can involve GPU warmup and module loading time.
const perf = compact.getPerformance();
const results = [];
@@ -1152,7 +1158,7 @@ export class Instance implements Disposable {
* we will need to call {@link moveToParentScope}
* for the objects that are created in the scope.
*/
- withNewScope<T>(action: ()=>T): T {
+ withNewScope<T>(action: () => T): T {
this.beginScope();
const val = action();
this.endScope();
@@ -1168,7 +1174,7 @@ export class Instance implements Disposable {
* the current scope. You only need to do so when you call
* {@link detachFromCurrentScope} to create a detached object.
*/
- attachToCurrentScope<T extends Disposable>(obj: T) : T {
+ attachToCurrentScope<T extends Disposable>(obj: T): T {
return this.ctx.attachToCurrentScope(obj);
}
@@ -1181,7 +1187,7 @@ export class Instance implements Disposable {
* @param obj The object to be moved.
* @returns The input obj.
*/
- moveToParentScope<T extends Disposable>(obj: T) : T {
+ moveToParentScope<T extends Disposable>(obj: T): T {
return this.ctx.moveToParentScope(obj);
}
@@ -1340,12 +1346,12 @@ export class Instance implements Disposable {
return ret;
}
- /**
- * Setup a virtual machine module with given device.
- *
- * @param dev DLDevice the device.
- * @returns The created virtual machime.
- */
+ /**
+ * Setup a virtual machine module with given device.
+ *
+ * @param dev DLDevice the device.
+ * @returns The created virtual machime.
+ */
createVirtualMachine(dev: DLDevice): VirtualMachine {
const mod = this.ctx.detachFromCurrentScope(
this.systemLib().getFunction("vm_load_executable")()
@@ -1374,7 +1380,7 @@ export class Instance implements Disposable {
* @param numParams Number of parameters.
* @returns
*/
- getParamsFromCache(prefix: string, numParams: number) : TVMObject {
+ getParamsFromCache(prefix: string, numParams: number): TVMObject {
return (this.ctx.paramModuleFromCache(
prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")();
}
@@ -1384,7 +1390,7 @@ export class Instance implements Disposable {
* @param name The name of array.
* @returns The result.
*/
- ndarrayCacheGet(name: string) : NDArray | undefined {
+ ndarrayCacheGet(name: string): NDArray | undefined {
return this.ctx.arrayCacheGet(name);
}
@@ -1393,7 +1399,7 @@ export class Instance implements Disposable {
* @param name The name of array.
* @returns The result.
*/
- ndarrayCacheRemove(name: string) : NDArray | undefined {
+ ndarrayCacheRemove(name: string): NDArray | undefined {
return this.ctx.arrayCacheRemove(name);
}
@@ -1466,10 +1472,10 @@ export class Instance implements Disposable {
let fetchedBytes = 0;
let timeElapsed = 0;
- const reportCallback = (iter: number)=> {
+ const reportCallback = (iter: number) => {
// report
for (let j = 0; j < this.initProgressCallback.length; ++j) {
- let text = "Fetching param cache[" + iter + "/" + list.length+ "]: ";
+ let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
@@ -1702,6 +1708,25 @@ export class Instance implements Disposable {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
}
+ /**
+ * Apply repetition penalty to the logits.
+ * @param logits The input logits before penalty.
+ * @param token_ids The appeared token ids.
+ * @param penalty The penalty factor.
+ */
+ applyRepetitionPenalty(logits: NDArray, token_ids: NDArray, penalty: number) {
+ return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty);
+ }
+
+ /**
+ * Apply softmax with temperature to the logits.
+ * @param logits The input logits before softmax w/ temperature.
+ * @param temperature The temperature factor.
+ */
+ applySoftmaxWithTemperature(logits: NDArray, temperature: number) {
+ return this.ctx.applySoftmaxWithTemperature(logits, temperature);
+ }
+
/**
* Bind canvas to the current WebGPU context
* @param canvas The canvas.
@@ -1750,7 +1775,7 @@ export class Instance implements Disposable {
* @param inputs The input array
* @returns The result array.
*/
- makeTVMArray(
+ makeTVMArray(
inputs: Array<TVMObjectBase>
): TVMArray {
return this.ctx.arrayMake(...inputs) as TVMArray;
@@ -1771,7 +1796,7 @@ export class Instance implements Disposable {
* @param shape The shape .
* @returns The created shape tuple.
*/
- makeShapeTuple(shape: Array<number>) : TVMObject {
+ makeShapeTuple(shape: Array<number>): TVMObject {
const shapeArray = shape.map((value) => new Scalar(value, "int"));
return this.ctx.makeShapeTuple(...shapeArray);
}
@@ -1782,7 +1807,7 @@ export class Instance implements Disposable {
*/
typeKey2Index(
typeKey: string
- ) : number {
+ ): number {
const stack = this.lib.getOrAllocCallStack();
const typeKeyOffset = stack.allocRawBytes(typeKey.length + 1);
stack.storeRawBytes(typeKeyOffset, StringToUint8Array(typeKey));
@@ -1895,7 +1920,7 @@ export class Instance implements Disposable {
// report
for (let j = 0; j < this.initProgressCallback.length; ++j) {
const progress = finishCounter / fmapEntries.length;
- let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length+ "]: ";
+ let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length + "]: ";
text += Math.floor(progress * 100).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
this.initProgressCallback[j]({
@@ -1905,7 +1930,7 @@ export class Instance implements Disposable {
});
}
});
- allEvents = Promise.all([allEvents, event]).then(()=>{});
+ allEvents = Promise.all([allEvents, event]).then(() => { });
}
await allEvents;
assert(finishCounter == fmapEntries.length);
@@ -1937,11 +1962,11 @@ export class Instance implements Disposable {
this.registerObjectConstructor("Array",
(handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
return new TVMArray(handle, lib, ctx);
- });
+ });
this.registerObjectConstructor("runtime.String",
(handle: number, lib: FFILibrary, ctx: RuntimeContext) => {
return new TVMString(handle, lib, ctx);
- });
+ });
}
/** Register global packed functions needed by the backend to the env. */
@@ -1992,7 +2017,7 @@ export class Instance implements Disposable {
} while (durationMs < minRepeatMs && absoluteZeroTimes < limitZeroTimeIterations);
const speed = durationMs / setupNumber / 1000;
result.push(speed);
- if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0 ) {
+ if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) == 0) {
await new Promise(r => setTimeout(r, cooldownIntervalMs));
}
}
@@ -2030,9 +2055,9 @@ export class Instance implements Disposable {
this.lib.checkCall(
(this.exports
.TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)(
- findex,
- outPtr
- )
+ findex,
+ outPtr
+ )
);
const ret = this.makePackedFunc(this.memory.loadPointer(outPtr));
this.lib.recycleCallStack(stack);
@@ -2236,9 +2261,9 @@ export class Instance implements Disposable {
return this.memory.loadI64(rvaluePtr);
case ArgTypeCode.Float:
return this.memory.loadF64(rvaluePtr);
- case ArgTypeCode.TVMOpaqueHandle: {
- return this.memory.loadPointer(rvaluePtr);
- }
+ case ArgTypeCode.TVMOpaqueHandle: {
+ return this.memory.loadPointer(rvaluePtr);
+ }
case ArgTypeCode.TVMNDArrayHandle: {
return this.ctx.attachToCurrentScope(
new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib, this.ctx)
@@ -2256,7 +2281,7 @@ export class Instance implements Disposable {
}
case ArgTypeCode.TVMModuleHandle: {
return this.ctx.attachToCurrentScope(
- new Module(
+ new Module(
this.memory.loadPointer(rvaluePtr),
this.lib,
(ptr: Pointer) => {