Categories
Offsites

Rearranging the Visual World

Rearranging objects (such as organizing books on a bookshelf, moving utensils on a dinner table, or pushing piles of coffee beans) is a fundamental skill that can enable robots to physically interact with our diverse and unstructured world. While easy for people, accomplishing such tasks remains an open research challenge for embodied machine learning (ML) systems, as it requires both high-level and low-level perceptual reasoning. For example, when stacking a pile of books, one might consider where the books should be stacked, and in which order, while ensuring that the edges of the books align with each other to form a neat pile.

Across many application areas in ML, simple differences in model architecture can exhibit vastly different generalization properties. Therefore, one might ask whether there are certain deep network architectures that favor simple underlying elements of the rearrangement problem. Convolutional architectures, for example, are common in computer vision as they encode translational invariance, yielding the same response even if an image is shifted, while Transformer architectures are common in language processing because they exploit self-attention to capture long-range contextual dependencies. In robotics applications, one common architectural element is to use object-centric representations such as poses, keypoints, or object descriptors inside learned models, but these representations require additional training data (often manually annotated) and struggle to describe difficult scenarios such as deformables (e.g., playdough), fluids (honey), or piles of stuff (chopped onions).

Today, we present the Transporter Network, a simple model architecture for learning vision-based rearrangement tasks, which appeared as a publication and plenary talk during CoRL 2020. Transporter Nets use a novel approach to 3D spatial understanding that avoids reliance on object-centric representations, making them general for vision-based manipulation but far more sample efficient than benchmarked end-to-end alternatives. As a consequence, they are fast and practical to train on real robots. We are also releasing an accompanying open-source implementation of Transporter Nets together with Ravens, our new simulated benchmark suite of ten vision-based manipulation tasks.

Transporter Networks: Rearranging the Visual World for Robotic Manipulation
The key idea behind the Transporter Network architecture is that one can formulate the rearrangement problem as learning how to move a chunk of 3D space. Rather than relying on an explicit definition of objects (which is bound to struggle at capturing all edge cases), 3D space is a much broader definition for what could serve as the atomic units being rearranged, and can broadly encompass an object, part of an object, or multiple objects, etc. Transporter Nets leverage this structure by capturing a deep representation of the 3D visual world, then overlaying parts of it on itself to imagine various possible rearrangements of 3D space. It then chooses the rearrangements that best match those it has seen during training (e.g., from expert demonstrations), and uses them to parameterize robot actions. This formulation allows Transporter Nets to generalize to unseen objects and enables them to better exploit geometric symmetries in the data, so that they can extrapolate to new scene configurations. Transporter Nets are applicable to a wide variety of rearrangement tasks for robotic manipulation, expanding beyond our earlier models, such as affordance-based manipulation and TossingBot, that focus only on grasping and tossing.

Transporter Nets capture a deep representation of the visual world, then overlay parts of it on itself to imagine various possible rearrangements of 3D space to find the best one and inform robot actions.

Ravens Benchmark
To evaluate the performance of Transporter Nets in a consistent environment for fair comparisons to baselines and ablations, we developed Ravens, a benchmark suite of ten simulated vision-based rearrangement tasks. Ravens features a Gym API with a built-in stochastic oracle to evaluate the sample efficiency of imitation learning methods. Ravens avoids assumptions that cannot transfer to a real setup: observation data contains only RGB-D images and camera parameters; actions are end effector poses (transposed into joint positions with inverse kinematics).

Experiments on these ten tasks show that Transporter Nets are orders of magnitude more sample efficient than other end-to-end methods, and are capable of achieving over 90% success on many tasks with just 100 demonstrations, while the baselines struggle to generalize with the same amount of data. In practice, this makes collecting enough demonstrations a more viable option for training these models on real robots (which we show examples of below).

Our new Ravens benchmark includes ten simulated vision-based manipulation tasks, including pushing and pick-and-place, for which experiments show that Transporter Nets are orders of magnitude more sample efficient than other end-to-end methods. Ravens features a Gym API with a built-in stochastic oracle to evaluate the sample efficiency of imitation learning methods.

Our new Ravens benchmark includes ten simulated vision-based manipulation tasks, including pushing and pick-and-place, for which experiments show that Transporter Nets are orders of magnitude more sample efficient than other end-to-end methods. Ravens features a Gym API with a built-in stochastic oracle to evaluate the sample efficiency of imitation learning methods.

Highlights
Given 10 example demonstrations, Transporter Nets can learn pick and place tasks such as stacking plates (surprisingly easy to misplace!), multimodal tasks like aligning any corner of a box to a marker on the tabletop, or building a pyramid of blocks.

By leveraging closed-loop visual feedback, Transporter Nets have the capacity to learn various multi-step sequential tasks with a modest number of demonstrations: such as moving disks for Tower of Hanoi, palletizing boxes, or assembling kits of new objects not seen during training. These tasks have considerably “long horizons”, meaning that to solve the task the model must correctly sequence many individual choices. Policies also tend to learn emergent recovery behaviors.

One surprising thing about these results was that beyond just perception, the models were starting to learn behaviors that resemble high-level planning. For example, to solve Towers of Hanoi, the models have to pick which disk to move next, which requires recognizing the state of the board based on the current visible disks and their positions. With a box-palletizing task, the models must locate the empty spaces of the pallet, and identify how new boxes can fit into those voids. Such behaviors are exciting because they suggest that with all the baked-in invariances, the model can focus its capacity on learning the more high-level patterns in manipulation.

Transporter Nets can also learn tasks that use any motion primitive defined by two end effector poses, such as pushing piles of small objects into a target set, or reconfiguring a deformable rope to connect the two end-points of a 3-sided square. This suggests that rigid spatial displacements can serve as useful priors for nonrigid ones.

Conclusion
Transporter Nets bring a promising approach to learning vision-based manipulation, but are not without limitations. For example, they can be susceptible to noisy 3D data, we have only demonstrated them for sparse waypoint-based control with motion primitives, and it remains unclear how to extend them beyond spatial action spaces to force or torque-based actions. But overall, we are excited about this direction of work, and we hope that it provides inspiration for extensions beyond the applications we’ve discussed. For more details, please check out our paper.

Acknowledgements
This research was done by Andy Zeng, Pete Florence, Jonathan Tompson, Stefan Welker, Jonathan Chien, Maria Attarian, Travis Armstrong, Ivan Krasin, Dan Duong, Vikas Sindhwani, and Johnny Lee, with special thanks to Ken Goldberg, Razvan Surdulescu, Daniel Seita, Ayzaan Wahid, Vincent Vanhoucke, Anelia Angelova, Kendra Byrne, for helpful feedback on writing; Sean Snyder, Jonathan Vela, Larry Bisares, Michael Villanueva, Brandon Hurd for operations and hardware support; Robert Baruch for software infrastructure, Jared Braun for UI contributions; Erwin Coumans for PyBullet advice; Laura Graesser for video narration.

Categories
Offsites

3D Scene Understanding with TensorFlow 3D

The growing ubiquity of 3D sensors (e.g., Lidar, depth sensing cameras and radar) over the last few years has created a need for scene understanding technology that can process the data these devices capture. Such technology can enable machine learning (ML) systems that use these sensors, like autonomous cars and robots, to navigate and operate in the real world, and can create an improved augmented reality experience on mobile devices. The field of computer vision has recently begun making good progress in 3D scene understanding, including models for mobile 3D object detection, transparent object detection, and more, but entry to the field can be challenging due to the limited availability tools and resources that can be applied to 3D data.

In order to further improve 3D scene understanding and reduce barriers to entry for interested researchers, we are releasing TensorFlow 3D (TF 3D), a highly modular and efficient library that is designed to bring 3D deep learning capabilities into TensorFlow. TF 3D provides a set of popular operations, loss functions, data processing tools, models and metrics that enables the broader research community to develop, train and deploy state-of-the-art 3D scene understanding models.

TF 3D contains training and evaluation pipelines for state-of-the-art 3D semantic segmentation, 3D object detection and 3D instance segmentation, with support for distributed training. It also enables other potential applications like 3D object shape prediction, point cloud registration and point cloud densification. In addition, it offers a unified dataset specification and configuration for training and evaluation of the standard 3D scene understanding datasets. It currently supports the Waymo Open, ScanNet, and Rio datasets. However, users can freely convert other popular datasets, such as NuScenes and Kitti, into a similar format and use them in the pre-existing or custom created pipelines, and can leverage TF 3D for a wide variety of 3D deep learning research and applications, from quickly prototyping and trying new ideas to deploying a real-time inference system.

An example output of the 3D object detection model in TF 3D on a frame from Waymo Open Dataset is shown on the left. An example output of the 3D instance segmentation model on a scene from ScanNet dataset is shown on the right.

Here, we will present the efficient and configurable sparse convolutional backbone that is provided in TF 3D, which is the key to achieving state-of-the-art results on various 3D scene understanding tasks. Furthermore, we will go over each of the three pipelines that TF 3D currently supports: 3D semantic segmentation, 3D object detection and 3D instance segmentation.

3D Sparse Convolutional Network
The 3D data captured by sensors often consists of a scene that contains a set of objects of interest (e.g. cars, pedestrians, etc.) surrounded mostly by open space, which is of limited (or no) interest. As such, 3D data is inherently sparse. In such an environment, standard implementation of convolutions would be computationally intensive and consume a large amount of memory. So, in TF 3D we use submanifold sparse convolution and pooling operations, which are designed to process 3D sparse data more efficiently. Sparse convolutional models are core to the state-of-the-art methods applied in most outdoor self-driving (e.g. Waymo, NuScenes) and indoor benchmarks (e.g. ScanNet).

We also use various CUDA techniques to speed up the computation (e.g., hashing, partitioning / caching the filter in shared memory, and using bit operations). Experiments on the Waymo Open dataset shows that this implementation is around 20x faster than a well-designed implementation with pre-existing TensorFlow operations.

TF 3D then uses the 3D submanifold sparse U-Net architecture to extract a feature for each voxel. The U-Net architecture has proven to be effective by letting the network extract both coarse and fine features and combining them to make the predictions. The U-Net network consists of three modules, an encoder, a bottleneck, and a decoder, each of which consists of a number of sparse convolution blocks with possible pooling or un-pooling operations.

A 3D sparse voxel U-Net architecture. Note that a horizontal arrow takes in the voxel features and applies a submanifold sparse convolution to it. An arrow that is moving down performs a submanifold sparse pooling. An arrow that is moving up will gather back the pooled features, concatenate them with the features coming from the horizontal arrow, and perform a submanifold sparse convolution on the concatenated features.

The sparse convolutional network described above is the backbone for the 3D scene understanding pipelines that are offered in TF 3D. Each of the models described below uses this backbone network to extract features for the sparse voxels, and then adds one or multiple additional prediction heads to infer the task of interest. The user can configure the U-Net network by changing the number of encoder / decoder layers and the number of convolutions in each layer, and by modifying the convolution filter sizes, which enables a wide range of speed / accuracy tradeoffs to be explored through the different backbone configurations

3D Semantic Segmentation
The 3D semantic segmentation model has only one output head for predicting the per-voxel semantic scores, which are mapped back to points to predict a semantic label per point.

3D semantic segmentation of an indoor scene from ScanNet dataset.

3D Instance Segmentation
In 3D instance segmentation, in addition to predicting semantics, the goal is to group the voxels that belong to the same object together. The 3D instance segmentation algorithm used in TF 3D is based on our previous work on 2D image segmentation using deep metric learning. The model predicts a per-voxel instance embedding vector as well as a semantic score for each voxel. The instance embedding vectors map the voxels to an embedding space where voxels that correspond to the same object instance are close together, while those that correspond to different objects are far apart. In this case, the input is a point cloud instead of an image, and it uses a 3D sparse network instead of a 2D image network. At inference time, a greedy algorithm picks one instance seed at a time, and uses the distance between the voxel embeddings to group them into segments.

3D Object Detection
The 3D object detection model predicts per-voxel size, center, and rotation matrices and the object semantic scores. At inference time, a box proposal mechanism is used to reduce the hundreds of thousands of per-voxel box predictions into a few accurate box proposals, and then at training time, box prediction and classification losses are applied to per-voxel predictions. We apply a Huber loss on the distance between predicted and the ground-truth box corners. Since the function that estimates the box corners from its size, center and rotation matrix is differentiable, the loss will automatically propagate back to those predicted object properties. We use a dynamic box classification loss that classifies a box that strongly overlaps with the ground-truth as positive and classifies the non-overlapping boxes as negative.

Our 3D object detection results on ScanNet dataset.

In our recent paper, “DOPS: Learning to Detect 3D Objects and Predict their 3D Shapes”, we describe in detail the single-stage weakly supervised learning algorithm used for object detection in TF 3D. In addition, in a follow up work, we extended the 3D object detection model to leverage temporal information by proposing a sparse LSTM-based multi-frame model. We go on to show that this temporal model outperforms the frame-by-frame approach by 7.5% in the Waymo Open dataset.

The 3D object detection and shape prediction model introduced in the DOPS paper. A 3D sparse U-Net is used to extract a feature vector for each voxel. The object detection module uses these features to propose 3D boxes and semantic scores. At the same time, the other branch of the network predicts a shape embedding that is used to output a mesh for each object.

Ready to Get Started?
We’ve certainly found this codebase to be useful for our 3D computer vision projects, and we hope that you will as well. Contributions to the codebase are welcome and please stay tuned for our own further updates to the framework. To get started please visit our github repository.

Acknowledgements
The release of the TensorFlow 3D codebase and model has been the result of widespread collaboration among Google researchers with feedback and testing from product groups. In particular we want to highlight the core contributions by Alireza Fathi and Rui Huang (work performed while at Google), with special additional thanks to Guangda Lai, Abhijit Kundu, Pei Sun, Thomas Funkhouser, David Ross, Caroline Pantofaru, Johanna Wald, Angela Dai and Matthias Niessner.

Categories
Offsites

Uncovering Unknown Unknowns in Machine Learning

The performance of machine learning (ML) models depends both on the learning algorithms, as well as the data used for training and evaluation. The role of the algorithms is well studied and the focus of a multitude of challenges, such as SQuAD, GLUE, ImageNet, and many others. In addition, there have been efforts to also improve the data, including a series of workshops addressing issues for ML evaluation. In contrast, research and challenges that focus on the data used for evaluation of ML models are not commonplace. Furthermore, many evaluation datasets contain items that are easy to evaluate, e.g., photos with a subject that is easy to identify, and thus they miss the natural ambiguity of real world context. The absence of ambiguous real-world examples in evaluation undermines the ability to reliably test machine learning performance, which makes ML models prone to develop “weak spots”, i.e., classes of examples that are difficult or impossible for a model to accurately evaluate, because that class of examples is missing from the evaluation set.

To address the problem of identifying these weaknesses in ML models, we recently launched the Crowdsourcing Adverse Test Sets for Machine Learning (CATS4ML) Data Challenge at HCOMP 2020 (open until 30 April, 2021 to researchers and developers worldwide). The goal of the challenge is to raise the bar in ML evaluation sets and to find as many examples as possible that are confusing or otherwise problematic for algorithms to process. CATS4ML relies on people’s abilities and intuition to spot new data examples about which machine learning is confident, but actually misclassifies.

What are ML “Weak Spots”?
There are two categories of weak spots: known unknowns and unknown unknowns. Known unknowns are examples for which a model is unsure about the correct classification. The research community continues to study this in a field known as active learning, and has found the solution to be, in very general terms, to interactively solicit new labels from people on uncertain examples. For example, if a model is not certain whether or not the subject of a photo is a cat, a person is asked to verify; but if the system is certain, a person is not asked. While there is room for improvement in this area, what is comforting is that the confidence of the model is correlated with its performance, i.e., one can see what the model doesn’t know.

Unknown unknowns, on the other hand, are examples where a model is confident about its answer, but is actually wrong. Efforts to proactively discover unknown unknowns (e.g., Attenberg 2015 and Crawford 2019) have helped uncover a multitude of unintended machine behaviours. In contrast to such approaches for the discovery of unknown unknowns, generative adversarial networks (GANs) generate unknown unknowns for image recognition models in the form of optical illusions for computers that cause deep learning models to make mistakes beyond human perception. While GANs uncover model exploits in the event of an intentional manipulation, real-world examples can better highlight a model’s failures in its day-to-day performance. These real-world examples are the unknown unknowns of interest to CATS4ML — the challenge aims to gather unmanipulated examples that humans can reliably interpret but on which many ML models would confidently disagree.

Example illustrating how optical illusions for computers caused by adversarial noise help discover machine manipulated unknown unknowns for ML models (based on Brown 2018).

First Edition of CATS4ML Data Challenge: Open Images Dataset
The CATS4ML Data Challenge focuses on visual recognition, using images and labels from the Open Images Dataset. The target images for the challenge are selected from the Open Images Dataset along with a set of 24 target labels from the same dataset. The challenge participants are invited to invent new and creative ways to explore this existing publicly available dataset and, focussed on a list of pre-selected target labels, discover examples of unknown unknowns for ML models.

Examples from the Open Images Dataset as possible unknown unknowns for ML models.

CATS4ML is a complementary effort to FAIR’s recently introduced DynaBench research platform for dynamic data collection. Where DynaBench tackles issues with static benchmarks using ML models with humans in the loop, CATS4ML focuses on improving evaluation datasets for ML by encouraging the exploration of existing ML benchmarks for adverse examples that can be unknown unknowns. The results will help detect and avoid future errors, and also will give insights to model explainability.

In this way, CATS4ML aims to raise greater awareness of the problem by providing dataset resources that developers can use to uncover the weak spots of their algorithms. This will also inform researchers on how to create benchmark datasets for machine learning that are more balanced, diverse and socially aware.

Get Involved
We invite the global community of ML researchers and practitioners to join us in the effort of discovering interesting, difficult examples from the Open Images Dataset. Register on the challenge website, download the target images and labeled data, contribute the images you discover and join the competition for the winning participant!

To score points in this competition, participants should submit a set of image-label pairs that will be confirmed by human-in-the-loop raters, whose votes should be in disagreement with the average machine score for the label over a number of machine learning models.

An example of how a submitted image can score points. The same image can score as a false positive (Left) and as a false negative (Right) with two different labels. In both cases the human verification is in disagreement with the machine score. Participants score on submitted image-label pairs, which means that one and the same image can be an example of an ML unknown unknown for different labels.

The challenge is open until 30 April, 2021 to researchers and developers worldwide. To learn more about CATS4ML and how to join, please review these slides and visit the challenge website.

Acknowledgements
The release of the CATS4ML Data Challenge has been possible thanks to the hard work of a lot of people including, but not limited to, the following (in alphabetical order of last name): Osman Aka, Ken Burke, Tulsee Doshi, Mig Gerard, Victor Gomes, Shahab Kamali, Igor Karpov, Devi Krishna, Daphne Luong, Carey Radebaugh, Jamie Taylor, Nithum Thain, Kenny Wibowo, Ka Wong, and Tong Zhou.

Categories
Offsites

torch, tidymodels, and high-energy physics

So what’s with the clickbait (high-energy physics)? Well, it’s not just clickbait. To showcase TabNet, we will be using the Higgs dataset (Baldi, Sadowski, and Whiteson (2014)), available at UCI Machine Learning Repository. I don’t know about you, but I always enjoy using datasets that motivate me to learn more about things. But first, let’s get acquainted with the main actors of this post!

TabNet

TabNet was introduced in Arik and Pfister (2020). It is interesting for three reasons:

  • It claims highly competitive performance on tabular data, an area where deep learning has not gained much of a reputation yet.

  • TabNet includes interpretability1 features by design.

  • It is claimed to significantly profit from self-supervised pre-training, again in an area where this is anything but undeserving of mention.

In this post, we won’t go into (3), but we do expand on (2), the ways TabNet allows access to its inner workings.

How do we use TabNet from R? The torch ecosystem includes a package – tabnet – that not only implements the model of the same name, but also allows you to make use of it as part of a tidymodels workflow.

tidymodels

To many R-using data scientists, the tidymodels framework will not be a stranger. tidymodels provides a high-level, unified approach to model training, hyperparameter optimization, and inference.

tabnet is the first (of many, we hope) torch models that let you use a tidymodels workflow all the way: from data pre-processing over hyperparameter tuning to performance evaluation and inference. While the first, as well as the last, may seem nice-to-have but not “mandatory”, the tuning experience is likely to be something you’ll won’t want to do without!

Using tabnet with tidymodels

In this post, we first showcase a tabnet-using workflow in a nutshell, making use of hyperparameter settings reported in the paper.

Then, we initiate a tidymodels-powered hyperparameter search, focusing on the basics but also, encouraging you to dig deeper at your leisure.

Finally, we circle back to the promise of interpretability, demonstrating what is offered by tabnet and ending in a short discussion.

In the flow with TabNet

As usual, we start by loading all required libraries. We also set a random seed, on the R as well as the torch sides. When model interpretation is part of your task, you will want to investigate the role of random initialization.

library(torch)
library(tabnet)
library(tidyverse)
library(tidymodels)
library(finetune) # to use tuning functions from the new finetune package
library(vip) # to plot feature importances

set.seed(777)
torch_manual_seed(777)

Next, we load the dataset.

# download from https://archive.ics.uci.edu/ml/datasets/HIGGS
higgs <- read_csv(
  "HIGGS.csv",
  col_names = c("class", "lepton_pT", "lepton_eta", "lepton_phi", "missing_energy_magnitude",
                "missing_energy_phi", "jet_1_pt", "jet_1_eta", "jet_1_phi", "jet_1_b_tag",
                "jet_2_pt", "jet_2_eta", "jet_2_phi", "jet_2_b_tag", "jet_3_pt", "jet_3_eta",
                "jet_3_phi", "jet_3_b_tag", "jet_4_pt", "jet_4_eta", "jet_4_phi", "jet_4_b_tag",
                "m_jj", "m_jjj", "m_lv", "m_jlv", "m_bb", "m_wbb", "m_wwbb"),
  col_types = "fdddddddddddddddddddddddddddd"
  )

What’s this about? In high-energy physics, the search for new particles takes place at powerful particle accelerators, such as (and most prominently) CERN’s Large Hadron Collider. In addition to actual experiments, simulation plays an important role. In simulations, “measurement” data are generated according to different underlying hypotheses, resulting in distributions that can be compared with each other. Given the likelihood of the simulated data, the goal then is to make inferences about the hypotheses.

The above dataset (Baldi, Sadowski, and Whiteson (2014)) results from just such a simulation. It explores what features could be measured assuming two different processes. In the first process, two gluons collide, and a heavy Higgs boson is produced; this is the signal process, the one we’re interested in. In the second, the collision of the gluons results in a pair of top quarks – this is the background process.

Through different intermediaries, both processes result in the same end products – so tracking these does not help. Instead, what the paper authors did was simulate kinematic features (momenta, specifically) of decay products, such as leptons (electrons and protons) and particle jets. In addition, they constructed a number of high-level features, features that presuppose domain knowledge. In their article, they showed that, in contrast to other machine learning methods, deep neural networks did nearly as well when presented with the low-level features (the momenta) only as with just the high-level features alone.

Certainly, it would be interesting to double-check these results on tabnet, and then, look at the respective feature importances. However, given the size of the dataset, non-negligible computing resources (and patience) will be required.

Speaking of size, let’s take a look:

higgs %>% glimpse()
Rows: 11,000,000
Columns: 29
$ class                    <fct> 1.000000000000000000e+00, 1.000000…
$ lepton_pT                <dbl> 0.8692932, 0.9075421, 0.7988347, 1…
$ lepton_eta               <dbl> -0.6350818, 0.3291473, 1.4706388, …
$ lepton_phi               <dbl> 0.225690261, 0.359411865, -1.63597…
$ missing_energy_magnitude <dbl> 0.3274701, 1.4979699, 0.4537732, 1…
$ missing_energy_phi       <dbl> -0.68999320, -0.31300953, 0.425629…
$ jet_1_pt                 <dbl> 0.7542022, 1.0955306, 1.1048746, 1…
$ jet_1_eta                <dbl> -0.24857314, -0.55752492, 1.282322…
$ jet_1_phi                <dbl> -1.09206390, -1.58822978, 1.381664…
$ jet_1_b_tag              <dbl> 0.000000, 2.173076, 0.000000, 0.00…
$ jet_2_pt                 <dbl> 1.3749921, 0.8125812, 0.8517372, 2…
$ jet_2_eta                <dbl> -0.6536742, -0.2136419, 1.5406590,…
$ jet_2_phi                <dbl> 0.9303491, 1.2710146, -0.8196895, …
$ jet_2_b_tag              <dbl> 1.107436, 2.214872, 2.214872, 2.21…
$ jet_3_pt                 <dbl> 1.1389043, 0.4999940, 0.9934899, 1…
$ jet_3_eta                <dbl> -1.578198314, -1.261431813, 0.3560…
$ jet_3_phi                <dbl> -1.04698539, 0.73215616, -0.208777…
$ jet_3_b_tag              <dbl> 0.000000, 0.000000, 2.548224, 0.00…
$ jet_4_pt                 <dbl> 0.6579295, 0.3987009, 1.2569546, 0…
$ jet_4_eta                <dbl> -0.01045457, -1.13893008, 1.128847…
$ jet_4_phi                <dbl> -0.0457671694, -0.0008191102, 0.90…
$ jet_4_btag               <dbl> 3.101961, 0.000000, 0.000000, 0.00…
$ m_jj                     <dbl> 1.3537600, 0.3022199, 0.9097533, 0…
$ m_jjj                    <dbl> 0.9795631, 0.8330482, 1.1083305, 1…
$ m_lv                     <dbl> 0.9780762, 0.9856997, 0.9856922, 0…
$ m_jlv                    <dbl> 0.9200048, 0.9780984, 0.9513313, 0…
$ m_bb                     <dbl> 0.7216575, 0.7797322, 0.8032515, 0…
$ m_wbb                    <dbl> 0.9887509, 0.9923558, 0.8659244, 1…
$ m_wwbb                   <dbl> 0.8766783, 0.7983426, 0.7801176, 0…

Eleven million “observations” (kind of) – that’s a lot! Like the authors of the TabNet paper (Arik and Pfister (2020)), we’ll use 500,000 of these for validation. (Unlike them, though, we won’t be able to train for 870,000 iterations!)

The first variable, class, is either 1 or 0, depending on whether a Higgs boson was present or not. While in experiments, only a tiny fraction of collisions produce one of those, both classes are about equally frequent in this dataset.

As for the predictors, the last seven are high-level (derived). All others are “measured”.

Data loaded, we’re ready to build a tidymodels workflow, resulting in a short sequence of concise steps.

First, split the data:

n <- 11000000
n_test <- 500000
test_frac <- n/n_all

split <- initial_time_split(higgs, prop = 1 - test_frac)
train <- training(split)
test  <- testing(split)

Second, create a recipe. We want to predict class from all other features present:

rec <- recipe(class ~ ., train) 

Third, create a parsnip model specification of class tabnet. The parameters passed are those reported by the TabNet paper, for the S-sized model variant used on this dataset.2

# hyperparameter settings (apart from epochs) as per the TabNet paper (TabNet-S)
mod <- tabnet(epochs = 3, batch_size = 16384, decision_width = 24, attention_width = 26,
              num_steps = 5, penalty = 0.000001, virtual_batch_size = 512, momentum = 0.6,
              feature_reusage = 1.5, learn_rate = 0.02) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("classification")

Fourth, bundle recipe and model specifications in a workflow:

wf <- workflow() %>%
  add_model(mod) %>%
  add_recipe(rec)

Fifth, train the model. This will take some time. Training finished, we save the trained parsnip model, so we can reuse it at a later time.

fitted_model <- wf %>% fit(train)

# access the underlying parsnip model and save it to RDS format
# depending on when you read this, a nice wrapper may exist
# see https://github.com/mlverse/tabnet/issues/27  
fitted_model$fit$fit$fit %>% saveRDS("saved_model.rds")

After three epochs, loss was at 0.609.

Sixth – and finally – we ask the model for test-set predictions and have accuracy computed.

preds <- test %>% 
  bind_cols(predict(fitted_model, test))

yardstick::accuracy(preds, class, .pred_class)
# A tibble: 1 x 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.672

We didn’t quite arrive at the accuracy reported in the TabNet paper (0.783), but then, we only trained for a tiny fraction of the time.

In case you’re thinking: well, that was a nice and effortless way of training a neural network! – just wait and see how easy hyperparameter tuning can get. In fact, no need to wait, we’ll take a look right now.

TabNet tuning

For hyperparameter tuning, the tidymodels framework makes use of cross-validation. With a dataset of considerable size, some time and patience is needed; for the purpose of this post, I’ll use 1/1,000 of observations.

Changes to the above workflow start at model specification. Let’s say we’ll leave most settings fixed, but vary the TabNet-specific hyperparameters decision_width, attention_width, and num_steps, as well as the learning rate:3

mod <- tabnet(epochs = 1, batch_size = 16384, decision_width = tune(), attention_width = tune(),
              num_steps = tune(), penalty = 0.000001, virtual_batch_size = 512, momentum = 0.6,
              feature_reusage = 1.5, learn_rate = tune()) %>%
  set_engine("torch", verbose = TRUE) %>%
  set_mode("classification")

Workflow creation looks the same as before:

wf <- workflow() %>%
  add_model(mod) %>%
  add_recipe(rec)

Next, we specify the hyperparameter ranges we’re interested in, and call one of the grid construction functions from the dials package to build one for us. If it wasn’t for demonstration purposes, we’d probably want to have more than eight alternatives though, and pass a higher size to grid_max_entropy() .

grid <-
  wf %>%
  parameters() %>%
  update(
    decision_width = decision_width(range = c(20, 40)),
    attention_width = attention_width(range = c(20, 40)),
    num_steps = num_steps(range = c(4, 6)),
    learn_rate = learn_rate(range = c(-2.5, -1))
  ) %>%
  grid_max_entropy(size = 8)

grid
# A tibble: 8 x 4
  learn_rate decision_width attention_width num_steps
       <dbl>          <int>           <int>     <int>
1    0.00529             28              25         5
2    0.0858              24              34         5
3    0.0230              38              36         4
4    0.0968              27              23         6
5    0.0825              26              30         4
6    0.0286              36              25         5
7    0.0230              31              37         5
8    0.00341             39              23         5

To search the space, we use tune_race_anova() from the new finetune package, making use of five-fold cross-validation:

ctrl <- control_race(verbose_elim = TRUE)
folds <- vfold_cv(train, v = 5)
set.seed(777)

res <- wf %>% 
    tune_race_anova(
    resamples = folds, 
    grid = grid,
    control = ctrl
  )

We can now extract the best hyperparameter combinations:

res %>% show_best("accuracy") %>% select(- c(.estimator, .config))
# A tibble: 5 x 8
  learn_rate decision_width attention_width num_steps .metric   mean     n std_err
       <dbl>          <int>           <int>     <int> <chr>    <dbl> <int>   <dbl>
1     0.0858             24              34         5 accuracy 0.516     5 0.00370
2     0.0230             38              36         4 accuracy 0.510     5 0.00786
3     0.0230             31              37         5 accuracy 0.510     5 0.00601
4     0.0286             36              25         5 accuracy 0.510     5 0.0136 
5     0.0968             27              23         6 accuracy 0.498     5 0.00835

It’s hard to imagine how tuning could be more convenient!

Now, we circle back to the original training workflow, and inspect TabNet’s interpretability features.

TabNet interpretability features

TabNet’s most prominent characteristic is the way – inspired by decision trees – it executes in distinct steps. At each step, it again looks at the original input features, and decides which of those to consider based on lessons learned in prior steps. Concretely, it uses an attention mechanism to learn sparse masks which are then applied to the features.

Now, these masks being “just” model weights means we can extract them and draw conclusions about feature importance. Depending on how we proceed, we can either

  • aggregate mask weights over steps, resulting in global per-feature importances;

  • run the model on a few test samples and aggregate over steps, resulting in observation-wise feature importances; or

  • run the model on a few test samples and extract individual weights observation- as well as step-wise.

This is how to accomplish the above with tabnet.

Per-feature importances

We continue with the fitted_model workflow object we ended up with at the end of part 1. vip::vip is able to display feature importances directly from the parsnip model:

fit <- pull_workflow_fit(fitted_model)
vip(fit) + theme_minimal()
Global feature importances.

(#fig:unnamed-chunk-16)Global feature importances.

Together, two high-level features dominate, accounting for nearly 50% of overall attention. Along with a third high-level feature, ranked in place four, they occupy about 60% of “importance space”.

Observation-level feature importances

We choose the first hundred observations in the test set to extract feature importances. Due to how TabNet enforces sparsity, we see that many features have not been made use of:

ex_fit <- tabnet_explain(fit$fit, test[1:100, ])

ex_fit$M_explain %>%
  mutate(observation = row_number()) %>%
  pivot_longer(-observation, names_to = "variable", values_to = "m_agg") %>%
  ggplot(aes(x = observation, y = variable, fill = m_agg)) +
  geom_tile() +
  theme_minimal() + 
  scale_fill_viridis_c()
Per-observation feature importances.

(#fig:unnamed-chunk-18)Per-observation feature importances.

Per-step, observation-level feature importances

Finally and on the same selection of observations, we again inspect the masks, but this time, per decision step:

ex_fit$masks %>% 
  imap_dfr(~mutate(
    .x, 
    step = sprintf("Step %d", .y),
    observation = row_number()
  )) %>% 
  pivot_longer(-c(observation, step), names_to = "variable", values_to = "m_agg") %>% 
  ggplot(aes(x = observation, y = variable, fill = m_agg)) +
  geom_tile() +
  theme_minimal() + 
  theme(axis.text = element_text(size = 5)) +
  scale_fill_viridis_c() +
  facet_wrap(~step)
Per-observation, per-step feature importances.

(#fig:unnamed-chunk-20)Per-observation, per-step feature importances.

This is nice: We clearly see how TabNet makes use of different features at different times.

So what do we make of this? It depends. Given the enormous societal importance of this topic – call it interpretability, explainability, or whatever – let’s finish this post with a short discussion.

Interpretable, explainable, …? Beyond the arbitrariness of definitions

An internet search for “interpretable vs. explainable ML” immediately turns up a number of sites confidently stating “interpretable ML is …” and “explainable ML is …”, as though there were no arbitrariness in common-speech definitions. Going deeper, you find articles such as Cynthia Rudin’s “Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead” (Rudin (2018)) that present you with a clear-cut, deliberate, instrumentalizable distinction that can actually be used in real-world scenarios.

In a nutshell, what she decides to call explainability is: approximate a black-box model by a simpler (e.g., linear) model and, starting from the simple model, make inferences about how the black-box model works. One of the examples she gives for how this could fail is so striking I’d like to fully cite it:

Even an explanation model that performs almost identically to a black box model might use completely different features, and is thus not faithful to the computation of the black box. Consider a black box model for criminal recidivism prediction, where the goal is to predict whether someone will be arrested within a certain time after being released from jail/prison. Most recidivism prediction models depend explicitly on age and criminal history, but do not explicitly depend on race. Since criminal history and age are correlated with race in all of our datasets, a fairly accurate explanation model could construct a rule such as “This person is predicted to be arrested because they are black.” This might be an accurate explanation model since it correctly mimics the predictions of the original model, but it would not be faithful to what the original model computes.

What she calls interpretability, in contrast, is deeply related to domain knowledge:

Interpretability is a domain-specific notion […] Usually, however, an interpretable machine learning model is constrained in model form so that it is either useful to someone, or obeys structural knowledge of the domain, such as monotonicity [e.g.,8], causality, structural (generative) constraints, additivity [9], or physical constraints that come from domain knowledge. Often for structured data, sparsity is a useful measure of interpretability […]. Sparse models allow a view of how variables interact jointly rather than individually. […] e.g., in some domains, sparsity is useful,and in others is it not.

If we accept these well-thought-out definitions, what can we say about TabNet? Is looking at attention masks more like constructing a post-hoc model or more like having domain knowledge incorporated? I believe Rudin would argue the former, since

  • the image-classification example she uses to point out weaknesses of explainability techniques employs saliency maps, a technical device comparable, in some ontological sense, to attention masks;

  • the sparsity enforced by TabNet is a technical, not a domain-related constraint;

  • we only know what features were used by TabNet, not how it used them.

On the other hand, one could disagree with Rudin (and others) about the premises. Do explanations have to be modeled after human cognition to be considered valid? Personally, I guess I’m not sure, and to cite from a post by Keith O’Rourke on just this topic of interpretability,

As with any critically-thinking inquirer, the views behind these deliberations are always subject to rethinking and revision at any time.

In any case though, we can be sure that this topic’s importance will only grow with time. While in the very early days of the GDPR (the EU General Data Protection Regulation) it was said that Article 22 (on automated decision-making) would have significant impact on how ML is used4, unfortunately the current view seems to be that its wordings are far too vague to have immediate consequences (e.g., Wachter, Mittelstadt, and Floridi (2017)). But this will be a fascinating topic to follow, from a technical as well as a political point of view.

Thanks for reading!

Arik, Sercan O., and Tomas Pfister. 2020. “TabNet: Attentive Interpretable Tabular Learning.” http://arxiv.org/abs/1908.07442.

Baldi, P., P. Sadowski, and D. Whiteson. 2014. “Searching for exotic particles in high-energy physics with deep learning.” Nature Communications 5 (July): 4308. https://doi.org/10.1038/ncomms5308.

Rudin, Cynthia. 2018. “Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead.” http://arxiv.org/abs/1811.10154.

Wachter, Sandra, Brent Mittelstadt, and Luciano Floridi. 2017. “Why a Right to Explanation of Automated Decision-Making Does Not Exist in the General Data Protection Regulation.” International Data Privacy Law 7 (2): 76–99. https://doi.org/10.1093/idpl/ipx005.

  1. I’m using the term “naively” here. For a short discussion, see the final section, [Interpretable, explainable, you tell me – beyond arbitrary definitions].↩︎

  2. Apart from the number of epochs, that is.↩︎

  3. The number of epochs is set to one for demonstration purposes only; in reality, you will want to tune this as well.↩︎

  4. See, e.g., <http://www.odbms.org/2018/07/ai-machine-learning-and-the-gdpr-are-the-wild-west-days-of-advanced-analytics-over/>.↩︎

Categories
Offsites

The ants and the pheromones

TLDR; this is the last edition of The Morning Paper for now. Plus: one strand of research you won’t want to miss!

I was listening to a BBC Radio 4 podcast recently (More or Less: Behind the Stats – Ants and Algorithms) in which the host Tim Harford is interviewing David Sumpter about his recent book, ‘The ten equations that rule the world.’ One of those equations, the ‘reward equation’ models how ants communicate using pheromones, and our own brains keep track of rewards using dopamine.

About 4 and a half minutes into the podcast Tim asks a fascinating question: the reward equation includes a decay or ‘forgetting’ parameter, so what happens if you disrupt established solutions for long enough that their hold is broken? For example, the complete disruption to our established routines that Covid has caused over the last year? The answer for ants, if you disrupt all of the pheromone trails around their nest, is that they converge on a new solution in the environment, but it won’t necessarily look the same as the one they had before the disruption. (If you’re interested in the amazing problem-solving skills of ants and how we can learn from them in computer science, I covered ‘Ant algorithms for discrete optimization’ in a previous edition of The Morning Paper). It’s highly likely that the same thing will happen to us when we can eventually return to normal – the patterns that we establish won’t necessarily be the same as the ones we had before the series of lockdowns began.

The lockdowns (as I write this, we’re in another strict lockdown in England, with no end date given) have certainly disrupted my own routines. I’ve lost the time and space that I depended on for studying and writing The Morning Paper – the one-hour each way train journey on my morning commute, and more crucially with two older children both full-time studying from home, the time and space within the home for the many hours of concentrated work required. I don’t think my love of learning will ever leave me though, and at the same time I’ve been branching out and studying other things: philosophy, ethics, physics, a little maths, a little biology,… I’m really enjoying that. My love of computer science remains of course, but when we finally get to lay down our new pheromone trails and establish a new normal, I’m not sure I’m going to want to focus on computer science to the exclusion of all else. It’s been an intense six-and-a-half years doing largely that while writing the blog. For the time being then, I’m putting The Morning Paper back on pause.

Before I wrap up though, I can’t resist pointing you in the direction of one incredibly exciting research project from the Hydro team at Berkeley’s RISELab. Joe Hellerstein recently posted a whole bunch of links and resources in this Twitter thread:

The “PACT” paper is here: New Directions in Cloud Programming, Cheung et al., CIDR 2021.

Categories
Offsites

TracIn — A Simple Method to Estimate Training Data Influence

The quality of a machine learning (ML) model’s training data can have a significant impact on its performance. One measure of data quality is the notion of influence, i.e., the degree to which a given training example affects the model and its predictive performance. And while influence is a well-known concept to ML researchers, the complexity behind deep learning models, coupled with their growing size, features and datasets, have made the quantification of influence difficult.

A few methods have been proposed recently to quantify influence. Some rely on changes in accuracy when retraining with one or several data points dropped, and some use established statistical methods, e.g., influence functions that estimate the impact of perturbing input points or representer methods that decompose a prediction into an importance weighted combination of training examples. Still other approaches require use of additional estimators, such as data valuation using reinforcement learning. Though these approaches are theoretically sound, their use in products has been limited by the resources needed to run them at scale or the additional burdens they place on training.

In “Estimating Training Data Influence by Tracing Gradient Descent”, published as a spotlight paper at NeurIPS 2020, we proposed TracIn, a simple scalable approach to tackle this challenge. The idea behind TracIn is straightforward — trace the training process to capture changes in prediction as individual training examples are visited. TracIn is effective in finding mislabeled examples and outliers from a variety of datasets, and is useful in explaining predictions in terms of training examples (as opposed to features) by assigning an influence score to each training example.

The Ideas Underlying TracIn
Deep learning algorithms are typically trained using an algorithm called stochastic gradient descent (SGD), or a variant of it. SGD operates by making multiple passes over the data and making modifications to the model parameters that locally reduce the loss (i.e., the model’s objective) with each pass. An example of this is demonstrated for an image classification task in the figure below, where the model’s task is to predict the subject of the test image on the left (“zucchini”). As the model progresses through training, it is exposed to various training examples that affect the loss on the test image, where the loss is a function both of the prediction score and the actual label — the higher the prediction score for zucchini, the lower the loss.

Estimating training data influence of the images on the right by tracing the loss change of the zucchini in the seatbelt image during training.

Suppose that the test example is known at training time and that the training process visited each training example one at a time. During the training, visiting a specific training example would change the model’s parameters, and that change would then modify the prediction/loss on the test example. If one could trace the training example through the process, then the change in loss or prediction on the test example could be attributed to the training example in question, where the influence of a training example would be the cumulative attribution across visits to the training example.

There are two types of relevant training examples. Those that reduce loss, like the images of zucchinis above, are called proponents, while those that increase loss, like the images of seatbelts, are called opponents. In the example above, the image labeled “sunglasses” is also a proponent, because it has a seatbelt in the image, but is labeled as “sunglasses,” driving the model to better distinguish between zucchini and seatbelts.

In practice, the test example is unknown at training time, a limitation that can be overcome by using the checkpoints output by the learning algorithm as a sketch of the training process. Another challenge is that the learning algorithm typically visits several points at once, not individually, which requires a method to disentangle the relative contributions of each training example. This can be done by applying pointwise loss gradients. Together, these two strategies capture the TracIn method, which can be reduced to the simple form of the dot product of loss gradients of the test and training examples, weighted by the learning rate, and summed across checkpoints.

The simple expression for TracIn influence. The dot product of loss gradients of training example (z) and test example (z’) is weighted by learning rate (ηi) at different checkpoints and summed up.

Alternatively, one could instead examine the influence on the prediction score, which would be useful if the test example has no label. This form simply requires the substitution of the loss gradient at the test example with the prediction gradient.

Computing Top Influence Examples
We illustrate the utility of TracIn by first calculating the loss gradient vector for some training data and a test example for a specific classification — an image of a chameleon — and then leveraging a standard k-nearest neighbors library to retrieve the top proponents and opponents. The top opponents indicate the chameleon’s ability to blend in! For comparison, we also show the k nearest neighbors with embeddings from the penultimate layer. Proponents are images that are not only similar, but also belong to the same class, and opponents are similar images but in a different class. Note that there isn’t an explicit enforcement on whether proponents or opponents belong to the same class.

Top row: Top proponents and opponents of influence vectors. Bottom row: Most similar and dissimilar examples of embedding vectors from the penultimate layer.

Clustering
The simplistic breakdown of the loss of the test example into training example influences given by TracIn also suggests that the loss (or prediction) from any gradient descent based neural model can be expressed as a sum of similarities in the space of gradients. Recent work has demonstrated that this functional form is similar to that of a kernel, implying that this gradient similarity described here can be applied to other similarity tasks, like clustering.

In this case, TracIn can be used as a similarity function within a clustering algorithm. To bound the similarity metric so that it can be converted to a distance measure (1 – similarity), we normalize the gradient vectors to have unit norm. Below, we apply TracIn clustering on images of zucchini to obtain finer clusters.

Finer clusters within Zucchini images using TracIn similarity. Each row is a cluster with zucchini in similar forms within the cluster: cross-sectionally sliced zucchini (top), piles of zucchinis (middle), and zucchinis on pizzas (bottom).

Identifying Outliers with Self-Influence
Finally, we can also use TracIn to identify outliers that exhibit a high self-influence, i.e., the influence of a training point on its own prediction. This happens either when the example is mislabeled or rare, both of which make it difficult for the model to generalize over the example. Below are some examples with high self-influence.

Mislabeled examples. Assigned labels are striked out, correct labels are at bottom.

Left: A rare oscilloscope example with just the oscillations, and no instrument in the image gets high self-influence. Right: Other common oscilloscope images have the scope with knobs and wires. These have a low self-influence.

Applications
Having no requirement other than being trained using SGD (or related variants), TracIn is task-independent and applicable to a variety of models. For example, we have used TracIn to study training data for a deep learning model used to parse queries to the Google Assistant, queries of the kind “set my alarm for 7AM”. We were intrigued to see that the top opponent for the query “disable my alarm” with an alarm active on the device, was “disable my timer”, also with an alarm active on the device. This suggests that Assistant users often interchange the words “timer” and “alarm”. TracIn helped us interpret the Assistant data.

More examples can be found in the paper, including a regression task on structured data and a number of text classification tasks.

Conclusion
TracIn is a simple, easy-to-implement, scalable way to compute the influence of training data examples on individual predictions or to find rare and mislabeled training examples. For implementation references of the method, you can find a link to code examples for images from the github linked in the paper.

Acknowledgements
The NeurIPS paper was jointly co-authored with Satyen Kale and Mukund Sundararajan (corresponding author). A special thanks to Binbin Xiong for providing various conceptual and implementation insights. We also thank Qiqi Yan and Salem Haykal for numerous discussions. Images throughout this post sourced from Getty Images.

Categories
Offsites

Machine Learning for Computer Architecture

One of the key contributors to recent machine learning (ML) advancements is the development of custom accelerators, such as Google TPUs and Edge TPUs, which significantly increase available compute power unlocking various capabilities such as AlphaGo, RankBrain, WaveNets, and Conversational Agents. This increase can lead to improved performance in neural network training and inference, enabling new possibilities in a broad range of applications, such as vision, language, understanding, and self-driving cars.

To sustain these advances, the hardware accelerator ecosystem must continue to innovate in architecture design and acclimate to rapidly evolving ML models and applications. This requires the evaluation of many different accelerator design points, each of which may not only improve the compute power, but also unravel a new capability. These design points are generally parameterized by a variety of hardware and software factors (e.g., memory capacity, number of compute units at different levels, parallelism, interconnection networks, pipelining, software mapping, etc.). This is a daunting optimization task, due to the fact that the search space is exponentially large1 while the objective function (e.g., lower latency and/or higher energy efficiency) is computationally expensive to evaluate through simulations or synthesis, making identification of feasible accelerator configurations challenging .

In “Apollo: Transferable Architecture Exploration”, we present the progress of our research on ML-driven design of custom accelerators. While recent work has demonstrated promising results in leveraging ML to improve the low-level floorplanning process (in which the hardware components are spatially laid out and connected in silicon), in this work we focus on blending ML into the high-level system specification and architectural design stage, a pivotal contributing factor to the overall performance of the chip in which the design elements that control the high-level functionality are established. Our research shows how ML algorithms can facilitate architecture exploration and suggest high-performing architectures across a range of deep neural networks, with domains spanning image classification, object detection, OCR and semantic segmentation.

Architecture Search Space and Workloads
The objective in architecture exploration is to discover a set of feasible accelerator parameters for a set of workloads, such that a desired objective function (e.g., the weighted average of runtime) is minimized under an optional set of user-defined constraints. However, the manifold of architecture search generally contains many points for which there is no feasible mapping from software to hardware. Some of these design points are known a priori and can be bypassed by formulating them as optimization constraints by the user (e.g., in the case of an area budget2 constraint, the total memory size must not pass over a predefined limit). However, due to the interplay of the architecture and compiler and the complexity of the search space, some of the constraints may not be properly formulated into the optimization, and so the compiler may not find a feasible software mapping for the target hardware. These infeasible points are not easy to formulate in the optimization problem, and are generally unknown until the whole compiler pass is performed. As such, one of main challenges for architecture exploration is to effectively sidestep the infeasible points for efficient exploration of the search space with a minimum number of cycle-accurate architecture simulations.

The following figure shows the overall architecture search space of a target ML accelerator. The accelerator contains a 2D array of processing elements (PE), each of which performs a set of arithmetic computations in a single instruction multiple data (SIMD) manner. The main architectural components of each PE are processing cores that include multiple compute lanes for SIMD operations. Each PE has shared memory (PE Memory) across all their compute cores, which is mainly used to store model activations, partial results, and outputs, while individual cores feature memory that is mainly used for storing model parameters. Each core has multiple compute lanes with multi-way multiply-accumulate (MAC) units. The results of model computations at each cycle are either stored back in the PE memory for further computation or are offloaded back into the DRAM.

Overview of the template-based ML accelerator used for architecture exploration.

Optimization Strategies
In this study, we explored four optimization strategies in the context of architecture exploration:

  1. Random:Samples the architecture search space uniformly at random.
  2. Vizier: Uses Bayesian optimization for the exploration in the search space in which the evaluation of the objective function is expensive (e.g. hardware simulation which can take hours to complete). Using a collection of sampled points from the search space, the Bayesian optimization forms a surrogate function, usually represented by a Gaussian process, that approximates the manifold of the search space. Guided by the value of the surrogate function, the Bayesian optimization algorithm decides, in an exploration and exploitation trade-off, whether to sample more from the promising regions in the manifold (exploitation) or sample more from the unseen regions in the search space (exploration). Then, the optimization algorithm uses these newly sampled points and further updates the surrogate function to better model the target search space. Vizier uses expected improvement as its core acquisition function.
  3. Evolutionary: Performs evolutionary search using a population of k individuals, where the genome of each individual corresponds to a sequence of discretized accelerator configurations. New individuals are generated by selecting for each individual two parents from the population using tournament selecting, recombining their genomes with some crossover rate, and mutating the recombined genome with some probability.
  4. Population-based black-box optimization (P3BO): Uses an ensemble of optimization methods, including evolutionary and model-based, which has been shown to increase sample-efficiency and robustness. The sampled data are exchanged between optimization methods in the ensemble, and optimizers are weighted by their performance history to generate new configurations. In our study, we use a variant of P3BO in which the hyper-parameters of the optimizers are dynamically updated using evolutionary search.

Accelerator Search Space Embeddings
To better visualize the effectiveness of each optimization strategy in navigating the accelerator search space, we use t-distributed stochastic neighbor embedding (t-SNE) to map the explored configurations into a two-dimensional space across the optimization horizon. The objective (reward) for all the experiments is defined as the throughput (inference/second) per accelerator area. In the figures below, the x and y axes indicate the t-SNE components (embedding 1 and embedding 2) of the embedding space. The star and circular markers show the infeasible (zero reward) and feasible design points, respectively, with the size of the feasible points corresponding to their reward.

As expected, the random strategy searches the space in a uniformly distributed way and eventually finds very few feasible points in the design space.

Visualization presenting the t-SNE of the explored design points (~4K) by random optimization strategy (max reward = 0.96). The maximum reward points (red cross markers) are highlighted at the last frame of the animation.

Compared to the random sampling approach, the Vizier default optimization strategy strikes a good balance between exploring the search space and finding the design points with higher rewards (1.14 vs. 0.96). However, this approach tends to get stuck in infeasible regions and, while it does find a few points with the maximum reward (indicated by the red cross markers), it finds few feasible points during the last iterations of exploration.

As above, with the Vizier (default) optimization strategy (max reward = 1.14). The maximum reward points (red cross markers) are highlighted at the last frame of the animation.

The evolutionary optimization strategy, on the other hand, finds feasible solutions very early in the optimization and assemble clusters of feasible points around them. As such, this approach mostly navigates the feasible regions (the green circles) and efficiently sidesteps the infeasible points. In addition, the evolutionary search is able to find more design options with maximum reward (the red crosses). This diversity in the solutions with high reward provides flexibility to the designer in exploring various architectures with different design trade-offs.

As above, with the evolutionary optimization strategy (max reward = 1.10). The maximum reward points (red cross markers) are highlighted at the last frame of the animation.

Finally, the population-based optimization method (P3BO) explores the design space in a more targeted way (regions with high reward points) in order to find optimal solutions. The P3BO strategy finds design points with the highest reward in search spaces with tighter constraints (e.g., a larger number of infeasible points), showing its effectiveness in navigating search spaces with large numbers of infeasible points.

As above, with the P3BO optimization strategy (max reward = 1.13). The maximum reward points (red cross markers) are highlighted at the last frame of the animation.

Architecture Exploration under Different Design Constraints
We also studied the benefits of each optimization strategy under different area budget constraints, 6.8 mm2, 5.8 mm2 and 4.8 mm2. The following violin plots show the full distribution of the maximum achievable reward at the end of optimization (after ten runs each with 4K trials) across the studied optimization strategies. The wider sections represent a higher probability of observing feasible architecture configurations at a particular given reward. This implies that we favor the optimization algorithm that yields increased width at the points with higher reward (higher performance).

The two top-performing optimization strategies for architecture exploration are evolutionary and P3BO, both in terms of delivering solutions with high reward and robustness across multiple runs. Looking into different design constraints, we observe that as one tightens the area budget constraint, the P3BO optimization strategy yields more high performing solutions. For example, when the area budget constraint is set to 5.8 mm2, P3BO finds design points with a reward (throughput / accelerator area) of 1.25 outperforming all the other optimization strategies. The same trend is observed when the area budget constraint is set to 4.8 mm2, a slightly better reward is found with more robustness (less variability) across multiple runs.

Violin plot showing the full distribution of the maximum achievable reward in ten runs across the optimization strategies after 4K trial evaluations under an area budget of 6.8 mm2. The P3BO and Evolutionary algorithm yield larger numbers of high-performing designs (wider sections). The x and y axes indicate the studied optimization algorithms and the geometric mean of speedup (reward) over the baseline accelerator, respectively.
As above, under an area budget of 5.8 mm2.
As above, under an area budget of 4.8 mm2.

Conclusion
While Apollo presents the first step towards better understanding of accelerator design space and building more efficient hardware, inventing accelerators with new capabilities is still an uncharted territory and a new frontier. We believe that this research is an exciting path forward to further explore ML-driven techniques for architecture design and co-optimization (e.g., compiler, mapping, and scheduling) across the computing stack to invent efficient accelerators with new capabilities for the next generation of applications.

Acknowledgments
This work was performed by Amir Yazdanbakhsh, Christof Angermueller, and Berkin Akin . We would like to also thank Milad Hashemi, Kevin Swersky, James Laudon, Herman Schmit, Cliff Young, Yanqi Zhou, Albin Jones, Satrajit Chatterjee, Ravi Narayanaswami, Ray (I-Jui) Sung, Suyog Gupta, Kiran Seshadri, Suvinay Subramanian, Matthew Denton, and the Vizier team for their help and support.


1 In our target accelerator, the total number of design points is around 5 x 108

2 The chip area is approximately the sum of total hardware components on the chip, including on-chip storage, processing engines, controllers, I/O pins, and etc.  

Categories
Offsites

Simple audio classification with torch

This article translates Daniel Falbel’s ‘Simple Audio Classification’ article from tensorflow/keras to torch/torchaudio. The main goal is to introduce torchaudio and illustrate its contributions to the torch ecosystem. Here, we focus on a popular dataset, the audio loader and the spectrogram transformer. An interesting side product is the parallel between torch and tensorflow, showing sometimes the differences, sometimes the similarities between them.

library(torch)
library(torchaudio)

Downloading and Importing

torchaudio has the speechcommand_dataset built in. It filters out background_noise by default and lets us choose between versions v0.01 and v0.02.

# set an existing folder here to cache the dataset
DATASETS_PATH <- "~/datasets/"

# 1.4GB download
df <- speechcommand_dataset(
  root = DATASETS_PATH, 
  url = "speech_commands_v0.01",
  download = TRUE
)

# expect folder: _background_noise_
df$EXCEPT_FOLDER
# [1] "_background_noise_"

# number of audio files
length(df)
# [1] 64721

# a sample
sample <- df[1]

sample$waveform[, 1:10]
torch_tensor
0.0001 *
 0.9155  0.3052  1.8311  1.8311 -0.3052  0.3052  2.4414  0.9155 -0.9155 -0.6104
[ CPUFloatType{1,10} ]
sample$sample_rate
# 16000
sample$label
# bed

plot(sample$waveform[1], type = "l", col = "royalblue", main = sample$label)
A sample waveform for a 'bed'.

(#fig:unnamed-chunk-4)A sample waveform for a ‘bed’.

Classes

df$classes
 [1] "bed"    "bird"   "cat"    "dog"    "down"   "eight"  "five"  
 [8] "four"   "go"     "happy"  "house"  "left"   "marvin" "nine"  
[15] "no"     "off"    "on"     "one"    "right"  "seven"  "sheila"
[22] "six"    "stop"   "three"  "tree"   "two"    "up"     "wow"   
[29] "yes"    "zero"  

Generator Dataloader

torch::dataloader has the same task as data_generator defined in the original article. It is responsible for preparing batches – including shuffling, padding, one-hot encoding, etc. – and for taking care of parallelism / device I/O orchestration.

In torch we do this by passing the train/test subset to torch::dataloader and encapsulating all the batch setup logic inside a collate_fn() function.

set.seed(6)
id_train <- sample(length(df), size = 0.7*length(df))
id_test <- setdiff(seq_len(length(df)), id_train)
# subsets

train_subset <- torch::dataset_subset(df, id_train)
test_subset <- torch::dataset_subset(df, id_test)

At this point, dataloader(train_subset) would not work because the samples are not padded. So we need to build our own collate_fn() with the padding strategy.

I suggest using the following approach when implementing the collate_fn():

  1. begin with collate_fn <- function(batch) browser().
  2. instantiate dataloader with the collate_fn()
  3. create an environment by calling enumerate(dataloader) so you can ask to retrieve a batch from dataloader.
  4. run environment[[1]][[1]]. Now you should be sent inside collate_fn() with access to batch input object.
  5. build the logic.
collate_fn <- function(batch) {
  browser()
}

ds_train <- dataloader(
  train_subset, 
  batch_size = 32, 
  shuffle = TRUE, 
  collate_fn = collate_fn
)

ds_train_env <- enumerate(ds_train)
ds_train_env[[1]][[1]]

The final collate_fn() pads the waveform to length 16001 and then stacks everything up together. At this point there are no spectrograms yet. We going to make spectrogram transformation a part of model architecture.

pad_sequence <- function(batch) {
    # Make all tensors in a batch the same length by padding with zeros
    batch <- sapply(batch, function(x) (x$t()))
    batch <- torch::nn_utils_rnn_pad_sequence(batch, batch_first = TRUE, padding_value = 0.)
    return(batch$permute(c(1, 3, 2)))
  }

# Final collate_fn
collate_fn <- function(batch) {
 # Input structure:
 # list of 32 lists: list(waveform, sample_rate, label, speaker_id, utterance_number)
 # Transpose it
 batch <- purrr::transpose(batch)
 tensors <- batch$waveform
 targets <- batch$label_index

 # Group the list of tensors into a batched tensor
 tensors <- pad_sequence(tensors)
 
 # target encoding
 targets <- torch::torch_stack(targets)

 list(tensors = tensors, targets = targets) # (64, 1, 16001)
}

Batch structure is:

  • batch[[1]]: waveforms – tensor with dimension (32, 1, 16001)
  • batch[[2]]: targets – tensor with dimension (32, 1)

Also, torchaudio comes with 3 loaders, av_loader, tuner_loader, and audiofile_loader- more to come. set_audio_backend() is used to set one of them as the audio loader. Their performances differ based on audio format (mp3 or wav). There is no perfect world yet: tuner_loader is best for mp3, audiofile_loader is best for wav, but neither of them has the option of partially loading a sample from an audio file without bringing all the data into memory first.

For a given audio backend we need pass it to each worker through worker_init_fn() argument.

ds_train <- dataloader(
  train_subset, 
  batch_size = 128, 
  shuffle = TRUE, 
  collate_fn = collate_fn,
  num_workers = 16,
  worker_init_fn = function(.) {torchaudio::set_audio_backend("audiofile_loader")},
  worker_globals = c("pad_sequence") # pad_sequence is needed for collect_fn
)

ds_test <- dataloader(
  test_subset, 
  batch_size = 64, 
  shuffle = FALSE, 
  collate_fn = collate_fn,
  num_workers = 8,
  worker_globals = c("pad_sequence") # pad_sequence is needed for collect_fn
)

Model definition

Instead of keras::keras_model_sequential(), we are going to define a torch::nn_module(). As referenced by the original article, the model is based on this architecture for MNIST from this tutorial, and I’ll call it ‘DanielNN’.

dan_nn <- torch::nn_module(
  "DanielNN",
  
  initialize = function(
    window_size_ms = 30, 
    window_stride_ms = 10
  ) {
    
    # spectrogram spec
    window_size <- as.integer(16000*window_size_ms/1000)
    stride <- as.integer(16000*window_stride_ms/1000)
    fft_size <- as.integer(2^trunc(log(window_size, 2) + 1))
    n_chunks <- length(seq(0, 16000, stride))
    
    self$spectrogram <- torchaudio::transform_spectrogram(
      n_fft = fft_size, 
      win_length = window_size, 
      hop_length = stride, 
      normalized = TRUE, 
      power = 2
    )
    
    # convs 2D
    self$conv1 <- torch::nn_conv2d(in_channels = 1, out_channels = 32, kernel_size = c(3,3))
    self$conv2 <- torch::nn_conv2d(in_channels = 32, out_channels = 64, kernel_size = c(3,3))
    self$conv3 <- torch::nn_conv2d(in_channels = 64, out_channels = 128, kernel_size = c(3,3))
    self$conv4 <- torch::nn_conv2d(in_channels = 128, out_channels = 256, kernel_size = c(3,3))
    
    # denses
    self$dense1 <- torch::nn_linear(in_features = 14336, out_features = 128)
    self$dense2 <- torch::nn_linear(in_features = 128, out_features = 30)
  },
  
  forward = function(x) {
    x %>% # (64, 1, 16001)
      self$spectrogram() %>% # (64, 1, 257, 101)
      torch::torch_add(0.01) %>%
      torch::torch_log() %>%
      self$conv1() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      
      self$conv2() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      
      self$conv3() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      
      self$conv4() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      
      torch::nnf_dropout(p = 0.25) %>%
      torch::torch_flatten(start_dim = 2) %>%
      
      self$dense1() %>%
      torch::nnf_relu() %>%
      torch::nnf_dropout(p = 0.5) %>%
      self$dense2() 
  }
)

model <- dan_nn()


device <- torch::torch_device(if(torch::cuda_is_available()) "cuda" else "cpu")
model$to(device = device)

print(model)
An `nn_module` containing 2,226,846 parameters.

── Modules ──────────────────────────────────────────────────────
● spectrogram: <Spectrogram> #0 parameters
● conv1: <nn_conv2d> #320 parameters
● conv2: <nn_conv2d> #18,496 parameters
● conv3: <nn_conv2d> #73,856 parameters
● conv4: <nn_conv2d> #295,168 parameters
● dense1: <nn_linear> #1,835,136 parameters
● dense2: <nn_linear> #3,870 parameters

Model fitting

Unlike in tensorflow, there is no model %>% compile(…) step in torch, so we are going to set loss criterion, optimizer strategy and evaluation metrics explicitly in the training loop.

loss_criterion <- torch::nn_cross_entropy_loss()
optimizer <- torch::optim_adadelta(model$parameters, rho = 0.95, eps = 1e-7)
metrics <- list(acc = yardstick::accuracy_vec)

Training loop

library(glue)
library(progress)

pred_to_r <- function(x) {
  classes <- factor(df$classes)
  classes[as.numeric(x$to(device = "cpu"))]
}

set_progress_bar <- function(total) {
  progress_bar$new(
    total = total, clear = FALSE, width = 70,
    format = ":current/:total [:bar] - :elapsed - loss: :loss - acc: :acc"
  )
}
epochs <- 20
losses <- c()
accs <- c()

for(epoch in seq_len(epochs)) {
  pb <- set_progress_bar(length(ds_train))
  pb$message(glue("Epoch {epoch}/{epochs}"))
  coro::loop(for(batch in ds_train) {
    optimizer$zero_grad()
    predictions <- model(batch[[1]]$to(device = device))
    targets <- batch[[2]]$to(device = device)
    loss <- loss_criterion(predictions, targets)
    loss$backward()
    optimizer$step()
    
    # eval reports
    prediction_r <- pred_to_r(predictions$argmax(dim = 2))
    targets_r <- pred_to_r(targets)
    acc <- metrics$acc(targets_r, prediction_r)
    accs <- c(accs, acc)
    loss_r <- as.numeric(loss$item())
    losses <- c(losses, loss_r)
    
    pb$tick(tokens = list(loss = round(mean(losses), 4), acc = round(mean(accs), 4)))
  })
}



# test
predictions_r <- c()
targets_r <- c()
coro::loop(for(batch_test in ds_test) {
  predictions <- model(batch_test[[1]]$to(device = device))
  targets <- batch_test[[2]]$to(device = device)
  predictions_r <- c(predictions_r, pred_to_r(predictions$argmax(dim = 2)))
  targets_r <- c(targets_r, pred_to_r(targets))
})
val_acc <- metrics$acc(factor(targets_r, levels = 1:30), factor(predictions_r, levels = 1:30))
cat(glue("val_acc: {val_acc}nn"))
Epoch 1/20                                                            
[W SpectralOps.cpp:590] Warning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (function operator())
354/354 [=========================] -  1m - loss: 2.6102 - acc: 0.2333
Epoch 2/20                                                            
354/354 [=========================] -  1m - loss: 1.9779 - acc: 0.4138
Epoch 3/20                                                            
354/354 [============================] -  1m - loss: 1.62 - acc: 0.519
Epoch 4/20                                                            
354/354 [=========================] -  1m - loss: 1.3926 - acc: 0.5859
Epoch 5/20                                                            
354/354 [==========================] -  1m - loss: 1.2334 - acc: 0.633
Epoch 6/20                                                            
354/354 [=========================] -  1m - loss: 1.1135 - acc: 0.6685
Epoch 7/20                                                            
354/354 [=========================] -  1m - loss: 1.0199 - acc: 0.6961
Epoch 8/20                                                            
354/354 [=========================] -  1m - loss: 0.9444 - acc: 0.7181
Epoch 9/20                                                            
354/354 [=========================] -  1m - loss: 0.8816 - acc: 0.7365
Epoch 10/20                                                           
354/354 [=========================] -  1m - loss: 0.8278 - acc: 0.7524
Epoch 11/20                                                           
354/354 [=========================] -  1m - loss: 0.7818 - acc: 0.7659
Epoch 12/20                                                           
354/354 [=========================] -  1m - loss: 0.7413 - acc: 0.7778
Epoch 13/20                                                           
354/354 [=========================] -  1m - loss: 0.7064 - acc: 0.7881
Epoch 14/20                                                           
354/354 [=========================] -  1m - loss: 0.6751 - acc: 0.7974
Epoch 15/20                                                           
354/354 [=========================] -  1m - loss: 0.6469 - acc: 0.8058
Epoch 16/20                                                           
354/354 [=========================] -  1m - loss: 0.6216 - acc: 0.8133
Epoch 17/20                                                           
354/354 [=========================] -  1m - loss: 0.5985 - acc: 0.8202
Epoch 18/20                                                           
354/354 [=========================] -  1m - loss: 0.5774 - acc: 0.8263
Epoch 19/20                                                           
354/354 [==========================] -  1m - loss: 0.5582 - acc: 0.832
Epoch 20/20                                                           
354/354 [=========================] -  1m - loss: 0.5403 - acc: 0.8374
val_acc: 0.876705979296493

Making predictions

We already have all predictions calculated for test_subset, let’s recreate the alluvial plot from the original article.

library(dplyr)
library(alluvial)
df_validation <- data.frame(
  pred_class = df$classes[predictions_r],
  class = df$classes[targets_r]
)
x <-  df_validation %>%
  mutate(correct = pred_class == class) %>%
  count(pred_class, class, correct)

alluvial(
  x %>% select(class, pred_class),
  freq = x$n,
  col = ifelse(x$correct, "lightblue", "red"),
  border = ifelse(x$correct, "lightblue", "red"),
  alpha = 0.6,
  hide = x$n < 20
)
Model performance: true labels <--> predicted labels.

(#fig:unnamed-chunk-15)Model performance: true labels <–> predicted labels.

Model accuracy is 87,7%, somewhat worse than tensorflow version from the original post. Nevertheless, all conclusions from original post still hold.

Categories
Offsites

Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning

Model-free reinforcement learning has been successfully demonstrated across a range of domains, including robotics, control, playing games and autonomous vehicles. These systems learn by simple trial and error and thus require a vast number of attempts at a given task before solving it. In contrast, model-based reinforcement learning (MBRL) learns a model of the environment (often referred to as a world model or a dynamics model) that enables the agent to predict the outcomes of potential actions, which reduces the amount of environment interaction needed to solve a task.

In principle, all that is strictly necessary for planning is to predict future rewards, which could then be used to select near-optimal future actions. Nevertheless, many recent methods, such as Dreamer, PlaNet, and SimPLe, additionally leverage the training signal of predicting future images. But is predicting future images actually necessary, or helpful? What benefit do visual MBRL algorithms actually derive from also predicting future images? The computational and representational cost of predicting entire images is considerable, so understanding whether this is actually useful is of profound importance for MBRL research.

In “Models, Pixels, and Rewards: Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning”, we demonstrate that predicting future images provides a substantial benefit, and is in fact a key ingredient in training successful visual MBRL agents. We developed a new open-source library, called the World Models Library, which enabled us to rigorously evaluate various world model designs to determine the relative impact of image prediction on returned rewards for each.

World Models Library
The World Models Library, designed specifically for visual MBRL training and evaluation, enables the empirical study of the effects of each design decision on the final performance of an agent across multiple tasks on a large scale. The library introduces a platform-agnostic visual MBRL simulation loop and the APIs to seamlessly define new world-models, planners and tasks or to pick and choose from the existing catalog, which includes agents (e.g., PlaNet), video models (e.g., SV2P), and a variety of DeepMind Control tasks and planners, such as CEM and MPPI.

Using the library, developers can study the effect of a varying factor in MBRL, such as the model design or representation space, on the performance of the agent on a suite of tasks. The library supports the training of the agents from scratch, or on a pre-collected set of trajectories, as well as evaluation of a pre-trained agent on a given task. The models, planning algorithms and the tasks can be easily mixed and matched to any desired combination.

To provide the greatest flexibility for users, the library is built using the NumPy interface, which enables different components to be implemented in either TensorFlow, Pytorch or JAX. Please look at this colab for a quick introduction.

Impact of Image Prediction
Using the World Models Library, we trained multiple world models with different levels of image prediction. All of these models use the same input (previously observed images) to predict an image and a reward, but they differ on what percentage of the image they predict. As the number of image pixels predicted by the agent increases, the agent performance as measured by the true reward generally improves.

The input to the model is fixed (previous observed images), but the fraction of the image predicted varies. As can be seen in the graph on the right, increasing the number of predicted pixels significantly improves the performance of the model.

Interestingly, the correlation between reward prediction accuracy and agent performance is not as strong, and in some cases a more accurate reward prediction can even result in lower agent performance. At the same time, there is a strong correlation between image reconstruction error and the performance of the agent.

Correlation between accuracy of image/reward prediction (x-axis) and task performance (y-axis). This graph clearly demonstrates a stronger correlation between image prediction accuracy and task performance.

This phenomenon is directly related to exploration, i.e., when the agent attempts more risky and potentially less rewarding actions in order to collect more information about the unknown options in the environment. This can be shown by testing and comparing models in an offline setup (i.e., learning policies from pre-collected datasets, as opposed to online RL, which learns policies by interacting with an environment). An offline setup ensures that there is no exploration and all of the models are trained on the same data. We observed that models that fit the data better usually perform better in the offline setup, and surprisingly, these may not be the same models that perform the best when learning and exploring from scratch.

Scores achieved by different visual MBRL models across different tasks. The top and bottom half of the graph visualizes the achieved score when trained in the online and offline settings for each task, respectively. Each color is a different model. It is common for a poorly-performing model in the online setting to achieve high scores when trained on pre-collected data (the offline setting) and vice versa.

Conclusion
We have empirically demonstrated that predicting images can substantially improve task performance over models that only predict the expected reward. We have also shown that the accuracy of image prediction strongly correlates with the final task performance of these models. These findings can be used for better model design and can be particularly useful for any future setting where the input space is high-dimensional and collecting data is expensive.

If you’d like to develop your own models and experiments, head to our repository and colab where you’ll find instructions on how to reproduce this work and use or extend the World Models Library.

Acknowledgement:
We would like to give special recognition to multiple researchers in the Google Brain team and co-authors of the paper: Mohammad Taghi Saffar, Danijar Hafner, Harini Kannan, Chelsea Finn and Sergey Levine.

Categories
Offsites

Forecasting El Niño-Southern Oscillation (ENSO)

Today, we use the convLSTM introduced in a previous post to predict El Niño-Southern Oscillation (ENSO).

El Niño, la Niña

ENSO refers to a changing pattern of sea surface temperatures and sea-level pressures occurring in the equatorial Pacific. From its three overall states, probably the best-known is El Niño. El Niño occurs when surface water temperatures in the eastern Pacific are higher than normal, and the strong winds that normally blow from east to west are unusually weak. The opposite conditions are termed La Niña. Everything in-between is classified as normal.

ENSO has great impact on the weather worldwide, and routinely harms ecosystems and societies through storms, droughts and flooding, possibly resulting in famines and economic crises. The best societies can do is try to adapt and mitigate severe consequences. Such efforts are aided by accurate forecasts, the further ahead the better.

Here, deep learning (DL) can potentially help: Variables like sea surface temperatures and pressures are given on a spatial grid – that of the earth – and as we know, DL is good at extracting spatial (e.g., image) features. For ENSO prediction, architectures like convolutional neural networks (Ham, Kim, and Luo (2019)) or convolutional-recurrent hybrids1 are habitually used. One such hybrid is just our convLSTM; it operates on sequences of features given on a spatial grid. Today, thus, we’ll be training a model for ENSO forecasting. This model will have a convLSTM for its central ingredient.

Before we start, a note. While our model fits well with architectures described in the relevant papers, the same cannot be said for amount of training data used. For reasons of practicality, we use actual observations only; consequently, we end up with a small (relative to the task) dataset. In contrast, research papers tend to make use of climate simulations2, resulting in significantly more data to work with.

From the outset, then, we don’t expect stellar performance. Nevertheless, this should make for an interesting case study, and a useful code template for our readers to apply to their own data.

Data

We will attempt to predict monthly average sea surface temperature in the Niño 3.4 region3, as represented by the Niño 3.4 Index, plus categorization as one of El Niño, La Niña or neutral4. Predictions will be based on prior monthly sea surface temperatures spanning a large portion of the globe.

On the input side, public and ready-to-use data may be downloaded from Tokyo Climate Center; as to prediction targets, we obtain index and classification here.

Input and target data both are provided monthly. They intersect in the time period ranging from 1891-01-01 to 2020-08-01; so this is the range of dates we’ll be zooming in on.

Input: Sea Surface Temperatures

Monthly sea surface temperatures are provided in a latitude-longitude grid of resolution 1°. Details of how the data were processed are available here.

Data files are available in GRIB format; each file contains averages computed for a single month. We can either download individual files or generate a text file of URLs for download. In case you’d like to follow along with the post, you’ll find the contents of the text file I generated in the appendix. Once you’ve saved these URLs to a file, you can have R get the files for you like so:

purrr::walk(
   readLines("files"),
   function(f) download.file(url = f, destfile = basename(f))
)

From R, we can read GRIB files using stars. For example:

# let's just quickly load all libraries we require to start with

library(torch)
library(tidyverse)
library(stars)
library(viridis)
library(ggthemes)

torch_manual_seed(777)

read_stars(file.path(grb_dir, "sst189101.grb"))
stars object with 2 dimensions and 1 attribute
attribute(s):
 sst189101.grb   
 Min.   :-274.9  
 1st Qu.:-272.8  
 Median :-259.1  
 Mean   :-260.0  
 3rd Qu.:-248.4  
 Max.   :-242.8  
 NA's   :21001   
dimension(s):
  from  to offset delta                       refsys point values    
x    1 360      0     1 Coordinate System importe...    NA   NULL [x]
y    1 180     90    -1 Coordinate System importe...    NA   NULL [y]

So in this GRIB file, we have one attribute – which we know to be sea surface temperature – on a two-dimensional grid. As to the latter, we can complement what stars tells us with additional info found in the documentation:

The east-west grid points run eastward from 0.5ºE to 0.5ºW, while the north-south grid points run northward from 89.5ºS to 89.5ºN.

We note a few things we’ll want to do with this data. For one, the temperatures seem to be given in Kelvin, but with minus signs.5 We’ll remove the minus signs and convert to degrees Celsius for convenience. We’ll also have to think about what to do with the NAs that appear for all non-maritime coordinates.

Before we get there though, we need to combine data from all files into a single data frame. This adds an additional dimension, time, ranging from 1891/01/01 to 2020/01/12:

grb <- read_stars(
  file.path(grb_dir, map(readLines("files", warn = FALSE), basename)), along = "time") %>%
  st_set_dimensions(3,
                    values = seq(as.Date("1891-01-01"), as.Date("2020-12-01"), by = "months"),
                    names = "time"
                    )

grb
stars object with 3 dimensions and 1 attribute
attribute(s), summary of first 1e+05 cells:
 sst189101.grb   
 Min.   :-274.9  
 1st Qu.:-273.3  
 Median :-258.8  
 Mean   :-260.0  
 3rd Qu.:-247.8  
 Max.   :-242.8  
 NA's   :33724   
dimension(s):
     from   to offset delta                       refsys point                    values    
x       1  360      0     1 Coordinate System importe...    NA                      NULL [x]
y       1  180     90    -1 Coordinate System importe...    NA                      NULL [y]
time    1 1560     NA    NA                         Date    NA 1891-01-01,...,2020-12-01    

Let’s visually inspect the spatial distribution of monthly temperatures for one year, 2020:

ggplot() +
  geom_stars(data = grb %>% filter(between(time, as.Date("2020-01-01"), as.Date("2020-12-01"))), alpha = 0.8) +
  facet_wrap("time") +
  scale_fill_viridis() +
  coord_equal() +
  theme_map() +
  theme(legend.position = "none") 
Monthly sea surface temperatures, 2020/01/01 - 2020/01/12.

(#fig:unnamed-chunk-5)Monthly sea surface temperatures, 2020/01/01 – 2020/01/12.

Target: Niño 3.4 Index

For the Niño 3.4 Index, we download the monthly data and, among the provided features, zoom in on two: the index itself (column NINO34_MEAN) and PHASE, which can be E (El Niño), L (La Niño) or N (neutral).

nino <- read_table2("ONI_NINO34_1854-2020.txt", skip = 9) %>%
  mutate(month = as.Date(paste0(YEAR, "-", `MON/MMM`, "-01"))) %>%
  select(month, NINO34_MEAN, PHASE) %>%
  filter(between(month, as.Date("1891-01-01"), as.Date("2020-08-01"))) %>%
  mutate(phase_code = as.numeric(as.factor(PHASE)))

nrow(nino)
1556

Next, we look at how to get the data into a format convenient for training and prediction.

Preprocessing Input

First, we remove all input data for points in time where ground truth data are still missing.

sst <- grb %>% filter(time <= as.Date("2020-08-01"))

Next, as is done by e.g. Ham, Kim, and Luo (2019), we only use grid points between 55° south and 60° north. This has the additional advantage of reducing memory requirements.

sst <- grb %>% filter(between(y,-55, 60))

dim(sst)
360, 115, 1560

As already alluded to, with the little data we have we can’t expect much in terms of generalization. Still, we set aside a small portion of the data for validation, since we’d like for this post to serve as a useful template to be used with bigger datasets.

sst_train <- sst %>% filter(time < as.Date("1990-01-01"))
sst_valid <- sst %>% filter(time >= as.Date("1990-01-01"))

From here on, we work with R arrays.

sst_train <- as.tbl_cube.stars(sst_train)$mets[[1]]
sst_valid <- as.tbl_cube.stars(sst_valid)$mets[[1]]

Conversion to degrees Celsius is not strictly necessary, as initial experiments showed a slight performance increase due to normalizing the input, and we’re going to do that anyway. Still, it reads nicer to humans than Kelvin.

sst_train <- sst_train + 273.15
quantile(sst_train, na.rm = TRUE)
     0%     25%     50%     75%    100% 
-1.8000 12.9975 21.8775 26.8200 34.3700 

Not at all surprisingly, global warming is evident from inspecting temperature distribution on the validation set (which was chosen to span the last thirty-one years).

sst_valid <- sst_valid + 273.15
quantile(sst_valid, na.rm = TRUE)
    0%    25%    50%    75%   100% 
-1.800 13.425 22.335 27.240 34.870 

The next-to-last step normalizes both sets according to training mean and variance.

train_mean <- mean(sst_train, na.rm = TRUE)
train_sd <- sd(sst_train, na.rm = TRUE)

sst_train <- (sst_train - train_mean) / train_sd

sst_valid <- (sst_valid - train_mean) / train_sd

Finally, what should we do about the NA entries? We set them to zero, the (training set) mean. That may not be enough of an action though: It means we’re feeding the network roughly 30% misleading data. This is something we’re not done with yet.

sst_train[is.na(sst_train)] <- 0
sst_valid[is.na(sst_valid)] <- 0

Target

The target data are split analogously. Let’s check though: Are phases (categorizations) distributedly similarly in both sets?

nino_train <- nino %>% filter(month < as.Date("1990-01-01"))
nino_valid <- nino %>% filter(month >= as.Date("1990-01-01"))

nino_train %>% group_by(phase_code, PHASE) %>% summarise(count = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Groups:   phase_code [3]
  phase_code PHASE count   avg
       <dbl> <chr> <int> <dbl>
1          1 E       301  27.7
2          2 L       333  25.6
3          3 N       554  26.7
nino_valid %>% group_by(phase_code, PHASE) %>% summarise(count = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Groups:   phase_code [3]
  phase_code PHASE count   avg
       <dbl> <chr> <int> <dbl>
1          1 E        93  28.1
2          2 L        93  25.9
3          3 N       182  27.2

This doesn’t look too bad. Of course, we again see the overall rise in temperature, irrespective of phase.

Lastly, we normalize the index, same as we did for the input data.

train_mean_nino <- mean(nino_train$NINO34_MEAN)
train_sd_nino <- sd(nino_train$NINO34_MEAN)

nino_train <- nino_train %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, center = train_mean_nino, scale = train_sd_nino))
nino_valid <- nino_valid %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, center = train_mean_nino, scale = train_sd_nino))

On to the torch dataset.

Torch dataset

The dataset is responsible for correctly matching up inputs and targets.

Our goal is to take six months of global sea surface temperatures and predict the Niño 3.4 Index for the following month. Input-wise, the model will expect the following format semantics:

batch_size * timesteps * width * height * channels, where

  • batch_size is the number of observations worked on in one round of computations,

  • timesteps chains consecutive observations from adjacent months,

  • width and height together constitute the spatial grid, and

  • channels corresponds to available visual channels in the “image”.

In .getitem(), we select the consecutive observations, starting at a given index, and stack them in dimension one. (One, not two, as batches will only start to exist once the dataloader comes into play.)

Now, what about the target? Our ultimate goal was – is – predicting the Niño 3.4 Index. However, as you see we define three targets: One is the index, as expected; an additional one holds the spatially-gridded sea surface temperatures for the prediction month. Why? Our main instrument, the most prominent constituent of the model, will be a convLSTM, an architecture designed for spatial prediction. Thus, to train it efficiently, we want to give it the opportunity to predict values on a spatial grid. So far so good; but there’s one more target, the phase/category. This was added for experimentation purposes: Maybe predicting both index and phase helps in training?

Finally, here is the code for the dataset. In our experiments, we based predictions on inputs from the preceding six months (n_timesteps <- 6). This is a parameter you might want to play with, though.

n_timesteps <- 6

enso_dataset <- dataset(
  name = "enso_dataset",
  
  initialize = function(sst, nino, n_timesteps) {
    self$sst <- sst
    self$nino <- nino
    self$n_timesteps <- n_timesteps
  },
  
  .getitem = function(i) {
    x <- torch_tensor(self$sst[, , i:(n_timesteps + i - 1)]) # (360, 115, n_timesteps)
    x <- x$permute(c(3,1,2))$unsqueeze(2) # (n_timesteps, 1, 360, 115))
    
    y1 <- torch_tensor(self$sst[, , n_timesteps + i])$unsqueeze(1) # (1, 360, 115)
    y2 <- torch_tensor(self$nino$NINO34_MEAN[n_timesteps + i])
    y3 <- torch_tensor(self$nino$phase_code[n_timesteps + i])$squeeze()$to(torch_long())
    list(x = x, y1 = y1, y2 = y2, y3 = y3)
  },
  
  .length = function() {
    nrow(self$nino) - n_timesteps
  }
  
)

valid_ds <- enso_dataset(sst_valid, nino_valid, n_timesteps)

Dataloaders

After the custom dataset, we create the – pretty typical – dataloaders, making use of a batch size of 4.

batch_size <- 4

train_dl <- train_ds %>% dataloader(batch_size = batch_size, shuffle = TRUE)

valid_dl <- valid_ds %>% dataloader(batch_size = batch_size)

Next, we proceed to model creation.

Model

The model’s main ingredient is the convLSTM introduced in a prior post. For convenience, we reproduce the code in the appendix.

Besides the convLSTM, the model makes use of three convolutional layers, a batchnorm layer and five linear layers. The logic is the following.

First, the convLSTM job is to predict the next month’s sea surface temperatures on the spatial grid. For that, we almost just return its final state, – almost: We use self$conv1 to reduce the number channels to one.

For predicting index and phase, we then need to flatten the grid, as we require a single value each. This is where the additional conv layers come in. We do hope they’ll aid in learning, but we also want to reduce the number of parameters a bit, downsizing the grid (strides = 2 and strides = 3, resp.) a bit before the upcoming torch_flatten().

Once we have a flat structure, learning is shared between the tasks of index and phase prediction (self$linear), until finally their paths split (self$cont and self$cat, resp.), and they return their separate outputs.

(The batchnorm? I’ll comment on that in the Discussion.)

model <- nn_module(
  
  initialize = function(channels_in,
                        convlstm_hidden,
                        convlstm_kernel,
                        convlstm_layers) {
    
    self$n_layers <- convlstm_layers
    
    self$convlstm <- convlstm(
      input_dim = channels_in,
      hidden_dims = convlstm_hidden,
      kernel_sizes = convlstm_kernel,
      n_layers = convlstm_layers
    )
    
    self$conv1 <-
      nn_conv2d(
        in_channels = 32,
        out_channels = 1,
        kernel_size = 5,
        padding = 2
      )
    
    self$conv2 <-
      nn_conv2d(
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 2
      )
    
    self$conv3 <-
      nn_conv2d(
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 3
      )
    
    self$linear <- nn_linear(33408, 64)
    
    self$b1 <- nn_batch_norm1d(num_features = 64)
        
    self$cont <- nn_linear(64, 128)
    self$cat <- nn_linear(64, 128)
    
    self$cont_output <- nn_linear(128, 1)
    self$cat_output <- nn_linear(128, 3)
    
  },
  
  forward = function(x) {
    
    ret <- self$convlstm(x)
    layer_last_states <- ret[[2]]
    last_hidden <- layer_last_states[[self$n_layers]][[1]]
    
    next_sst <- last_hidden %>% self$conv1() 
    
    c2 <- last_hidden %>% self$conv2() 
    c3 <- c2 %>% self$conv3() 
    
    flat <- torch_flatten(c3, start_dim = 2)
    common <- self$linear(flat) %>% self$b3() %>% nnf_relu()

    next_temp <- common %>% self$cont() %>% nnf_relu() %>% self$cont_output()
    next_nino <- common %>% self$cat() %>% nnf_relu() %>% self$cat_output()
    
    list(next_sst, next_temp, next_nino)
    
  }
  
)

Next, we instantiate a pretty small-ish model. You’re more than welcome to experiment with larger models, but training time as well as GPU memory requirements will increase.

net <- model(
  channels_in = 1,
  convlstm_hidden = c(16, 16, 32),
  convlstm_kernel = c(3, 3, 5),
  convlstm_layers = 3
)

device <- torch_device(if (cuda_is_available()) "cuda" else "cpu")

net <- net$to(device = device)
net
An `nn_module` containing 2,389,605 parameters.

── Modules ───────────────────────────────────────────────────────────────────────────────
● convlstm: <nn_module> #182,080 parameters
● conv1: <nn_conv2d> #801 parameters
● conv2: <nn_conv2d> #25,632 parameters
● conv3: <nn_conv2d> #25,632 parameters
● linear: <nn_linear> #2,138,176 parameters
● b1: <nn_batch_norm1d> #128 parameters
● cont: <nn_linear> #8,320 parameters
● cat: <nn_linear> #8,320 parameters
● cont_output: <nn_linear> #129 parameters
● cat_output: <nn_linear> #387 parameters

Training

We have three model outputs. How should we combine the losses?

Given that the main goal is predicting the index, and the other two outputs are essentially means to an end, I found the following combination rather effective:

# weight for sea surface temperature prediction
lw_sst <- 0.2

# weight for prediction of El Nino 3.4 Index
lw_temp <- 0.4

# weight for phase prediction
lw_nino <- 0.4

The training process follows the pattern seen in all torch posts so far: For each epoch, loop over the training set, backpropagate, check performance on validation set.

But, when we did the pre-processing, we were aware of an imminent problem: the missing temperatures for continental areas, which we set to zero. As a sole measure, this approach is clearly insufficient. What if we had chosen to use latitude-dependent averages? Or interpolation? Both may be better than a global average, but both have their problems as well. Let’s at least alleviate negative consequences by not using the respective pixels for spatial loss calculation. This is taken care of by the following line below:

sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])

Here, then, is the complete training code.

optimizer <- optim_adam(net$parameters, lr = 0.001)

num_epochs <- 50

train_batch <- function(b) {
  
  optimizer$zero_grad()
  output <- net(b$x$to(device = device))
  
  sst_output <- output[[1]]
  sst_target <- b$y1$to(device = device)
  
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(device = device))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(device = device))
  
  loss <- lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss
  loss$backward()
  optimizer$step()

  list(sst_loss$item(), temp_loss$item(), nino_loss$item(), loss$item())
  
}

valid_batch <- function(b) {
  
  output <- net(b$x$to(device = device))
  
  sst_output <- output[[1]]
  sst_target <- b$y1$to(device = device)
  
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(device = device))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(device = device))
  
  loss <-
    lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss

  list(sst_loss$item(),
       temp_loss$item(),
       nino_loss$item(),
       loss$item())
}

for (epoch in 1:num_epochs) {
  
  net$train()
  
  train_loss_sst <- c()
  train_loss_temp <- c()
  train_loss_nino <- c()
  train_loss <- c()

  coro::loop(for (b in train_dl) {
    losses <- train_batch(b)
    train_loss_sst <- c(train_loss_sst, losses[[1]])
    train_loss_temp <- c(train_loss_temp, losses[[2]])
    train_loss_nino <- c(train_loss_nino, losses[[3]])
    train_loss <- c(train_loss, losses[[4]])
  })
  
  cat(
    sprintf(
      "nEpoch %d, training: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(train_loss), mean(train_loss_sst), mean(train_loss_temp), mean(train_loss_nino)
    )
  )
  
  net$eval()
  
  valid_loss_sst <- c()
  valid_loss_temp <- c()
  valid_loss_nino <- c()
  valid_loss <- c()

  coro::loop(for (b in valid_dl) {
    losses <- valid_batch(b)
    valid_loss_sst <- c(valid_loss_sst, losses[[1]])
    valid_loss_temp <- c(valid_loss_temp, losses[[2]])
    valid_loss_nino <- c(valid_loss_nino, losses[[3]])
    valid_loss <- c(valid_loss, losses[[4]])
    
  })
  
  cat(
    sprintf(
      "nEpoch %d, validation: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(valid_loss), mean(valid_loss_sst), mean(valid_loss_temp), mean(valid_loss_nino)
    )
  )
  
  torch_save(net, paste0(
    "model_", epoch, "_", round(mean(train_loss), 3), "_", round(mean(valid_loss), 3), ".pt"
  ))
  
}

When I ran this, performance on the training set decreased in a not-too-fast, but continuous way, while validation set performance kept fluctuating. For reference, total (composite) losses looked like this:

Epoch     Training    Validation
   
   10        0.336         0.633
   20        0.233         0.295
   30        0.135         0.461
   40        0.099         0.903
   50        0.061         0.727

Thinking of the size of the validation set – thirty-one years, or equivalently, 372 data points – those fluctuations may not be all too surprising.

Predictions

Now losses tend to be abstract; let’s see what actually gets predicted. We obtain predictions for index values and phases like so …

net$eval()

pred_index <- c()
pred_phase <- c()

coro::loop(for (b in valid_dl) {

  output <- net(b$x$to(device = device))

  pred_index <- c(pred_index, output[[2]]$to(device = "cpu"))
  pred_phase <- rbind(pred_phase, as.matrix(output[[3]]$to(device = "cpu")))

})

… and combine these with the ground truth, stripping off the first..