
Simple audio classification with torch

This article translates Daniel Falbel’s ‘Simple Audio Classification’ article from tensorflow/keras to torch/torchaudio. The main goal is to introduce torchaudio and illustrate its contributions to the torch ecosystem. Here, we focus on a popular dataset, the audio loader and the spectrogram transformer. An interesting side product is the parallel between torch and tensorflow, showing sometimes the differences, sometimes the similarities between them.


Downloading and Importing

torchaudio has the speechcommand_dataset built in. It filters out background_noise by default and lets us choose between versions v0.01 and v0.02.

# set an existing folder here to cache the dataset
DATASETS_PATH <- "~/datasets/"

# 1.4GB download
df <- speechcommand_dataset(
  root = DATASETS_PATH, 
  url = "speech_commands_v0.01",
  download = TRUE

# expect folder: _background_noise_
# [1] "_background_noise_"

# number of audio files
# [1] 64721

# a sample
sample <- df[1]

sample$waveform[, 1:10]
0.0001 *
 0.9155  0.3052  1.8311  1.8311 -0.3052  0.3052  2.4414  0.9155 -0.9155 -0.6104
[ CPUFloatType{1,10} ]
# 16000
# bed

plot(sample$waveform[1], type = "l", col = "royalblue", main = sample$label)
A sample waveform for a 'bed'.

(#fig:unnamed-chunk-4)A sample waveform for a ‘bed’.


 [1] "bed"    "bird"   "cat"    "dog"    "down"   "eight"  "five"  
 [8] "four"   "go"     "happy"  "house"  "left"   "marvin" "nine"  
[15] "no"     "off"    "on"     "one"    "right"  "seven"  "sheila"
[22] "six"    "stop"   "three"  "tree"   "two"    "up"     "wow"   
[29] "yes"    "zero"  

Generator Dataloader

torch::dataloader has the same task as data_generator defined in the original article. It is responsible for preparing batches – including shuffling, padding, one-hot encoding, etc. – and for taking care of parallelism / device I/O orchestration.

In torch we do this by passing the train/test subset to torch::dataloader and encapsulating all the batch setup logic inside a collate_fn() function.

id_train <- sample(length(df), size = 0.7*length(df))
id_test <- setdiff(seq_len(length(df)), id_train)
# subsets

train_subset <- torch::dataset_subset(df, id_train)
test_subset <- torch::dataset_subset(df, id_test)

At this point, dataloader(train_subset) would not work because the samples are not padded. So we need to build our own collate_fn() with the padding strategy.

I suggest using the following approach when implementing the collate_fn():

  1. begin with collate_fn <- function(batch) browser().
  2. instantiate dataloader with the collate_fn()
  3. create an environment by calling enumerate(dataloader) so you can ask to retrieve a batch from dataloader.
  4. run environment[[1]][[1]]. Now you should be sent inside collate_fn() with access to batch input object.
  5. build the logic.
collate_fn <- function(batch) {

ds_train <- dataloader(
  batch_size = 32, 
  shuffle = TRUE, 
  collate_fn = collate_fn

ds_train_env <- enumerate(ds_train)

The final collate_fn() pads the waveform to length 16001 and then stacks everything up together. At this point there are no spectrograms yet. We going to make spectrogram transformation a part of model architecture.

pad_sequence <- function(batch) {
    # Make all tensors in a batch the same length by padding with zeros
    batch <- sapply(batch, function(x) (x$t()))
    batch <- torch::nn_utils_rnn_pad_sequence(batch, batch_first = TRUE, padding_value = 0.)
    return(batch$permute(c(1, 3, 2)))

# Final collate_fn
collate_fn <- function(batch) {
 # Input structure:
 # list of 32 lists: list(waveform, sample_rate, label, speaker_id, utterance_number)
 # Transpose it
 batch <- purrr::transpose(batch)
 tensors <- batch$waveform
 targets <- batch$label_index

 # Group the list of tensors into a batched tensor
 tensors <- pad_sequence(tensors)
 # target encoding
 targets <- torch::torch_stack(targets)

 list(tensors = tensors, targets = targets) # (64, 1, 16001)

Batch structure is:

  • batch[[1]]: waveforms – tensor with dimension (32, 1, 16001)
  • batch[[2]]: targets – tensor with dimension (32, 1)

Also, torchaudio comes with 3 loaders, av_loader, tuner_loader, and audiofile_loader- more to come. set_audio_backend() is used to set one of them as the audio loader. Their performances differ based on audio format (mp3 or wav). There is no perfect world yet: tuner_loader is best for mp3, audiofile_loader is best for wav, but neither of them has the option of partially loading a sample from an audio file without bringing all the data into memory first.

For a given audio backend we need pass it to each worker through worker_init_fn() argument.

ds_train <- dataloader(
  batch_size = 128, 
  shuffle = TRUE, 
  collate_fn = collate_fn,
  num_workers = 16,
  worker_init_fn = function(.) {torchaudio::set_audio_backend("audiofile_loader")},
  worker_globals = c("pad_sequence") # pad_sequence is needed for collect_fn

ds_test <- dataloader(
  batch_size = 64, 
  shuffle = FALSE, 
  collate_fn = collate_fn,
  num_workers = 8,
  worker_globals = c("pad_sequence") # pad_sequence is needed for collect_fn

Model definition

Instead of keras::keras_model_sequential(), we are going to define a torch::nn_module(). As referenced by the original article, the model is based on this architecture for MNIST from this tutorial, and I’ll call it ‘DanielNN’.

dan_nn <- torch::nn_module(
  initialize = function(
    window_size_ms = 30, 
    window_stride_ms = 10
  ) {
    # spectrogram spec
    window_size <- as.integer(16000*window_size_ms/1000)
    stride <- as.integer(16000*window_stride_ms/1000)
    fft_size <- as.integer(2^trunc(log(window_size, 2) + 1))
    n_chunks <- length(seq(0, 16000, stride))
    self$spectrogram <- torchaudio::transform_spectrogram(
      n_fft = fft_size, 
      win_length = window_size, 
      hop_length = stride, 
      normalized = TRUE, 
      power = 2
    # convs 2D
    self$conv1 <- torch::nn_conv2d(in_channels = 1, out_channels = 32, kernel_size = c(3,3))
    self$conv2 <- torch::nn_conv2d(in_channels = 32, out_channels = 64, kernel_size = c(3,3))
    self$conv3 <- torch::nn_conv2d(in_channels = 64, out_channels = 128, kernel_size = c(3,3))
    self$conv4 <- torch::nn_conv2d(in_channels = 128, out_channels = 256, kernel_size = c(3,3))
    # denses
    self$dense1 <- torch::nn_linear(in_features = 14336, out_features = 128)
    self$dense2 <- torch::nn_linear(in_features = 128, out_features = 30)
  forward = function(x) {
    x %>% # (64, 1, 16001)
      self$spectrogram() %>% # (64, 1, 257, 101)
      torch::torch_add(0.01) %>%
      torch::torch_log() %>%
      self$conv1() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      self$conv2() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      self$conv3() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      self$conv4() %>%
      torch::nnf_relu() %>%
      torch::nnf_max_pool2d(kernel_size = c(2,2)) %>%
      torch::nnf_dropout(p = 0.25) %>%
      torch::torch_flatten(start_dim = 2) %>%
      self$dense1() %>%
      torch::nnf_relu() %>%
      torch::nnf_dropout(p = 0.5) %>%

model <- dan_nn()

device <- torch::torch_device(if(torch::cuda_is_available()) "cuda" else "cpu")
model$to(device = device)

An `nn_module` containing 2,226,846 parameters.

── Modules ──────────────────────────────────────────────────────
● spectrogram: <Spectrogram> #0 parameters
● conv1: <nn_conv2d> #320 parameters
● conv2: <nn_conv2d> #18,496 parameters
● conv3: <nn_conv2d> #73,856 parameters
● conv4: <nn_conv2d> #295,168 parameters
● dense1: <nn_linear> #1,835,136 parameters
● dense2: <nn_linear> #3,870 parameters

Model fitting

Unlike in tensorflow, there is no model %>% compile(…) step in torch, so we are going to set loss criterion, optimizer strategy and evaluation metrics explicitly in the training loop.

loss_criterion <- torch::nn_cross_entropy_loss()
optimizer <- torch::optim_adadelta(model$parameters, rho = 0.95, eps = 1e-7)
metrics <- list(acc = yardstick::accuracy_vec)

Training loop


pred_to_r <- function(x) {
  classes <- factor(df$classes)
  classes[as.numeric(x$to(device = "cpu"))]

set_progress_bar <- function(total) {
    total = total, clear = FALSE, width = 70,
    format = ":current/:total [:bar] - :elapsed - loss: :loss - acc: :acc"
epochs <- 20
losses <- c()
accs <- c()

for(epoch in seq_len(epochs)) {
  pb <- set_progress_bar(length(ds_train))
  pb$message(glue("Epoch {epoch}/{epochs}"))
  coro::loop(for(batch in ds_train) {
    predictions <- model(batch[[1]]$to(device = device))
    targets <- batch[[2]]$to(device = device)
    loss <- loss_criterion(predictions, targets)
    # eval reports
    prediction_r <- pred_to_r(predictions$argmax(dim = 2))
    targets_r <- pred_to_r(targets)
    acc <- metrics$acc(targets_r, prediction_r)
    accs <- c(accs, acc)
    loss_r <- as.numeric(loss$item())
    losses <- c(losses, loss_r)
    pb$tick(tokens = list(loss = round(mean(losses), 4), acc = round(mean(accs), 4)))

# test
predictions_r <- c()
targets_r <- c()
coro::loop(for(batch_test in ds_test) {
  predictions <- model(batch_test[[1]]$to(device = device))
  targets <- batch_test[[2]]$to(device = device)
  predictions_r <- c(predictions_r, pred_to_r(predictions$argmax(dim = 2)))
  targets_r <- c(targets_r, pred_to_r(targets))
val_acc <- metrics$acc(factor(targets_r, levels = 1:30), factor(predictions_r, levels = 1:30))
cat(glue("val_acc: {val_acc}nn"))
Epoch 1/20                                                            
[W SpectralOps.cpp:590] Warning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (function operator())
354/354 [=========================] -  1m - loss: 2.6102 - acc: 0.2333
Epoch 2/20                                                            
354/354 [=========================] -  1m - loss: 1.9779 - acc: 0.4138
Epoch 3/20                                                            
354/354 [============================] -  1m - loss: 1.62 - acc: 0.519
Epoch 4/20                                                            
354/354 [=========================] -  1m - loss: 1.3926 - acc: 0.5859
Epoch 5/20                                                            
354/354 [==========================] -  1m - loss: 1.2334 - acc: 0.633
Epoch 6/20                                                            
354/354 [=========================] -  1m - loss: 1.1135 - acc: 0.6685
Epoch 7/20                                                            
354/354 [=========================] -  1m - loss: 1.0199 - acc: 0.6961
Epoch 8/20                                                            
354/354 [=========================] -  1m - loss: 0.9444 - acc: 0.7181
Epoch 9/20                                                            
354/354 [=========================] -  1m - loss: 0.8816 - acc: 0.7365
Epoch 10/20                                                           
354/354 [=========================] -  1m - loss: 0.8278 - acc: 0.7524
Epoch 11/20                                                           
354/354 [=========================] -  1m - loss: 0.7818 - acc: 0.7659
Epoch 12/20                                                           
354/354 [=========================] -  1m - loss: 0.7413 - acc: 0.7778
Epoch 13/20                                                           
354/354 [=========================] -  1m - loss: 0.7064 - acc: 0.7881
Epoch 14/20                                                           
354/354 [=========================] -  1m - loss: 0.6751 - acc: 0.7974
Epoch 15/20                                                           
354/354 [=========================] -  1m - loss: 0.6469 - acc: 0.8058
Epoch 16/20                                                           
354/354 [=========================] -  1m - loss: 0.6216 - acc: 0.8133
Epoch 17/20                                                           
354/354 [=========================] -  1m - loss: 0.5985 - acc: 0.8202
Epoch 18/20                                                           
354/354 [=========================] -  1m - loss: 0.5774 - acc: 0.8263
Epoch 19/20                                                           
354/354 [==========================] -  1m - loss: 0.5582 - acc: 0.832
Epoch 20/20                                                           
354/354 [=========================] -  1m - loss: 0.5403 - acc: 0.8374
val_acc: 0.876705979296493

Making predictions

We already have all predictions calculated for test_subset, let’s recreate the alluvial plot from the original article.

df_validation <- data.frame(
  pred_class = df$classes[predictions_r],
  class = df$classes[targets_r]
x <-  df_validation %>%
  mutate(correct = pred_class == class) %>%
  count(pred_class, class, correct)

  x %>% select(class, pred_class),
  freq = x$n,
  col = ifelse(x$correct, "lightblue", "red"),
  border = ifelse(x$correct, "lightblue", "red"),
  alpha = 0.6,
  hide = x$n < 20
Model performance: true labels <--> predicted labels.

(#fig:unnamed-chunk-15)Model performance: true labels <–> predicted labels.

Model accuracy is 87,7%, somewhat worse than tensorflow version from the original post. Nevertheless, all conclusions from original post still hold.


Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning

Model-free reinforcement learning has been successfully demonstrated across a range of domains, including robotics, control, playing games and autonomous vehicles. These systems learn by simple trial and error and thus require a vast number of attempts at a given task before solving it. In contrast, model-based reinforcement learning (MBRL) learns a model of the environment (often referred to as a world model or a dynamics model) that enables the agent to predict the outcomes of potential actions, which reduces the amount of environment interaction needed to solve a task.

In principle, all that is strictly necessary for planning is to predict future rewards, which could then be used to select near-optimal future actions. Nevertheless, many recent methods, such as Dreamer, PlaNet, and SimPLe, additionally leverage the training signal of predicting future images. But is predicting future images actually necessary, or helpful? What benefit do visual MBRL algorithms actually derive from also predicting future images? The computational and representational cost of predicting entire images is considerable, so understanding whether this is actually useful is of profound importance for MBRL research.

In “Models, Pixels, and Rewards: Evaluating Design Trade-offs in Visual Model-Based Reinforcement Learning”, we demonstrate that predicting future images provides a substantial benefit, and is in fact a key ingredient in training successful visual MBRL agents. We developed a new open-source library, called the World Models Library, which enabled us to rigorously evaluate various world model designs to determine the relative impact of image prediction on returned rewards for each.

World Models Library
The World Models Library, designed specifically for visual MBRL training and evaluation, enables the empirical study of the effects of each design decision on the final performance of an agent across multiple tasks on a large scale. The library introduces a platform-agnostic visual MBRL simulation loop and the APIs to seamlessly define new world-models, planners and tasks or to pick and choose from the existing catalog, which includes agents (e.g., PlaNet), video models (e.g., SV2P), and a variety of DeepMind Control tasks and planners, such as CEM and MPPI.

Using the library, developers can study the effect of a varying factor in MBRL, such as the model design or representation space, on the performance of the agent on a suite of tasks. The library supports the training of the agents from scratch, or on a pre-collected set of trajectories, as well as evaluation of a pre-trained agent on a given task. The models, planning algorithms and the tasks can be easily mixed and matched to any desired combination.

To provide the greatest flexibility for users, the library is built using the NumPy interface, which enables different components to be implemented in either TensorFlow, Pytorch or JAX. Please look at this colab for a quick introduction.

Impact of Image Prediction
Using the World Models Library, we trained multiple world models with different levels of image prediction. All of these models use the same input (previously observed images) to predict an image and a reward, but they differ on what percentage of the image they predict. As the number of image pixels predicted by the agent increases, the agent performance as measured by the true reward generally improves.

The input to the model is fixed (previous observed images), but the fraction of the image predicted varies. As can be seen in the graph on the right, increasing the number of predicted pixels significantly improves the performance of the model.

Interestingly, the correlation between reward prediction accuracy and agent performance is not as strong, and in some cases a more accurate reward prediction can even result in lower agent performance. At the same time, there is a strong correlation between image reconstruction error and the performance of the agent.

Correlation between accuracy of image/reward prediction (x-axis) and task performance (y-axis). This graph clearly demonstrates a stronger correlation between image prediction accuracy and task performance.

This phenomenon is directly related to exploration, i.e., when the agent attempts more risky and potentially less rewarding actions in order to collect more information about the unknown options in the environment. This can be shown by testing and comparing models in an offline setup (i.e., learning policies from pre-collected datasets, as opposed to online RL, which learns policies by interacting with an environment). An offline setup ensures that there is no exploration and all of the models are trained on the same data. We observed that models that fit the data better usually perform better in the offline setup, and surprisingly, these may not be the same models that perform the best when learning and exploring from scratch.

Scores achieved by different visual MBRL models across different tasks. The top and bottom half of the graph visualizes the achieved score when trained in the online and offline settings for each task, respectively. Each color is a different model. It is common for a poorly-performing model in the online setting to achieve high scores when trained on pre-collected data (the offline setting) and vice versa.

We have empirically demonstrated that predicting images can substantially improve task performance over models that only predict the expected reward. We have also shown that the accuracy of image prediction strongly correlates with the final task performance of these models. These findings can be used for better model design and can be particularly useful for any future setting where the input space is high-dimensional and collecting data is expensive.

If you’d like to develop your own models and experiments, head to our repository and colab where you’ll find instructions on how to reproduce this work and use or extend the World Models Library.

We would like to give special recognition to multiple researchers in the Google Brain team and co-authors of the paper: Mohammad Taghi Saffar, Danijar Hafner, Harini Kannan, Chelsea Finn and Sergey Levine.


Forecasting El Niño-Southern Oscillation (ENSO)

Today, we use the convLSTM introduced in a previous post to predict El Niño-Southern Oscillation (ENSO).

El Niño, la Niña

ENSO refers to a changing pattern of sea surface temperatures and sea-level pressures occurring in the equatorial Pacific. From its three overall states, probably the best-known is El Niño. El Niño occurs when surface water temperatures in the eastern Pacific are higher than normal, and the strong winds that normally blow from east to west are unusually weak. The opposite conditions are termed La Niña. Everything in-between is classified as normal.

ENSO has great impact on the weather worldwide, and routinely harms ecosystems and societies through storms, droughts and flooding, possibly resulting in famines and economic crises. The best societies can do is try to adapt and mitigate severe consequences. Such efforts are aided by accurate forecasts, the further ahead the better.

Here, deep learning (DL) can potentially help: Variables like sea surface temperatures and pressures are given on a spatial grid – that of the earth – and as we know, DL is good at extracting spatial (e.g., image) features. For ENSO prediction, architectures like convolutional neural networks (Ham, Kim, and Luo (2019)) or convolutional-recurrent hybrids1 are habitually used. One such hybrid is just our convLSTM; it operates on sequences of features given on a spatial grid. Today, thus, we’ll be training a model for ENSO forecasting. This model will have a convLSTM for its central ingredient.

Before we start, a note. While our model fits well with architectures described in the relevant papers, the same cannot be said for amount of training data used. For reasons of practicality, we use actual observations only; consequently, we end up with a small (relative to the task) dataset. In contrast, research papers tend to make use of climate simulations2, resulting in significantly more data to work with.

From the outset, then, we don’t expect stellar performance. Nevertheless, this should make for an interesting case study, and a useful code template for our readers to apply to their own data.


We will attempt to predict monthly average sea surface temperature in the Niño 3.4 region3, as represented by the Niño 3.4 Index, plus categorization as one of El Niño, La Niña or neutral4. Predictions will be based on prior monthly sea surface temperatures spanning a large portion of the globe.

On the input side, public and ready-to-use data may be downloaded from Tokyo Climate Center; as to prediction targets, we obtain index and classification here.

Input and target data both are provided monthly. They intersect in the time period ranging from 1891-01-01 to 2020-08-01; so this is the range of dates we’ll be zooming in on.

Input: Sea Surface Temperatures

Monthly sea surface temperatures are provided in a latitude-longitude grid of resolution 1°. Details of how the data were processed are available here.

Data files are available in GRIB format; each file contains averages computed for a single month. We can either download individual files or generate a text file of URLs for download. In case you’d like to follow along with the post, you’ll find the contents of the text file I generated in the appendix. Once you’ve saved these URLs to a file, you can have R get the files for you like so:

   function(f) download.file(url = f, destfile = basename(f))

From R, we can read GRIB files using stars. For example:

# let's just quickly load all libraries we require to start with



read_stars(file.path(grb_dir, "sst189101.grb"))
stars object with 2 dimensions and 1 attribute
 Min.   :-274.9  
 1st Qu.:-272.8  
 Median :-259.1  
 Mean   :-260.0  
 3rd Qu.:-248.4  
 Max.   :-242.8  
 NA's   :21001   
  from  to offset delta                       refsys point values    
x    1 360      0     1 Coordinate System importe...    NA   NULL [x]
y    1 180     90    -1 Coordinate System importe...    NA   NULL [y]

So in this GRIB file, we have one attribute – which we know to be sea surface temperature – on a two-dimensional grid. As to the latter, we can complement what stars tells us with additional info found in the documentation:

The east-west grid points run eastward from 0.5ºE to 0.5ºW, while the north-south grid points run northward from 89.5ºS to 89.5ºN.

We note a few things we’ll want to do with this data. For one, the temperatures seem to be given in Kelvin, but with minus signs.5 We’ll remove the minus signs and convert to degrees Celsius for convenience. We’ll also have to think about what to do with the NAs that appear for all non-maritime coordinates.

Before we get there though, we need to combine data from all files into a single data frame. This adds an additional dimension, time, ranging from 1891/01/01 to 2020/01/12:

grb <- read_stars(
  file.path(grb_dir, map(readLines("files", warn = FALSE), basename)), along = "time") %>%
                    values = seq(as.Date("1891-01-01"), as.Date("2020-12-01"), by = "months"),
                    names = "time"

stars object with 3 dimensions and 1 attribute
attribute(s), summary of first 1e+05 cells:
 Min.   :-274.9  
 1st Qu.:-273.3  
 Median :-258.8  
 Mean   :-260.0  
 3rd Qu.:-247.8  
 Max.   :-242.8  
 NA's   :33724   
     from   to offset delta                       refsys point                    values    
x       1  360      0     1 Coordinate System importe...    NA                      NULL [x]
y       1  180     90    -1 Coordinate System importe...    NA                      NULL [y]
time    1 1560     NA    NA                         Date    NA 1891-01-01,...,2020-12-01    

Let’s visually inspect the spatial distribution of monthly temperatures for one year, 2020:

ggplot() +
  geom_stars(data = grb %>% filter(between(time, as.Date("2020-01-01"), as.Date("2020-12-01"))), alpha = 0.8) +
  facet_wrap("time") +
  scale_fill_viridis() +
  coord_equal() +
  theme_map() +
  theme(legend.position = "none") 
Monthly sea surface temperatures, 2020/01/01 - 2020/01/12.

(#fig:unnamed-chunk-5)Monthly sea surface temperatures, 2020/01/01 – 2020/01/12.

Target: Niño 3.4 Index

For the Niño 3.4 Index, we download the monthly data and, among the provided features, zoom in on two: the index itself (column NINO34_MEAN) and PHASE, which can be E (El Niño), L (La Niño) or N (neutral).

nino <- read_table2("ONI_NINO34_1854-2020.txt", skip = 9) %>%
  mutate(month = as.Date(paste0(YEAR, "-", `MON/MMM`, "-01"))) %>%
  select(month, NINO34_MEAN, PHASE) %>%
  filter(between(month, as.Date("1891-01-01"), as.Date("2020-08-01"))) %>%
  mutate(phase_code = as.numeric(as.factor(PHASE)))


Next, we look at how to get the data into a format convenient for training and prediction.

Preprocessing Input

First, we remove all input data for points in time where ground truth data are still missing.

sst <- grb %>% filter(time <= as.Date("2020-08-01"))

Next, as is done by e.g. Ham, Kim, and Luo (2019), we only use grid points between 55° south and 60° north. This has the additional advantage of reducing memory requirements.

sst <- grb %>% filter(between(y,-55, 60))

360, 115, 1560

As already alluded to, with the little data we have we can’t expect much in terms of generalization. Still, we set aside a small portion of the data for validation, since we’d like for this post to serve as a useful template to be used with bigger datasets.

sst_train <- sst %>% filter(time < as.Date("1990-01-01"))
sst_valid <- sst %>% filter(time >= as.Date("1990-01-01"))

From here on, we work with R arrays.

sst_train <- as.tbl_cube.stars(sst_train)$mets[[1]]
sst_valid <- as.tbl_cube.stars(sst_valid)$mets[[1]]

Conversion to degrees Celsius is not strictly necessary, as initial experiments showed a slight performance increase due to normalizing the input, and we’re going to do that anyway. Still, it reads nicer to humans than Kelvin.

sst_train <- sst_train + 273.15
quantile(sst_train, na.rm = TRUE)
     0%     25%     50%     75%    100% 
-1.8000 12.9975 21.8775 26.8200 34.3700 

Not at all surprisingly, global warming is evident from inspecting temperature distribution on the validation set (which was chosen to span the last thirty-one years).

sst_valid <- sst_valid + 273.15
quantile(sst_valid, na.rm = TRUE)
    0%    25%    50%    75%   100% 
-1.800 13.425 22.335 27.240 34.870 

The next-to-last step normalizes both sets according to training mean and variance.

train_mean <- mean(sst_train, na.rm = TRUE)
train_sd <- sd(sst_train, na.rm = TRUE)

sst_train <- (sst_train - train_mean) / train_sd

sst_valid <- (sst_valid - train_mean) / train_sd

Finally, what should we do about the NA entries? We set them to zero, the (training set) mean. That may not be enough of an action though: It means we’re feeding the network roughly 30% misleading data. This is something we’re not done with yet.

sst_train[] <- 0
sst_valid[] <- 0


The target data are split analogously. Let’s check though: Are phases (categorizations) distributedly similarly in both sets?

nino_train <- nino %>% filter(month < as.Date("1990-01-01"))
nino_valid <- nino %>% filter(month >= as.Date("1990-01-01"))

nino_train %>% group_by(phase_code, PHASE) %>% summarise(count = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Groups:   phase_code [3]
  phase_code PHASE count   avg
       <dbl> <chr> <int> <dbl>
1          1 E       301  27.7
2          2 L       333  25.6
3          3 N       554  26.7
nino_valid %>% group_by(phase_code, PHASE) %>% summarise(count = n(), avg = mean(NINO34_MEAN))
# A tibble: 3 x 4
# Groups:   phase_code [3]
  phase_code PHASE count   avg
       <dbl> <chr> <int> <dbl>
1          1 E        93  28.1
2          2 L        93  25.9
3          3 N       182  27.2

This doesn’t look too bad. Of course, we again see the overall rise in temperature, irrespective of phase.

Lastly, we normalize the index, same as we did for the input data.

train_mean_nino <- mean(nino_train$NINO34_MEAN)
train_sd_nino <- sd(nino_train$NINO34_MEAN)

nino_train <- nino_train %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, center = train_mean_nino, scale = train_sd_nino))
nino_valid <- nino_valid %>% mutate(NINO34_MEAN = scale(NINO34_MEAN, center = train_mean_nino, scale = train_sd_nino))

On to the torch dataset.

Torch dataset

The dataset is responsible for correctly matching up inputs and targets.

Our goal is to take six months of global sea surface temperatures and predict the Niño 3.4 Index for the following month. Input-wise, the model will expect the following format semantics:

batch_size * timesteps * width * height * channels, where

  • batch_size is the number of observations worked on in one round of computations,

  • timesteps chains consecutive observations from adjacent months,

  • width and height together constitute the spatial grid, and

  • channels corresponds to available visual channels in the “image”.

In .getitem(), we select the consecutive observations, starting at a given index, and stack them in dimension one. (One, not two, as batches will only start to exist once the dataloader comes into play.)

Now, what about the target? Our ultimate goal was – is – predicting the Niño 3.4 Index. However, as you see we define three targets: One is the index, as expected; an additional one holds the spatially-gridded sea surface temperatures for the prediction month. Why? Our main instrument, the most prominent constituent of the model, will be a convLSTM, an architecture designed for spatial prediction. Thus, to train it efficiently, we want to give it the opportunity to predict values on a spatial grid. So far so good; but there’s one more target, the phase/category. This was added for experimentation purposes: Maybe predicting both index and phase helps in training?

Finally, here is the code for the dataset. In our experiments, we based predictions on inputs from the preceding six months (n_timesteps <- 6). This is a parameter you might want to play with, though.

n_timesteps <- 6

enso_dataset <- dataset(
  name = "enso_dataset",
  initialize = function(sst, nino, n_timesteps) {
    self$sst <- sst
    self$nino <- nino
    self$n_timesteps <- n_timesteps
  .getitem = function(i) {
    x <- torch_tensor(self$sst[, , i:(n_timesteps + i - 1)]) # (360, 115, n_timesteps)
    x <- x$permute(c(3,1,2))$unsqueeze(2) # (n_timesteps, 1, 360, 115))
    y1 <- torch_tensor(self$sst[, , n_timesteps + i])$unsqueeze(1) # (1, 360, 115)
    y2 <- torch_tensor(self$nino$NINO34_MEAN[n_timesteps + i])
    y3 <- torch_tensor(self$nino$phase_code[n_timesteps + i])$squeeze()$to(torch_long())
    list(x = x, y1 = y1, y2 = y2, y3 = y3)
  .length = function() {
    nrow(self$nino) - n_timesteps

valid_ds <- enso_dataset(sst_valid, nino_valid, n_timesteps)


After the custom dataset, we create the – pretty typical – dataloaders, making use of a batch size of 4.

batch_size <- 4

train_dl <- train_ds %>% dataloader(batch_size = batch_size, shuffle = TRUE)

valid_dl <- valid_ds %>% dataloader(batch_size = batch_size)

Next, we proceed to model creation.


The model’s main ingredient is the convLSTM introduced in a prior post. For convenience, we reproduce the code in the appendix.

Besides the convLSTM, the model makes use of three convolutional layers, a batchnorm layer and five linear layers. The logic is the following.

First, the convLSTM job is to predict the next month’s sea surface temperatures on the spatial grid. For that, we almost just return its final state, – almost: We use self$conv1 to reduce the number channels to one.

For predicting index and phase, we then need to flatten the grid, as we require a single value each. This is where the additional conv layers come in. We do hope they’ll aid in learning, but we also want to reduce the number of parameters a bit, downsizing the grid (strides = 2 and strides = 3, resp.) a bit before the upcoming torch_flatten().

Once we have a flat structure, learning is shared between the tasks of index and phase prediction (self$linear), until finally their paths split (self$cont and self$cat, resp.), and they return their separate outputs.

(The batchnorm? I’ll comment on that in the Discussion.)

model <- nn_module(
  initialize = function(channels_in,
                        convlstm_layers) {
    self$n_layers <- convlstm_layers
    self$convlstm <- convlstm(
      input_dim = channels_in,
      hidden_dims = convlstm_hidden,
      kernel_sizes = convlstm_kernel,
      n_layers = convlstm_layers
    self$conv1 <-
        in_channels = 32,
        out_channels = 1,
        kernel_size = 5,
        padding = 2
    self$conv2 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 2
    self$conv3 <-
        in_channels = 32,
        out_channels = 32,
        kernel_size = 5,
        stride = 3
    self$linear <- nn_linear(33408, 64)
    self$b1 <- nn_batch_norm1d(num_features = 64)
    self$cont <- nn_linear(64, 128)
    self$cat <- nn_linear(64, 128)
    self$cont_output <- nn_linear(128, 1)
    self$cat_output <- nn_linear(128, 3)
  forward = function(x) {
    ret <- self$convlstm(x)
    layer_last_states <- ret[[2]]
    last_hidden <- layer_last_states[[self$n_layers]][[1]]
    next_sst <- last_hidden %>% self$conv1() 
    c2 <- last_hidden %>% self$conv2() 
    c3 <- c2 %>% self$conv3() 
    flat <- torch_flatten(c3, start_dim = 2)
    common <- self$linear(flat) %>% self$b3() %>% nnf_relu()

    next_temp <- common %>% self$cont() %>% nnf_relu() %>% self$cont_output()
    next_nino <- common %>% self$cat() %>% nnf_relu() %>% self$cat_output()
    list(next_sst, next_temp, next_nino)

Next, we instantiate a pretty small-ish model. You’re more than welcome to experiment with larger models, but training time as well as GPU memory requirements will increase.

net <- model(
  channels_in = 1,
  convlstm_hidden = c(16, 16, 32),
  convlstm_kernel = c(3, 3, 5),
  convlstm_layers = 3

device <- torch_device(if (cuda_is_available()) "cuda" else "cpu")

net <- net$to(device = device)
An `nn_module` containing 2,389,605 parameters.

── Modules ───────────────────────────────────────────────────────────────────────────────
● convlstm: <nn_module> #182,080 parameters
● conv1: <nn_conv2d> #801 parameters
● conv2: <nn_conv2d> #25,632 parameters
● conv3: <nn_conv2d> #25,632 parameters
● linear: <nn_linear> #2,138,176 parameters
● b1: <nn_batch_norm1d> #128 parameters
● cont: <nn_linear> #8,320 parameters
● cat: <nn_linear> #8,320 parameters
● cont_output: <nn_linear> #129 parameters
● cat_output: <nn_linear> #387 parameters


We have three model outputs. How should we combine the losses?

Given that the main goal is predicting the index, and the other two outputs are essentially means to an end, I found the following combination rather effective:

# weight for sea surface temperature prediction
lw_sst <- 0.2

# weight for prediction of El Nino 3.4 Index
lw_temp <- 0.4

# weight for phase prediction
lw_nino <- 0.4

The training process follows the pattern seen in all torch posts so far: For each epoch, loop over the training set, backpropagate, check performance on validation set.

But, when we did the pre-processing, we were aware of an imminent problem: the missing temperatures for continental areas, which we set to zero. As a sole measure, this approach is clearly insufficient. What if we had chosen to use latitude-dependent averages? Or interpolation? Both may be better than a global average, but both have their problems as well. Let’s at least alleviate negative consequences by not using the respective pixels for spatial loss calculation. This is taken care of by the following line below:

sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])

Here, then, is the complete training code.

optimizer <- optim_adam(net$parameters, lr = 0.001)

num_epochs <- 50

train_batch <- function(b) {
  output <- net(b$x$to(device = device))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(device = device)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(device = device))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(device = device))
  loss <- lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss

  list(sst_loss$item(), temp_loss$item(), nino_loss$item(), loss$item())

valid_batch <- function(b) {
  output <- net(b$x$to(device = device))
  sst_output <- output[[1]]
  sst_target <- b$y1$to(device = device)
  sst_loss <- nnf_mse_loss(sst_output[sst_target != 0], sst_target[sst_target != 0])
  temp_loss <- nnf_mse_loss(output[[2]], b$y2$to(device = device))
  nino_loss <- nnf_cross_entropy(output[[3]], b$y3$to(device = device))
  loss <-
    lw_sst * sst_loss + lw_temp * temp_loss + lw_nino * nino_loss


for (epoch in 1:num_epochs) {
  train_loss_sst <- c()
  train_loss_temp <- c()
  train_loss_nino <- c()
  train_loss <- c()

  coro::loop(for (b in train_dl) {
    losses <- train_batch(b)
    train_loss_sst <- c(train_loss_sst, losses[[1]])
    train_loss_temp <- c(train_loss_temp, losses[[2]])
    train_loss_nino <- c(train_loss_nino, losses[[3]])
    train_loss <- c(train_loss, losses[[4]])
      "nEpoch %d, training: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(train_loss), mean(train_loss_sst), mean(train_loss_temp), mean(train_loss_nino)
  valid_loss_sst <- c()
  valid_loss_temp <- c()
  valid_loss_nino <- c()
  valid_loss <- c()

  coro::loop(for (b in valid_dl) {
    losses <- valid_batch(b)
    valid_loss_sst <- c(valid_loss_sst, losses[[1]])
    valid_loss_temp <- c(valid_loss_temp, losses[[2]])
    valid_loss_nino <- c(valid_loss_nino, losses[[3]])
    valid_loss <- c(valid_loss, losses[[4]])
      "nEpoch %d, validation: loss: %3.3f sst: %3.3f temp: %3.3f nino: %3.3f n",
      epoch, mean(valid_loss), mean(valid_loss_sst), mean(valid_loss_temp), mean(valid_loss_nino)
  torch_save(net, paste0(
    "model_", epoch, "_", round(mean(train_loss), 3), "_", round(mean(valid_loss), 3), ".pt"

When I ran this, performance on the training set decreased in a not-too-fast, but continuous way, while validation set performance kept fluctuating. For reference, total (composite) losses looked like this:

Epoch     Training    Validation
   10        0.336         0.633
   20        0.233         0.295
   30        0.135         0.461
   40        0.099         0.903
   50        0.061         0.727

Thinking of the size of the validation set – thirty-one years, or equivalently, 372 data points – those fluctuations may not be all too surprising.


Now losses tend to be abstract; let’s see what actually gets predicted. We obtain predictions for index values and phases like so …


pred_index <- c()
pred_phase <- c()

coro::loop(for (b in valid_dl) {

  output <- net(b$x$to(device = device))

  pred_index <- c(pred_index, output[[2]]$to(device = "cpu"))
  pred_phase <- rbind(pred_phase, as.matrix(output[[3]]$to(device = "cpu")))


… and combine these with the ground truth, stripping off the first..


A2C Advantage Actor Critic in TensorFlow 2

In a previous post, I gave an introduction to Policy Gradient reinforcement learning. Policy gradient-based reinforcement learning relies on using neural networks to learn an action policy for the control of agents in an environment. This is opposed to controlling agents based on neural network estimations of a value-based function, such as the Q value in deep Q learning. However, there are problems with straight Monte-Carlo based methods of policy gradient learning as covered in the previously mentioned policy gradient post. In particular, one significant problem is a high variance in the learning. This problem can be solved by a process called baselining, with the most effective baselining method being the Advantage Actor Critic method or A2c. In this post, I’ll review the theory of the A2c method, and demonstrate how to build an A2c algorithm in TensorFlow 2.

All code shown in this tutorial can be found at this site’s Github repository, in the file.

A quick recap of some important concepts

In the A2C algorithm, notice the title “Advantage Actor” – this refers first to the actor, the part of the neural network that is used to determine the actions of the agent. The “advantage” is a concept that expresses the relative benefit of taking a certain action at time t ($a_t$) from a certain state $s_t$. Note that it is not the “absolute” benefit, but the “relative” benefit. This will become clearer when I discuss the concept of “value”. The advantage is expressed as:

$$A(s_t, a_t) = Q(s_t, a_t) – V(s_t)$$

The Q value (discussed in other posts, for instance here, here and here) is the expected future rewards of taking action $a_t$ from state $s_t$. The value $V(s_t)$ is the expected value of the agent being in that state and operating under a certain action policy $pi$. It can be expressed as:

$$V^{pi}(s) = mathbb{E} left[sum_{i=1}^T gamma^{i-1}r_{i}right]$$

Here $mathbb{E}$ is the expectation operator, and the value $V^{pi}(s)$ can be read as the expected value of future discounted rewards that will be gathered by the agent, operating under a certain action policy $pi$. So, the Q value is the expected value of taking a certain action from the current state, whereas V is the expected value of simply being in the current state, under a certain action policy.

The advantage then is the relative benefit of taking a certain action from the current state. It’s kind of like a normalized Q value. For example, let’s consider the last state in a game, where after the next action the game ends. There are three possible actions from this state, with rewards of (51, 50, 49). Let’s also assume that the action selection policy $pi$ is simply random, so there is an equal chance of any of the three actions being selected. The value of this state, then, is 50 ((51+50+49) / 3). If the first action is randomly selected (reward=51), the Q value is 51. However, the advantage is only equal to 1 (Q-V = 51-50). As can be observed and as stated above, the advantage is a kind of normalized or relative Q value.

Why is this important? If we are using Q values in some way to train our action-taking policy, in the example above the first action would send a “signal” or contribution of 51 to the gradient optimizer, which may be significant enough to push the parameters of the neural network significantly in a certain direction. However, given the other two actions possible from this state also have a high reward (50 and 49), the signal or contribution is really higher than it should be – it is not that much better to take action 1 instead of action 3. Therefore, Q values can be a source of high variance in the training process, and it is much better to use the normalized or baseline Q values i.e. the advantage, in training. For more discussion of Q, values, and advantages, see my post on dueling Q networks.

Policy gradient reinforcement learning and its problems

In a previous post, I presented the policy gradient reinforcement learning algorithm. For details on this algorithm, please consult that post. However, the A2C algorithm shares important similarities with the PG algorithm, and therefore it is necessary to recap some of the theory. First, it has to be recalled that PG-based algorithms involve a neural network that directly outputs estimates of the probability distribution of the best next action to take in a given state. So, for instance, if we have an environment with 4 possible actions, the output from the neural network could be something like [0.5, 0.25, 0.1, 0.15], with the first action being currently favored. In the PG case, then, the neural network is the direct instantiation of the policy of the agent $pi_{theta}$ – where this policy is controlled by the parameters of the neural network $theta$. This is opposed to Q based RL algorithms, where the neural network estimates the Q value in a given state for each possible action. In these algorithms, the action policy is generally an epsilon-greedy policy, where the best action is that action with the highest Q value (with some random choices involved to improve exploration).

The gradient of the loss function for the policy gradient algorithm is as follows:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)left(sum_{t’= t + 1}^{T} gamma^{t’-t-1} r_{t’} right)$$

Note that the term:

$$G_t = left(sum_{t’= t + 1}^{T} gamma^{t’-t-1} r_{t’} right)$$

Is just the discounted sum of the rewards onwards from state $s_t$. In other words, it is an estimate of the true value function $V^{pi}(s)$. Remember that in the PG algorithm, the network can only be trained after each full episode, and this is because of the term above. Therefore, note that the $G_t$ term above is an estimate of the true value function as it is based on only a single trajectory of the agent through the game.

Now, because it is based on samples of reward trajectories, which aren’t “normalized” or baselined in any way, the PG algorithm suffers from variance issues, resulting in slower and more erratic training progress. A better solution is to replace the $G_t$ function above with the Advantage – $A(s_t, a_t)$, and this is what the Advantage-Actor Critic method does.

The A2C algorithm

Replacing the $G_t$ function with the advantage, we come up with the following gradient function which can be used in training the neural network:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$$

Now, as shown above, the advantage is:

$$A(s_t, a_t) = Q(s_t, a_t) – V(s_t)$$

However, using Bellman’s equation, the Q value can be expressed purely in terms of the rewards and the value function:

$$Q(s_t, a_t) = mathbb{E}left[r_{t+1} + gamma V(s_{t+1})right]$$

Therefore, the advantage can now be estimated as:

$$A(s_t, a_t) = r_{t+1} + gamma V(s_{t+1}) – V(s_t)$$

As can be seen from the above, there is a requirement to be able to estimate the value function V. We could estimate it by running our agents through full episodes, in the same way we did in the policy gradient method. However, it would be better to be able to just collect batches of game-steps and train whenever the batch buffer was full, rather than having to wait for an episode to finish. That way, the agent could actually learn “on-the-go” during the middle of an episode/game.

So, do we build another neural network to estimate V? We could have two networks, one to learn the policy and produce actions, and another to estimate the state values. A more efficient solution is to create one network, but with two output channels, and this is how the A2C method is outworked. The figure below shows the network architecture for an A2C neural network:

A2C architecture

A2C architecture

This architecture is based on an A2C method that takes game images as the state input, hence the convolutional neural network layers at the beginning of the network (for more on CNNs, see my post here). This network architecture also resembles the Dueling Q network architecture (see my Dueling Q post). The point to note about the architecture above is that most of the network is shared, with a late bifurcation between the policy part and the value part. The outputs $P(s, a_i)$ are the action probabilities of the policy (generated from the neural network) – $P(a_t|s_t)$. The other output channel is the value estimation – a scalar output which is the predicted value of state s – $V(s)$. The two dense channels disambiguate the policy and the value outputs from the front-end of the neural network.

In this example, we’ll just be demonstrating the A2C algorithm on the Cartpole OpenAI Gym environment which doesn’t require a visual state input (i.e. a set of pixels as the input to the NN), and therefore the two output channels will simply share some dense layers, rather than a series of CNN layers.

The A2C loss functions

There are actually three loss values that need to be calculated in the A2C algorithm. Each of these losses is in practice given a weighting, and then they are summed together (with the entropy loss having a negative sign, see below).

The Critic loss

The loss function of the Critic i.e. the value estimating output of the neural network $V(s)$, needs to be trained so that it predicts more and more closely the actual value of the state. As shown before, the value of a state is calculated as:

$$V^{pi}(s) = mathbb{E} left[sum_{i=1}^T gamma^{i-1}r_{i}right]$$

So $V^{pi}(s)$ is the expected value of the discounted future rewards obtained by outworking a trajectory through the game based on a certain operating policy $pi$. We can therefore compare the predicted $V(s)$ at each state in the game, and the actual sampled discounted rewards that were gathered, and the difference between the two is the Critic loss. In this example, we’ll use a mean squared error function as the loss function, between the discounted rewards and the predicted values ($(V(s) – DR)^2$).

Now, given that, under the A2C algorithm, we collect state, action and reward tuples until a batch buffer is filled, how are we meant to figure out this discounted rewards sum? Let’s say we progress 3 states through a game, and we collect:

$(V(s_0), r_0), (V(s_1), r_1), (V(s_2), r_2)$

For the first Critic loss, we could calculate it as:

$$MSE(V(s_0), r_0 + gamma r_1 + gamma^2 r_2)$$

But that is missing all the following rewards $r_3, r_4, …., r_n$ until the game terminates. We didn’t have this problem in the Policy Gradient method, because in that method, we made sure a full run through the game had completed before training the neural network. In the A2c method, we use a trick called bootstrapping. To replace all the discounted $r_3, r_4, …., r_n$ values, we get the network to estimate the value for state 3, $V(s_3)$, and this will be an estimate for all the discounted future rewards beyond that point in the game. So, for the first Critic loss, we would have:

$$MSE(V(s_0), r_0 + gamma r_1 + gamma^2 r_2 + gamma^3 V(s_3))$$

Where $V(s_3)$ is a bootstrapped estimate of the value of the next state $s_3$.

This will be explained more in the code-walkthrough to follow.

The Actor loss

The second loss function needs to train the Actor (i.e. the action policy). Recall that the advantage weighted policy loss is:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$$

Let’s start with the advantage – $A(s_t, a_t) = r_{t+1} + gamma V(s_{t+1}) – V(s_t)$

This is simply the bootstrapped discounted rewards minus the predicted state values $V(s_t)$ that we gathered up while playing the game. So calculating the advantage is quite straight-forward, once we have the bootstrapped discounted rewards, as will be seen in the code walk-through shortly.

Now, with regards to the $log P_{pi_{theta}}(a_t|s_t)$ statement, in this instance, we can just calculate the log of the softmax probability estimate for whatever action was taken. So, for instance, if in state 1 ($s_1$) the network softmax output produces {0.1, 0.9} (for a 2-action environment), and the second action was actually taken by the agent, we would want to calculate log(0.9). We can make use of the TensorFlow-Keras SparseCategoricalCrossEntropy calculation, which takes the action as an integer, and this specifies which softmax output value to apply the log to. So in this example, y_pred = [1] and y_target = [0.1, 0.9] and the answer would be -log(0.9) = 0.105.

Another handy feature with the SpareCategoricalCrossEntropy loss in Keras is that it can be called with a “sample_weight” argument. This basically multiplies the log calculation with a value. So, in this example, we can supply the advantages as the sample weights, and it will calculate  $nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$ for us. This will be shown below, but the call will look like:

policy_loss = sparse_ce(actions, policy_logits, sample_weight=advantages)

Entropy loss

In many implementations of the A2c algorithm, another loss term is subtracted – the entropy loss. Entropy is a measure, broadly speaking, of randomness. The higher the entropy, the more random the state of affairs, the lower the entropy, the more ordered the state of affairs. In the case of A2c, entropy is calculated on the softmax policy action ($P(a_t|s_t)$) output of the neural network. Let’s go back to our two action example from above. In the case of a probability output of {0.1, 0.9} for the two possible actions, this is an ordered, less-random selection of actions. In other words, there will be a consistent selection of action 2, and only rarely will action 1 be taken. The entropy formula is:

$$E = -sum p(x) log(p(x))$$

So in this case, the entropy of that output would be 0.325. However, if the probability output was instead {0.5, 0.5}, the entropy would be 0.693. The 50-50 action probability distribution will produce more random actions, and therefore the entropy is higher.

By subtracting the entropy calculation from the total loss (or giving the entropy loss a negative sign), it encourages more randomness and therefore more exploration. The A2c algorithm can have a tendency of converging on particular actions, so this subtraction of the entropy encourages a better exploration of alternative actions, though making the weighting on this component of the loss too large can also reduce training performance.

Again, we can use an already existing Keras loss function to calculate the entropy. The Keras categorical cross-entropy performs the following calculation:

Keras output of cross-entropy loss function

Keras output of cross-entropy loss function

If we just pass in the probability outputs as both target and output to this function, then it will calculate the entropy for us. This will be shown in the code below.

The total loss

The total loss function for the A2C algorithm is:

Loss = Actor Loss + Critic Loss * CRITIC_WEIGHT – Entropy Loss * ENTROPY_WEIGHT

A common value for the critic weight is 0.5, and the entropy weight is usually quite low (i.e. on the order of 0.01-0.001), though these hyperparameters can be adjusted and experimented with depending on the environment and network.

Implementing A2C in TensorFlow 2

In the following section, I will provide a walk-through of some code to implement the A2C methodology in TensorFlow 2. The code for this can be found at this site’s Github repository, in the file.

First, we perform the usual imports, set some constants, initialize the environment and finally create the neural network model which instantiates the A2C architecture:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import gym
import datetime as dt

STORE_PATH = '/Users/andrewthomas/Adventures in ML/TensorFlowBook/TensorBoard/A2CCartPole'
GAMMA = 0.95

env = gym.make("CartPole-v0")
state_size = 4
num_actions = env.action_space.n

class Model(keras.Model):
    def __init__(self, num_actions):
        self.num_actions = num_actions
        self.dense1 = keras.layers.Dense(64, activation='relu',
        self.dense2 = keras.layers.Dense(64, activation='relu',
        self.value = keras.layers.Dense(1)
        self.policy_logits = keras.layers.Dense(num_actions)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.value(x), self.policy_logits(x)

    def action_value(self, state):
        value, logits = self.predict_on_batch(state)
        action = tf.random.categorical(logits, 1)[0]
        return action, value

As can be seen, for this example I have set the critic, actor and entropy loss weights to 0.5, 1.0 and 0.05 respectively. Next the environment is setup, and then the model class is created.

This class inherits from keras.Model, which enables it to be integrated into the streamlined Keras methods of training and evaluating (for more information, see this Keras tutorial). In the initialization of the class, we see that 2 dense layers have been created, with 64 nodes in each. Then a value layer with one output is created, which evaluates $V(s)$, and finally the policy layer output with a size equal to the number of available actions. Note that this layer produces logits only, the softmax function which creates pseudo-probabilities ($P(a_t, s_t)$) will be applied within the various TensorFlow functions, as will be seen.

Next, the call function is defined – this function is run whenever a state needs to be “run” through the model, to produce a value and policy logits output. The Keras model API will use this function in its predict functions and also its training functions. In this function, it can be observed that the input is passed through the two common dense layers, and then the function returns first the value output, then the policy logits output.

The next function is the action_value function. This function is called upon when an action needs to be chosen from the model. As can be seen, the first step of the function is to run the predict_on_batch Keras model API function. This function just runs the function defined above. The output is both the values and the policy logits. An action is then selected by randomly choosing an action based on the action probabilities. Note that tf.random.categorical takes as input logits, not softmax outputs. The next function, outside of the Model class, is the function that calculates the critic loss:

def critic_loss(discounted_rewards, predicted_values):
    return keras.losses.mean_squared_error(discounted_rewards, predicted_values) * CRITIC_LOSS_WEIGHT

As explained above, the critic loss comprises of the mean squared error between the discounted rewards (which is calculated in another function, soon to be discussed) and the values predicted from the value output of the model (which are accumulated in a list during the agent’s trajectory through the game).

The following function shows the actor loss function:

def actor_loss(combined, policy_logits):
    actions = combined[:, 0]
    advantages = combined[:, 1]
    sparse_ce = keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.SUM

    actions = tf.cast(actions, tf.int32)
    policy_loss = sparse_ce(actions, policy_logits, sample_weight=advantages)

    probs = tf.nn.softmax(policy_logits)
    entropy_loss = keras.losses.categorical_crossentropy(probs, probs)

    return policy_loss * ACTOR_LOSS_WEIGHT..

A2C Advantage Actor Critic in TensorFlow 2

In a previous post, I gave an introduction to Policy Gradient reinforcement learning. Policy gradient-based reinforcement learning relies on using neural networks to learn an action policy for the control of agents in an environment. This is opposed to controlling agents based on neural network estimations of a value-based function, such as the Q value in deep Q learning. However, there are problems with straight Monte-Carlo based methods of policy gradient learning as covered in the previously mentioned policy gradient post. In particular, one significant problem is a high variance in the learning. This problem can be solved by a process called baselining, with the most effective baselining method being the Advantage Actor Critic method or A2c. In this post, I’ll review the theory of the A2c method, and demonstrate how to build an A2c algorithm in TensorFlow 2.

All code shown in this tutorial can be found at this site’s Github repository, in the file.

A quick recap of some important concepts

In the A2C algorithm, notice the title “Advantage Actor” – this refers first to the actor, the part of the neural network that is used to determine the actions of the agent. The “advantage” is a concept that expresses the relative benefit of taking a certain action at time t ($a_t$) from a certain state $s_t$. Note that it is not the “absolute” benefit, but the “relative” benefit. This will become clearer when I discuss the concept of “value”. The advantage is expressed as:

$$A(s_t, a_t) = Q(s_t, a_t) – V(s_t)$$

The Q value (discussed in other posts, for instance here, here and here) is the expected future rewards of taking action $a_t$ from state $s_t$. The value $V(s_t)$ is the expected value of the agent being in that state and operating under a certain action policy $pi$. It can be expressed as:

$$V^{pi}(s) = mathbb{E} left[sum_{i=1}^T gamma^{i-1}r_{i}right]$$

Here $mathbb{E}$ is the expectation operator, and the value $V^{pi}(s)$ can be read as the expected value of future discounted rewards that will be gathered by the agent, operating under a certain action policy $pi$. So, the Q value is the expected value of taking a certain action from the current state, whereas V is the expected value of simply being in the current state, under a certain action policy.

The advantage then is the relative benefit of taking a certain action from the current state. It’s kind of like a normalized Q value. For example, let’s consider the last state in a game, where after the next action the game ends. There are three possible actions from this state, with rewards of (51, 50, 49). Let’s also assume that the action selection policy $pi$ is simply random, so there is an equal chance of any of the three actions being selected. The value of this state, then, is 50 ((51+50+49) / 3). If the first action is randomly selected (reward=51), the Q value is 51. However, the advantage is only equal to 1 (Q-V = 51-50). As can be observed and as stated above, the advantage is a kind of normalized or relative Q value.

Why is this important? If we are using Q values in some way to train our action-taking policy, in the example above the first action would send a “signal” or contribution of 51 to the gradient optimizer, which may be significant enough to push the parameters of the neural network significantly in a certain direction. However, given the other two actions possible from this state also have a high reward (50 and 49), the signal or contribution is really higher than it should be – it is not that much better to take action 1 instead of action 3. Therefore, Q values can be a source of high variance in the training process, and it is much better to use the normalized or baseline Q values i.e. the advantage, in training. For more discussion of Q, values, and advantages, see my post on dueling Q networks.

Policy gradient reinforcement learning and its problems

In a previous post, I presented the policy gradient reinforcement learning algorithm. For details on this algorithm, please consult that post. However, the A2C algorithm shares important similarities with the PG algorithm, and therefore it is necessary to recap some of the theory. First, it has to be recalled that PG-based algorithms involve a neural network that directly outputs estimates of the probability distribution of the best next action to take in a given state. So, for instance, if we have an environment with 4 possible actions, the output from the neural network could be something like [0.5, 0.25, 0.1, 0.15], with the first action being currently favored. In the PG case, then, the neural network is the direct instantiation of the policy of the agent $pi_{theta}$ – where this policy is controlled by the parameters of the neural network $theta$. This is opposed to Q based RL algorithms, where the neural network estimates the Q value in a given state for each possible action. In these algorithms, the action policy is generally an epsilon-greedy policy, where the best action is that action with the highest Q value (with some random choices involved to improve exploration).

The gradient of the loss function for the policy gradient algorithm is as follows:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)left(sum_{t’= t + 1}^{T} gamma^{t’-t-1} r_{t’} right)$$

Note that the term:

$$G_t = left(sum_{t’= t + 1}^{T} gamma^{t’-t-1} r_{t’} right)$$

Is just the discounted sum of the rewards onwards from state $s_t$. In other words, it is an estimate of the true value function $V^{pi}(s)$. Remember that in the PG algorithm, the network can only be trained after each full episode, and this is because of the term above. Therefore, note that the $G_t$ term above is an estimate of the true value function as it is based on only a single trajectory of the agent through the game.

Now, because it is based on samples of reward trajectories, which aren’t “normalized” or baselined in any way, the PG algorithm suffers from variance issues, resulting in slower and more erratic training progress. A better solution is to replace the $G_t$ function above with the Advantage – $A(s_t, a_t)$, and this is what the Advantage-Actor Critic method does.

The A2C algorithm

Replacing the $G_t$ function with the advantage, we come up with the following gradient function which can be used in training the neural network:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$$

Now, as shown above, the advantage is:

$$A(s_t, a_t) = Q(s_t, a_t) – V(s_t)$$

However, using Bellman’s equation, the Q value can be expressed purely in terms of the rewards and the value function:

$$Q(s_t, a_t) = mathbb{E}left[r_{t+1} + gamma V(s_{t+1})right]$$

Therefore, the advantage can now be estimated as:

$$A(s_t, a_t) = r_{t+1} + gamma V(s_{t+1}) – V(s_t)$$

As can be seen from the above, there is a requirement to be able to estimate the value function V. We could estimate it by running our agents through full episodes, in the same way we did in the policy gradient method. However, it would be better to be able to just collect batches of game-steps and train whenever the batch buffer was full, rather than having to wait for an episode to finish. That way, the agent could actually learn “on-the-go” during the middle of an episode/game.

So, do we build another neural network to estimate V? We could have two networks, one to learn the policy and produce actions, and another to estimate the state values. A more efficient solution is to create one network, but with two output channels, and this is how the A2C method is outworked. The figure below shows the network architecture for an A2C neural network:

A2C architecture

A2C architecture

This architecture is based on an A2C method that takes game images as the state input, hence the convolutional neural network layers at the beginning of the network (for more on CNNs, see my post here). This network architecture also resembles the Dueling Q network architecture (see my Dueling Q post). The point to note about the architecture above is that most of the network is shared, with a late bifurcation between the policy part and the value part. The outputs $P(s, a_i)$ are the action probabilities of the policy (generated from the neural network) – $P(a_t|s_t)$. The other output channel is the value estimation – a scalar output which is the predicted value of state s – $V(s)$. The two dense channels disambiguate the policy and the value outputs from the front-end of the neural network.

In this example, we’ll just be demonstrating the A2C algorithm on the Cartpole OpenAI Gym environment which doesn’t require a visual state input (i.e. a set of pixels as the input to the NN), and therefore the two output channels will simply share some dense layers, rather than a series of CNN layers.

The A2C loss functions

There are actually three loss values that need to be calculated in the A2C algorithm. Each of these losses is in practice given a weighting, and then they are summed together (with the entropy loss having a negative sign, see below).

The Critic loss

The loss function of the Critic i.e. the value estimating output of the neural network $V(s)$, needs to be trained so that it predicts more and more closely the actual value of the state. As shown before, the value of a state is calculated as:

$$V^{pi}(s) = mathbb{E} left[sum_{i=1}^T gamma^{i-1}r_{i}right]$$

So $V^{pi}(s)$ is the expected value of the discounted future rewards obtained by outworking a trajectory through the game based on a certain operating policy $pi$. We can therefore compare the predicted $V(s)$ at each state in the game, and the actual sampled discounted rewards that were gathered, and the difference between the two is the Critic loss. In this example, we’ll use a mean squared error function as the loss function, between the discounted rewards and the predicted values ($(V(s) – DR)^2$).

Now, given that, under the A2C algorithm, we collect state, action and reward tuples until a batch buffer is filled, how are we meant to figure out this discounted rewards sum? Let’s say we progress 3 states through a game, and we collect:

$(V(s_0), r_0), (V(s_1), r_1), (V(s_2), r_2)$

For the first Critic loss, we could calculate it as:

$$MSE(V(s_0), r_0 + gamma r_1 + gamma^2 r_2)$$

But that is missing all the following rewards $r_3, r_4, …., r_n$ until the game terminates. We didn’t have this problem in the Policy Gradient method, because in that method, we made sure a full run through the game had completed before training the neural network. In the A2c method, we use a trick called bootstrapping. To replace all the discounted $r_3, r_4, …., r_n$ values, we get the network to estimate the value for state 3, $V(s_3)$, and this will be an estimate for all the discounted future rewards beyond that point in the game. So, for the first Critic loss, we would have:

$$MSE(V(s_0), r_0 + gamma r_1 + gamma^2 r_2 + gamma^3 V(s_3))$$

Where $V(s_3)$ is a bootstrapped estimate of the value of the next state $s_3$.

This will be explained more in the code-walkthrough to follow.

The Actor loss

The second loss function needs to train the Actor (i.e. the action policy). Recall that the advantage weighted policy loss is:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$$

Let’s start with the advantage – $A(s_t, a_t) = r_{t+1} + gamma V(s_{t+1}) – V(s_t)$

This is simply the bootstrapped discounted rewards minus the predicted state values $V(s_t)$ that we gathered up while playing the game. So calculating the advantage is quite straight-forward, once we have the bootstrapped discounted rewards, as will be seen in the code walk-through shortly.

Now, with regards to the $log P_{pi_{theta}}(a_t|s_t)$ statement, in this instance, we can just calculate the log of the softmax probability estimate for whatever action was taken. So, for instance, if in state 1 ($s_1$) the network softmax output produces {0.1, 0.9} (for a 2-action environment), and the second action was actually taken by the agent, we would want to calculate log(0.9). We can make use of the TensorFlow-Keras SparseCategoricalCrossEntropy calculation, which takes the action as an integer, and this specifies which softmax output value to apply the log to. So in this example, y_pred = [1] and y_target = [0.1, 0.9] and the answer would be -log(0.9) = 0.105.

Another handy feature with the SpareCategoricalCrossEntropy loss in Keras is that it can be called with a “sample_weight” argument. This basically multiplies the log calculation with a value. So, in this example, we can supply the advantages as the sample weights, and it will calculate  $nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$ for us. This will be shown below, but the call will look like:

policy_loss = sparse_ce(actions, policy_logits, sample_weight=advantages)

Entropy loss

In many implementations of the A2c algorithm, another loss term is subtracted – the entropy loss. Entropy is a measure, broadly speaking, of randomness. The higher the entropy, the more random the state of affairs, the lower the entropy, the more ordered the state of affairs. In the case of A2c, entropy is calculated on the softmax policy action ($P(a_t|s_t)$) output of the neural network. Let’s go back to our two action example from above. In the case of a probability output of {0.1, 0.9} for the two possible actions, this is an ordered, less-random selection of actions. In other words, there will be a consistent selection of action 2, and only rarely will action 1 be taken. The entropy formula is:

$$E = -sum p(x) log(p(x))$$

So in this case, the entropy of that output would be 0.325. However, if the probability output was instead {0.5, 0.5}, the entropy would be 0.693. The 50-50 action probability distribution will produce more random actions, and therefore the entropy is higher.

By subtracting the entropy calculation from the total loss (or giving the entropy loss a negative sign), it encourages more randomness and therefore more exploration. The A2c algorithm can have a tendency of converging on particular actions, so this subtraction of the entropy encourages a better exploration of alternative actions, though making the weighting on this component of the loss too large can also reduce training performance.

Again, we can use an already existing Keras loss function to calculate the entropy. The Keras categorical cross-entropy performs the following calculation:

Keras output of cross-entropy loss function

Keras output of cross-entropy loss function

If we just pass in the probability outputs as both target and output to this function, then it will calculate the entropy for us. This will be shown in the code below.

The total loss

The total loss function for the A2C algorithm is:

Loss = Actor Loss + Critic Loss * CRITIC_WEIGHT – Entropy Loss * ENTROPY_WEIGHT

A common value for the critic weight is 0.5, and the entropy weight is usually quite low (i.e. on the order of 0.01-0.001), though these hyperparameters can be adjusted and experimented with depending on the environment and network.

Implementing A2C in TensorFlow 2

In the following section, I will provide a walk-through of some code to implement the A2C methodology in TensorFlow 2. The code for this can be found at this site’s Github repository, in the file.

First, we perform the usual imports, set some constants, initialize the environment and finally create the neural network model which instantiates the A2C architecture:

import tensorflow as tf
from tensorflow import keras
import numpy as np
import gym
import datetime as dt

STORE_PATH = '/Users/andrewthomas/Adventures in ML/TensorFlowBook/TensorBoard/A2CCartPole'
GAMMA = 0.95

env = gym.make("CartPole-v0")
state_size = 4
num_actions = env.action_space.n

class Model(keras.Model):
    def __init__(self, num_actions):
        self.num_actions = num_actions
        self.dense1 = keras.layers.Dense(64, activation='relu',
        self.dense2 = keras.layers.Dense(64, activation='relu',
        self.value = keras.layers.Dense(1)
        self.policy_logits = keras.layers.Dense(num_actions)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.value(x), self.policy_logits(x)

    def action_value(self, state):
        value, logits = self.predict_on_batch(state)
        action = tf.random.categorical(logits, 1)[0]
        return action, value

As can be seen, for this example I have set the critic, actor and entropy loss weights to 0.5, 1.0 and 0.05 respectively. Next the environment is setup, and then the model class is created.

This class inherits from keras.Model, which enables it to be integrated into the streamlined Keras methods of training and evaluating (for more information, see this Keras tutorial). In the initialization of the class, we see that 2 dense layers have been created, with 64 nodes in each. Then a value layer with one output is created, which evaluates $V(s)$, and finally the policy layer output with a size equal to the number of available actions. Note that this layer produces logits only, the softmax function which creates pseudo-probabilities ($P(a_t, s_t)$) will be applied within the various TensorFlow functions, as will be seen.

Next, the call function is defined – this function is run whenever a state needs to be “run” through the model, to produce a value and policy logits output. The Keras model API will use this function in its predict functions and also its training functions. In this function, it can be observed that the input is passed through the two common dense layers, and then the function returns first the value output, then the policy logits output.

The next function is the action_value function. This function is called upon when an action needs to be chosen from the model. As can be seen, the first step of the function is to run the predict_on_batch Keras model API function. This function just runs the function defined above. The output is both the values and the policy logits. An action is then selected by randomly choosing an action based on the action probabilities. Note that tf.random.categorical takes as input logits, not softmax outputs. The next function, outside of the Model class, is the function that calculates the critic loss:

def critic_loss(discounted_rewards, predicted_values):
    return keras.losses.mean_squared_error(discounted_rewards, predicted_values) * CRITIC_LOSS_WEIGHT

As explained above, the critic loss comprises of the mean squared error between the discounted rewards (which is calculated in another function, soon to be discussed) and the values predicted from the value output of the model (which are accumulated in a list during the agent’s trajectory through the game).

The following function shows the actor loss function:

def actor_loss(combined, policy_logits):
    actions = combined[:, 0]
    advantages = combined[:, 1]
    sparse_ce = keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.SUM

    actions = tf.cast(actions, tf.int32)
    policy_loss = sparse_ce(actions, policy_logits, sample_weight=advantages)

    probs = tf.nn.softmax(policy_logits)
    entropy_loss = keras.losses.categorical_crossentropy(probs, probs)

    return policy_loss * ACTOR_LOSS_WEIGHT - entropy_loss * ENTROPY_LOSS_WEIGHT

The first argument to the actor_loss function is an array with two columns (and BATCH_SIZE rows). The first column corresponds to the recorded actions of the agent as it traversed the game. The second column is the calculated advantages – the calculation of which will be shown shortly. Next, the sparse categorical cross-entropy function class is created. The arguments specify that the input to the function is logits (i.e. they don’t have softmax applied to them yet), and it also specifies the reduction to apply to the BATCH_SIZE number of calculated losses – in this case, a sum() function which aligns with the summation in:

$$nabla_theta J(theta) sim left(sum_{t=0}^{T-1} log P_{pi_{theta}}(a_t|s_t)right)A(s_t, a_t)$$

Next, the actions are cast to be integers (rather than floats) and finally, the policy loss is calculated based on the sparse_ce function. As discussed above, the sparse categorical cross-entropy function will select those policy probabilities that correspond to the actions actually taken in the game, and weight them by the advantage values. By applying a summation reduction, the formula above will be implemented in this function.

Next, the actual probabilities for action are estimated by applying the softmax function to the logits, and the entropy loss is calculated by applying the categorical cross-entropy function. See the previous discussion on how this works.

The following function calculates the discounted reward values and the advantages:

def discounted_rewards_advantages(rewards, dones, values, next_value):
    discounted_rewards = np.array(rewards + [next_value[0]])

    for t in reversed(range(len(rewards))):
        discounted_rewards[t] = rewards[t] + GAMMA * discounted_rewards[t+1] * (1-dones[t])
    discounted_rewards = discounted_rewards[:-1]
    # advantages are bootstrapped discounted rewards - values, using Bellman's equation
    advantages = discounted_rewards - np.stack(values)[:, 0]
    return discounted_rewards, advantages

The first input value to this function is a list of all the rewards that were accumulated during the agent’s traversal of the game. The next is a list, the elements of which is either 1 or 0 depending on whether the game or episode ended or not at each time step. The values argument is a list of all values, $V(s)$ generated by the model at each time step. Finally, the next_value argument is the bootstrapped estimate of the value for the state $s_{t+1}$ – in other words, it is the estimated value of all the discounted rewards “downstream” of the last state recorded in the lists. Further discussion on bootstrapping was given in a previous section.

On the first line of the function, a numpy array is created out of the list of rewards, with the bootstrapped next_value appended to it. A reversed loop is then entered into. To explain how this loop works, it is perhaps best to give a simple example. For every time-step in the Cartpole environment, if the stick hasn’t fallen past horizontal, a reward of 1 is awarded. So let’s consider a small batch of samples of only 3 time steps. Let’s also say that the bootstrapped next_value estimate is 0.5. Therefore, at this point, the discounted rewards array looks like the following: [1, 1, 1, 0.5].

This is what the discounted_rewards array looks like at each step in the loop:

t = 2 — discounted_rewards[2] = 1 + GAMMA * 0.5

t = 1 — discounted_rewards[1] = 1 + GAMMA(1 + GAMMA * 0.5) = 1 + GAMMA + GAMMA^2 * 0.5

t = 0 — discounted_rewards[0] = 1 + GAMMA(1 + GAMMA + GAMMA^2 * 0.5) = 1 + GAMMA + GAMMA^2 + GAMMA^3 * 0.5

As can be observed, this loop correctly generates the downstream discounted rewards values for each step in the batch. If the game is finished in one of these time-steps, the accumulation of discounted future rewards will be reset, so that rewards from a subsequent game wont flow into the previous game that just ended.

Because discounted_rewards[3] actually equals the bootstrapped next_value, it doesn’t apply to calculating the advantage, so the next line in the code simply restricts the scope of the discounted_rewards array so that this next_value is excluded.

Next, the advantages are calculated, simply by subtracting the estimated values from the discounted_rewards.

The following lines of code create a model instance, compile the model, and set up a TensorBoard writer for visualization purposes.

model = Model(num_actions)
model.compile(optimizer=keras.optimizers.Adam(), loss=[critic_loss, actor_loss])

train_writer = tf.summary.create_file_writer(STORE_PATH + f"/A2C-CartPole_{'%d%m%Y%H%M')}")

Note that in the model compilation function, the loss function specified is a compound of the critic and actor loss (with the actor loss featuring the entropy impact, as shown above).

The code below shows the main training loop:

num_steps = 10000000
episode_reward_sum = 0
state = env.reset()
episode = 1
for step in range(num_steps):
    rewards = []
    actions = []
    values = []
    states = []
    dones = []
    for _ in range(BATCH_SIZE):
        _, policy_logits = model(state.reshape(1, -1))

        action, value  = model.action_value(state.reshape(1, -1))
        new_state, reward, done, _ = env.step(action.numpy()[0])

        episode_reward_sum += reward

        state = new_state
        if done:
            state = env.reset()
            print(f"Episode: {episode}, latest episode reward: {episode_reward_sum}, loss: {loss}")
            with train_writer.as_default():
                tf.summary.scalar('rewards', episode_reward_sum, episode)
            episode_reward_sum = 0
            episode += 1

    _, next_value = model.action_value(state.reshape(1, -1))
    discounted_rewards, advantages = discounted_rewards_advantages(rewards, dones, values, next_value.numpy()[0])

    # combine the actions and advantages into a combined array for passing to
    # actor_loss function
    combined = np.zeros((len(actions), 2))
    combined[:, 0] = actions
    combined[:, 1] = advantages

    loss = model.train_on_batch(tf.stack(states), [discounted_rewards, combined])

    with train_writer.as_default():
        tf.summary.scalar('tot_loss', np.sum(loss), step)

At the beginning of each “step” or batch number, all of the lists (rewards, actions, values, states, dones) are emptied. A secondary loop is then entered into, which accumulates all of these lists. Within this inner loop, the action logits are generated from the model, and the actual action to be taken (action variable) and the state value (value variable) are retrieved from the model.action_value function. The action is then fed into the environment so that a step can be taken. This generates a new state, the reward for taking that action, and the done flag – signifying if that action ended the game. All of these values are then appended to the various lists, and the episode reward accumulator is added to.

If the episode is done, the environment is reset and the total episode rewards are stored in the TensorBoard writer. If not, the reward is simply stored in the list.

After the BATCH_SIZE number of episodes is stored in the list, the inner loop is exited and it is time to train the model. The next_value bootstrapped value estimate is generated (recall the state variable has been updated to the next_state variable i.e. the next state in the game), and the discounted rewards and advantages are calculated. Next, a combined array is created and populated column-wise by the actions and advantages. These are then passed to the model.train_on_batch function. The discounted_rewards and combined variables are passed to this function, which will, in turn, be automatically fed into the critic and actor loss functions, respectively (along with the outputs from the function – the value estimate and the policy logits).

The loss is returned and finally, this is logged also.

The outcome of training the Cartpole environment for 200 episodes can be seen in the graph below:

A2C training progress on Cartpole environment

A2C training progress on Cartpole environment

That’s the end of this tutorial on the powerful A2C reinforcement learning algorithm, and how to implement it in TensorFlow 2. In a future post, I will demonstrate how to apply this technique to a more challenging Atari game environment, making use of convolutional neural network layers and the actual game screen pixels.

I hope this was useful for you – all the best.

The post A2C Advantage Actor Critic in TensorFlow 2 appeared first on Adventures in Machine Learning.


Learning to Reason Over Tables from Less Data

The task of recognizing textual entailment, also known as natural language inference, consists of determining whether a piece of text (a premise), can be implied or contradicted (or neither) by another piece of text (the hypothesis). While this problem is often considered an important test for the reasoning skills of machine learning (ML) systems and has been studied in depth for plain text inputs, much less effort has been put into applying such models to structured data, such as websites, tables, databases, etc. Yet, recognizing textual entailment is especially relevant whenever the contents of a table need to be accurately summarized and presented to a user, and is essential for high fidelity question answering systems and virtual assistants.

In “Understanding tables with intermediate pre-training“, published in Findings of EMNLP 2020, we introduce the first pre-training tasks customized for table parsing, enabling models to learn better, faster and from less data. We build upon our earlier TAPAS model, which was an extension of the BERT bi-directional Transformer model with special embeddings to find answers in tables. Applying our new pre-training objectives to TAPAS yields a new state of the art on multiple datasets involving tables. On TabFact, for example, it reduces the gap between model and human performance by ~50%. We also systematically benchmark methods of selecting relevant input for higher efficiency, achieving 4x gains in speed and memory, while retaining 92% of the results. All the models for different tasks and sizes are released on GitHub repo, where you can try them out yourself in a colab Notebook.

Textual Entailment
The task of textual entailment is more challenging when applied to tabular data than plain text. Consider, for example, a table from Wikipedia with some sentences derived from its associated table content. Assessing if the content of the table entails or contradicts the sentence may require looking over multiple columns and rows, and possibly performing simple numeric computations, like averaging, summing, differencing, etc.

A table together with some statements from TabFact. The content of the table can be used to support or contradict the statements.

Following the methods used by TAPAS, we encode the content of a statement and a table together, pass them through a Transformer model, and obtain a single number with the probability that the statement is entailed or refuted by the table.

The TAPAS model architecture uses a BERT model to encode the statement and the flattened table, read row by row. Special embeddings are used to encode the table structure. The vector output of the first token is used to predict the probability of entailment.

Because the only information in the training examples is a binary value (i.e., “correct” or “incorrect”), training a model to understand whether a statement is entailed or not is challenging and highlights the difficulty in achieving generalization in deep learning, especially when the provided training signal is scarce. Seeing isolated entailed or refuted examples, a model can easily pick-up on spurious patterns in the data to make a prediction, for example the presence of the word “tie” in “Greg Norman and Billy Mayfair tie in rank”, instead of truly comparing their ranks, which is what is needed to successfully apply the model beyond the original training data.

Pre-training Tasks
Pre-training tasks can be used to “warm-up” models by providing them with large amounts of readily available unlabeled data. However, pre-training typically includes primarily plain text and not tabular data. In fact, TAPAS was originally pre-trained using a simple masked language modelling objective that was not designed for tabular data applications. In order to improve the model performance on tabular data, we introduce two novel pretraining binary-classification tasks called counterfactual and synthetic, which can be applied as a second stage of pre-training (often called intermediate pre-training).

In the counterfactual task, we source sentences from Wikipedia that mention an entity (person, place or thing) that also appears in a given table. Then, 50% of the time, we modify the statement by swapping the entity for another alternative. To make sure the statement is realistic, we choose a replacement among the entities in the same column in the table. The model is trained to recognize whether the statement was modified or not. This pre-training task includes millions of such examples, and although the reasoning about them is not complex, they typically will still sound natural.

For the synthetic task, we follow a method similar to semantic parsing in which we generate statements using a simple set of grammar rules that require the model to understand basic mathematical operations, such as sums and averages (e.g., “the sum of earnings”), or to understand how to filter the elements in the table using some condition (e.g.,”the country is Australia”). Although these statements are artificial, they help improve the numerical and logical reasoning skills of the model.

Example instances for the two novel pre-training tasks. Counterfactual examples swap entities mentioned in a sentence that accompanies the input table for a plausible alternative. Synthetic statements use grammar rules to create new sentences that require combining the information of the table in complex ways.

We evaluate the success of the counterfactual and synthetic pre-training objectives on the TabFact dataset by comparing to the baseline TAPAS model and to two prior models that have exhibited success in the textual entailment domain, LogicalFactChecker (LFC) and Structure Aware Transformer (SAT). The baseline TAPAS model exhibits improved performance relative to LFC and SAT, but the pre-trained model (TAPAS+CS) performs significantly better, achieving a new state of the art.

We also apply TAPAS+CS to question answering tasks on the SQA dataset, which requires that the model find answers from the content of tables in a dialog setting. The inclusion of CS objectives improves the previous best performance by more than 4 points, demonstrating that this approach also generalizes performance beyond just textual entailment.

Results on TabFact (left) and SQA (right). Using the synthetic and counterfactual datasets, we achieve new state-of-the-art results in both tasks by a large margin.

Data and Compute Efficiency
Another aspect of the counterfactual and synthetic pre-training tasks is that since the models are already tuned for binary classification, they can be applied without any fine-tuning to TabFact. We explore what happens to each of the models when trained only on a subset (or even none) of the data. Without looking at a single example, the TAPAS+CS model is competitive with a strong baseline Table-Bert, and when only 10% of the data are included, the results are comparable to the previous state-of-the-art.

Dev accuracy on TabFact relative to the fraction of the training data used.

A general concern when trying to use large models such as this to operate on tables, is that their high computational requirements makes it difficult for them to parse very large tables. To address this, we investigate whether one can heuristically select subsets of the input to pass through the model in order to optimize its computational efficiency.

We conducted a systematic study of different approaches to filter the input and discovered that simple methods that select for word overlap between a full column and the subject statement give the best results. By dynamically selecting which tokens of the input to include, we can use fewer resources or work on larger inputs at the same cost. The challenge is doing so without losing important information and hurting accuracy. 

For instance, the models discussed above all use sequences of 512 tokens, which is around the normal limit for a transformer model (although recent efficiency methods like the Reformer or Performer are proving effective in scaling the input size). The column selection methods we propose here can allow for faster training while still achieving high accuracy on TabFact. For 256 input tokens we get a very small drop in accuracy, but the model can now be pre-trained, fine-tuned and make predictions up to two times faster. With 128 tokens the model still outperforms the previous state-of-the-art model, with an even more significant speed-up — 4x faster across the board.

Accuracy on TabFact using different sequence lengths, by shortening the input with our column selection method.

Using both the column selection method we proposed and the novel pre-training tasks, we can create table parsing models that need fewer data and less compute power to obtain better results.

We have made available the new models and pre-training techniques at our GitHub repo, where you can try it out yourself in colab. In order to make this approach more accessible, we also shared models of varying sizes all the way down to “tiny”. It is our hope that these results will help spur development of table reasoning among the broader research community.

This work was carried out by Julian Martin Eisenschlos, Syrine Krichene and Thomas Müller from our Language Team in Zürich. We would like to thank Jordan Boyd-Graber, Yasemin Altun, Emily Pitler, Benjamin Boerschinger, Srini Narayanan, Slav Petrov, William Cohen and Jonathan Herzig for their useful comments and suggestions.


Long-Term Stock Forecasting


Forecasting Dividends


Improving Mobile App Accessibility with Icon Detection

Voice Access enables users to control their Android device hands free, using only verbal commands. In order to function properly, it needs on-screen user interface (UI) elements to have reliable accessibility labels, which are provided to the operating system’s accessibility services via the accessibility tree. Unfortunately, in many apps, adequate labels aren’t always available for UI elements, e.g. images and icons, reducing the usability of Voice Access.

The Voice Access app extracts elements from the view hierarchy to localize and annotate various UI elements. It can provide a precise description for elements that have an explicit content description. On the other hand, the absence of content description can result in many unrecognized elements undermining the ability of Voice Access to function with some apps.

Addressing this challenge requires a system that can automatically detect icons using only the pixel values displayed on the screen, regardless of whether icons have been given suitable accessibility labels. What little research exists on this topic typically uses classifiers, sometimes combined with language models to infer classes and attributes from UI elements. However, these classifiers still rely on the accessibility tree to obtain bounding boxes for UI elements, and fail when appropriate labels do not exist.

Here, we describe IconNet, a vision-based object detection model that can automatically detect icons on the screen in a manner that is agnostic to the underlying structure of the app being used, launched as part of the latest version of Voice Access. IconNet can detect 31 different icon types (to be extended to more than 70 types soon) based on UI screenshots alone. IconNet is optimized to run on-device for mobile environments, with a compact size and fast inference time to enable a seamless user experience. The current IconNet model achieves a mean average precision (mAP) of 94.2% running at 9 FPS on a Pixel 3A.

Voice Access 5.0: the icons detected by IconNet can now be referred to by their names.

Detecting Icons in Screenshots
From a technical perspective, the problem of detecting icons on app screens is similar to classical object detection, in that individual elements are labelled by the model with their locations and sizes. But, in other ways, it’s quite different. Icons are typically small objects, with relatively basic geometric shapes and a limited range of colors, and app screens widely differ from natural images in that they are more structured and geometrical.

A significant challenge in the development of an on-device UI element detector for Voice Access is that it must be able to run on a wide variety of phones with a range of performance performance capabilities, while preserving the user’s privacy. For a fast user experience, a lightweight model with low inference latency is needed. Because Voice Access needs to use the labels in response to an utterance from a user (e.g., “tap camera”, or “show labels”) inference time needs to be short (<150 ms on a Pixel 3A) with a model size less than 10 MB.

IconNet is based on the novel CenterNet architecture, which extracts features from input images and then predicts appropriate bounding box centers and sizes (in the form of heatmaps). CenterNet is particularly suited here because UI elements consist of simple, symmetric geometric shapes, making it easier to identify their centers than for natural images. The total loss used is a combination of a standard L1 loss for the icon sizes and a modified CornerNet Focal loss for the center predictions, the latter of which addresses icon class imbalances between commonly occurring icons (e.g., arrow backward, menu, more, and star) and underrepresented icons (end call, delete, launch apps, etc.)..

After experimenting with several backbones (MobileNet, ResNet, UNet, etc), we selected the most promising server-side architecture — Hourglass — as a starting point for designing a backbone tailored for icon and UI element detection. While this architecture is perfectly suitable for server side models, vanilla Hourglass backbones are not an option for a model that will run on a mobile device, due to their large size and slow inference time. We restricted our on-device network design to a single stack, and drastically reduced the width of the backbone. Furthermore, as the detection of icons relies on more local features (compared to real objects), we could further reduce the depth of the backbone without adversely affecting the performance. Ablation studies convinced us of the importance of skip connections and high resolution features. For example, trimming skip connections in the final layer reduced the mAP by 1.5%, and removing such connections from both the final and penultimate layers resulted in a decline of 3.5% mAP.

IconNet analyzes the pixels of the screen and identifies the centers of icons by generating heatmaps, which provide precise information about the position and type of the different types of icons present on the screen. This enables Voice Access users to refer to these elements by their name (e.g., “Tap ‘menu”).

Model Improvements
Once the backbone architecture was selected, we used neural architecture search (NAS) to explore variations on the network architecture and uncover an optimal set of training and model parameters that would balance model performance (mAP) with latency (FLOPs). Additionally, we used Fine-Grained Stochastic Architecture Search (FiGS) to further refine the backbone design. FiGS is a differentiable architecture search technique that uncovers sparse structures by pruning a candidate architecture and discarding unnecessary connections. This technique allowed us to reduce the model size by 20% without any loss in performance, and by 50% with only a minor drop of 0.3% in mAP.

Improving the quality of the training dataset also played an important role in boosting the model performance. We collected and labeled more than 700K screenshots, and in the process, we streamlined data collection by using heuristics and auxiliary models to identify rarer icons. We also took advantage of data augmentation techniques by enriching existing screenshots with infrequent icons.

To improve the inference time, we modified our model to run using Neural Networks API (NNAPI) on a variety of Qualcomm DSPs available on many mobile phones. For this we converted the model to use 8-bit integer quantization which gives the additional benefit of model size reduction. After some experimentation, we used quantization aware training to quantize the model, while matching the performance of a server-side floating point model. The quantized model results in a 6x speed-up (700ms vs 110ms) and 50% size reduction while losing only ~0.5% mAP compared to the unquantized model.

We use traditional object detection metrics (e.g., mAP) to measure model performance. In addition, to better capture the use case of voice controlled user actions, we define a modified version of a false positive (FP) detection, where we penalize more incorrect detections for icon classes that are present on the screen. For comparing detections with ground truth, we use the center in region of interest (CIROI), another metric we developed for this work, which returns in a positive match when the center of the detected bounding box lies inside the ground truth bounding box. This better captures the Voice Access mode of operation, where actions are performed by tapping anywhere in the region of the UI element of interest.

We compared the IconNet model with various other mobile compatible object detectors, including MobileNetEdgeTPU and SSD MobileNet v2. Experiments showed that for a fixed latency, IconNet outperformed the other models in terms of mAP@CIROI on our internal evaluation set.

Model    mAP@CIROI
IconNet (Hourglass)    96%
IconNet (HRNet)    89%
MobilenetEdgeTPU (AutoML)    91%
SSD Mobilenet v2    88%

The performance advantage of IconNet persists when considering quantized models and models for a fixed latency budget.

Models (Quantized)    mAP@CIROI    Model size    Latency*
IconNet (Currently deployed)    94.20%    8.5 MB    107 ms
IconNet (XS)    92.80%    2.3 MB    102 ms
IconNet (S)    91.70%    4.4 MB    45 ms
MobilenetEdgeTPU (AutoML)    88.90%    7.8 MB    26 ms
*Measured on Pixel 3A.

Conclusion and Future Work
We are constantly working on improving IconNet. Among other things, we are interested in increasing the range of elements supported by IconNet to include any generic UI element, such as images, text, or buttons. We also plan to extend IconNet to differentiate between similar looking icons by identifying their functionality. On the application side, we are hoping to increase the number of apps with valid content descriptions by augmenting developer tools to suggest content descriptions for different UI elements when building applications.

This project is the result of joint work with Maria Wang, Tautvydas Misiūnas, Lijuan Liu, Ying Xu, Nevan Wichers, Xiaoxue Zang, Gabriel Schubiner, Abhinav Rastogi, Jindong (JD) Chen, Abhanshu Sharma, Pranav Khaitan, Matt Sharifi and Blaise Aguera y Arcas. We sincerely thank our collaborators Robert Berry, Folawiyo Campbell, Shraman Ray Chaudhuri, Nghi Doan, Elad Eban, Marybeth Fair, Alec Go, Sahil Goel, Tom Hume, Cassandra Luongo, Yair Movshovitz-Attias, James Stout, Gabriel Taubman and Anton Vayvod.


Addressing Range Anxiety with Smart Electric Vehicle Routing

Mapping algorithms used for navigation often rely on Dijkstra’s algorithm, a fundamental textbook solution for finding shortest paths in graphs. Dijkstra’s algorithm is simple and elegant — rather than considering all possible routes (an exponential number) it iteratively improves an initial solution, and works in polynomial time. The original algorithm and practical extensions of it (such as the A* algorithm) are used millions of times per day for routing vehicles on the global road network. However, due to the fact that most vehicles are gas-powered, these algorithms ignore refueling considerations because a) gas stations are usually available everywhere at the cost of a small detour, and b) the time needed to refuel is typically only a few minutes and is negligible compared to the total travel time.

This situation is different for electric vehicles (EVs). First, EV charging stations are not as commonly available as gas stations, which can cause range anxiety, the fear that the car will run out of power before reaching a charging station. This concern is common enough that it is considered one of the barriers to the widespread adoption of EVs. Second, charging an EV’s battery is a more decision-demanding task, because the charging time can be a significant fraction of the total travel time and can vary widely by station, vehicle model, and battery level. In addition, the charging time is non-linear — e.g., it takes longer to charge a battery from 90% to 100% than from 20% to 30%.

The EV can only travel a distance up to the illustrated range before needing to recharge. Different roads and different stations have different time costs. The goal is to optimize for the total trip time.

Today, we present a new approach for routing of EVs integrated into the latest release of Google Maps built into your car for participating EVs that reduces range anxiety by integrating recharging stations into the navigational route. Based on the battery level and the destination, Maps will recommend the charging stops and the corresponding charging levels that will minimize the total duration of the trip. To accomplish this we engineered a highly scalable solution for recommending efficient routes through charging stations, which optimizes the sum of the driving time and the charging time together.

The fastest route from Berlin to Paris for a gas fueled car is shown in the top figure. The middle figure shows the optimal route for a 400 km range EV (travel time indicated – charging time excluded), where the larger white circles along the route indicate charging stops. The bottom figure shows the optimal route for a 200 km range EV.

Routing Through Charging Stations
A fundamental constraint on route selection is that the distance between recharging stops cannot be higher than what the vehicle can reach on a full charge. Consequently, the route selection model emphasizes the graph of charging stations, as opposed to the graph of road segments of the road network, where each charging station is a node and each trip between charging stations is an edge. Taking into consideration the various characteristics of each EV (such as the weight, maximum battery level, plug type, etc.) the algorithm identifies which of the edges are feasible for the EV under consideration and which are not. Once the routing request comes in, Maps EV routing augments the feasible graph with two new nodes, the origin and the destination, and with multiple new (feasible) edges that outline the potential trips from the origin to its nearby charging stations and to the destination from each of its nearby charging stations.

Routing using Dijkstra’s algorithm or A* on this graph is sufficient to give a feasible solution that optimizes for the travel time for drivers that do not care at all about the charging time, (i.e., drivers who always fully charge their batteries at each charging station). However, such algorithms are not sufficient to account for charging times. In this case, the algorithm constructs a new graph by replicating each charging station node multiple times. Half of the copies correspond to entering the station with a partially charged battery, with a charge, x, ranging from 0%-100%. The other half correspond to exiting the station with a fractional charge, y (again from 0%-100%). We add an edge from the entry node at the charge x to the exit node at charge y (constrained by y > x), with a corresponding charging time to get from x to y. When the trip from Station A to Station B spends some fraction (z) of the battery charge, we introduce an edge between every exit node of Station A to the corresponding entry node of Station B (at charge xz). After performing this transformation, using Dijkstra or A* recovers the solution.

An example of our node/edge replication. In this instance the algorithm opts to pass through the first station without charging and charges at the second station from 20% to 80% battery.

Graph Sparsification
To perform the above operations while addressing range anxiety with confidence, the algorithm must compute the battery consumption of each trip between stations with good precision. For this reason, Maps maintains detailed information about the road characteristics along the trip between any two stations (e.g., the length, elevation, and slope, for each segment of the trip), taking into consideration the properties of each type of EV.

Due to the volume of information required for each segment, maintaining a large number of edges can become a memory intensive task. While this is not a problem for areas where EV charging stations are sparse, there exist locations in the world (such as Northern Europe) where the density of stations is very high. In such locations, adding an edge for every pair of stations between which an EV can travel quickly grows to billions of possible edges.

The figure on the left illustrates the high density of charging stations in Northern Europe. Different colors correspond to different plug types. The figure on the right illustrates why the routing graph scales up very quickly in size as the density of stations increases. When there are many stations within range of each other, the induced routing graph is a complete graph that stores detailed information for each edge.

However, this high density implies that a trip between two stations that are relatively far apart will undoubtedly pass through multiple other stations. In this case, maintaining information about the long edge is redundant, making it possible to simply add the smaller edges (spanners) in the graph, resulting in sparser, more computationally feasible, graphs.

The spanner construction algorithm is a direct generalization of the greedy geometric spanner. The trips between charging stations are sorted from fastest to slowest and are processed in that order. For each trip between points a and b, the algorithm examines whether smaller subtrips already included in the spanner subsume the direct trip. To do so it compares the trip time and battery consumption that can be achieved using subtrips already in the spanner, against the same quantities for the direct ab route. If they are found to be within a tiny error threshold, the direct trip from a to b is not added to the spanner, otherwise it is. Applying this sparsification algorithm has a notable impact and allows the graph to be served efficiently in responding to users’ routing requests.

On the left is the original road network (EV stations in light red). The station graph in the middle has edges for all feasible trips between stations. The sparse graph on the right maintains the distances with much fewer edges.

In this work we engineer a scalable solution for routing EVs on long trips to include access to charging stations through the use of graph sparsification and novel framing of standard routing algorithms. We are excited to put algorithmic ideas and techniques in the hands of Maps users and look forward to serving stress-free routes for EV drivers across the globe!

We thank our collaborators Dixie Wang, Xin Wei Chow, Navin Gunatillaka, Stephen Broadfoot, Alex Donaldson, and Ivan Kuznetsov.