Web Workers with TensorFlow.js
Lately my team has been experimenting with creating a smoother or livelier experience with inference done in a mobile browser. Many examples out there demonstrate or emulate the interaction of taking a photo as the trigger to run a prediction, but very few exist that emulate the experience of taking a video.
I’ve created an example repo for your reference, so you can dive right in if you want to. You can also check out the demo. Otherwise, I’m going to run through some code snippets and walk you through some of the more important components.
For this example, we’ll be using a MobileNetV2 model from tfhub.dev.
Setting things up
A simple HTML page to kick things off:
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<script src="https://unpkg.com/@tensorflow/tfjs"></script>
<style>
#video {
display: none;
}
</style>
</head>
<body>
<div id="canvas-container">
<canvas id="canvas"></canvas>
</div>
<video id="video" autoplay playsinline muted></video>
<script type="text/javascript" src="app.js"></script>
</body>
</html>
Let’s begin by declaring all our variables, and assigning references to our DOM elements in app.js
:
// app.js
const MODEL_URL = "https://...";
const DICT_URL = "https://...";
// some globals
let model = null;
let dictionary = null;
let webWorker = null;
// and another object to hold all our DOM references
let DOM_EL = {
video: null,
canvas: null,
ctx: null
}
window.addEventListener('DOMContentLoaded', () => {
// assign references
DOM_EL.video = document.getElementById("video");
DOM_EL.canvas = document.getElementById("canvas");
DOM_EL.canvas.width = window.innerWidth;
DOM_EL.canvas.height = window.innerWidth;
DOM_EL.ctx = DOM_EL.canvas.getContext("2d", { alpha: false, desynchronized: false });
init();
})
Because we’re rendering the video buffer onto our canvas, we’ll be using a 2D context with both alpha
and desynchronized
options set to false
for some low-hanging optimisations.
In our init
function, we want to first get permission to access the device’s camera stream and return it. After that, we’ll attempt to load our model and render the stream onto the canvas:
const init = async function() {
try {
DOM_EL.video.srcObject = await setupCamera();
} catch(err) {
console.log(err);
}
setupModel();
DOM_EL.video.onloadeddata = e => {
DOM_EL.video.play();
render();
}
}
const setupCamera = async function() {
// .getUserMedia triggers a dialogue to ask for permission
return navigator.mediaDevices
.getUserMedia({ video: { facingMode: "environment" }, audio: false })
.then(stream => stream)
.catch(err => {
console.error("Can't obtain stream: ", err);
});
}
const setupModel = async function() {
if (window.Worker) {
webWorker = new Worker('web-worker.js');
webWorker.onmessage = evt => {
console.log(evt.data);
};
} else {
try {
model = await tf.loadGraphModel(MODEL_URL);
const response = await tf.util.fetch(DICT_URL);
const text = await response.text();
dictionary = text.trim().split('\n');
} catch(err) {
console.error("Can't load model: ", err)
}
const zeros = tf.zeros([1, 224, 224, 3]);
// warm-up the model
model.predict(zeros);
}
}
const render = function() {
DOM_EL.ctx.drawImage(DOM_EL.video, 0, 0, window.innerWidth, window.innerWidth);
window.requestAnimationFrame(render);
}
In setupModel
, we do a quick check to see if the browser supports web workers. If it does, we simply create one and register an onmessage
listener for it. If it doesn’t, we’ll swallow the blue pill and load the model on the current page instead.
web-worker.js
Since web workers run in isolated threads and have their own window context, you’ll have to import any libraries here as well. Note that importScripts
is synchronous/blocking!
The try/catch block for loading the model is essentially the same as in setupModel
. What’s new here is the onmessage
listener, which we’ll use to fire off predictions later. We’ll also need to tell the main page that our model is ready.
// web-worker.js
importScripts('https://unpkg.com/@tensorflow/tfjs');
tf.setBackend('cpu');
const MODEL_URL = "https://...";
const DICT_URL = "https://...";
let model;
let dictionary;
const setup = async () => {
try {
model = await tf.loadGraphModel(MODEL_URL, { fromTFHub: true });
const response = await tf.util.fetch(DICT_URL);
const text = await response.text();
dictionary = text.trim().split('\n');
} catch(err) {
console.error("Can't load model: ", err)
}
const zeros = tf.zeros([1, 224, 224, 3]);
// warm-up the model
model.predict(zeros);
// inform the main page that our model is loaded
postMessage({ modelIsReady: true});
}
setup();
// Register our listener for now, we'll write the predict function later
onmessage = evt => {
if (model) {
predict(evt.data);
}
}
We’ll need to catch the message for when our model is ready:
// app.js
const setupModel = async function() {
if (window.Worker) {
webWorker = new Worker('web-worker.js');
webWorker.onmessage = evt => {
if (evt.data.modelIsReady) {
workerModelIsReady = true;
}
console.log(evt.data);
};
} else {
//...
}
Prediction functions
Let’s get back to app.js
and write our prediction functions. We’ll need two of them; one to offload data to our web worker, and the other to be run on the current page. Let’s start with our regular prediction function:
// app.js
const predict = async function() {
const scores = tf.tidy(() => {
// tensor <- texture <- pixel data
const imgTensor = tf.browser.fromPixels(DOM_EL.canvas);
// pre-processing
const centerCroppedImg = centerCropAndResize(imgTensor);
const processedImg = centerCroppedImg.div(127.5).sub(1);
return model.predict(processedImg);
})
const probabilities = await scores.data();
scores.dispose();
const result = Array.from(probabilities)
.map((prob, i) => ({label: dictionary[i], prob}));
const prediction = result.reduce((prev, current) => {
return (prev.prob > current.prob) ? prev : current
})
// predicted class
console.log(prediction.label);
// probability
console.log(parseFloat(prediction.prob.toFixed(2)));
}
tf.tidy
automatically cleans up intermediate tensors allocated within the function for us. This is important because we don’t want any memory leaks. Inside the function, we’ll create a tensor using tf.browser.fromPixels
on our canvas; run a function to center crop and resize it, and finally normalise the pixel values. Once these three steps are done, we can pass it into model.predict
and wait for the result in console.
Next, we’ll write the offload function. We need a couple of flags to tell us when we can send data over to the web worker, so that we won’t flood it with messages while waiting for prediction results.
// app.js
// declare these at the top together with other globals
let isWaiting = false;
let workerModelIsReady = false;
const offloadPredict = async function() {
if (!isWaiting && !workerModelIsReady) {
const imageData = DOM_EL.ctx.getImageData(0, 0, DOM_EL.canvas.width, DOM_EL.canvas.height);
webWorker.postMessage(imageData);
isWaiting = !isWaiting;
}
}
It looks relatively simple compared to our actual predict function, and that’s because we have some limitations on the types of data we can send via postMessage
. It uses the structured clone algorithm which has the following caveats:
- Function objects cannot be duplicated
- The lastIndex property of RegExp objects is not preserved
- Property descriptors, setters, getters, and similar metadata-like features are not duplicated
- Prototype chain is not walked or duplicated
This means if we send a tfjs tensor via postMessage
, tensor methods like toFloat
and expandDims
which we need for pre-processing, cannot be accessed.
// app.js
const imgTensor = tf.browser.fromPixels(DOM_EL.canvas);
webWorker.postMessage(imgTensor);
// web-worker.js
console.log(imgTensor.toFloat().expandDims());
// this throws errors
Hence, we’ll use getImageData and send the object over.
web-worker.js again
// web-worker.js
// importScripts
// setup()
onmessage = evt => {
if (model) {
predict(evt.data);
}
}
const predict = async function(imageData) {
const scores = tf.tidy(() => {
const imgAsTensor = tf.browser.fromPixels(imageData);
const centerCroppedImg = centerCropAndResize(imgAsTensor);
const processedImg = centerCroppedImg.div(127.5).sub(1);
return model.predict(processedImg);
})
if (scores) {
const probabilities = await scores.data();
scores.dispose();
const result = Array.from(probabilities)
.map((prob, i) => ({label: dictionary[i], prob}));
var prediction = result.reduce(function(prev, current) {
return (prev.prob > current.prob) ? prev : current;
})
// send the result back as an array
postMessage([prediction.label, parseFloat(prediction.prob.toFixed(2))]);
} else {
console.log('Scores: ', scores);
}
}
No surprises here. It’s almost the same exact predict function we wrote in app.js
, except that we’re taking in an argument and also using postMessage
to send the results back.
Final steps
We’ve just connected all the pipes, and all that’s left to do is to call our predict functions. Back over at init
:
// app.js
const init = async function() {
DOM_EL.video.srcObject = await setupCamera();
DOM_EL.video.srcObject ? setupModel() : return;
DOM_EL.video.onloadeddata = e => {
DOM_EL.video.play();
render();
if (window.Worker) {
setInterval(offloadPredict, 1000);
} else {
setInterval(predict, 2000);
}
}
}
For my team’s prototype, one second for offloadPredict
is the threshhold for a certain amount of liveliness in the web app. Two seconds for predict
produces a small amount of jankyness that we can live with. Feel free to adjust the values that is appropriate for your project.
With an appropriate model, you should be getting the results printed in the browser console. You can also check the UI performance of both functions to see the gains for yourself. Check out the demo here.