This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 45df1247c6 [Web] Implement linear congruential generator, make runtime
seedable (#16722)
45df1247c6 is described below
commit 45df1247c66adf117d6a690aea3f51e3c1bd0453
Author: Charlie Ruan <[email protected]>
AuthorDate: Fri Mar 15 10:10:34 2024 -0400
[Web] Implement linear congruential generator, make runtime seedable
(#16722)
This PR implements `LinearCongruentialGenerator` in TVMjs,
following the C++ counterpart in https://github.com/apache/tvm/pull/8642/.
The motivation is that we want to seed autoregressive generation
to make results reproducible, supporting the OpenAI field `seed`.
The main function is `nextInt()`, which generates a number
`(0, 2^32 - 1)` non-inclusive.
Subsequently, we change all `Math.random()` in `runtime.ts` to
`this.rng.randomFloat()`, exposing API `Instance.setSeed()`.
Unit tests are added for `LinearCongruentialGenerator` for testing
seed and coverage.
---
web/src/index.ts | 2 +-
web/src/runtime.ts | 17 ++++++--
web/src/support.ts | 76 +++++++++++++++++++++++++++++++++
web/tests/node/test_random_generator.js | 71 ++++++++++++++++++++++++++++++
4 files changed, 161 insertions(+), 5 deletions(-)
diff --git a/web/src/index.ts b/web/src/index.ts
index 9099d8f373..edc695978f 100644
--- a/web/src/index.ts
+++ b/web/src/index.ts
@@ -26,7 +26,7 @@ export {
} from "./runtime";
export { Disposable, LibraryProvider } from "./types";
export { RPCServer } from "./rpc_server";
-export { wasmPath } from "./support";
+export { wasmPath, LinearCongruentialGenerator } from "./support";
export { detectGPUDevice, GPUDeviceDetectOutput } from "./webgpu";
export { assert } from "./support";
export { createPolyfillWASI } from "./compact";
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index ea022d1b3e..9142571b9e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -23,7 +23,7 @@
import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes";
import { Disposable } from "./types";
import { Memory, CachedCallStack } from "./memory";
-import { assert, StringToUint8Array } from "./support";
+import { assert, StringToUint8Array, LinearCongruentialGenerator } from
"./support";
import { Environment } from "./environment";
import { AsyncifyHandler } from "./asyncify";
import { FunctionInfo, WebGPUContext } from "./webgpu";
@@ -1079,6 +1079,7 @@ export class Instance implements Disposable {
private ctx: RuntimeContext;
private asyncifyHandler: AsyncifyHandler;
private initProgressCallback: Array<InitProgressCallback> = [];
+ private rng: LinearCongruentialGenerator;
/**
* Internal function(registered by the runtime)
@@ -1131,6 +1132,7 @@ export class Instance implements Disposable {
);
this.registerEnvGlobalPackedFuncs();
this.registerObjectFactoryFuncs();
+ this.rng = new LinearCongruentialGenerator();
}
/**
@@ -1811,11 +1813,18 @@ export class Instance implements Disposable {
const scale = high - low;
const input = new Float32Array(size);
for (let i = 0; i < input.length; ++i) {
- input[i] = low + Math.random() * scale;
+ input[i] = low + this.rng.randomFloat() * scale;
}
return ret.copyFrom(input);
}
+ /**
+ * Set the seed of the internal LinearCongruentialGenerator.
+ */
+ setSeed(seed: number): void {
+ this.rng.setSeed(seed);
+ }
+
/**
* Sample index via top-p sampling.
*
@@ -1825,7 +1834,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromLogits(logits: NDArray, temperature: number, top_p: number):
number {
- return this.ctx.sampleTopPFromLogits(logits, temperature, top_p,
Math.random());
+ return this.ctx.sampleTopPFromLogits(logits, temperature, top_p,
this.rng.randomFloat());
}
/**
@@ -1836,7 +1845,7 @@ export class Instance implements Disposable {
* @returns The sampled index.
*/
sampleTopPFromProb(prob: NDArray, top_p: number): number {
- return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
+ return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat());
}
/**
diff --git a/web/src/support.ts b/web/src/support.ts
index b03fa363cd..2fa87ed291 100644
--- a/web/src/support.ts
+++ b/web/src/support.ts
@@ -74,3 +74,79 @@ export function assert(condition: boolean, msg?: string):
asserts condition {
export function wasmPath(): string {
return __dirname + "/wasm";
}
+
+/**
+ * Linear congruential generator for random number generating that can be
seeded.
+ *
+ * Follows the implementation of `include/tvm/support/random_engine.h`, which
follows the
+ * sepcification in
https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine.
+ *
+ * Note `Number.MAX_SAFE_INTEGER = 2^53 - 1`, and our intermediates are
strictly less than 2^48.
+ */
+
+export class LinearCongruentialGenerator {
+ readonly modulus: number;
+ readonly multiplier: number;
+ readonly increment: number;
+ // Always within the range (0, 2^32 - 1) non-inclusive; if 0, will forever
generate 0.
+ private rand_state: number;
+
+ /**
+ * Set modulus, multiplier, and increment. Initialize `rand_state` according
to `Date.now()`.
+ */
+ constructor() {
+ this.modulus = 2147483647; // 2^32 - 1
+ this.multiplier = 48271; // between 2^15 and 2^16
+ this.increment = 0;
+ this.setSeed(Date.now());
+ }
+
+ /**
+ * Sets `rand_state` after normalized with `modulus` to ensure that it is
within range.
+ * @param seed Any integer. Used to set `rand_state` after normalized with
`modulus`.
+ *
+ * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an
integer.
+ */
+ setSeed(seed: number) {
+ if (!Number.isInteger(seed)) {
+ throw new Error("Seed should be an integer.");
+ }
+ this.rand_state = seed % this.modulus;
+ if (this.rand_state == 0) {
+ this.rand_state = 1;
+ }
+ this.checkRandState();
+ }
+
+ /**
+ * Generate the next integer in the range (0, this.modulus) non-inclusive,
updating `rand_state`.
+ *
+ * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an
integer.
+ */
+ nextInt(): number {
+ // `intermediate` is always < 2^48, hence less than
`Number.MAX_SAFE_INTEGER` due to the
+ // invariants as commented in the constructor.
+ const intermediate = this.multiplier * this.rand_state + this.increment;
+ this.rand_state = intermediate % this.modulus;
+ this.checkRandState();
+ return this.rand_state;
+ }
+
+ /**
+ * Generates random float between (0, 1) non-inclusive, updating
`rand_state`.
+ *
+ * Postcondition: pass `checkRandState()`, i.e. rand_state > 0 and is an
integer.
+ */
+ randomFloat(): number {
+ return this.nextInt() / this.modulus;
+ }
+
+ private checkRandState(): void {
+ if (this.rand_state <= 0) {
+ throw new Error("Random state is unexpectedly not strictly positive.");
+ }
+ if (!Number.isInteger(this.rand_state)) {
+ throw new Error("Random state is unexpectedly not an integer.");
+ }
+ }
+}
diff --git a/web/tests/node/test_random_generator.js
b/web/tests/node/test_random_generator.js
new file mode 100644
index 0000000000..adc6635d05
--- /dev/null
+++ b/web/tests/node/test_random_generator.js
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/* eslint-disable no-undef */
+
+const tvmjs = require("../../dist");
+
+test("Test coverage of [0,100] inclusive", () => {
+ const covered = Array(100);
+ const rng = new tvmjs.LinearCongruentialGenerator();
+ for (let i = 0; i < 100000; i++) {
+ covered[rng.nextInt() % 100] = true;
+ }
+ const notCovered = [];
+ for (let i = 0; i < 100; i++) {
+ if (!covered[i]) {
+ notCovered.push(i);
+ }
+ }
+ expect(notCovered).toEqual([]);
+});
+
+test("Test whether the same seed make two RNGs generate same results", () => {
+ const rng1 = new tvmjs.LinearCongruentialGenerator();
+ const rng2 = new tvmjs.LinearCongruentialGenerator();
+ rng1.setSeed(42);
+ rng2.setSeed(42);
+
+ for (let i = 0; i < 100; i++) {
+ expect(rng1.randomFloat()).toBeCloseTo(rng2.randomFloat());
+ }
+});
+
+test("Test two RNGs with different seeds generate different results", () => {
+ const rng1 = new tvmjs.LinearCongruentialGenerator();
+ const rng2 = new tvmjs.LinearCongruentialGenerator();
+ rng1.setSeed(41);
+ rng2.setSeed(42);
+ let numSame = 0;
+ const numTest = 100;
+
+ // Generate `numTest` random numbers, make sure not all are the same.
+ for (let i = 0; i < numTest; i++) {
+ if (rng1.nextInt() === rng2.nextInt()) {
+ numSame += 1;
+ }
+ }
+ expect(numSame < numTest).toBe(true);
+});
+
+test('Illegal argument to `setSeed()`', () => {
+ expect(() => {
+ const rng1 = new tvmjs.LinearCongruentialGenerator();
+ rng1.setSeed(42.5);
+ }).toThrow("Seed should be an integer.");
+});