Visualizing browser-based model training process using TensorFlow visualizer
Observing the neural network training process is equally important as it helps observe loss and accuracy per epoch and much more.
While training your model in python, the tensorboard library makes it possible to do visualization. Similarly, if you are training your model on a browser with tensorflowjs, you will require something for watching the training process. It is possible using the tensorflowjs visualizer. Using this, we can visualize the training loss/accuracy per epoch or batch while training. We can visualize the per class accuracy and the confusion matrix while evaluating and much more.
I am going to train a classifier using the Fashion MNIST dataset for 10 labels on google chrome browser using tfjs and also visualize model while training. Once the model is trained we can draw on canvas to identify the correct class.
Yes, you heard it right, we will draw an output on canvas and tell our classifier to identify the image class. I won’t be digging deep into coding, as my main purpose for writing this story is the implementation of the tfvis library in our js code and to show how visualizer works for browser-based models. I have included the link of my Github code at the bottom of this story.
For better understanding this story you can refer my previous blog. How to train a neural network on Chrome using tensorflow.js
Importing tfjs-vis library
Just like train.py in python, we will use a separate java-script for writing code for building and training a model and for downloading the Fashion MNIST dataset sprite sheet.
What is sprite sheet? Sprite sheet contains thousands of images combined together in a single file. This technique is used by game developers to make the process of data fetching easy by slicing portion of images from file rather than calling images multiple times. This makes process faster when we require lots of image processing in real time environments.
Why are we using sprite sheet? MNIST have tens of thousands of images and if you are going to be opening http connections tens of thousands of time to download those images that could be problematic. So unlike training in python where you can load 10,000 images one after another, you can’t do that in a web browser. - Laurence Moroney, deeplearning.ai
Defining metrics for callbacks
Let’s create metrics for observing loss, validation loss, accuracy, and validation accuracy which we can pass to fitCallbacks function of tfvis.
There is also provision for defining container name and size, which is also a required parameter for fitCallbacks function.
Setting up tf-vis show
Let’s set up a visualizer with our metrics and container size.
Set visualizer in training function (model.fit)
Once everything is setup you can see visualizer on web page while training.
The above values are loss and accuracy that we set in metrics that we passed to the tfvis fitCallbacks function.
Apart from training, you can also use tfvis for evaluating the model. For a classification task like this one, we can use the `perClassAccuracy` and `confusion matrix` functions.
Check out my model output
P.s: Ignore my drawing 😛
Wanna give it a try?
Click HERE 👈
Refer my GitHub repository from here. It contains code for training a classifier on the fashion mnist dataset in the browser and displaying visualizations.
Sushrut Ashtikar is one of the members of Nsemble.ai team, we love to research and develop challenging products using artificial intelligence. Nsemble has developed several solutions in the domain of Industry 4.0 and E-commerce. We will be happy to help you!
18 views0 comments