Categories
Offsites

Cappy: Outperforming and boosting large multi-task language models with a small scorer

Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as T0, FLAN, and OPT-IML. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., Put the concepts together to form a sentence: ski, mountain, skier) paired with a corresponding response (e.g., Skier skis down the mountain). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions.

The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks.

Due to the complexity of understanding and solving various tasks solely using instructions, the size of multi-task LLMs typically spans from several billion parameters to hundreds of billions (e.g., FLAN-11B, T0-11B and OPT-IML-175B). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues.

Certain parameter-efficient tuning strategies, including prompt tuning and adapters, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some in-context learning techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model’s maximum input length, which permits only a few samples to guide task resolution.

In “Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer”, presented at NeurIPS 2023, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of RoBERTa with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs.

Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction.

Pre-training

We begin with the same dataset collection, which includes 39 diverse datasets from PromptSource that were used to train T0. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response.

Cappy’s regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ Rouge-L, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision.

As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the RoBERTa model. The pre-training of Cappy is conducted on Google’s TPU-v4, with RedCoast, a lightweight toolkit for automating distributed training.

Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning.

Applying Cappy

Cappy solves practical tasks within a candidate-selection mechanism. More specifically, given an instruction and a set of candidate responses, Cappy produces a score for each candidate response. This is achieved by inputting the instruction alongside each individual response, and then assigning the response with the highest score as its prediction. In classification tasks, all candidate responses are inherently predefined. For example, for an instruction of a sentiment classification task (e.g., “Based on this review, would the user recommend this product?: ‘Stunning even for the non-gamer.’”), the candidate responses are “Yes” or “No”. In such scenarios, Cappy functions independently. On the other hand, in generation tasks, candidate responses are not pre-defined, requiring an existing multi-task LLM to yield the candidate responses. In this case, Cappy serves as an auxiliary component of the multi-task LLM, enhancing its decoding.

Adapting multi-task LLMs with Cappy

When there is available downstream training data, Cappy enables effective and efficient adaptation of multi-task LLMs on downstream tasks. Specifically, we fine-tune Cappy to integrate downstream task information into LLM predictions. This process involves creating a separate regression dataset specific to the downstream training data with the same data annotation process used to construct the pre-training data. As a result, the fine-tuned Cappy collaborates with a multi-task LLM, boosting the LLM’s performance on the downstream task.

In contrast to other LLM tuning strategies, adapting LLMs with Cappy significantly reduces the high demand for device memory as it avoids the need for back-propagation through LLM parameters for downstream tasks. Moreover, Cappy adaptation does not rely on the access to LLM parameters, making it compatible with closed-source multi-task LLMs, such as the ones only accessible via WebAPIs. Compared with in-context learning approaches, which circumvent model tuning by attaching training examples to the instruction prefix, Cappy is not restricted by the LLM’s maximum input length. Thus, Cappy can incorporate an unlimited number of downstream training examples. Cappy can also be applied with other adaptation methods, such as fine-tuning and in-context learning, further boosting their overall performance.

Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs.

Results

We assess Cappy’s performance across eleven held-out language understanding classification tasks from PromptSource. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on teacher-forcing training that utilizes only the ground truth responses.

The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a pre-trained RLHF reward model. Cappy matches the best ones among existing multi-task LLMs.

We also examine the adaptation of multi-task LLMs with Cappy on complex tasks from BIG-Bench, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself.

The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin.

Conclusion

We introduce Cappy, a novel approach that enhances the performance and efficiency of multi-task LLMs. In our experiments, we adapt a single LLM to several domains with Cappy. In the future, Cappy as a pre-trained model can potentially be used in other creative ways beyond on single LLMs.

Acknowledgments

Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions.

Categories
Offsites

Talk like a graph: Encoding graphs for large language models

Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term graph is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.

Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses.

Since graphs are everywhere and LLM technology is on the rise, in “Talk like a Graph: Encoding Graphs for Large Language Models”, presented at ICLR 2024, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called GraphQA to study different approaches on different graph reasoning problems and show how to phrase a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!

Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.

Graphs as text

To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called GraphQA. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.

Overview of our framework for reasoning with graphs using LLMs.

GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like Erdős-Rényi, scale-free networks, Barabasi-Albert model, and stochastic block model, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.

When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. Prompting heuristics are different strategies for doing this. Let’s break down the common ones:

  • Zero-shot: simply describe the task (“Is there a cycle in this graph?”) and tell the LLM to go for it. No examples provided.
  • Few-shot: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.
  • Chain-of-Thought: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own “thought process” when faced with new graphs.
  • Zero-CoT: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like “Let’s think step-by-step,” to trigger its own problem-solving breakdown.
  • BAG (build a graph): This is specifically for graph tasks. We add the phrase “Let’s build a graph…” to the description, helping the LLM focus on the graph structure.

We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:

  • Node encoding: How do we represent individual nodes? Options tested include simple integers, common names (people, characters), and letters.
  • Edge encoding: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like “are friends”, and symbolic representations like arrows.

Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:

Examples of graph encoding functions used to encode graphs via text.

Analysis and results

We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA.

How LLMs handle graph tasks

In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:

  • LLMs struggle: On most of these basic tasks, LLMs did not do much better than a random guess.
  • Encoding matters significantly: How we represent the graph as text has a great effect on LLM performance. The “incident” encoding excelled for most of the tasks in general.

Our results are summarized in the following chart.

Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.

Bigger is (usually) better

In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of PaLM 2. Here is a summary of our findings:

  • In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.
  • Oddly, size didn’t matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).
  • Even the biggest LLM couldn’t consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.

Do different graph shapes confuse LLMs

We wondered if the “shape” of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.

Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.

Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

Conclusion

In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:

  • How to translate the graph to text: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..
  • Task type: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.
  • Graph structure: Surprisingly, the “shape” of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.

This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM’s accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.

Acknowledgements

We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.

Categories
Offsites

Chain-of-table: Evolving tables in the reasoning chain for table understanding

People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in natural language processing (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize.

Recently, large language models (LLMs) have achieved outstanding performance across diverse natural language understanding (NLU) tasks by generating reliable reasoning chains, as shown in works like Chain-of-Thought and Least-to-Most. However, the most suitable way for LLMs to reason over tabular data remains an open question.

In “Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the WikiTQ, TabFact, and FeTaQA benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods.

Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question.

Chain-of-Table

In Chain-of-Table, we guide LLMs using in-context learning to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM.

For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question.

These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding.

Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably.

Chain-of-Table consists of three main stages. In the first stage, it instructs the LLM to dynamically plan the next operation by in-context learning. Specifically, the prompt involves three components as shown in the following figure:

  1. The question Q: “Which country had the most cyclists finish in the top 3?”
  2. The operation history chain: f_add_col(Country) and f_select_row(1, 2, 3).
  3. The latest intermediate table T: the transformed intermediate table.

By providing the triplet (T, Q, chain) in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step.

Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments.

After the next operation f is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table.

For instance, when the operation f_group_by is selected, it requires a header name as its argument.

The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning.

Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer.

Experimental setup

We use PaLM 2-S and GPT 3.5 as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: WikiTQ, TabFact, and FeTaQA. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and Chain-of-Thought) and the program-aided methods (e.g., Text-to-SQL, Binder, and Dater).

More accurate answers

Compared to the generic reasoning methods and program-aided reasoning methods, Chain-of-Table achieves better performance across PaLM 2 and GPT 3.5. This is attributed to the dynamically sampled operations and the informative intermediate tables.

Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models.

Better robustness on harder questions

In Chain-of-Table, longer operation chains indicate the higher difficulty and complexity of the questions and their corresponding tables. We categorize the test samples according to their operation lengths in Chain-of-Table. We compare Chain-of-Table with Chain-of-Thought and Dater, as representative generic and program-aided reasoning methods. We illustrate this using results from PaLM 2 on WikiTQ.

Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts.

Notably, Chain-of-Table consistently surpasses both baseline methods across all operation chain lengths, with a significant margin up to 11.6% compared with Chain-of-Thought, and up to 7.9% compared with Dater. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five.

Better robustness with larger tables

We categorize the tables from WikiTQ into three groups based on token number: small (<2000 tokens), medium (2000 to 4000 tokens) and large (>4000 tokens). We then compare Chain-of-Table with Dater and Binder, the two latest and strongest baselines.

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs.

Conclusion

Our proposed Chain-of-Table method enhances the reasoning capability of LLMs by leveraging the tabular structure to express intermediate steps for table-based reasoning. It instructs LLMs to dynamically plan an operation chain according to the input table and its associated question. This evolving table design sheds new light on the understanding of prompting LLMs for table understanding.

Acknowledgements

This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.

Categories
Offsites

Computer-aided diagnosis for lung cancer screening

Lung cancer is the leading cause of cancer-related deaths globally with 1.8 million deaths reported in 2020. Late diagnosis dramatically reduces the chances of survival. Lung cancer screening via computed tomography (CT), which provides a detailed 3D image of the lungs, has been shown to reduce mortality in high-risk populations by at least 20% by detecting potential signs of cancers earlier. In the US, screening involves annual scans, with some countries or cases recommending more or less frequent scans.

The United States Preventive Services Task Force recently expanded lung cancer screening recommendations by roughly 80%, which is expected to increase screening access for women and racial and ethnic minority groups. However, false positives (i.e., incorrectly reporting a potential cancer in a cancer-free patient) can cause anxiety and lead to unnecessary procedures for patients while increasing costs for the healthcare system. Moreover, efficiency in screening a large number of individuals can be challenging depending on healthcare infrastructure and radiologist availability.

At Google we have previously developed machine learning (ML) models for lung cancer detection, and have evaluated their ability to automatically detect and classify regions that show signs of potential cancer. Performance has been shown to be comparable to that of specialists in detecting possible cancer. While they have achieved high performance, effectively communicating findings in realistic environments is necessary to realize their full potential.

To that end, in “Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the US and Japan”, published in Radiology AI, we investigate how ML models can effectively communicate findings to radiologists. We also introduce a generalizable user-centric interface to help radiologists leverage such models for lung cancer screening. The system takes CT imaging as input and outputs a cancer suspicion rating using four categories (no suspicion, probably benign, suspicious, highly suspicious) along with the corresponding regions of interest. We evaluate the system’s utility in improving clinician performance through randomized reader studies in both the US and Japan, using the local cancer scoring systems (Lung-RADSs V1.1 and Sendai Score) and image viewers that mimic realistic settings. We found that reader specificity increases with model assistance in both reader studies. To accelerate progress in conducting similar studies with ML models, we have open-sourced code to process CT images and generate images compatible with the picture archiving and communication system (PACS) used by radiologists.

Developing an interface to communicate model results

Integrating ML models into radiologist workflows involves understanding the nuances and goals of their tasks to meaningfully support them. In the case of lung cancer screening, hospitals follow various country-specific guidelines that are regularly updated. For example, in the US, Lung-RADs V1.1 assigns an alpha-numeric score to indicate the lung cancer risk and follow-up recommendations. When assessing patients, radiologists load the CT in their workstation to read the case, find lung nodules or lesions, and apply set guidelines to determine follow-up decisions.

Our first step was to improve the previously developed ML models through additional training data and architectural improvements, including self-attention. Then, instead of targeting specific guidelines, we experimented with a complementary way of communicating AI results independent of guidelines or their particular versions. Specifically, the system output offers a suspicion rating and localization (regions of interest) for the user to consider in conjunction with their own specific guidelines. The interface produces output images directly associated with the CT study, requiring no changes to the user’s workstation. The radiologist only needs to review a small set of additional images. There is no other change to their system or interaction with the system.

Example of the assistive lung cancer screening system outputs. Results for the radiologist’s evaluation are visualized on the location of the CT volume where the suspicious lesion is found. The overall suspicion is displayed at the top of the CT images. Circles highlight the suspicious lesions while squares show a rendering of the same lesion from a different perspective, called a sagittal view.

The assistive lung cancer screening system comprises 13 models and has a high-level architecture similar to the end-to-end system used in prior work. The models coordinate with each other to first segment the lungs, obtain an overall assessment, locate three suspicious regions, then use the information to assign a suspicion rating to each region. The system was deployed on Google Cloud using a Google Kubernetes Engine (GKE) that pulled the images, ran the ML models, and provided results. This allows scalability and directly connects to servers where the images are stored in DICOM stores.

Outline of the Google Cloud deployment of the assistive lung cancer screening system and the directional calling flow for the individual components that serve the images and compute results. Images are served to the viewer and to the system using Google Cloud services. The system is run on a Google Kubernetes Engine that pulls the images, processes them, and writes them back into the DICOM store.

Reader studies

To evaluate the system’s utility in improving clinical performance, we conducted two reader studies (i.e., experiments designed to assess clinical performance comparing expert performance with and without the aid of a technology) with 12 radiologists using pre-existing, de-identified CT scans. We presented 627 challenging cases to 6 US-based and 6 Japan-based radiologists. In the experimental setup, readers were divided into two groups that read each case twice, with and without assistance from the model. Readers were asked to apply scoring guidelines they typically use in their clinical practice and report their overall suspicion of cancer for each case. We then compared the results of the reader’s responses to measure the impact of the model on their workflow and decisions. The score and suspicion level were judged against the actual cancer outcomes of the individuals to measure sensitivity, specificity, and area under the ROC curve (AUC) values. These were compared with and without assistance.

A multi-case multi-reader study involves each case being reviewed by each reader twice, once with ML system assistance and once without. In this visualization one reader first reviews Set A without assistance (blue) and then with assistance (orange) after a wash-out period. A second reader group follows the opposite path by reading the same set of cases Set A with assistance first. Readers are randomized to these groups to remove the effect of ordering.

The ability to conduct these studies using the same interface highlights its generalizability to completely different cancer scoring systems, and the generalization of the model and assistive capability to different patient populations. Our study results demonstrated that when radiologists used the system in their clinical evaluation, they had an increased ability to correctly identify lung images without actionable lung cancer findings (i.e., specificity) by an absolute 5–7% compared to when they didn’t use the assistive system. This potentially means that for every 15–20 patients screened, one may be able to avoid unnecessary follow-up procedures, thus reducing their anxiety and the burden on the health care system. This can, in turn, help improve the sustainability of lung cancer screening programs, particularly as more people become eligible for screening.

Reader specificity increases with ML model assistance in both the US-based and Japan-based reader studies. Specificity values were derived from reader scores from actionable findings (something suspicious was found) versus no actionable findings, compared against the true cancer outcome of the individual. Under model assistance, readers flagged fewer cancer-negative individuals for follow-up visits. Sensitivity for cancer positive individuals remained the same.

Translating this into real-world impact through partnership

The system results demonstrate the potential for fewer follow-up visits, reduced anxiety, as well lower overall costs for lung cancer screening. In an effort to translate this research into real-world clinical impact, we are working with: DeepHealth, a leading AI-powered health informatics provider; and Apollo Radiology International a leading provider of Radiology services in India to explore paths for incorporating this system into future products. In addition, we are looking to help other researchers studying how best to integrate ML model results into clinical workflows by open sourcing code used for the reader study and incorporating the insights described in this blog. We hope that this will help accelerate medical imaging researchers looking to conduct reader studies for their AI models, and catalyze translational research in the field.

Acknowledgements

Key contributors to this project include Corbin Cunningham, Zaid Nabulsi, Ryan Najafi, Jie Yang, Charles Lau, Joseph R. Ledsam, Wenxing Ye, Diego Ardila, Scott M. McKinney, Rory Pilgrim, Hiroaki Saito, Yasuteru Shimamura, Mozziyar Etemadi, Yun Liu, David Melnick, Sunny Jansen, Nadia Harhen, David P. Nadich, Mikhail Fomitchev, Ziyad Helali, Shabir Adeel, Greg S. Corrado, Lily Peng, Daniel Tse, Shravya Shetty, Shruthi Prabhakara, Neeral Beladia, and Krish Eswaran. Thanks to Arnav Agharwal and Andrew Sellergren for their open sourcing support and Vivek Natarajan and Michael D. Howell for their feedback. Sincere appreciation also goes to the radiologists who enabled this work with their image interpretation and annotation efforts throughout the study, and Jonny Wong and Carli Sampson for coordinating the reader studies.

Categories
Offsites

Using AI to expand global access to reliable flood forecasts

Floods are the most common natural disaster, and are responsible for roughly $50 billion in annual financial damages worldwide. The rate of flood-related disasters has more than doubled since the year 2000 partly due to climate change. Nearly 1.5 billion people, making up 19% of the world’s population, are exposed to substantial risks from severe flood events. Upgrading early warning systems to make accurate and timely information accessible to these populations can save thousands of lives per year.

Driven by the potential impact of reliable flood forecasting on people’s lives globally, we started our flood forecasting effort in 2017. Through this multi-year journey, we advanced research over the years hand-in-hand with building a real-time operational flood forecasting system that provides alerts on Google Search, Maps, Android notifications and through the Flood Hub. However, in order to scale globally, especially in places where accurate local data is not available, more research advances were required.

In “Global prediction of extreme floods in ungauged watersheds”, published in Nature, we demonstrate how machine learning (ML) technologies can significantly improve global-scale flood forecasting relative to the current state-of-the-art for countries where flood-related data is scarce. With these AI-based technologies we extended the reliability of currently-available global nowcasts, on average, from zero to five days, and improved forecasts across regions in Africa and Asia to be similar to what are currently available in Europe. The evaluation of the models was conducted in collaboration with the European Center for Medium Range Weather Forecasting (ECMWF).

These technologies also enable Flood Hub to provide real-time river forecasts up to seven days in advance, covering river reaches across over 80 countries. This information can be used by people, communities, governments and international organizations to take anticipatory action to help protect vulnerable populations.

Flood forecasting at Google

The ML models that power the FloodHub tool are the product of many years of research, conducted in collaboration with several partners, including academics, governments, international organizations, and NGOs.

In 2018, we launched a pilot early warning system in the Ganges-Brahmaputra river basin in India, with the hypothesis that ML could help address the challenging problem of reliable flood forecasting at scale. The pilot was further expanded the following year via the combination of an inundation model, real-time water level measurements, the creation of an elevation map and hydrologic modeling.

In collaboration with academics, and, in particular, with the JKU Institute for Machine Learning we explored ML-based hydrologic models, showing that LSTM-based models could produce more accurate simulations than traditional conceptual and physics-based hydrology models. This research led to flood forecasting improvements that enabled the expansion of our forecasting coverage to include all of India and Bangladesh. We also worked with researchers at Yale University to test technological interventions that increase the reach and impact of flood warnings.

Our hydrological models predict river floods by processing publicly available weather data like precipitation and physical watershed information. Such models must be calibrated to long data records from streamflow gauging stations in individual rivers. A low percentage of global river watersheds (basins) have streamflow gauges, which are expensive but necessary to supply relevant data, and it’s challenging for hydrological simulation and forecasting to provide predictions in basins that lack this infrastructure. Lower gross domestic product (GDP) is correlated with increased vulnerability to flood risks, and there is an inverse correlation between national GDP and the amount of publicly available data in a country. ML helps to address this problem by allowing a single model to be trained on all available river data and to be applied to ungauged basins where no data are available. In this way, models can be trained globally, and can make predictions for any river location.

There is an inverse (log-log) correlation between the amount of publicly available streamflow data in a country and national GDP. Streamflow data from the Global Runoff Data Center.

Our academic collaborations led to ML research that developed methods to estimate uncertainty in river forecasts and showed how ML river forecast models synthesize information from multiple data sources. They demonstrated that these models can simulate extreme events reliably, even when those events are not part of the training data. In an effort to contribute to open science, in 2023 we open-sourced a community-driven dataset for large-sample hydrology in Nature Scientific Data.

The river forecast model

Most hydrology models used by national and international agencies for flood forecasting and river modeling are state-space models, which depend only on daily inputs (e.g., precipitation, temperature, etc.) and the current state of the system (e.g., soil moisture, snowpack, etc.). LSTMs are a variant of state-space models and work by defining a neural network that represents a single time step, where input data (such as current weather conditions) are processed to produce updated state information and output values (streamflow) for that time step. LSTMs are applied sequentially to make time-series predictions, and in this sense, behave similarly to how scientists typically conceptualize hydrologic systems. Empirically, we have found that LSTMs perform well on the task of river forecasting.

A diagram of the LSTM, which is a neural network that operates sequentially in time. An accessible primer can be found here.

Our river forecast model uses two LSTMs applied sequentially: (1) a “hindcast” LSTM ingests historical weather data (dynamic hindcast features) up to the present time (or rather, the issue time of a forecast), and (2) a “forecast” LSTM ingests states from the hindcast LSTM along with forecasted weather data (dynamic forecast features) to make future predictions. One year of historical weather data are input into the hindcast LSTM, and seven days of forecasted weather data are input into the forecast LSTM. Static features include geographical and geophysical characteristics of watersheds that are input into both the hindcast and forecast LSTMs and allow the model to learn different hydrological behaviors and responses in various types of watersheds.

Output from the forecast LSTM is fed into a “head” layer that uses mixture density networks to produce a probabilistic forecast (i.e., predicted parameters of a probability distribution over streamflow). Specifically, the model predicts the parameters of a mixture of heavy-tailed probability density functions, called asymmetric Laplacian distributions, at each forecast time step. The result is a mixture density function, called a Countable Mixture of Asymmetric Laplacians (CMAL) distribution, which represents a probabilistic prediction of the volumetric flow rate in a particular river at a particular time.

LSTM-based river forecast model architecture. Two LSTMs are applied in sequence, one ingesting historical weather data and one ingesting forecasted weather data. The model outputs are the parameters of a probability distribution over streamflow at each forecasted timestep.

Input and training data

The model uses three types of publicly available data inputs, mostly from governmental sources:

  1. Static watershed attributes representing geographical and geophysical variables: From the HydroATLAS project, including data like long-term climate indexes (precipitation, temperature, snow fractions), land cover, and anthropogenic attributes (e.g., a nighttime lights index as a proxy for human development).
  2. Historical meteorological time-series data: Used to spin up the model for one year prior to the issue time of a forecast. The data comes from NASA IMERG, NOAA CPC Global Unified Gauge-Based Analysis of Daily Precipitation, and the ECMWF ERA5-land reanalysis. Variables include daily total precipitation, air temperature, solar and thermal radiation, snowfall, and surface pressure.
  3. Forecasted meteorological time series over a seven-day forecast horizon: Used as input for the forecast LSTM. These data are the same meteorological variables listed above, and come from the ECMWF HRES atmospheric model.

Training data are daily streamflow values from the Global Runoff Data Center over the time period 1980 – 2023. A single streamflow forecast model is trained using data from 5,680 diverse watershed streamflow gauges (shown below) to improve accuracy.

Location of 5,680 streamflow gauges that supply training data for the river forecast model from the Global Runoff Data Center.

Improving on the current state-of-the-art

We compared our river forecast model with GloFAS version 4, the current state-of-the-art global flood forecasting system. These experiments showed that ML can provide accurate warnings earlier and over larger and more impactful events.

The figure below shows the distribution of F1 scores when predicting different severity events at river locations around the world, with plus or minus 1 day accuracy. F1 scores are an average of precision and recall and event severity is measured by return period. For example, a 2-year return period event is a volume of streamflow that is expected to be exceeded on average once every two years. Our model achieves reliability scores at up to 4-day or 5-day lead times that are similar to or better, on average, than the reliability of GloFAS nowcasts (0-day lead time).

Distributions of F1 scores over 2-year return period events in 2,092 watersheds globally during the time period 2014-2023 from GloFAS (blue) and our model (orange) at different lead times. On average, our model is statistically as accurate as GloFAS nowcasts (0–day lead time) up to 5 days in advance over 2-year (shown) and 1-year, 5-year, and 10-year events (not shown).

Additionally (not shown), our model achieves accuracies over larger and rarer extreme events, with precision and recall scores over 5-year return period events that are similar to or better than GloFAS accuracies over 1-year return period events. See the paper for more information.

Looking into the future

The flood forecasting initiative is part of our Adaptation and Resilience efforts and reflects Google’s commitment to address climate change while helping global communities become more resilient. We believe that AI and ML will continue to play a critical role in helping advance science and research towards climate action.

We actively collaborate with several international aid organizations (e.g., the Centre for Humanitarian Data and the Red Cross) to provide actionable flood forecasts. Additionally, in an ongoing collaboration with the World Meteorological Organization (WMO) to support early warning systems for climate hazards, we are conducting a study to help understand how AI can help address real-world challenges faced by national flood forecasting agencies.

While the work presented here demonstrates a significant step forward in flood forecasting, future work is needed to further expand flood forecasting coverage to more locations globally and other types of flood-related events and disasters, including flash floods and urban floods. We are looking forward to continuing collaborations with our partners in the academic and expert communities, local governments and the industry to reach these goals.

Categories
Offsites

ScreenAI: A visual language model for UI and visually-situated language understanding

Screen user interfaces (UIs) and infographics, such as charts, diagrams and tables, play important roles in human communication and human-machine interaction as they facilitate rich and interactive user experiences. UIs and infographics share similar design principles and visual language (e.g., icons and layouts), that offer an opportunity to build a single model that can understand, reason, and interact with these interfaces. However, because of their complexity and varied presentation formats, infographics and UIs present a unique modeling challenge.

To that end, we introduce “ScreenAI: A Vision-Language Model for UI and Infographics Understanding”. ScreenAI improves upon the PaLI architecture with the flexible patching strategy from pix2struct. We train ScreenAI on a unique mixture of datasets and tasks, including a novel Screen Annotation task that requires the model to identify UI element information (i.e., type, location and description) on a screen. These text annotations provide large language models (LLMs) with screen descriptions, enabling them to automatically generate question-answering (QA), UI navigation, and summarization training datasets at scale. At only 5B parameters, ScreenAI achieves state-of-the-art results on UI- and infographic-based tasks (WebSRC and MoTIF), and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. We are also releasing three new datasets: Screen Annotation to evaluate the layout understanding capability of the model, as well as ScreenQA Short and Complex ScreenQA for a more comprehensive evaluation of its QA capability.

ScreenAI

ScreenAI’s architecture is based on PaLI, composed of a multimodal encoder block and an autoregressive decoder. The PaLI encoder uses a vision transformer (ViT) that creates image embeddings and a multimodal encoder that takes the concatenation of the image and text embeddings as input. This flexible architecture allows ScreenAI to solve vision tasks that can be recast as text+image-to-text problems.

On top of the PaLI architecture, we employ a flexible patching strategy introduced in pix2struct. Instead of using a fixed-grid pattern, the grid dimensions are selected such that they preserve the native aspect ratio of the input image. This enables ScreenAI to work well across images of various aspect ratios.

The ScreenAI model is trained in two stages: a pre-training stage followed by a fine-tuning stage. First, self-supervised learning is applied to automatically generate data labels, which are then used to train ViT and the language model. ViT is frozen during the fine-tuning stage, where most data used is manually labeled by human raters.

ScreenAI model architecture.

Data generation

To create a pre-training dataset for ScreenAI, we first compile an extensive collection of screenshots from various devices, including desktops, mobile, and tablets. This is achieved by using publicly accessible web pages and following the programmatic exploration approach used for the RICO dataset for mobile apps. We then apply a layout annotator, based on the DETR model, that identifies and labels a wide range of UI elements (e.g., image, pictogram, button, text) and their spatial relationships. Pictograms undergo further analysis using an icon classifier capable of distinguishing 77 different icon types. This detailed classification is essential for interpreting the subtle information conveyed through icons. For icons that are not covered by the classifier, and for infographics and images, we use the PaLI image captioning model to generate descriptive captions that provide contextual information. We also apply an optical character recognition (OCR) engine to extract and annotate textual content on screen. We combine the OCR text with the previous annotations to create a detailed description of each screen.

A mobile app screenshot with generated annotations that include UI elements and their descriptions, e.g., TEXT elements also contain the text content from OCR, IMAGE elements contain image captions, LIST_ITEMs contain all their child elements.

LLM-based data generation

We enhance the pre-training data’s diversity using PaLM 2 to generate input-output pairs in a two-step process. First, screen annotations are generated using the technique outlined above, then we craft a prompt around this schema for the LLM to create synthetic data. This process requires prompt engineering and iterative refinement to find an effective prompt. We assess the generated data’s quality through human validation against a quality threshold.

You only speak JSON. Do not write text that isn’t JSON.
You are given the following mobile screenshot, described in words. Can you generate 5 questions regarding the content of the screenshot as well as the corresponding short answers to them? 

The answer should be as short as possible, containing only the necessary information. Your answer should be structured as follows:
questions: [
{{question: the question,
    answer: the answer
}},
 ...
]

{THE SCREEN SCHEMA}

A sample prompt for QA data generation.

By combining the natural language capabilities of LLMs with a structured schema, we simulate a wide range of user interactions and scenarios to generate synthetic, realistic tasks. In particular, we generate three categories of tasks:

  • Question answering: The model is asked to answer questions regarding the content of the screenshots, e.g., “When does the restaurant open?”
  • Screen navigation: The model is asked to convert a natural language utterance into an executable action on a screen, e.g., “Click the search button.”
  • Screen summarization: The model is asked to summarize the screen content in one or two sentences.
Block diagram of our workflow for generating data for QA, summarization and navigation tasks using existing ScreenAI models and LLMs. Each task uses a custom prompt to emphasize desired aspects, like questions related to counting, involving reasoning, etc.

LLM-generated data. Examples for screen QA, navigation and summarization. For navigation, the action bounding box is displayed in red on the screenshot.

Experiments and results

As previously mentioned, ScreenAI is trained in two stages: pre-training and fine-tuning. Pre-training data labels are obtained using self-supervised learning and fine-tuning data labels comes from human raters.

We fine-tune ScreenAI using public QA, summarization, and navigation datasets and a variety of tasks related to UIs. For QA, we use well established benchmarks in the multimodal and document understanding field, such as ChartQA, DocVQA, Multi page DocVQA, InfographicVQA, OCR VQA, Web SRC and ScreenQA. For navigation, datasets used include Referring Expressions, MoTIF, Mug, and Android in the Wild. Finally, we use Screen2Words for screen summarization and Widget Captioning for describing specific UI elements. Along with the fine-tuning datasets, we evaluate the fine-tuned ScreenAI model using three novel benchmarks:

  1. Screen Annotation: Enables the evaluation model layout annotations and spatial understanding capabilities.
  2. ScreenQA Short: A variation of ScreenQA, where its ground truth answers have been shortened to contain only the relevant information that better aligns with other QA tasks.
  3. Complex ScreenQA: Complements ScreenQA Short with more difficult questions (counting, arithmetic, comparison, and non-answerable questions) and contains screens with various aspect ratios.

The fine-tuned ScreenAI model achieves state-of-the-art results on various UI and infographic-based tasks (WebSRC and MoTIF) and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. ScreenAI achieves competitive performance on Screen2Words and OCR-VQA. Additionally, we report results on the new benchmark datasets introduced to serve as a baseline for further research.

Comparing model performance of ScreenAI with state-of-the-art (SOTA) models of similar size.

Next, we examine ScreenAI’s scaling capabilities and observe that across all tasks, increasing the model size improves performances and the improvements have not saturated at the largest size.

Model performance increases with size, and the performance has not saturated even at the largest size of 5B params.

Conclusion

We introduce the ScreenAI model along with a unified representation that enables us to develop self-supervised learning tasks leveraging data from all these domains. We also illustrate the impact of data generation using LLMs and investigate improving model performance on specific aspects with modifying the training mixture. We apply all of these techniques to build multi-task trained models that perform competitively with state-of-the-art approaches on a number of public benchmarks. However, we also note that our approach still lags behind large models and further research is needed to bridge this gap.

Acknowledgements

This project is the result of joint work with Maria Wang, Fedir Zubach, Hassan Mansoor, Vincent Etter, Victor Carbune, Jason Lin, Jindong Chen and Abhanshu Sharma. We thank Fangyu Liu, Xi Chen, Efi Kokiopoulou, Jesse Berent, Gabriel Barcik, Lukas Zilka, Oriana Riva, Gang Li,Yang Li, Radu Soricut, and Tania Bedrax-Weiss for their insightful feedback and discussions, along with Rahul Aralikatte, Hao Cheng and Daniel Kim for their support in data preparation. We also thank Jay Yagnik, Blaise Aguera y Arcas, Ewa Dominowska, David Petrou, and Matt Sharifi for their leadership, vision and support. We are very grateful toTom Small for helping us create the animation in this post.

Categories
Offsites

MediaPipe FaceStylizer: On-device real-time few-shot face stylization


In recent years, we have witnessed rising interest across consumers and researchers in integrated augmented reality (AR) experiences using real-time face feature generation and editing functions in mobile applications, including short videos, virtual reality, and gaming. As a result, there is a growing demand for lightweight, yet high-quality face generation and editing models, which are often based on generative adversarial network (GAN) techniques. However, the majority of GAN models suffer from high computational complexity and the need for a large training dataset. In addition, it is also important to employ GAN models responsibly.

In this post, we introduce MediaPipe FaceStylizer, an efficient design for few-shot face stylization that addresses the aforementioned model complexity and data efficiency challenges while being guided by Google’s responsible AI Principles. The model consists of a face generator and a face encoder used as GAN inversion to map the image into latent code for the generator. We introduce a mobile-friendly synthesis network for the face generator with an auxiliary head that converts features to RGB at each level of the generator to generate high quality images from coarse to fine granularities. We also carefully designed the loss functions for the aforementioned auxiliary heads and combined them with the common GAN loss functions to distill the student generator from the teacher StyleGAN model, resulting in a lightweight model that maintains high generation quality. The proposed solution is available in open source through MediaPipe. Users can fine-tune the generator to learn a style from one or a few images using MediaPipe Model Maker, and deploy to on-device face stylization applications with the customized model using MediaPipe FaceStylizer.

Few-shot on-device face stylization

An end-to-end pipeline

Our goal is to build a pipeline to support users to adapt the MediaPipe FaceStylizer to different styles by fine-tuning the model with a few examples. To enable such a face stylization pipeline, we built the pipeline with a GAN inversion encoder and efficient face generator model (see below). The encoder and generator pipeline can then be adapted to different styles via a few-shot learning process. The user first sends a single or a few similar samples of the style images to MediaPipe ModelMaker to fine-tune the model. The fine-tuning process freezes the encoder module and only fine-tunes the generator. The training process samples multiple latent codes close to the encoding output of the input style images as the input to the generator. The generator is then trained to reconstruct an image of a person’s face in the style of the input style image by optimizing a joint adversarial loss function that also accounts for style and content. With such a fine-tuning process, the MediaPipe FaceStylizer can adapt to the customized style, which approximates the user’s input. It can then be applied to stylize test images of real human faces.

Generator: BlazeStyleGAN

The StyleGAN model family has been widely adopted for face generation and various face editing tasks. To support efficient on-device face generation, we based the design of our generator on StyleGAN. This generator, which we call BlazeStyleGAN, is similar to StyleGAN in that it also contains a mapping network and synthesis network. However, since the synthesis network of StyleGAN is the major contributor to the model’s high computation complexity, we designed and employed a more efficient synthesis network. The improved efficiency and generation quality is achieved by:

  1. Reducing the latent feature dimension in the synthesis network to a quarter of the resolution of the counterpart layers in the teacher StyleGAN,
  2. Designing multiple auxiliary heads to transform the downscaled feature to the image domain to form a coarse-to-fine image pyramid to evaluate the perceptual quality of the reconstruction, and
  3. Skipping all but the final auxiliary head at inference time.

With the newly designed architecture, we train the BlazeStyleGAN model by distilling it from a teacher StyleGAN model. We use a multi-scale perceptual loss and adversarial loss in the distillation to transfer the high fidelity generation capability from the teacher model to the student BlazeStyleGAN model and also to mitigate the artifacts from the teacher model.

More details of the model architecture and training scheme can be found in our paper.

Visual comparison between face samples generated by StyleGAN and BlazeStyleGAN. The images on the first row are generated by the teacher StyleGAN. The images on the second row are generated by the student BlazeStyleGAN. The face generated by BlazeStyleGAN has similar visual quality to the image generated by the teacher model. Some results demonstrate the student BlazeStyleGAN suppresses the artifacts from the teacher model in the distillation.

In the above figure, we demonstrate some sample results of our BlazeStyleGAN. By comparing with the face image generated by the teacher StyleGAN model (top row), the images generated by the student BlazeStyleGAN (bottom row) maintain high visual quality and further reduce artifacts produced by the teacher due to the loss function design in our distillation.

An encoder for efficient GAN inversion

To support image-to-image stylization, we also introduced an efficient GAN inversion as the encoder to map input images to the latent space of the generator. The encoder is defined by a MobileNet V2 backbone and trained with natural face images. The loss is defined as a combination of image perceptual quality loss, which measures the content difference, style similarity and embedding distance, as well as the L1 loss between the input images and reconstructed images.

On-device performance

We documented model complexities in terms of parameter numbers and computing FLOPs in the following table. Compared to the teacher StyleGAN (33.2M parameters), BlazeStyleGAN (generator) significantly reduces the model complexity, with only 2.01M parameters and 1.28G FLOPs for output resolution 256×256. Compared to StyleGAN-1024 (generating image size of 1024×1024), the BlazeStyleGAN-1024 can reduce both model size and computation complexity by 95% with no notable quality difference and can even suppress the artifacts from the teacher StyleGAN model.

Model     Image Size     #Params (M)     FLOPs (G)
StyleGAN     1024     33.17     74.3
BlazeStyleGAN     1024     2.07     4.70
BlazeStyleGAN     512     2.05     1.57
BlazeStyleGAN     256     2.01     1.28
Encoder     256     1.44     0.60
Model complexity measured by parameter numbers and FLOPs.

We benchmarked the inference time of the MediaPipe FaceStylizer on various high-end mobile devices and demonstrated the results in the table below. From the results, both BlazeStyleGAN-256 and BlazeStyleGAN-512 achieved real-time performance on all GPU devices. It can run in less than 10 ms runtime on a high-end phone’s GPU. BlazeStyleGAN-256 can also achieve real-time performance on the iOS devices’ CPU.

Model     BlazeStyleGAN-256 (ms)     Encoder-256 (ms)
iPhone 11     12.14     11.48
iPhone 12     11.99     12.25
iPhone 13 Pro     7.22     5.41
Pixel 6     12.24     11.23
Samsung Galaxy S10     17.01     12.70
Samsung Galaxy S20     8.95     8.20
Latency benchmark of the BlazeStyleGAN, face encoder, and the end-to-end pipeline on various mobile devices.

Fairness evaluation

The model has been trained with a high diversity dataset of human faces. The model is expected to be fair to different human faces. The fairness evaluation demonstrates the model performs good and balanced in terms of human gender, skin-tone, and ages.

Face stylization visualization

Some face stylization results are demonstrated in the following figure. The images in the top row (in orange boxes) represent the style images used to fine-tune the model. The images in the left column (in the green boxes) are the natural face images used for testing. The 2×4 matrix of images represents the output of the MediaPipe FaceStylizer which is blending outputs between the natural faces on the left-most column and the corresponding face styles on the top row. The results demonstrate that our solution can achieve high-quality face stylization for several popular styles.

Sample results of our MediaPipe FaceStylizer.

MediaPipe Solutions

The MediaPipe FaceStylizer is going to be released to public users in MediaPipe Solutions. Users can leverage MediaPipe Model Maker to train a customized face stylization model using their own style images. After training, the exported bundle of TFLite model files can be deployed to applications across platforms (Android, iOS, Web, Python, etc.) using the MediaPipe Tasks FaceStylizer API in just a few lines of code.

Acknowledgements

This work is made possible through a collaboration spanning several teams across Google. We’d like to acknowledge contributions from Omer Tov, Yang Zhao, Andrey Vakunov, Fei Deng, Ariel Ephrat, Inbar Mosseri, Lu Wang, Chuo-Ling Chang, Tingbo Hou, and Matthias Grundmann.

Categories
Offsites

On-device content distillation with graph neural networks

In today’s digital age, smartphones and desktop web browsers serve as the primary tools for accessing news and information. However, the proliferation of website clutter — encompassing complex layouts, navigation elements, and extraneous links — significantly impairs both the reading experience and article navigation. This issue is particularly acute for individuals with accessibility requirements.

To improve the user experience and make reading more accessible, Android and Chrome users may leverage the Reading Mode feature, which enhances accessibility by processing webpages to allow customizable contrast, adjustable text size, more legible fonts, and to enable text-to-speech utilities. Additionally, Android’s Reading Mode is equipped to distill content from apps. Expanding Reading Mode to encompass a wide array of content and improving its performance, while still operating locally on the user’s device without transmitting data externally, poses a unique challenge.

To broaden Reading Mode capabilities without compromising privacy, we have developed a novel on-device content distillation model. Unlike early attempts using DOM Distiller — a heuristic approach limited to news articles — our model excels in both quality and versatility across various types of content. We ensure that article content doesn’t leave the confines of the local environment. Our on-device content distillation model smoothly transforms long-form content into a simple and customizable layout for a more pleasant reading journey while also outperforming the leading alternative approaches. Here we explore details of this research highlighting our approach, methodology, and results.

Graph neural networks

Instead of relying on complicated heuristics that are difficult to maintain and scale to a variety of article layouts, we approach this task as a fully supervised learning problem. This data-driven approach allows the model to generalize better across different layouts, without the constraints and fragility of heuristics. Previous work for optimizing the reading experience relied on HTML or parsing, filtering, and modeling of a document object model (DOM), a programming interface automatically generated by the user’s web browser from site HTML that represents the structure of a document and allows it to be manipulated.

The new Reading Mode model relies on accessibility trees, which provide a streamlined and more accessible representation of the DOM. Accessibility trees are automatically generated from the DOM tree and are utilized by assistive technologies to allow people with disabilities to interact with web content. These are available on Chrome Web browser and on Android through AccessibilityNodeInfo objects, which are provided for both WebView and native application content.

We started by manually collecting and annotating accessibility trees. The Android dataset used for this project comprises on the order of 10k labeled examples, while the Chrome dataset contains approximately 100k labeled examples. We developed a novel tool that uses graph neural networks (GNNs) to distill essential content from the accessibility trees using a multi-class supervised learning approach. The datasets consist of long-form articles sampled from the web and labeled with classes such as headline, paragraph, images, publication date, etc.

GNNs are a natural choice for dealing with tree-like data structures, because unlike traditional models that often demand detailed, hand-crafted features to understand the layout and links within such trees, GNNs learn these connections naturally. To illustrate this, consider the analogy of a family tree. In such a tree, each node represents a family member and the connections denote familial relationships. If one were to predict certain traits using conventional models, features like the “number of immediate family members with a trait” might be needed. However, with GNNs, such manual feature crafting becomes redundant. By directly feeding the tree structure into the model, GNNs utilize a message-passing mechanism where each node communicates with its neighbors. Over time, information gets shared and accumulated across the network, enabling the model to naturally discern intricate relationships.

Returning to the context of accessibility trees, this means that GNNs can efficiently distill content by understanding and leveraging the inherent structure and relationships within the tree. This capability allows them to identify and possibly omit non-essential sections based on the information flow within the tree, ensuring more accurate content distillation.

Our architecture heavily follows the encode-process-decode paradigm using a message-passing neural network to classify text nodes. The overall design is illustrated in the figure below. The tree representation of the article is the input to the model. We compute lightweight features based on bounding box information, text information, and accessibility roles. The GNN then propagates each node’s latent representation through the edges of the tree using a message-passing neural network. This propagation process allows nearby nodes, containers, and text elements to share contextual information with each other, enhancing the model’s understanding of the page’s structure and content. Each node then updates its current state based on the message received, providing a more informed basis for classifying the nodes. After a fixed number of message-passing steps, the now contextualized latent representations of the nodes are decoded into essential or non-essential classes. This approach enables the model to leverage both the inherent relationships in the tree and the hand-crafted features representing each node, thereby enriching the final classification.

A visual demonstration of the algorithm in action, processing an article on a mobile device. A graph neural network (GNN) is used to distill essential content from an article. 1. A tree representation of the article is extracted from the application. 2. Lightweight features are computed for each node, represented as vectors. 3. A message-passing neural network propagates information through the edges of the tree and updates each node representation. 4. Leaf nodes containing text content are classified as essential or non-essential content. 5. A decluttered version of the application is composed based on the GNN output.

We deliberately restrict the feature set used by the model to increase its broad generalization across languages and speed up inference latency on user devices. This was a unique challenge, as we needed to create an on-device lightweight model that could preserve privacy.

Our final lightweight Android model has 64k parameters and is 334kB in size with a median latency of 800ms, while the Chrome model has 241k parameters, is 928kB in size, and has a 378ms median latency. By employing such on-device processing, we ensure that user data never leaves the device, reinforcing our responsible approach and commitment to user privacy. The features used in the model can be grouped into intermediate node features, leaf-node text features, and element position features. We performed feature engineering and feature selection to optimize the set of features for model performance and model size. The final model was transformed into TensorFlow Lite format to deploy as an on-device model on Android or Chrome.

Results

We trained the GNN for about 50 epochs in a single GPU. The performance of the Android model on webpages and native application test sets is presented below:

The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

<!–

Android Quality
    Webpages     Native Apps
node metrics     Precision     Recall     F1-score     Precision     Recall     F1-score
non-essential     0.9842     0.9846     0.9844     0.9744     0.9350     0.9543
headline     0.9187     0.9784     0.9476     0.9183     0.8568     0.8865
main-text     0.9223     0.9172     0.9197     0.8443     0.9424     0.8907
macro-average     0.9417     0.9600     0.9506     0.9124     0.9114     0.9105
weighted average     0.9736     0.9736     0.9736     0.9392     0.9353     0.9363
headline + main-text     0.9510     0.9683     0.9595     0.9473     0.9507     0.9490
The table presents the content distillation metrics in Android for webpages and native apps. We report precision, recall and F1-score for three classes: non-essential content, headline, and main body text, including macro average and weighted average by number of instances in each class. Node metrics assess the classification performance at the granularity of the accessibility tree node, which is analogous to a paragraph level. In contrast, word metrics evaluate classification at an individual word level, meaning each word within a node gets the same classification.

–>

In assessing the results’ quality on commonly visited webpage articles, an F1-score exceeding 0.9 for main-text (essentially paragraphs) corresponds to 88% of these articles being processed without missing any paragraphs. Furthermore, in over 95% of cases, the distillation proves to be valuable for readers. Put simply, the vast majority of readers will perceive the distilled content as both pertinent and precise, with errors or omissions being an infrequent occurrence.

The comparison of Chrome content distillation with other models such as DOM Distiller or Mozilla Readability on a set of English language pages is presented in the table below. We reuse the metrics from machine translation to compare the quality of these models. The reference text is from the groundtruth main content and the text from the models as hypothesis text. The results show the excellent performance of our models in comparison to other DOM-based approaches.

The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

<!–

Chrome Model Comparison on Webpages
Metric / Model     DOM Distiller     Mozilla Readability     Our Chrome model
BLEU     78.97     79.16     94.59
CHRF     0.92     0.92     0.98
ROUGE1     84.10     84.62     95.13
ROUGE2     81.84     82.66     94.81
ROUGE3     80.21     81.45     94.60
ROUGEL     83.58     84.02     95.04
ROUGEL-SUM     83.46     84.03     95.04
The table presents the comparison between DOM-Distiller, Mozilla Readability and the new Chrome model. We report text-based metrics, such as BLUE, CHRF and ROUGE, by comparing the main body text distilled from each model to a ground-truth text manually labeled by raters using our annotation policy.

–>

The F1-score of the Chrome content distillation model for headline and main text content on the test sets of different widely spoken languages demonstrates that the Chrome model, in particular, is able to support a wide range of languages.

The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

<!–

Chrome Model on Different Languages
F1-score     de     en     es     fr     it     fa     ja     ko     pt     vi     zh-Hans     zh-Hant     average
headline     0.91     0.97     0.99     0.98     0.97     0.89     0.97     0.98     0.99     0.98     0.97     0.93     0.96
main text     0.84     0.90     0.93     0.91     0.93     0.87     0.88     0.91     0.91     0.90     0.90     0.90     0.90
The table presents per language of F1-scores of the Chrome model for the headline and main text classes. The language codes correspond to the following languages: German, English, Spanish, French, Italian, Persian, Japanese, Korean, Portuguese, Vietnamese, simplified Chinese and traditional Chinese.

–>

Conclusion

The digital age demands both streamlined content presentation and an unwavering commitment to user privacy. Our research highlights the effectiveness of Reading Mode in platforms like Android and Chrome, offering an innovative, data-driven approach to content parsing through Graph Neural Networks. Crucially, our lightweight on-device model ensures that content distillation occurs without compromising user data, with all processes executed locally. This not only enhances the reading experience but also reinforces our dedication to user privacy. As we navigate the evolving landscape of digital content consumption, our findings underscore the paramount importance of prioritizing the user in both experience and security.

Acknowledgements

This project is the result of joint work with Manuel Tragut, Mihai Popa, Abodunrinwa Toki, Abhanshu Sharma, Matt Sharifi, David Petrou and Blaise Aguera y Arcas. We sincerely thank our collaborators Gang Li and Yang Li. We are very grateful to Tom Small for assisting us in preparing the post.

Categories
Offsites

World scale inverse reinforcement learning in Google Maps

Routing in Google Maps remains one of our most helpful and frequently used features. Determining the best route from A to B requires making complex trade-offs between factors including the estimated time of arrival (ETA), tolls, directness, surface conditions (e.g., paved, unpaved roads), and user preferences, which vary across transportation mode and local geography. Often, the most natural visibility we have into travelers’ preferences is by analyzing real-world travel patterns.

Learning preferences from observed sequential decision making behavior is a classic application of inverse reinforcement learning (IRL). Given a Markov decision process (MDP) — a formalization of the road network — and a set of demonstration trajectories (the traveled routes), the goal of IRL is to recover the users’ latent reward function. Although past research has created increasingly general IRL solutions, these have not been successfully scaled to world-sized MDPs. Scaling IRL algorithms is challenging because they typically require solving an RL subroutine at every update step. At first glance, even attempting to fit a world-scale MDP into memory to compute a single gradient step appears infeasible due to the large number of road segments and limited high bandwidth memory. When applying IRL to routing, one needs to consider all reasonable routes between each demonstration’s origin and destination. This implies that any attempt to break the world-scale MDP into smaller components cannot consider components smaller than a metropolitan area.

To this end, in “Massively Scalable Inverse Reinforcement Learning in Google Maps“, we share the result of a multi-year collaboration among Google Research, Maps, and Google DeepMind to surpass this IRL scalability limitation. We revisit classic algorithms in this space, and introduce advances in graph compression and parallelization, along with a new IRL algorithm called Receding Horizon Inverse Planning (RHIP) that provides fine-grained control over performance trade-offs. The final RHIP policy achieves a 16–24% relative improvement in global route match rate, i.e., the percentage of de-identified traveled routes that exactly match the suggested route in Google Maps. To the best of our knowledge, this represents the largest instance of IRL in a real world setting to date.

Google Maps improvements in route match rate relative to the existing baseline, when using the RHIP inverse reinforcement learning policy.

The benefits of IRL

A subtle but crucial detail about the routing problem is that it is goal conditioned, meaning that every destination state induces a slightly different MDP (specifically, the destination is a terminal, zero-reward state). IRL approaches are well suited for these types of problems because the learned reward function transfers across MDPs, and only the destination state is modified. This is in contrast to approaches that directly learn a policy, which typically require an extra factor of S parameters, where S is the number of MDP states.

Once the reward function is learned via IRL, we take advantage of a powerful inference-time trick. First, we evaluate the entire graph’s rewards once in an offline batch setting. This computation is performed entirely on servers without access to individual trips, and operates only over batches of road segments in the graph. Then, we save the results to an in-memory database and use a fast online graph search algorithm to find the highest reward path for routing requests between any origin and destination. This circumvents the need to perform online inference of a deeply parameterized model or policy, and vastly improves serving costs and latency.

Reward model deployment using batch inference and fast online planners.

Receding Horizon Inverse Planning

To scale IRL to the world MDP, we compress the graph and shard the global MDP using a sparse Mixture of Experts (MoE) based on geographic regions. We then apply classic IRL algorithms to solve the local MDPs, estimate the loss, and send gradients back to the MoE. The worldwide reward graph is computed by decompressing the final MoE reward model. To provide more control over performance characteristics, we introduce a new generalized IRL algorithm called Receding Horizon Inverse Planning (RHIP).

IRL reward model training using MoE parallelization, graph compression, and RHIP.

RHIP is inspired by people’s tendency to perform extensive local planning (“What am I doing for the next hour?”) and approximate long-term planning (“What will my life look like in 5 years?”). To take advantage of this insight, RHIP uses robust yet expensive stochastic policies in the local region surrounding the demonstration path, and switches to cheaper deterministic planners beyond some horizon. Adjusting the horizon H allows controlling computational costs, and often allows the discovery of the performance sweet spot. Interestingly, RHIP generalizes many classic IRL algorithms and provides the novel insight that they can be viewed along a stochastic vs. deterministic spectrum (specifically, for H=∞ it reduces to MaxEnt, for H=1 it reduces to BIRL, and for H=0 it reduces to MMP).

Given a demonstration from so to sd, (1) RHIP follows a robust yet expensive stochastic policy in the local region surrounding the demonstration (blue region). (2) Beyond some horizon H, RHIP switches to following a cheaper deterministic planner (red lines). Adjusting the horizon enables fine-grained control over performance and computational costs.

Routing wins

The RHIP policy provides a 15.9% and 24.1% lift in global route match rate for driving and two-wheelers (e.g., scooters, motorcycles, mopeds) relative to the well-tuned Maps baseline, respectively. We’re especially excited about the benefits to more sustainable transportation modes, where factors beyond journey time play a substantial role. By tuning RHIP’s horizon H, we’re able to achieve a policy that is both more accurate than all other IRL policies and 70% faster than MaxEnt.

Our 360M parameter reward model provides intuitive wins for Google Maps users in live A/B experiments. Examining road segments with a large absolute difference between the learned rewards and the baseline rewards can help improve certain Google Maps routes. For example:

Nottingham, UK. The preferred route (blue) was previously marked as private property due to the presence of a large gate, which indicated to our systems that the road may be closed at times and would not be ideal for drivers. As a result, Google Maps routed drivers through a longer, alternate detour instead (red). However, because real-world driving patterns showed that users regularly take the preferred route without an issue (as the gate is almost never closed), IRL now learns to route drivers along the preferred route by placing a large positive reward on this road segment.

Conclusion

Increasing performance via increased scale – both in terms of dataset size and model complexity – has proven to be a persistent trend in machine learning. Similar gains for inverse reinforcement learning problems have historically remained elusive, largely due to the challenges with handling practically sized MDPs. By introducing scalability advancements to classic IRL algorithms, we’re now able to train reward models on problems with hundreds of millions of states, demonstration trajectories, and model parameters, respectively. To the best of our knowledge, this is the largest instance of IRL in a real-world setting to date. See the paper to learn more about this work.

Acknowledgements

This work is a collaboration across multiple teams at Google. Contributors to the project include Matthew Abueg, Oliver Lange, Matt Deeds, Jason Trader, Denali Molitor, Markus Wulfmeier, Shawn O’Banion, Ryan Epp, Renaud Hartert, Rui Song, Thomas Sharp, Rémi Robert, Zoltan Szego, Beth Luan, Brit Larabee and Agnieszka Madurska.

We’d also like to extend our thanks to Arno Eigenwillig, Jacob Moorman, Jonathan Spencer, Remi Munos, Michael Bloesch and Arun Ahuja for valuable discussions and suggestions.

Categories
Offsites

Differentially private median and more

Differential privacy (DP) is a rigorous mathematical definition of privacy. DP algorithms are randomized to protect user data by ensuring that the probability of any particular output is nearly unchanged when a data point is added or removed. Therefore, the output of a DP algorithm does not disclose the presence of any one data point. There has been significant progress in both foundational research and adoption of differential privacy with contributions such as the Privacy Sandbox and Google Open Source Library.

ML and data analytics algorithms can often be described as performing multiple basic computation steps on the same dataset. When each such step is differentially private, so is the output, but with multiple steps the overall privacy guarantee deteriorates, a phenomenon known as the cost of composition. Composition theorems bound the increase in privacy loss with the number k of computations: In the general case, the privacy loss increases with the square root of k. This means that we need much stricter privacy guarantees for each step in order to meet our overall privacy guarantee goal. But in that case, we lose utility. One way to improve the privacy vs. utility trade-off is to identify when the use cases admit a tighter privacy analysis than what follows from composition theorems.

Good candidates for such improvement are when each step is applied to a disjoint part (slice) of the dataset. When the slices are selected in a data-independent way, each point affects only one of the k outputs and the privacy guarantees do not deteriorate with k. However, there are applications in which we need to select the slices adaptively (that is, in a way that depends on the output of prior steps). In these cases, a change of a single data point may cascade — changing multiple slices and thus increasing composition cost.

In “Õptimal Differentially Private Learning of Thresholds and Quasi-Concave Optimization”, presented at STOC 2023, we describe a new paradigm that allows for slices to be selected adaptively and yet avoids composition cost. We show that DP algorithms for multiple fundamental aggregation and learning tasks can be expressed in this Reorder-Slice-Compute (RSC) paradigm, gaining significant improvements in utility.

The Reorder-Slice-Compute (RSC) paradigm

An algorithm A falls in the RSC paradigm if it can be expressed in the following general form (see visualization below). The input is a sensitive set D of data points. The algorithm then performs a sequence of k steps as follows:

  1. Select an ordering over data points, a slice size m, and a DP algorithm M. The selection may depend on the output of A in prior steps (and hence is adaptive).
  2. Slice out the (approximately) top m data points according to the order from the dataset D, apply M to the slice, and output the result.
A visualization of three Reorder-Slice-Compute (RSC) steps.

If we analyze the overall privacy loss of an RSC algorithm using DP composition theorems, the privacy guarantee suffers from the expected composition cost, i.e., it deteriorates with the square root of the number of steps k. To eliminate this composition cost, we provide a novel analysis that removes the dependence on k altogether: the overall privacy guarantee is close to that of a single step! The idea behind our tighter analysis is a novel technique that limits the potential cascade of affected steps when a single data point is modified (details in the paper).

Tighter privacy analysis means better utility. The effectiveness of DP algorithms is often stated in terms of the smallest input size (number of data points) that suffices in order to release a correct result that meets the privacy requirements. We describe several problems with algorithms that can be expressed in the RSC paradigm and for which our tighter analysis improved utility.

Private interval point

We start with the following basic aggregation task. The input is a dataset D of n points from an ordered domain X (think of the domain as the natural numbers between 1 and |X|). The goal is to return a point y in X that is in the interval of D, that is between the minimum and the maximum points in D.

The solution to the interval point problem is trivial without the privacy requirement: simply return any point in the dataset D. But this solution is not privacy-preserving as it discloses the presence of a particular datapoint in the input. We can also see that if there is only one point in the dataset, a privacy-preserving solution is not possible, as it must return that point. We can therefore ask the following fundamental question: What is the smallest input size N for which we can solve the private interval point problem?

It is known that N must increase with the domain size |X| and that this dependence is at least the iterated log function log* |X| [1, 2]. On the other hand, the best prior DP algorithm required the input size to be at least (log* |X|)1.5. To close this gap, we designed an RSC algorithm that requires only an order of log* |X| points.

The iterated log function is extremely slow growing: It is the number of times we need to take a logarithm of a value before we reach a value that is equal to or smaller than 1. How did this function naturally come out in the analysis? Each step of the RSC algorithm remapped the domain to a logarithm of its prior size. Therefore there were log* |X| steps in total. The tighter RSC analysis eliminated a square root of the number of steps from the required input size.

Even though the interval point task seems very basic, it captures the essence of the difficulty of private solutions for common aggregation tasks. We next describe two of these tasks and express the required input size to these tasks in terms of N.

Private approximate median

One of these common aggregation tasks is approximate median: The input is a dataset D of n points from an ordered domain X. The goal is to return a point y that is between the ⅓ and ⅔ quantiles of D. That is, at least a third of the points in D are smaller or equal to y and at least a third of the points are larger or equal to y. Note that returning an exact median is not possible with differential privacy, since it discloses the presence of a datapoint. Hence we consider the relaxed requirement of an approximate median (shown below).

We can compute an approximate median by finding an interval point: We slice out the N smallest points and the N largest points and then compute an interval point of the remaining points. The latter must be an approximate median. This works when the dataset size is at least 3N.

An example of a data D over domain X, the set of interval points, and the set of approximate medians.

Private learning of axis-aligned rectangles

For the next task, the input is a set of n labeled data points, where each point x = (x1,….,xd) is a d-dimensional vector over a domain X. Displayed below, the goal is to learn values ai , bi for the axes i=1,…,d that define a d-dimensional rectangle, so that for each example x

  • If x is positively labeled (shown as red plus signs below) then it lies within the rectangle, that is, for all axes i, xi is in the interval [ai ,bi], and
  • If x is negatively labeled (shown as blue minus signs below) then it lies outside the rectangle, that is, for at least one axis i, xi is outside the interval [ai ,bi].
A set of 2-dimensional labeled points and a respective rectangle.

Any DP solution for this problem must be approximate in that the learned rectangle must be allowed to mislabel some data points, with some positively labeled points outside the rectangle or negatively labeled points inside it. This is because an exact solution could be very sensitive to the presence of a particular data point and would not be private. The goal is a DP solution that keeps this necessary number of mislabeled points small.

We first consider the one-dimensional case (d = 1). We are looking for an interval [a,b] that covers all positive points and none of the negative points. We show that we can do this with at most 2N mislabeled points. We focus on the positively labeled points. In the first RSC step we slice out the N smallest points and compute a private interval point as a. We then slice out the N largest points and compute a private interval point as b. The solution [a,b] correctly labels all negatively labeled points and mislabels at most 2N of the positively labeled points. Thus, at most ~2N points are mislabeled in total.

Illustration for d = 1, we slice out N left positive points and compute an interval point a, slice out N right positive points and compute an interval point b.

With d > 1, we iterate over the axes i = 1,….,d and apply the above for the ith coordinates of input points to obtain the values ai , bi. In each iteration, we perform two RSC steps and slice out 2N positively labeled points. In total, we slice out 2dN points and all remaining points were correctly labeled. That is, all negatively-labeled points are outside the final d-dimensional rectangle and all positively-labeled points, except perhaps ~2dN, lie inside the rectangle. Note that this algorithm uses the full flexibility of RSC in that the points are ordered differently by each axis. Since we perform d steps, the RSC analysis shaves off a factor of square root of d from the number of mislabeled points.

Training ML models with adaptive selection of training examples

The training efficiency or performance of ML models can sometimes be improved by selecting training examples in a way that depends on the current state of the model, e.g., self-paced curriculum learning or active learning.

The most common method for private training of ML models is DP-SGD, where noise is added to the gradient update from each minibatch of training examples. Privacy analysis with DP-SGD typically assumes that training examples are randomly partitioned into minibatches. But if we impose a data-dependent selection order on training examples, and further modify the selection criteria k times during training, then analysis through DP composition results in deterioration of the privacy guarantees of a magnitude equal to the square root of k.

Fortunately, example selection with DP-SGD can be naturally expressed in the RSC paradigm: each selection criteria reorders the training examples and each minibatch is a slice (for which we compute a noisy gradient). With RSC analysis, there is no privacy deterioration with k, which brings DP-SGD training with example selection into the practical domain.

Conclusion

The RSC paradigm was introduced in order to tackle an open problem that is primarily of theoretical significance, but turns out to be a versatile tool with the potential to enhance data efficiency in production environments.

Acknowledgments

The work described here was done jointly with Xin Lyu, Jelani Nelson, and Tamas Sarlos.