Categories
Offsites

Computer-aided diagnosis for lung cancer screening

Lung cancer is the leading cause of cancer-related deaths globally with 1.8 million deaths reported in 2020. Late diagnosis dramatically reduces the chances of survival. Lung cancer screening via computed tomography (CT), which provides a detailed 3D image of the lungs, has been shown to reduce mortality in high-risk populations by at least 20% by detecting potential signs of cancers earlier. In the US, screening involves annual scans, with some countries or cases recommending more or less frequent scans.

The United States Preventive Services Task Force recently expanded lung cancer screening recommendations by roughly 80%, which is expected to increase screening access for women and racial and ethnic minority groups. However, false positives (i.e., incorrectly reporting a potential cancer in a cancer-free patient) can cause anxiety and lead to unnecessary procedures for patients while increasing costs for the healthcare system. Moreover, efficiency in screening a large number of individuals can be challenging depending on healthcare infrastructure and radiologist availability.

At Google we have previously developed machine learning (ML) models for lung cancer detection, and have evaluated their ability to automatically detect and classify regions that show signs of potential cancer. Performance has been shown to be comparable to that of specialists in detecting possible cancer. While they have achieved high performance, effectively communicating findings in realistic environments is necessary to realize their full potential.

To that end, in “Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the US and Japan”, published in Radiology AI, we investigate how ML models can effectively communicate findings to radiologists. We also introduce a generalizable user-centric interface to help radiologists leverage such models for lung cancer screening. The system takes CT imaging as input and outputs a cancer suspicion rating using four categories (no suspicion, probably benign, suspicious, highly suspicious) along with the corresponding regions of interest. We evaluate the system’s utility in improving clinician performance through randomized reader studies in both the US and Japan, using the local cancer scoring systems (Lung-RADSs V1.1 and Sendai Score) and image viewers that mimic realistic settings. We found that reader specificity increases with model assistance in both reader studies. To accelerate progress in conducting similar studies with ML models, we have open-sourced code to process CT images and generate images compatible with the picture archiving and communication system (PACS) used by radiologists.

Developing an interface to communicate model results

Integrating ML models into radiologist workflows involves understanding the nuances and goals of their tasks to meaningfully support them. In the case of lung cancer screening, hospitals follow various country-specific guidelines that are regularly updated. For example, in the US, Lung-RADs V1.1 assigns an alpha-numeric score to indicate the lung cancer risk and follow-up recommendations. When assessing patients, radiologists load the CT in their workstation to read the case, find lung nodules or lesions, and apply set guidelines to determine follow-up decisions.

Our first step was to improve the previously developed ML models through additional training data and architectural improvements, including self-attention. Then, instead of targeting specific guidelines, we experimented with a complementary way of communicating AI results independent of guidelines or their particular versions. Specifically, the system output offers a suspicion rating and localization (regions of interest) for the user to consider in conjunction with their own specific guidelines. The interface produces output images directly associated with the CT study, requiring no changes to the user’s workstation. The radiologist only needs to review a small set of additional images. There is no other change to their system or interaction with the system.

Example of the assistive lung cancer screening system outputs. Results for the radiologist’s evaluation are visualized on the location of the CT volume where the suspicious lesion is found. The overall suspicion is displayed at the top of the CT images. Circles highlight the suspicious lesions while squares show a rendering of the same lesion from a different perspective, called a sagittal view.

The assistive lung cancer screening system comprises 13 models and has a high-level architecture similar to the end-to-end system used in prior work. The models coordinate with each other to first segment the lungs, obtain an overall assessment, locate three suspicious regions, then use the information to assign a suspicion rating to each region. The system was deployed on Google Cloud using a Google Kubernetes Engine (GKE) that pulled the images, ran the ML models, and provided results. This allows scalability and directly connects to servers where the images are stored in DICOM stores.

Outline of the Google Cloud deployment of the assistive lung cancer screening system and the directional calling flow for the individual components that serve the images and compute results. Images are served to the viewer and to the system using Google Cloud services. The system is run on a Google Kubernetes Engine that pulls the images, processes them, and writes them back into the DICOM store.

Reader studies

To evaluate the system’s utility in improving clinical performance, we conducted two reader studies (i.e., experiments designed to assess clinical performance comparing expert performance with and without the aid of a technology) with 12 radiologists using pre-existing, de-identified CT scans. We presented 627 challenging cases to 6 US-based and 6 Japan-based radiologists. In the experimental setup, readers were divided into two groups that read each case twice, with and without assistance from the model. Readers were asked to apply scoring guidelines they typically use in their clinical practice and report their overall suspicion of cancer for each case. We then compared the results of the reader’s responses to measure the impact of the model on their workflow and decisions. The score and suspicion level were judged against the actual cancer outcomes of the individuals to measure sensitivity, specificity, and area under the ROC curve (AUC) values. These were compared with and without assistance.

A multi-case multi-reader study involves each case being reviewed by each reader twice, once with ML system assistance and once without. In this visualization one reader first reviews Set A without assistance (blue) and then with assistance (orange) after a wash-out period. A second reader group follows the opposite path by reading the same set of cases Set A with assistance first. Readers are randomized to these groups to remove the effect of ordering.

The ability to conduct these studies using the same interface highlights its generalizability to completely different cancer scoring systems, and the generalization of the model and assistive capability to different patient populations. Our study results demonstrated that when radiologists used the system in their clinical evaluation, they had an increased ability to correctly identify lung images without actionable lung cancer findings (i.e., specificity) by an absolute 5–7% compared to when they didn’t use the assistive system. This potentially means that for every 15–20 patients screened, one may be able to avoid unnecessary follow-up procedures, thus reducing their anxiety and the burden on the health care system. This can, in turn, help improve the sustainability of lung cancer screening programs, particularly as more people become eligible for screening.

Reader specificity increases with ML model assistance in both the US-based and Japan-based reader studies. Specificity values were derived from reader scores from actionable findings (something suspicious was found) versus no actionable findings, compared against the true cancer outcome of the individual. Under model assistance, readers flagged fewer cancer-negative individuals for follow-up visits. Sensitivity for cancer positive individuals remained the same.

Translating this into real-world impact through partnership

The system results demonstrate the potential for fewer follow-up visits, reduced anxiety, as well lower overall costs for lung cancer screening. In an effort to translate this research into real-world clinical impact, we are working with: DeepHealth, a leading AI-powered health informatics provider; and Apollo Radiology International a leading provider of Radiology services in India to explore paths for incorporating this system into future products. In addition, we are looking to help other researchers studying how best to integrate ML model results into clinical workflows by open sourcing code used for the reader study and incorporating the insights described in this blog. We hope that this will help accelerate medical imaging researchers looking to conduct reader studies for their AI models, and catalyze translational research in the field.

Acknowledgements

Key contributors to this project include Corbin Cunningham, Zaid Nabulsi, Ryan Najafi, Jie Yang, Charles Lau, Joseph R. Ledsam, Wenxing Ye, Diego Ardila, Scott M. McKinney, Rory Pilgrim, Hiroaki Saito, Yasuteru Shimamura, Mozziyar Etemadi, Yun Liu, David Melnick, Sunny Jansen, Nadia Harhen, David P. Nadich, Mikhail Fomitchev, Ziyad Helali, Shabir Adeel, Greg S. Corrado, Lily Peng, Daniel Tse, Shravya Shetty, Shruthi Prabhakara, Neeral Beladia, and Krish Eswaran. Thanks to Arnav Agharwal and Andrew Sellergren for their open sourcing support and Vivek Natarajan and Michael D. Howell for their feedback. Sincere appreciation also goes to the radiologists who enabled this work with their image interpretation and annotation efforts throughout the study, and Jonny Wong and Carli Sampson for coordinating the reader studies.

Categories
Offsites

Using AI to expand global access to reliable flood forecasts

Floods are the most common natural disaster, and are responsible for roughly $50 billion in annual financial damages worldwide. The rate of flood-related disasters has more than doubled since the year 2000 partly due to climate change. Nearly 1.5 billion people, making up 19% of the world’s population, are exposed to substantial risks from severe flood events. Upgrading early warning systems to make accurate and timely information accessible to these populations can save thousands of lives per year.

Driven by the potential impact of reliable flood forecasting on people’s lives globally, we started our flood forecasting effort in 2017. Through this multi-year journey, we advanced research over the years hand-in-hand with building a real-time operational flood forecasting system that provides alerts on Google Search, Maps, Android notifications and through the Flood Hub. However, in order to scale globally, especially in places where accurate local data is not available, more research advances were required.

In “Global prediction of extreme floods in ungauged watersheds”, published in Nature, we demonstrate how machine learning (ML) technologies can significantly improve global-scale flood forecasting relative to the current state-of-the-art for countries where flood-related data is scarce. With these AI-based technologies we extended the reliability of currently-available global nowcasts, on average, from zero to five days, and improved forecasts across regions in Africa and Asia to be similar to what are currently available in Europe. The evaluation of the models was conducted in collaboration with the European Center for Medium Range Weather Forecasting (ECMWF).

These technologies also enable Flood Hub to provide real-time river forecasts up to seven days in advance, covering river reaches across over 80 countries. This information can be used by people, communities, governments and international organizations to take anticipatory action to help protect vulnerable populations.

Flood forecasting at Google

The ML models that power the FloodHub tool are the product of many years of research, conducted in collaboration with several partners, including academics, governments, international organizations, and NGOs.

In 2018, we launched a pilot early warning system in the Ganges-Brahmaputra river basin in India, with the hypothesis that ML could help address the challenging problem of reliable flood forecasting at scale. The pilot was further expanded the following year via the combination of an inundation model, real-time water level measurements, the creation of an elevation map and hydrologic modeling.

In collaboration with academics, and, in particular, with the JKU Institute for Machine Learning we explored ML-based hydrologic models, showing that LSTM-based models could produce more accurate simulations than traditional conceptual and physics-based hydrology models. This research led to flood forecasting improvements that enabled the expansion of our forecasting coverage to include all of India and Bangladesh. We also worked with researchers at Yale University to test technological interventions that increase the reach and impact of flood warnings.

Our hydrological models predict river floods by processing publicly available weather data like precipitation and physical watershed information. Such models must be calibrated to long data records from streamflow gauging stations in individual rivers. A low percentage of global river watersheds (basins) have streamflow gauges, which are expensive but necessary to supply relevant data, and it’s challenging for hydrological simulation and forecasting to provide predictions in basins that lack this infrastructure. Lower gross domestic product (GDP) is correlated with increased vulnerability to flood risks, and there is an inverse correlation between national GDP and the amount of publicly available data in a country. ML helps to address this problem by allowing a single model to be trained on all available river data and to be applied to ungauged basins where no data are available. In this way, models can be trained globally, and can make predictions for any river location.

There is an inverse (log-log) correlation between the amount of publicly available streamflow data in a country and national GDP. Streamflow data from the Global Runoff Data Center.

Our academic collaborations led to ML research that developed methods to estimate uncertainty in river forecasts and showed how ML river forecast models synthesize information from multiple data sources. They demonstrated that these models can simulate extreme events reliably, even when those events are not part of the training data. In an effort to contribute to open science, in 2023 we open-sourced a community-driven dataset for large-sample hydrology in Nature Scientific Data.

The river forecast model

Most hydrology models used by national and international agencies for flood forecasting and river modeling are state-space models, which depend only on daily inputs (e.g., precipitation, temperature, etc.) and the current state of the system (e.g., soil moisture, snowpack, etc.). LSTMs are a variant of state-space models and work by defining a neural network that represents a single time step, where input data (such as current weather conditions) are processed to produce updated state information and output values (streamflow) for that time step. LSTMs are applied sequentially to make time-series predictions, and in this sense, behave similarly to how scientists typically conceptualize hydrologic systems. Empirically, we have found that LSTMs perform well on the task of river forecasting.

A diagram of the LSTM, which is a neural network that operates sequentially in time. An accessible primer can be found here.

Our river forecast model uses two LSTMs applied sequentially: (1) a “hindcast” LSTM ingests historical weather data (dynamic hindcast features) up to the present time (or rather, the issue time of a forecast), and (2) a “forecast” LSTM ingests states from the hindcast LSTM along with forecasted weather data (dynamic forecast features) to make future predictions. One year of historical weather data are input into the hindcast LSTM, and seven days of forecasted weather data are input into the forecast LSTM. Static features include geographical and geophysical characteristics of watersheds that are input into both the hindcast and forecast LSTMs and allow the model to learn different hydrological behaviors and responses in various types of watersheds.

Output from the forecast LSTM is fed into a “head” layer that uses mixture density networks to produce a probabilistic forecast (i.e., predicted parameters of a probability distribution over streamflow). Specifically, the model predicts the parameters of a mixture of heavy-tailed probability density functions, called asymmetric Laplacian distributions, at each forecast time step. The result is a mixture density function, called a Countable Mixture of Asymmetric Laplacians (CMAL) distribution, which represents a probabilistic prediction of the volumetric flow rate in a particular river at a particular time.

LSTM-based river forecast model architecture. Two LSTMs are applied in sequence, one ingesting historical weather data and one ingesting forecasted weather data. The model outputs are the parameters of a probability distribution over streamflow at each forecasted timestep.

Input and training data

The model uses three types of publicly available data inputs, mostly from governmental sources:

  1. Static watershed attributes representing geographical and geophysical variables: From the HydroATLAS project, including data like long-term climate indexes (precipitation, temperature, snow fractions), land cover, and anthropogenic attributes (e.g., a nighttime lights index as a proxy for human development).
  2. Historical meteorological time-series data: Used to spin up the model for one year prior to the issue time of a forecast. The data comes from NASA IMERG, NOAA CPC Global Unified Gauge-Based Analysis of Daily Precipitation, and the ECMWF ERA5-land reanalysis. Variables include daily total precipitation, air temperature, solar and thermal radiation, snowfall, and surface pressure.
  3. Forecasted meteorological time series over a seven-day forecast horizon: Used as input for the forecast LSTM. These data are the same meteorological variables listed above, and come from the ECMWF HRES atmospheric model.

Training data are daily streamflow values from the Global Runoff Data Center over the time period 1980 – 2023. A single streamflow forecast model is trained using data from 5,680 diverse watershed streamflow gauges (shown below) to improve accuracy.

Location of 5,680 streamflow gauges that supply training data for the river forecast model from the Global Runoff Data Center.

Improving on the current state-of-the-art

We compared our river forecast model with GloFAS version 4, the current state-of-the-art global flood forecasting system. These experiments showed that ML can provide accurate warnings earlier and over larger and more impactful events.

The figure below shows the distribution of F1 scores when predicting different severity events at river locations around the world, with plus or minus 1 day accuracy. F1 scores are an average of precision and recall and event severity is measured by return period. For example, a 2-year return period event is a volume of streamflow that is expected to be exceeded on average once every two years. Our model achieves reliability scores at up to 4-day or 5-day lead times that are similar to or better, on average, than the reliability of GloFAS nowcasts (0-day lead time).

Distributions of F1 scores over 2-year return period events in 2,092 watersheds globally during the time period 2014-2023 from GloFAS (blue) and our model (orange) at different lead times. On average, our model is statistically as accurate as GloFAS nowcasts (0–day lead time) up to 5 days in advance over 2-year (shown) and 1-year, 5-year, and 10-year events (not shown).

Additionally (not shown), our model achieves accuracies over larger and rarer extreme events, with precision and recall scores over 5-year return period events that are similar to or better than GloFAS accuracies over 1-year return period events. See the paper for more information.

Looking into the future

The flood forecasting initiative is part of our Adaptation and Resilience efforts and reflects Google’s commitment to address climate change while helping global communities become more resilient. We believe that AI and ML will continue to play a critical role in helping advance science and research towards climate action.

We actively collaborate with several international aid organizations (e.g., the Centre for Humanitarian Data and the Red Cross) to provide actionable flood forecasts. Additionally, in an ongoing collaboration with the World Meteorological Organization (WMO) to support early warning systems for climate hazards, we are conducting a study to help understand how AI can help address real-world challenges faced by national flood forecasting agencies.

While the work presented here demonstrates a significant step forward in flood forecasting, future work is needed to further expand flood forecasting coverage to more locations globally and other types of flood-related events and disasters, including flash floods and urban floods. We are looking forward to continuing collaborations with our partners in the academic and expert communities, local governments and the industry to reach these goals.

Categories
Offsites

ScreenAI: A visual language model for UI and visually-situated language understanding

Screen user interfaces (UIs) and infographics, such as charts, diagrams and tables, play important roles in human communication and human-machine interaction as they facilitate rich and interactive user experiences. UIs and infographics share similar design principles and visual language (e.g., icons and layouts), that offer an opportunity to build a single model that can understand, reason, and interact with these interfaces. However, because of their complexity and varied presentation formats, infographics and UIs present a unique modeling challenge.

To that end, we introduce “ScreenAI: A Vision-Language Model for UI and Infographics Understanding”. ScreenAI improves upon the PaLI architecture with the flexible patching strategy from pix2struct. We train ScreenAI on a unique mixture of datasets and tasks, including a novel Screen Annotation task that requires the model to identify UI element information (i.e., type, location and description) on a screen. These text annotations provide large language models (LLMs) with screen descriptions, enabling them to automatically generate question-answering (QA), UI navigation, and summarization training datasets at scale. At only 5B parameters, ScreenAI achieves state-of-the-art results on UI- and infographic-based tasks (WebSRC and MoTIF), and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. We are also releasing three new datasets: Screen Annotation to evaluate the layout understanding capability of the model, as well as ScreenQA Short and Complex ScreenQA for a more comprehensive evaluation of its QA capability.

ScreenAI

ScreenAI’s architecture is based on PaLI, composed of a multimodal encoder block and an autoregressive decoder. The PaLI encoder uses a vision transformer (ViT) that creates image embeddings and a multimodal encoder that takes the concatenation of the image and text embeddings as input. This flexible architecture allows ScreenAI to solve vision tasks that can be recast as text+image-to-text problems.

On top of the PaLI architecture, we employ a flexible patching strategy introduced in pix2struct. Instead of using a fixed-grid pattern, the grid dimensions are selected such that they preserve the native aspect ratio of the input image. This enables ScreenAI to work well across images of various aspect ratios.

The ScreenAI model is trained in two stages: a pre-training stage followed by a fine-tuning stage. First, self-supervised learning is applied to automatically generate data labels, which are then used to train ViT and the language model. ViT is frozen during the fine-tuning stage, where most data used is manually labeled by human raters.

ScreenAI model architecture.

Data generation

To create a pre-training dataset for ScreenAI, we first compile an extensive collection of screenshots from various devices, including desktops, mobile, and tablets. This is achieved by using publicly accessible web pages and following the programmatic exploration approach used for the RICO dataset for mobile apps. We then apply a layout annotator, based on the DETR model, that identifies and labels a wide range of UI elements (e.g., image, pictogram, button, text) and their spatial relationships. Pictograms undergo further analysis using an icon classifier capable of distinguishing 77 different icon types. This detailed classification is essential for interpreting the subtle information conveyed through icons. For icons that are not covered by the classifier, and for infographics and images, we use the PaLI image captioning model to generate descriptive captions that provide contextual information. We also apply an optical character recognition (OCR) engine to extract and annotate textual content on screen. We combine the OCR text with the previous annotations to create a detailed description of each screen.

A mobile app screenshot with generated annotations that include UI elements and their descriptions, e.g., TEXT elements also contain the text content from OCR, IMAGE elements contain image captions, LIST_ITEMs contain all their child elements.

LLM-based data generation

We enhance the pre-training data’s diversity using PaLM 2 to generate input-output pairs in a two-step process. First, screen annotations are generated using the technique outlined above, then we craft a prompt around this schema for the LLM to create synthetic data. This process requires prompt engineering and iterative refinement to find an effective prompt. We assess the generated data’s quality through human validation against a quality threshold.

You only speak JSON. Do not write text that isn’t JSON.
You are given the following mobile screenshot, described in words. Can you generate 5 questions regarding the content of the screenshot as well as the corresponding short answers to them? 

The answer should be as short as possible, containing only the necessary information. Your answer should be structured as follows:
questions: [
{{question: the question,
    answer: the answer
}},
 ...
]

{THE SCREEN SCHEMA}

A sample prompt for QA data generation.

By combining the natural language capabilities of LLMs with a structured schema, we simulate a wide range of user interactions and scenarios to generate synthetic, realistic tasks. In particular, we generate three categories of tasks:

  • Question answering: The model is asked to answer questions regarding the content of the screenshots, e.g., “When does the restaurant open?”
  • Screen navigation: The model is asked to convert a natural language utterance into an executable action on a screen, e.g., “Click the search button.”
  • Screen summarization: The model is asked to summarize the screen content in one or two sentences.
Block diagram of our workflow for generating data for QA, summarization and navigation tasks using existing ScreenAI models and LLMs. Each task uses a custom prompt to emphasize desired aspects, like questions related to counting, involving reasoning, etc.

LLM-generated data. Examples for screen QA, navigation and summarization. For navigation, the action bounding box is displayed in red on the screenshot.

Experiments and results

As previously mentioned, ScreenAI is trained in two stages: pre-training and fine-tuning. Pre-training data labels are obtained using self-supervised learning and fine-tuning data labels comes from human raters.

We fine-tune ScreenAI using public QA, summarization, and navigation datasets and a variety of tasks related to UIs. For QA, we use well established benchmarks in the multimodal and document understanding field, such as ChartQA, DocVQA, Multi page DocVQA, InfographicVQA, OCR VQA, Web SRC and ScreenQA. For navigation, datasets used include Referring Expressions, MoTIF, Mug, and Android in the Wild. Finally, we use Screen2Words for screen summarization and Widget Captioning for describing specific UI elements. Along with the fine-tuning datasets, we evaluate the fine-tuned ScreenAI model using three novel benchmarks:

  1. Screen Annotation: Enables the evaluation model layout annotations and spatial understanding capabilities.
  2. ScreenQA Short: A variation of ScreenQA, where its ground truth answers have been shortened to contain only the relevant information that better aligns with other QA tasks.
  3. Complex ScreenQA: Complements ScreenQA Short with more difficult questions (counting, arithmetic, comparison, and non-answerable questions) and contains screens with various aspect ratios.

The fine-tuned ScreenAI model achieves state-of-the-art results on various UI and infographic-based tasks (WebSRC and MoTIF) and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. ScreenAI achieves competitive performance on Screen2Words and OCR-VQA. Additionally, we report results on the new benchmark datasets introduced to serve as a baseline for further research.

Comparing model performance of ScreenAI with state-of-the-art (SOTA) models of similar size.

Next, we examine ScreenAI’s scaling capabilities and observe that across all tasks, increasing the model size improves performances and the improvements have not saturated at the largest size.

Model performance increases with size, and the performance has not saturated even at the largest size of 5B params.

Conclusion

We introduce the ScreenAI model along with a unified representation that enables us to develop self-supervised learning tasks leveraging data from all these domains. We also illustrate the impact of data generation using LLMs and investigate improving model performance on specific aspects with modifying the training mixture. We apply all of these techniques to build multi-task trained models that perform competitively with state-of-the-art approaches on a number of public benchmarks. However, we also note that our approach still lags behind large models and further research is needed to bridge this gap.

Acknowledgements

This project is the result of joint work with Maria Wang, Fedir Zubach, Hassan Mansoor, Vincent Etter, Victor Carbune, Jason Lin, Jindong Chen and Abhanshu Sharma. We thank Fangyu Liu, Xi Chen, Efi Kokiopoulou, Jesse Berent, Gabriel Barcik, Lukas Zilka, Oriana Riva, Gang Li,Yang Li, Radu Soricut, and Tania Bedrax-Weiss for their insightful feedback and discussions, along with Rahul Aralikatte, Hao Cheng and Daniel Kim for their support in data preparation. We also thank Jay Yagnik, Blaise Aguera y Arcas, Ewa Dominowska, David Petrou, and Matt Sharifi for their leadership, vision and support. We are very grateful toTom Small for helping us create the animation in this post.

Categories
Offsites

MediaPipe FaceStylizer: On-device real-time few-shot face stylization


In recent years, we have witnessed rising interest across consumers and researchers in integrated augmented reality (AR) experiences using real-time face feature generation and editing functions in mobile applications, including short videos, virtual reality, and gaming. As a result, there is a growing demand for lightweight, yet high-quality face generation and editing models, which are often based on generative adversarial network (GAN) techniques. However, the majority of GAN models suffer from high computational complexity and the need for a large training dataset. In addition, it is also important to employ GAN models responsibly.

In this post, we introduce MediaPipe FaceStylizer, an efficient design for few-shot face stylization that addresses the aforementioned model complexity and data efficiency challenges while being guided by Google’s responsible AI Principles. The model consists of a face generator and a face encoder used as GAN inversion to map the image into latent code for the generator. We introduce a mobile-friendly synthesis network for the face generator with an auxiliary head that converts features to RGB at each level of the generator to generate high quality images from coarse to fine granularities. We also carefully designed the loss functions for the aforementioned auxiliary heads and combined them with the common GAN loss functions to distill the student generator from the teacher StyleGAN model, resulting in a lightweight model that maintains high generation quality. The proposed solution is available in open source through MediaPipe. Users can fine-tune the generator to learn a style from one or a few images using MediaPipe Model Maker, and deploy to on-device face stylization applications with the customized model using MediaPipe FaceStylizer.

Few-shot on-device face stylization

An end-to-end pipeline

Our goal is to build a pipeline to support users to adapt the MediaPipe FaceStylizer to different styles by fine-tuning the model with a few examples. To enable such a face stylization pipeline, we built the pipeline with a GAN inversion encoder and efficient face generator model (see below). The encoder and generator pipeline can then be adapted to different styles via a few-shot learning process. The user first sends a single or a few similar samples of the style images to MediaPipe ModelMaker to fine-tune the model. The fine-tuning process freezes the encoder module and only fine-tunes the generator. The training process samples multiple latent codes close to the encoding output of the input style images as the input to the generator. The generator is then trained to reconstruct an image of a person’s face in the style of the input style image by optimizing a joint adversarial loss function that also accounts for style and content. With such a fine-tuning process, the MediaPipe FaceStylizer can adapt to the customized style, which approximates the user’s input. It can then be applied to stylize test images of real human faces.

Generator: BlazeStyleGAN

The StyleGAN model family has been widely adopted for face generation and various face editing tasks. To support efficient on-device face generation, we based the design of our generator on StyleGAN. This generator, which we call BlazeStyleGAN, is similar to StyleGAN in that it also contains a mapping network and synthesis network. However, since the synthesis network of StyleGAN is the major contributor to the model’s high computation complexity, we designed and employed a more efficient synthesis network. The improved efficiency and generation quality is achieved by:

  1. Reducing the latent feature dimension in the synthesis network to a quarter of the resolution of the counterpart layers in the teacher StyleGAN,
  2. Designing multiple auxiliary heads to transform the downscaled feature to the image domain to form a coarse-to-fine image pyramid to evaluate the perceptual quality of the reconstruction, and
  3. Skipping all but the final auxiliary head at inference time.

With the newly designed architecture, we train the BlazeStyleGAN model by distilling it from a teacher StyleGAN model. We use a multi-scale perceptual loss and adversarial loss in the distillation to transfer the high fidelity generation capability from the teacher model to the student BlazeStyleGAN model and also to mitigate the artifacts from the teacher model.

More details of the model architecture and training scheme can be found in our paper.

Visual comparison between face samples generated by StyleGAN and BlazeStyleGAN. The images on the first row are generated by the teacher StyleGAN. The images on the second row are generated by the student BlazeStyleGAN. The face generated by BlazeStyleGAN has similar visual quality to the image generated by the teacher model. Some results demonstrate the student BlazeStyleGAN suppresses the artifacts from the teacher model in the distillation.

In the above figure, we demonstrate some sample results of our BlazeStyleGAN. By comparing with the face image generated by the teacher StyleGAN model (top row), the images generated by the student BlazeStyleGAN (bottom row) maintain high visual quality and further reduce artifacts produced by the teacher due to the loss function design in our distillation.

An encoder for efficient GAN inversion

To support image-to-image stylization, we also introduced an efficient GAN inversion as the encoder to map input images to the latent space of the generator. The encoder is defined by a MobileNet V2 backbone and trained with natural face images. The loss is defined as a combination of image perceptual quality loss, which measures the content difference, style similarity and embedding distance, as well as the L1 loss between the input images and reconstructed images.

On-device performance

We documented model complexities in terms of parameter numbers and computing FLOPs in the following table. Compared to the teacher StyleGAN (33.2M parameters), BlazeStyleGAN (generator) significantly reduces the model complexity, with only 2.01M parameters and 1.28G FLOPs for output resolution 256×256. Compared to StyleGAN-1024 (generating image size of 1024×1024), the BlazeStyleGAN-1024 can reduce both model size and computation complexity by 95% with no notable quality difference and can even suppress the artifacts from the teacher StyleGAN model.

Model     Image Size     #Params (M)     FLOPs (G)
StyleGAN     1024     33.17     74.3
BlazeStyleGAN     1024     2.07     4.70
BlazeStyleGAN     512     2.05     1.57
BlazeStyleGAN     256     2.01     1.28
Encoder     256     1.44     0.60
Model complexity measured by parameter numbers and FLOPs.

We benchmarked the inference time of the MediaPipe FaceStylizer on various high-end mobile devices and demonstrated the results in the table below. From the results, both BlazeStyleGAN-256 and BlazeStyleGAN-512 achieved real-time performance on all GPU devices. It can run in less than 10 ms runtime on a high-end phone’s GPU. BlazeStyleGAN-256 can also achieve real-time performance on the iOS devices’ CPU.

Model     BlazeStyleGAN-256 (ms)     Encoder-256 (ms)
iPhone 11     12.14     11.48
iPhone 12     11.99     12.25
iPhone 13 Pro     7.22     5.41
Pixel 6     12.24     11.23
Samsung Galaxy S10     17.01     12.70
Samsung Galaxy S20     8.95     8.20
Latency benchmark of the BlazeStyleGAN, face encoder, and the end-to-end pipeline on various mobile devices.

Fairness evaluation

The model has been trained with a high diversity dataset of human faces. The model is expected to be fair to different human faces. The fairness evaluation demonstrates the model performs good and balanced in terms of human gender, skin-tone, and ages.

Face stylization visualization

Some face stylization results are demonstrated in the following figure. The images in the top row (in orange boxes) represent the style images used to fine-tune the model. The images in the left column (in the green boxes) are the natural face images used for testing. The 2×4 matrix of images represents the output of the MediaPipe FaceStylizer which is blending outputs between the natural faces on the left-most column and the corresponding face styles on the top row. The results demonstrate that our solution can achieve high-quality face stylization for several popular styles.

Sample results of our MediaPipe FaceStylizer.

MediaPipe Solutions

The MediaPipe FaceStylizer is going to be released to public users in MediaPipe Solutions. Users can leverage MediaPipe Model Maker to train a customized face stylization model using their own style images. After training, the exported bundle of TFLite model files can be deployed to applications across platforms (Android, iOS, Web, Python, etc.) using the MediaPipe Tasks FaceStylizer API in just a few lines of code.

Acknowledgements

This work is made possible through a collaboration spanning several teams across Google. We’d like to acknowledge contributions from Omer Tov, Yang Zhao, Andrey Vakunov, Fei Deng, Ariel Ephrat, Inbar Mosseri, Lu Wang, Chuo-Ling Chang, Tingbo Hou, and Matthias Grundmann.

Categories
Offsites

On-device content distillation with graph neural networks

In today’s digital age, smartphones and desktop web browsers serve as the primary tools for accessing news and information. However, the proliferation of website clutter — encompassing complex layouts, navigation elements, and extraneous links — significantly impairs both the reading experience and article navigation. This issue is particularly acute for individuals with accessibility requirements.

To improve the user experience and make reading more accessible, Android and Chrome users may leverage the Reading Mode feature, which enhances accessibility by processing webpages to allow customizable contrast, adjustable text size, more legible fonts, and to enable text-to-speech utilities. Additionally, Android’s Reading Mode is equipped to distill content from apps. Expanding Reading Mode to encompass a wide array of content and improving its performance, while still operating locally on the user’s device without transmitting data externally, poses a unique challenge.

To broaden Reading Mode capabilities without compromising privacy, we have developed a novel on-device content distillation model. Unlike early attempts using DOM Distiller — a heuristic approach limited to news articles — our model excels in both quality and versatility across various types of content. We ensure that article content doesn’t leave the confines of the local environment. Our on-device content distillation model smoothly transforms long-form content into a simple and customizable layout for a more pleasant reading journey while also outperforming the leading alternative approaches. Here we explore details of this research highlighting our approach, methodology, and results.

Graph neural networks

Instead of relying on complicated heuristics that are difficult to maintain and scale to a variety of article layouts, we approach this task as a fully supervised learning problem. This data-driven approach allows the model to generalize better across different layouts, without the constraints and fragility of heuristics. Previous work for optimizing the reading experience relied on HTML or parsing, filtering, and modeling of a document object model (DOM), a programming interface automatically generated by the user’s web browser from site HTML that represents the structure of a document and allows it to be manipulated.

The new Reading Mode model relies on accessibility trees, which provide a streamlined and more accessible representation of the DOM. Accessibility trees are automatically generated from the DOM tree and are utilized by assistive technologies to allow people with disabilities to interact with web content. These are available on Chrome Web browser and on Android through AccessibilityNodeInfo objects, which are provided for both WebView and native application content.

We started by manually collecting and annotating accessibility trees. The Android dataset used for this project comprises on the order of 10k labeled examples, while the Chrome dataset contains approximately 100k labeled examples. We developed a novel tool that uses graph neural networks (GNNs) to distill essential content from the accessibility trees using a multi-class supervised learning approach. The datasets consist of long-form articles sampled from the web and labeled with classes such as headline, paragraph, images, publication date, etc.

GNNs are a natural choice for dealing with tree-like data structures, because unlike traditional models that often demand detailed, hand-crafted features to understand the layout and links within such trees, GNNs learn these connections naturally. To illustrate this, consider the analogy of a family tree. In such a tree, each node represents a family member and the connections denote familial relationships. If one were to predict certain traits using conventional models, features like the “number of immediate family members with a trait” might be needed. However, with GNNs, such manual feature crafting becomes redundant. By directly feeding the tree structure into the model, GNNs utilize a message-passing mechanism where each node communicates with its neighbors. Over time, information gets shared and accumulated across the network, enabling the model to naturally discern intricate relationships.

Returning to the context of accessibility trees, this means that GNNs can efficiently distill content by understanding and leveraging the inherent structure and relationships within the tree. This capability allows them to identify and possibly omit non-essential sections based on the information flow within the tree, ensuring more accurate content distillation.

Our architecture heavily follows the encode-process-decode paradigm using a message-passing neural network to classify text nodes. The overall design is illustrated in the figure below. The tree representation of the article is the input to the model. We compute lightweight features based on bounding box information, text information, and accessibility roles. The GNN then propagates each node’s latent representation through the edges of the tree using a message-passing neural network. This propagation process allows nearby nodes, containers, and text elements to share contextual information with each other, enhancing the model’s understanding of the page’s structure and content. Each node then updates its current state based on the message received, providing a more informed basis for classifying the nodes. After a fixed number of message-passing steps, the now contextualized latent representations of the nodes are decoded into essential or non-essential classes. This approach enables the model to leverage both the inherent relationships in the tree and the hand-crafted features representing each node, thereby enriching the final classification.

A visual demonstration of the algorithm in action, processing an article on a mobile device. A graph neural network (GNN) is used to distill essential content from an article. 1. A tree representation of the article is extracted from the application. 2. Lightweight features are computed for each node, represented as vectors. 3. A message-passing neural network propagates information through the edges of the tree and updates each node representation. 4. Leaf nodes containing text content are classified as essential or non-essential content. 5. A decluttered version of the application is composed based on the GNN output.

We deliberately restrict the feature set used by the model to increase its broad generalization across languages and speed up inference latency on user devices. This was a unique challenge, as we needed to create an on-device lightweight model that could preserve privacy.

Our final lightweight Android model has 64k parameters and is 334kB in size with a median latency of 800ms, while the Chrome model has 241k parameters, is 928kB in size, and has a 378ms median latency. By employing such on-device processing, we ensure that user data never leaves the device, reinforcing our responsible approach and commitment to user privacy. The features used in the model can be grouped into intermediate node features, leaf-node text features, and element position features. We performed feature engineering and feature selection to optimize the set of features for model performance and model size. The final model was transformed into TensorFlow Lite format to deploy as an on-device model on Android or Chrome.

Results

We trained the GNN for about 50 epochs in a single GPU. The performance of the Android model on webpages and native application test sets is presented below:

The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

<!–

Android Quality
    Webpages     Native Apps
node metrics     Precision     Recall     F1-score     Precision     Recall     F1-score
non-essential     0.9842     0.9846     0.9844     0.9744     0.9350     0.9543
headline     0.9187     0.9784     0.9476     0.9183     0.8568     0.8865
main-text     0.9223     0.9172     0.9197     0.8443     0.9424     0.8907
macro-average     0.9417     0.9600     0.9506     0.9124     0.9114     0.9105
weighted average     0.9736     0.9736     0.9736     0.9392     0.9353     0.9363
headline + main-text     0.9510     0.9683     0.9595     0.9473     0.9507     0.9490
The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

–>

In assessing the results’ quality on commonly visited webpage articles, an F1-score exceeding 0.9 for main-text (essentially paragraphs) corresponds to 88% of these articles being processed without missing any paragraphs. Furthermore, in over 95% of cases, the distillation proves to be valuable for readers. Put simply, the vast majority of readers will perceive the distilled content as both pertinent and precise, with errors or omissions being an infrequent occurrence.

The comparison of Chrome content distillation with other models such as DOM Distiller or Mozilla Readability on a set of English language pages is presented in the table below. We reuse the metrics from machine translation to compare the quality of these models. The reference text is from the groundtruth main content and the text from the models as hypothesis text. The results show the excellent performance of our models in comparison to other DOM-based approaches.

The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

<!–

Chrome Model Comparison on Webpages
Metric / Model     DOM Distiller     Mozilla Readability     Our Chrome model
BLEU     78.97     79.16     94.59
CHRF     0.92     0.92     0.98
ROUGE1     84.10     84.62     95.13
ROUGE2     81.84     82.66     94.81
ROUGE3     80.21     81.45     94.60
ROUGEL     83.58     84.02     95.04
ROUGEL-SUM     83.46     84.03     95.04
The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

–>

The F1-score of the Chrome content distillation model for headline and main text content on the test sets of different widely spoken languages demonstrates that the Chrome model, in particular, is able to support a wide range of languages.

The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

<!–

Chrome Model on Different Languages
F1-score     de     en     es     fr     it     fa     ja     ko     pt     vi     zh-Hans     zh-Hant     average
headline     0.91     0.97     0.99     0.98     0.97     0.89     0.97     0.98     0.99     0.98     0.97     0.93     0.96
main text     0.84     0.90     0.93     0.91     0.93     0.87     0.88     0.91     0.91     0.90     0.90     0.90     0.90
The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

–>

Conclusion

The digital age demands both streamlined content presentation and an unwavering commitment to user privacy. Our research highlights the effectiveness of Reading Mode in platforms like Android and Chrome, offering an innovative, data-driven approach to content parsing through Graph Neural Networks. Crucially, our lightweight on-device model ensures that content distillation occurs without compromising user data, with all processes executed locally. This not only enhances the reading experience but also reinforces our dedication to user privacy. As we navigate the evolving landscape of digital content consumption, our findings underscore the paramount importance of prioritizing the user in both experience and security.

Acknowledgements

This project is the result of joint work with Manuel Tragut, Mihai Popa, Abodunrinwa Toki, Abhanshu Sharma, Matt Sharifi, David Petrou and Blaise Aguera y Arcas. We sincerely thank our collaborators Gang Li and Yang Li. We are very grateful to Tom Small for assisting us in preparing the post.

Categories
Offsites

World scale inverse reinforcement learning in Google Maps

Routing in Google Maps remains one of our most helpful and frequently used features. Determining the best route from A to B requires making complex trade-offs between factors including the estimated time of arrival (ETA), tolls, directness, surface conditions (e.g., paved, unpaved roads), and user preferences, which vary across transportation mode and local geography. Often, the most natural visibility we have into travelers’ preferences is by analyzing real-world travel patterns.

Learning preferences from observed sequential decision making behavior is a classic application of inverse reinforcement learning (IRL). Given a Markov decision process (MDP) — a formalization of the road network — and a set of demonstration trajectories (the traveled routes), the goal of IRL is to recover the users’ latent reward function. Although past research has created increasingly general IRL solutions, these have not been successfully scaled to world-sized MDPs. Scaling IRL algorithms is challenging because they typically require solving an RL subroutine at every update step. At first glance, even attempting to fit a world-scale MDP into memory to compute a single gradient step appears infeasible due to the large number of road segments and limited high bandwidth memory. When applying IRL to routing, one needs to consider all reasonable routes between each demonstration’s origin and destination. This implies that any attempt to break the world-scale MDP into smaller components cannot consider components smaller than a metropolitan area.

To this end, in “Massively Scalable Inverse Reinforcement Learning in Google Maps“, we share the result of a multi-year collaboration among Google Research, Maps, and Google DeepMind to surpass this IRL scalability limitation. We revisit classic algorithms in this space, and introduce advances in graph compression and parallelization, along with a new IRL algorithm called Receding Horizon Inverse Planning (RHIP) that provides fine-grained control over performance trade-offs. The final RHIP policy achieves a 16–24% relative improvement in global route match rate, i.e., the percentage of de-identified traveled routes that exactly match the suggested route in Google Maps. To the best of our knowledge, this represents the largest instance of IRL in a real world setting to date.

Google Maps improvements in route match rate relative to the existing baseline, when using the RHIP inverse reinforcement learning policy.

The benefits of IRL

A subtle but crucial detail about the routing problem is that it is goal conditioned, meaning that every destination state induces a slightly different MDP (specifically, the destination is a terminal, zero-reward state). IRL approaches are well suited for these types of problems because the learned reward function transfers across MDPs, and only the destination state is modified. This is in contrast to approaches that directly learn a policy, which typically require an extra factor of S parameters, where S is the number of MDP states.

Once the reward function is learned via IRL, we take advantage of a powerful inference-time trick. First, we evaluate the entire graph’s rewards once in an offline batch setting. This computation is performed entirely on servers without access to individual trips, and operates only over batches of road segments in the graph. Then, we save the results to an in-memory database and use a fast online graph search algorithm to find the highest reward path for routing requests between any origin and destination. This circumvents the need to perform online inference of a deeply parameterized model or policy, and vastly improves serving costs and latency.

Reward model deployment using batch inference and fast online planners.

Receding Horizon Inverse Planning

To scale IRL to the world MDP, we compress the graph and shard the global MDP using a sparse Mixture of Experts (MoE) based on geographic regions. We then apply classic IRL algorithms to solve the local MDPs, estimate the loss, and send gradients back to the MoE. The worldwide reward graph is computed by decompressing the final MoE reward model. To provide more control over performance characteristics, we introduce a new generalized IRL algorithm called Receding Horizon Inverse Planning (RHIP).

IRL reward model training using MoE parallelization, graph compression, and RHIP.

RHIP is inspired by people’s tendency to perform extensive local planning (“What am I doing for the next hour?”) and approximate long-term planning (“What will my life look like in 5 years?”). To take advantage of this insight, RHIP uses robust yet expensive stochastic policies in the local region surrounding the demonstration path, and switches to cheaper deterministic planners beyond some horizon. Adjusting the horizon H allows controlling computational costs, and often allows the discovery of the performance sweet spot. Interestingly, RHIP generalizes many classic IRL algorithms and provides the novel insight that they can be viewed along a stochastic vs. deterministic spectrum (specifically, for H=∞ it reduces to MaxEnt, for H=1 it reduces to BIRL, and for H=0 it reduces to MMP).

Given a demonstration from so to sd, (1) RHIP follows a robust yet expensive stochastic policy in the local region surrounding the demonstration (blue region). (2) Beyond some horizon H, RHIP switches to following a cheaper deterministic planner (red lines). Adjusting the horizon enables fine-grained control over performance and computational costs.

Routing wins

The RHIP policy provides a 15.9% and 24.1% lift in global route match rate for driving and two-wheelers (e.g., scooters, motorcycles, mopeds) relative to the well-tuned Maps baseline, respectively. We’re especially excited about the benefits to more sustainable transportation modes, where factors beyond journey time play a substantial role. By tuning RHIP’s horizon H, we’re able to achieve a policy that is both more accurate than all other IRL policies and 70% faster than MaxEnt.

Our 360M parameter reward model provides intuitive wins for Google Maps users in live A/B experiments. Examining road segments with a large absolute difference between the learned rewards and the baseline rewards can help improve certain Google Maps routes. For example:

Nottingham, UK. The preferred route (blue) was previously marked as private property due to the presence of a large gate, which indicated to our systems that the road may be closed at times and would not be ideal for drivers. As a result, Google Maps routed drivers through a longer, alternate detour instead (red). However, because real-world driving patterns showed that users regularly take the preferred route without an issue (as the gate is almost never closed), IRL now learns to route drivers along the preferred route by placing a large positive reward on this road segment.

Conclusion

Increasing performance via increased scale – both in terms of dataset size and model complexity – has proven to be a persistent trend in machine learning. Similar gains for inverse reinforcement learning problems have historically remained elusive, largely due to the challenges with handling practically sized MDPs. By introducing scalability advancements to classic IRL algorithms, we’re now able to train reward models on problems with hundreds of millions of states, demonstration trajectories, and model parameters, respectively. To the best of our knowledge, this is the largest instance of IRL in a real-world setting to date. See the paper to learn more about this work.

Acknowledgements

This work is a collaboration across multiple teams at Google. Contributors to the project include Matthew Abueg, Oliver Lange, Matt Deeds, Jason Trader, Denali Molitor, Markus Wulfmeier, Shawn O’Banion, Ryan Epp, Renaud Hartert, Rui Song, Thomas Sharp, Rémi Robert, Zoltan Szego, Beth Luan, Brit Larabee and Agnieszka Madurska.

We’d also like to extend our thanks to Arno Eigenwillig, Jacob Moorman, Jonathan Spencer, Remi Munos, Michael Bloesch and Arun Ahuja for valuable discussions and suggestions.

Categories
Offsites

Differentially private median and more

Differential privacy (DP) is a rigorous mathematical definition of privacy. DP algorithms are randomized to protect user data by ensuring that the probability of any particular output is nearly unchanged when a data point is added or removed. Therefore, the output of a DP algorithm does not disclose the presence of any one data point. There has been significant progress in both foundational research and adoption of differential privacy with contributions such as the Privacy Sandbox and Google Open Source Library.

ML and data analytics algorithms can often be described as performing multiple basic computation steps on the same dataset. When each such step is differentially private, so is the output, but with multiple steps the overall privacy guarantee deteriorates, a phenomenon known as the cost of composition. Composition theorems bound the increase in privacy loss with the number k of computations: In the general case, the privacy loss increases with the square root of k. This means that we need much stricter privacy guarantees for each step in order to meet our overall privacy guarantee goal. But in that case, we lose utility. One way to improve the privacy vs. utility trade-off is to identify when the use cases admit a tighter privacy analysis than what follows from composition theorems.

Good candidates for such improvement are when each step is applied to a disjoint part (slice) of the dataset. When the slices are selected in a data-independent way, each point affects only one of the k outputs and the privacy guarantees do not deteriorate with k. However, there are applications in which we need to select the slices adaptively (that is, in a way that depends on the output of prior steps). In these cases, a change of a single data point may cascade — changing multiple slices and thus increasing composition cost.

In “Õptimal Differentially Private Learning of Thresholds and Quasi-Concave Optimization”, presented at STOC 2023, we describe a new paradigm that allows for slices to be selected adaptively and yet avoids composition cost. We show that DP algorithms for multiple fundamental aggregation and learning tasks can be expressed in this Reorder-Slice-Compute (RSC) paradigm, gaining significant improvements in utility.

The Reorder-Slice-Compute (RSC) paradigm

An algorithm A falls in the RSC paradigm if it can be expressed in the following general form (see visualization below). The input is a sensitive set D of data points. The algorithm then performs a sequence of k steps as follows:

  1. Select an ordering over data points, a slice size m, and a DP algorithm M. The selection may depend on the output of A in prior steps (and hence is adaptive).
  2. Slice out the (approximately) top m data points according to the order from the dataset D, apply M to the slice, and output the result.
A visualization of three Reorder-Slice-Compute (RSC) steps.

If we analyze the overall privacy loss of an RSC algorithm using DP composition theorems, the privacy guarantee suffers from the expected composition cost, i.e., it deteriorates with the square root of the number of steps k. To eliminate this composition cost, we provide a novel analysis that removes the dependence on k altogether: the overall privacy guarantee is close to that of a single step! The idea behind our tighter analysis is a novel technique that limits the potential cascade of affected steps when a single data point is modified (details in the paper).

Tighter privacy analysis means better utility. The effectiveness of DP algorithms is often stated in terms of the smallest input size (number of data points) that suffices in order to release a correct result that meets the privacy requirements. We describe several problems with algorithms that can be expressed in the RSC paradigm and for which our tighter analysis improved utility.

Private interval point

We start with the following basic aggregation task. The input is a dataset D of n points from an ordered domain X (think of the domain as the natural numbers between 1 and |X|). The goal is to return a point y in X that is in the interval of D, that is between the minimum and the maximum points in D.

The solution to the interval point problem is trivial without the privacy requirement: simply return any point in the dataset D. But this solution is not privacy-preserving as it discloses the presence of a particular datapoint in the input. We can also see that if there is only one point in the dataset, a privacy-preserving solution is not possible, as it must return that point. We can therefore ask the following fundamental question: What is the smallest input size N for which we can solve the private interval point problem?

It is known that N must increase with the domain size |X| and that this dependence is at least the iterated log function log* |X| [1, 2]. On the other hand, the best prior DP algorithm required the input size to be at least (log* |X|)1.5. To close this gap, we designed an RSC algorithm that requires only an order of log* |X| points.

The iterated log function is extremely slow growing: It is the number of times we need to take a logarithm of a value before we reach a value that is equal to or smaller than 1. How did this function naturally come out in the analysis? Each step of the RSC algorithm remapped the domain to a logarithm of its prior size. Therefore there were log* |X| steps in total. The tighter RSC analysis eliminated a square root of the number of steps from the required input size.

Even though the interval point task seems very basic, it captures the essence of the difficulty of private solutions for common aggregation tasks. We next describe two of these tasks and express the required input size to these tasks in terms of N.

Private approximate median

One of these common aggregation tasks is approximate median: The input is a dataset D of n points from an ordered domain X. The goal is to return a point y that is between the ⅓ and ⅔ quantiles of D. That is, at least a third of the points in D are smaller or equal to y and at least a third of the points are larger or equal to y. Note that returning an exact median is not possible with differential privacy, since it discloses the presence of a datapoint. Hence we consider the relaxed requirement of an approximate median (shown below).

We can compute an approximate median by finding an interval point: We slice out the N smallest points and the N largest points and then compute an interval point of the remaining points. The latter must be an approximate median. This works when the dataset size is at least 3N.

An example of a data D over domain X, the set of interval points, and the set of approximate medians.

Private learning of axis-aligned rectangles

For the next task, the input is a set of n labeled data points, where each point x = (x1,….,xd) is a d-dimensional vector over a domain X. Displayed below, the goal is to learn values ai , bi for the axes i=1,…,d that define a d-dimensional rectangle, so that for each example x

  • If x is positively labeled (shown as red plus signs below) then it lies within the rectangle, that is, for all axes i, xi is in the interval [ai ,bi], and
  • If x is negatively labeled (shown as blue minus signs below) then it lies outside the rectangle, that is, for at least one axis i, xi is outside the interval [ai ,bi].
A set of 2-dimensional labeled points and a respective rectangle.

Any DP solution for this problem must be approximate in that the learned rectangle must be allowed to mislabel some data points, with some positively labeled points outside the rectangle or negatively labeled points inside it. This is because an exact solution could be very sensitive to the presence of a particular data point and would not be private. The goal is a DP solution that keeps this necessary number of mislabeled points small.

We first consider the one-dimensional case (d = 1). We are looking for an interval [a,b] that covers all positive points and none of the negative points. We show that we can do this with at most 2N mislabeled points. We focus on the positively labeled points. In the first RSC step we slice out the N smallest points and compute a private interval point as a. We then slice out the N largest points and compute a private interval point as b. The solution [a,b] correctly labels all negatively labeled points and mislabels at most 2N of the positively labeled points. Thus, at most ~2N points are mislabeled in total.

Illustration for d = 1, we slice out N left positive points and compute an interval point a, slice out N right positive points and compute an interval point b.

With d > 1, we iterate over the axes i = 1,….,d and apply the above for the ith coordinates of input points to obtain the values ai , bi. In each iteration, we perform two RSC steps and slice out 2N positively labeled points. In total, we slice out 2dN points and all remaining points were correctly labeled. That is, all negatively-labeled points are outside the final d-dimensional rectangle and all positively-labeled points, except perhaps ~2dN, lie inside the rectangle. Note that this algorithm uses the full flexibility of RSC in that the points are ordered differently by each axis. Since we perform d steps, the RSC analysis shaves off a factor of square root of d from the number of mislabeled points.

Training ML models with adaptive selection of training examples

The training efficiency or performance of ML models can sometimes be improved by selecting training examples in a way that depends on the current state of the model, e.g., self-paced curriculum learning or active learning.

The most common method for private training of ML models is DP-SGD, where noise is added to the gradient update from each minibatch of training examples. Privacy analysis with DP-SGD typically assumes that training examples are randomly partitioned into minibatches. But if we impose a data-dependent selection order on training examples, and further modify the selection criteria k times during training, then analysis through DP composition results in deterioration of the privacy guarantees of a magnitude equal to the square root of k.

Fortunately, example selection with DP-SGD can be naturally expressed in the RSC paradigm: each selection criteria reorders the training examples and each minibatch is a slice (for which we compute a noisy gradient). With RSC analysis, there is no privacy deterioration with k, which brings DP-SGD training with example selection into the practical domain.

Conclusion

The RSC paradigm was introduced in order to tackle an open problem that is primarily of theoretical significance, but turns out to be a versatile tool with the potential to enhance data efficiency in production environments.

Acknowledgments

The work described here was done jointly with Xin Lyu, Jelani Nelson, and Tamas Sarlos.

Categories
Offsites

A novel computational fluid dynamics framework for turbulent flow research

Turbulence is ubiquitous in environmental and engineering fluid flows, and is encountered routinely in everyday life. A better understanding of these turbulent processes could provide valuable insights across a variety of research areas — improving the prediction of cloud formation by atmospheric transport and the spreading of wildfires by turbulent energy exchange, understanding sedimentation of deposits in rivers, and improving the efficiency of combustion in aircraft engines to reduce emissions, to name a few. However, despite its importance, our current understanding and our ability to reliably predict such flows remains limited. This is mainly attributed to the highly chaotic nature and the enormous spatial and temporal scales these fluid flows occupy, ranging from energetic, large-scale movements on the order of several meters on the high-end, where energy is injected into the fluid flow, all the way down to micrometers (μm) on the low-end, where the turbulence is dissipated into heat by viscous friction.

A powerful tool to understand these turbulent flows is the direct numerical simulation (DNS), which provides a detailed representation of the unsteady three-dimensional flow-field without making any approximations or simplifications. More specifically, this approach utilizes a discrete grid with small enough grid spacing to capture the underlying continuous equations that govern the dynamics of the system (in this case, variable-density Navier-Stokes equations, which govern all fluid flow dynamics). When the grid spacing is small enough, the discrete grid points are enough to represent the true (continuous) equations without the loss of accuracy. While this is attractive, such simulations require tremendous computational resources in order to capture the correct fluid-flow behaviors across such a wide range of spatial scales.

The actual span in spatial resolution to which direct numerical calculations must be applied depends on the task and is determined by the Reynolds number, which compares inertial to viscous forces. Typically, the Reynolds number can range between 102 up to 107 (even larger for atmospheric or interstellar problems). In 3D, the grid size for the resolution required scales roughly with the Reynolds number to the power of 4.5! Because of this strong scaling dependency, simulating such flows is generally limited to flow regimes with moderate Reynolds numbers, and typically requires access to high-performance computing systems with millions of CPU/GPU cores.

In “A TensorFlow simulation framework for scientific computing of fluid flows on tensor processing units”, we introduce a new simulation framework that enables the computation of fluid flows with TPUs. By leveraging latest advances on TensorFlow software and TPU-hardware architecture, this software tool allows detailed large-scale simulations of turbulent flows at unprecedented scale, pushing the boundaries of scientific discovery and turbulence analysis. We demonstrate that this framework scales efficiently to accommodate the scale of the problem or, alternatively, improved run times, which is remarkable since most large-scale distributed computation frameworks exhibit reduced efficiency with scaling. The software is available as an open-source project on GitHub.

Large-scale scientific computation with accelerators

The software solves variable-density Navier-Stokes equations on TPU architectures using the TensorFlow framework. The single-instruction, multiple-data (SIMD) approach is adopted for parallelization of the TPU solver implementation. The finite difference operators on a colocated structured mesh are cast as filters of the convolution function of TensorFlow, leveraging TPU’s matrix multiply unit (MXU). The framework takes advantage of the low-latency high-bandwidth inter-chips interconnect (ICI) between the TPU accelerators. In addition, by leveraging the single-precision floating-point computations and highly optimized executable through the accelerated linear algebra (XLA) compiler, it’s possible to perform large-scale simulations with excellent scaling on TPU hardware architectures.

This research effort demonstrates that the graph-based TensorFlow in combination with new types of ML special purpose hardware, can be used as a programming paradigm to solve partial differential equations representing multiphysics flows. The latter is achieved by augmenting the Navier-Stokes equations with physical models to account for chemical reactions, heat-transfer, and density changes to enable, for example, simulations of cloud formation and wildfires.

It’s worth noting that this framework is the first open-source computational fluid dynamics (CFD) framework for high-performance, large-scale simulations to fully leverage the cloud accelerators that have become common (and become a commodity) with the advancement of machine learning (ML) in recent years. While our work focuses on using TPU accelerators, the code can be easily adjusted for other accelerators, such as GPU clusters.

This framework demonstrates a way to greatly reduce the cost and turn-around time associated with running large-scale scientific CFD simulations and enables even greater iteration speed in fields, such as climate and weather research. Since the framework is implemented using TensorFlow, an ML language, it also enables the ready integration with ML methods and allows the exploration of ML approaches on CFD problems. With the general accessibility of TPU and GPU hardware, this approach lowers the barrier for researchers to contribute to our understanding of large-scale turbulent systems.

Framework validation and homogeneous isotropic turbulence

Beyond demonstrating the performance and the scaling capabilities, it is also critical to validate the correctness of this framework to ensure that when it is used for CFD problems, we get reasonable results. For this purpose, researchers typically use idealized benchmark problems during CFD solver development, many of which we adopted in our work (more details in the paper).

One such benchmark for turbulence analysis is homogeneous isotropic turbulence (HIT), which is a canonical and well studied flow in which the statistical properties, such as kinetic energy, are invariant under translations and rotations of the coordinate axes. By pushing the resolution to the limits of the current state of the art, we were able to perform direct numerical simulations with more than eight billion degrees of freedom — equivalent to a three-dimensional mesh with 2,048 grid points along each of the three directions. We used 512 TPU-v4 cores, distributing the computation of the grid points along the x, y, and z axes to a distribution of [2,2,128] cores, respectively, optimized for the performance on TPU. The wall clock time per timestep was around 425 milliseconds and the flow was simulated for a total of 400,000 timesteps. 50 TB data, which includes the velocity and density fields, is stored for 400 timesteps (every 1,000th step). To our knowledge, this is one of the largest turbulent flow simulations of its kind conducted to date.

Due to the complex, chaotic nature of the turbulent flow field, which extends across several magnitudes of resolution, simulating the system in high resolution is necessary. Because we employ a fine-resolution grid with eight billion points, we are able to accurately resolve the field.

Contours of x-component of velocity along the z midplane. The high resolution of the simulation is critical to accurately represent the turbulent field.

The turbulent kinetic energy and dissipation rates are two statistical quantities commonly used to analyze a turbulent flow. The temporal decay of these properties in a turbulent field without additional energy injection is due to viscous dissipation and the decay asymptotes follow the expected analytical power law. This is in agreement with the theoretical asymptotes and observations reported in the literature and thus, validates our framework.

Solid line: Temporal evolution of turbulent kinetic energy (k). Dashed line: Analytical power laws for decaying homogeneous isotropic turbulence (n=1.3) (l: eddy turnover time).
Solid line: Temporal evolution of dissipation rate (ε). Dashed line: Analytical power laws for decaying homogeneous isotropic turbulence (n=1.3).

The energy spectrum of a turbulent flow represents the energy content across wavenumber, where the wavenumber k is proportional to the inverse wavelength λ (i.e., k ∝ 1/λ). Generally, the spectrum can be qualitatively divided into three ranges: source range, inertial range and viscous dissipative range (from left to right on the wavenumber axis, below). The lowest wavenumbers in the source range correspond to the largest turbulent eddies, which have the most energy content. These large eddies transfer energy to turbulence in the intermediate wavenumbers (inertial range), which is statistically isotropic (i.e., essentially uniform in all directions). The smallest eddies, corresponding to the largest wavenumbers, are dissipated into thermal energy by the viscosity of the fluid. By virtue of the fine grid having 2,048 points in each of the three spatial directions, we are able to resolve the flow field up to the length scale at which viscous dissipation takes place. This direct numerical simulation approach is the most accurate as it does not require any closure model to approximate the energy cascade below the grid size.

Spectrum of turbulent kinetic energy at different time instances. The spectrum is normalized by the instantaneous integral length (l) and the turbulent kinetic energy (k).

A new era for turbulent flows research

More recently, we extended this framework to predict wildfires and atmospheric flows, which is relevant for climate-risk assessment. Apart from enabling high-fidelity simulations of complex turbulent flows, this simulation framework also provides capabilities for scientific machine learning (SciML) — for example, downsampling from a fine to a coarse grid (model reduction) or building models that run at lower resolution while still capturing the correct dynamic behaviors. It could also provide avenues for further scientific discovery, such as building ML-based models to better parameterize microphysics of turbulent flows, including physical relationships between temperature, pressure, vapor fraction, etc., and could improve upon various control tasks, e.g., to reduce the energy consumption of buildings or find more efficient propeller shapes. While attractive, a main bottleneck in SciML has been the availability of data for training. To explore this, we have been working with groups at Stanford and Kaggle to make the data from our high-resolution HIT simulation available through a community-hosted web-platform, BLASTNet, to provide broad access to high-fidelity data to the research community via a network-of-datasets approach. We hope that the availability of these emerging high-fidelity simulation tools in conjunction with community-driven datasets will lead to significant advances in various areas of fluid mechanics.

Acknowledgements

We would like to thank Qing Wang, Yi-Fan Chen, and John Anderson for consulting and advice, Tyler Russell and Carla Bromberg for program management.

Categories
Offsites

TSMixer: An all-MLP architecture for time series forecasting

Time series forecasting is critical to various real-world applications, from demand forecasting to pandemic spread prediction. In multivariate time series forecasting (forecasting multiple variants at the same time), one can split existing methods into two categories: univariate models and multivariate models. Univariate models focus on inter-series interactions or temporal patterns that encompass trends and seasonal patterns on a time series with a single variable. Examples of such trends and seasonal patterns might be the way mortgage rates increase due to inflation, and how traffic peaks during rush hour. In addition to inter-series patterns, multivariate models process intra-series features, known as cross-variate information, which is especially useful when one series is an advanced indicator of another series. For example, a rise in body weight may cause an increase in blood pressure, and increasing the price of a product may lead to a decrease in sales. Multivariate models have recently become popular solutions for multivariate forecasting as practitioners believe their capability of handling cross-variate information may lead to better performance.

In recent years, deep learning Transformer-based architectures have become a popular choice for multivariate forecasting models due to their superior performance on sequence tasks. However, advanced multivariate models perform surprisingly worse than simple univariate linear models on commonly-used long-term forecasting benchmarks, such as Electricity Transformer Temperature (ETT), Electricity, Traffic, and Weather. These results raise two questions:

  • Does cross-variate information benefit time series forecasting?
  • When cross-variate information is not beneficial, can multivariate models still perform as well as univariate models?

In “TSMixer: An All-MLP Architecture for Time Series Forecasting”, we analyze the advantages of univariate linear models and reveal their effectiveness. Insights from this analysis lead us to develop Time-Series Mixer (TSMixer), an advanced multivariate model that leverages linear model characteristics and performs well on long-term forecasting benchmarks. To the best of our knowledge, TSMixer is the first multivariate model that performs as well as state-of-the-art univariate models on long-term forecasting benchmarks, where we show that cross-variate information is less beneficial. To demonstrate the importance of cross-variate information, we evaluate a more challenging real-world application, M5. Finally, empirical results show that TSMixer outperforms state-of-the-art models, such as PatchTST, Fedformer, Autoformer, DeepAR and TFT.

TSMixer architecture

A key difference between linear models and Transformers is how they capture temporal patterns. On one hand, linear models apply fixed and time-step-dependent weights to capture static temporal patterns, and are unable to process cross-variate information. On the other hand, Transformers use attention mechanisms that apply dynamic and data-dependent weights at each time step, capturing dynamic temporal patterns and enabling them to process cross-variate information.

In our analysis, we show that under common assumptions of temporal patterns, linear models have naïve solutions to perfectly recover the time series or place bounds on the error, which means they are great solutions for learning static temporal patterns of univariate time series more effectively. In contrast, it is non-trivial to find similar solutions for attention mechanisms, as the weights applied to each time step are dynamic. Consequently, we develop a new architecture by replacing Transformer attention layers with linear layers. The resulting TSMixer model, which is similar to the computer vision MLP-Mixer method, alternates between applications of the multi-layer perceptron in different directions, which we call time-mixing and feature-mixing, respectively. The TSMixer architecture efficiently captures both temporal patterns and cross-variate information, as shown in the figure below. The residual designs ensure that TSMixer retains the capacity of temporal linear models while still being able to exploit cross-variate information.

Transformer block and TSMixer block architectures. TSMixer replaces the multi-head attention layer with time-mixing, a linear model applied on the time dimension.

Comparison between data-dependent (attention mechanisms) and time-step-dependent (linear models). This is an example of forecasting the next time step by learning the weights of the previous three time steps.

Evaluation on long-term forecasting benchmarks

We evaluate TSMixer using seven popular long-term forecasting datasets (ETTm1, ETTm2, ETTh1, ETTh2, Electricity, Traffic, and Weather), where recent research has shown that univariate linear models outperform advanced multivariate models with large margins. We compare TSMixer with state-of-the-art multivariate models (TFT, FEDformer, Autoformer, Informer), and univariate models, including linear models and PatchTST. The figure below shows the average improvement of mean squared error (MSE) by TSMixer compared with others. The average is calculated across datasets and multiple forecasting horizons. We demonstrate that TSMixer significantly outperforms other multivariate models and performs on par with state-of-the-art univariate models. These results show that multivariate models are capable of performing as well as univariate models.

The average MSE improvement of TSMixer compared with other baselines. The red bars show multivariate methods and the blue bars show univariate methods. TSMixer achieves significant improvement over other multivariate models and achieves comparable results to univariate models.

Ablation study

We performed an ablation study to compare TSMixer with TMix-Only, a TSMixer variant that consists of time mixing layers only. The results show that TMix-Only performs almost the same as TSMixer, which means the additional feature mixing layers do not improve the performance and confirms that cross-variate information is less beneficial on popular benchmarks. The results validate the superior univariate model performance shown in previous research. However, existing long-term forecasting benchmarks are not well representative of the need for cross-variate information in some real-world applications where time series may be intermittent or sparse, hence temporal patterns may not be sufficient for forecasting. Therefore, it may be inappropriate to evaluate multivariate forecasting models solely on these benchmarks.

Evaluation on M5: Effectiveness of cross-variate information

To further demonstrate the benefit of multivariate models, we evaluate TSMixer on the challenging M5 benchmark, a large-scale retail dataset containing crucial cross-variate interactions. M5 contains the information of 30,490 products collected over 5 years. Each product description includes time series data, like daily sales, sell price, promotional event information, and static (non-time-series) features, such as store location and product category. The goal is to forecast the daily sales of each product for the next 28 days, evaluated using the weighted root mean square scaled error (WRMSSE) from the M5 competition. The complicated nature of retail makes it more challenging to forecast solely using univariate models that focus on temporal patterns, so multivariate models with cross-variate information and even auxiliary features are more essential.

First, we compare TSMixer to other methods only considering the historical data, such as daily sales and historical sell prices. The results show that multivariate models outperforms univariate models significantly, indicating the usefulness of cross-variate information. And among all compared methods, TSMixer effectively leverages the cross-variate information and achieves the best performance.

Additionally, to leverage more information, such as static features (e.g., store location, product category) and future time series (e.g., a promotional event scheduled in coming days) provided in M5, we propose a principle design to extend TSMixer. The extended TSMixer aligns different types of features into the same length, and then applies multiple mixing layers to the concatenated features to make predictions. The extended TSMixer architecture outperforms models popular in industrial applications, including DeepAR and TFT, showcasing its strong potential for real-world impact.

The architecture of the extended TSMixer. In the first stage (align stage), it aligns the different types of features into the same length before concatenating them. In the second stage (mixing stage) it applies multiple mixing layers conditioned with static features.

The WRMSSE on M5. The first three methods (blue) are univariate models. The middle three methods (orange) are multivariate models that consider only historical features. The last three methods (red) are multivariate models that consider historical, future, and static features.

Conclusion

We present TSMixer, an advanced multivariate model that leverages linear model characteristics and performs as well as state-of-the-art univariate models on long-term forecasting benchmarks. TSMixer creates new possibilities for the development of time series forecasting architectures by providing insights into the importance of cross-variate and auxiliary information in real-world scenarios. The empirical results highlight the need to consider more realistic benchmarks for multivariate forecasting models in future research. We hope that this work will inspire further exploration in the field of time series forecasting, and lead to the development of more powerful and effective models that can be applied to real-world applications.

Acknowledgements

This research was conducted by Si-An Chen, Chun-Liang Li, Nate Yoder, Sercan O. Arik, and Tomas Pfister.

Categories
Offsites

This demo tests your understanding of light | Barber pole, part 1