Web Workers with TensorFlow.js

// development

Lately my team has been experimenting with creating a smoother or livelier experience with inference done in a mobile browser. Many examples out there demonstrate or emulate the interaction of taking a photo as the trigger to run a prediction, but very few exist that emulate the experience of taking a video.

I’ve created an example repo for your reference, so you can dive right in if you want to. You can also check out the demo. Otherwise, I’m going to run through some code snippets and walk you through some of the more important components.

For this example, we’ll be using a MobileNetV2 model from tfhub.dev.

Setting things up

A simple HTML page to kick things off:

<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8" />
    <script src="https://unpkg.com/@tensorflow/tfjs"></script>
    <style>
      #video {
        display: none;
      }
    </style>
  </head>
  <body>
    <div id="canvas-container">
      <canvas id="canvas"></canvas>
    </div>
    <video id="video" autoplay playsinline muted></video>
    <script type="text/javascript" src="app.js"></script>
  </body>
</html>

Let’s begin by declaring all our variables, and assigning references to our DOM elements in app.js:

// app.js
const MODEL_URL = "https://...";
const DICT_URL = "https://...";

// some globals
let model = null;
let dictionary = null;
let webWorker = null;

// and another object to hold all our DOM references
let DOM_EL = {
  video: null,
  canvas: null,
  ctx: null
}

window.addEventListener('DOMContentLoaded', () => {
  // assign references
  DOM_EL.video = document.getElementById("video");
  DOM_EL.canvas = document.getElementById("canvas");
  DOM_EL.canvas.width = window.innerWidth;
  DOM_EL.canvas.height = window.innerWidth;
  DOM_EL.ctx = DOM_EL.canvas.getContext("2d", { alpha: false, desynchronized: false });

  init();
})

Because we’re rendering the video buffer onto our canvas, we’ll be using a 2D context with both alpha and desynchronized options set to false for some low-hanging optimisations.

In our init function, we want to first get permission to access the device’s camera stream and return it. After that, we’ll attempt to load our model and render the stream onto the canvas:

const init = async function() {
  try {
    DOM_EL.video.srcObject = await setupCamera();
  } catch(err) {
    console.log(err);
  }

  setupModel();

  DOM_EL.video.onloadeddata = e => {
    DOM_EL.video.play();
    render();
  }
}

const setupCamera = async function() {
  // .getUserMedia triggers a dialogue to ask for permission
  return navigator.mediaDevices
  .getUserMedia({ video: { facingMode: "environment" }, audio: false })
  .then(stream => stream)
  .catch(err => {
    console.error("Can't obtain stream: ", err);
  });
}

const setupModel = async function() {
  if (window.Worker) {
    webWorker = new Worker('web-worker.js');
    webWorker.onmessage = evt => {
      console.log(evt.data);
    };
  } else {
    try {
      model = await tf.loadGraphModel(MODEL_URL);
      const response = await tf.util.fetch(DICT_URL);
      const text = await response.text();
      dictionary = text.trim().split('\n');
    } catch(err) {
      console.error("Can't load model: ", err)
    }
    const zeros = tf.zeros([1, 224, 224, 3]);
    // warm-up the model
    model.predict(zeros);
  }
}

const render = function() {
  DOM_EL.ctx.drawImage(DOM_EL.video, 0, 0, window.innerWidth, window.innerWidth);
  window.requestAnimationFrame(render);
}

In setupModel, we do a quick check to see if the browser supports web workers. If it does, we simply create one and register an onmessage listener for it. If it doesn’t, we’ll swallow the blue pill and load the model on the current page instead.

web-worker.js

Since web workers run in isolated threads and have their own window context, you’ll have to import any libraries here as well. Note that importScripts is synchronous/blocking!

The try/catch block for loading the model is essentially the same as in setupModel. What’s new here is the onmessage listener, which we’ll use to fire off predictions later. We’ll also need to tell the main page that our model is ready.

// web-worker.js
importScripts('https://unpkg.com/@tensorflow/tfjs');
tf.setBackend('cpu');

const MODEL_URL = "https://...";
const DICT_URL = "https://...";

let model;
let dictionary;

const setup = async () => {
  try {
    model = await tf.loadGraphModel(MODEL_URL, { fromTFHub: true });
    const response = await tf.util.fetch(DICT_URL);
    const text = await response.text();
    dictionary = text.trim().split('\n');
  } catch(err) {
    console.error("Can't load model: ", err)
  }
  
  const zeros = tf.zeros([1, 224, 224, 3]);
  // warm-up the model
  model.predict(zeros);
  // inform the main page that our model is loaded
  postMessage({ modelIsReady: true});
}

setup();

// Register our listener for now, we'll write the predict function later
onmessage = evt => {
  if (model) {
    predict(evt.data);
  }
}

We’ll need to catch the message for when our model is ready:

// app.js
const setupModel = async function() {
  if (window.Worker) {
    webWorker = new Worker('web-worker.js');
    webWorker.onmessage = evt => {
      if (evt.data.modelIsReady) {
        workerModelIsReady = true;
      }
      console.log(evt.data);
    };
  } else {
    //...
  }

Prediction functions

Let’s get back to app.js and write our prediction functions. We’ll need two of them; one to offload data to our web worker, and the other to be run on the current page. Let’s start with our regular prediction function:

// app.js
const predict = async function() {
  const scores = tf.tidy(() => {
    // tensor <- texture <- pixel data
    const imgTensor = tf.browser.fromPixels(DOM_EL.canvas);
    // pre-processing
    const centerCroppedImg = centerCropAndResize(imgTensor);
    const processedImg = centerCroppedImg.div(127.5).sub(1);
    return model.predict(processedImg);
  })
  const probabilities = await scores.data();
  scores.dispose();
  const result = Array.from(probabilities)
                     .map((prob, i) => ({label: dictionary[i], prob}));

  const prediction = result.reduce((prev, current) => {
    return (prev.prob > current.prob) ? prev : current
  })

  // predicted class
  console.log(prediction.label);
  // probability
  console.log(parseFloat(prediction.prob.toFixed(2)));
}

tf.tidy automatically cleans up intermediate tensors allocated within the function for us. This is important because we don’t want any memory leaks. Inside the function, we’ll create a tensor using tf.browser.fromPixels on our canvas; run a function to center crop and resize it, and finally normalise the pixel values. Once these three steps are done, we can pass it into model.predict and wait for the result in console.

Next, we’ll write the offload function. We need a couple of flags to tell us when we can send data over to the web worker, so that we won’t flood it with messages while waiting for prediction results.

// app.js

// declare these at the top together with other globals
let isWaiting = false;
let workerModelIsReady = false;

const offloadPredict = async function() {
  if (!isWaiting && !workerModelIsReady) {
    const imageData = DOM_EL.ctx.getImageData(0, 0, DOM_EL.canvas.width, DOM_EL.canvas.height);
    webWorker.postMessage(imageData);
    isWaiting = !isWaiting;
  }
}

It looks relatively simple compared to our actual predict function, and that’s because we have some limitations on the types of data we can send via postMessage. It uses the structured clone algorithm with has the following caveats:

  • Function objects cannot be duplicated
  • The lastIndex property of RegExp objects is not preserved
  • Property descriptors, setters, getters, and similar metadata-like features are not duplicated
  • Prototype chain is not walked or duplicated

This means if we send a tfjs tensor via postMessage, tensor methods like toFloat and expandDims which we need for pre-processing, cannot be accessed.

// app.js
const imgTensor = tf.browser.fromPixels(DOM_EL.canvas);
webWorker.postMessage(imgTensor);

// web-worker.js
console.log(imgTensor.toFloat().expandDims());
// this throws errors

Hence, we’ll use getImageData and send the object over.

web-worker.js again

// web-worker.js

// importScripts

// setup()

onmessage = evt => {
  if (model) {
    predict(evt.data);
  }
}

const predict = async function(imageData) {
  const scores = tf.tidy(() => {
    const imgAsTensor = tf.browser.fromPixels(imageData);
    const centerCroppedImg = centerCropAndResize(imgAsTensor);
    const processedImg = centerCroppedImg.div(127.5).sub(1);
    return model.predict(processedImg);
  })
  if (scores) {
    const probabilities = await scores.data();
    scores.dispose();

    const result = Array.from(probabilities)
                      .map((prob, i) => ({label: dictionary[i], prob}));
    var prediction = result.reduce(function(prev, current) {
      return (prev.prob > current.prob) ? prev : current;
    })

    // send the result back as an array
    postMessage([prediction.label, parseFloat(prediction.prob.toFixed(2))]);
  } else {
    console.log('Scores: ', scores);
  }
}

No surprises here. It’s almost the same exact predict function we wrote in app.js, except that we’re taking in an argument and also using postMessage to send the results back.

Final steps

We’ve just connected all the pipes, and all that’s left to do is to call our predict functions. Back over at init:

// app.js
const init = async function() {
  DOM_EL.video.srcObject = await setupCamera();
  DOM_EL.video.srcObject ? setupModel() : return;

  DOM_EL.video.onloadeddata = e => {
    DOM_EL.video.play();
    render();
    if (window.Worker) {
      setInterval(offloadPredict, 1000);
    } else {
      setInterval(predict, 2000);
    }
  }
}

For my team’s prototype, one second for offloadPredict is the threshhold for a certain amount of liveliness in the web app. Two seconds for predict produces a small amount of jankyness that we can live with. Feel free to adjust the values that is appropriate for your project.

With an appropriate model, you should be getting the results printed in the browser console. You can also check the UI performance of both functions to see the gains for yourself. Check out the demo here.