Categories
Offsites

The medical test paradox (well “paradox”)

Categories
Offsites

Three levels of understanding Bayes’ theorem

Categories
Offsites

Ellipses have multiple definitions, how are these the same?

Categories
Offsites

A challenging puzzle about subset sums

Categories
Offsites

How the Mandelbrot set is defined

Categories
Offsites

Simulating the electric field and a moving charge

Categories
Offsites

Health-specific embedding tools for dermatology and pathology

There’s a worldwide shortage of access to medical imaging expert interpretation across specialties including radiology, dermatology and pathology. Machine learning (ML) technology can help ease this burden by powering tools that enable doctors to interpret these images more accurately and efficiently. However, the development and implementation of such ML tools are often limited by the availability of high-quality data, ML expertise, and computational resources.

One way to catalyze the use of ML for medical imaging is via domain-specific models that utilize deep learning (DL) to capture the information in medical images as compressed numerical vectors (called embeddings). These embeddings represent a type of pre-learned understanding of the important features in an image. Identifying patterns in the embeddings reduces the amount of data, expertise, and compute needed to train performant models as compared to working with high-dimensional data, such as images, directly. Indeed, these embeddings can be used to perform a variety of downstream tasks within the specialized domain (see animated graphic below). This framework of leveraging pre-learned understanding to solve related tasks is similar to that of a seasoned guitar player quickly learning a new song by ear. Because the guitar player has already built up a foundation of skill and understanding, they can quickly pick up the patterns and groove of a new song.

Path Foundation is used to convert a small dataset of (image, label) pairs into (embedding, label) pairs. These pairs can then be used to train a task-specific classifier using a linear probe, (i.e., a lightweight linear classifier) as represented in this graphic, or other types of models using the embeddings as input.

Once the linear probe is trained, it can be used to make predictions on embeddings from new images. These predictions can be compared to ground truth information in order to evaluate the linear probe’s performance.

In order to make this type of embedding model available and drive further development of ML tools in medical imaging, we are excited to release two domain-specific tools for research use: Derm Foundation and Path Foundation. This follows on the strong response we’ve already received from researchers using the CXR Foundation embedding tool for chest radiographs and represents a portion of our expanding research offerings across multiple medical-specialized modalities. These embedding tools take an image as input and produce a numerical vector (the embedding) that is specialized to the domains of dermatology and digital pathology images, respectively. By running a dataset of chest X-ray, dermatology, or pathology images through the respective embedding tool, researchers can obtain embeddings for their own images, and use these embeddings to quickly develop new models for their applications.

Path Foundation

In “Domain-specific optimization and diverse evaluation of self-supervised models for histopathology”, we showed that self-supervised learning (SSL) models for pathology images outperform traditional pre-training approaches and enable efficient training of classifiers for downstream tasks. This effort focused on hematoxylin and eosin (H&E) stained slides, the principal tissue stain in diagnostic pathology that enables pathologists to visualize cellular features under a microscope. The performance of linear classifiers trained using the output of the SSL models matched that of prior DL models trained on orders of magnitude more labeled data.

Due to substantial differences between digital pathology images and “natural image” photos, this work involved several pathology-specific optimizations during model training. One key element is that whole-slide images (WSIs) in pathology can be 100,000 pixels across (thousands of times larger than typical smartphone photos) and are analyzed by experts at multiple magnifications (zoom levels). As such, the WSIs are typically broken down into smaller tiles or patches for computer vision and DL applications. The resulting images are information dense with cells or tissue structures distributed throughout the frame instead of having distinct semantic objects or foreground vs. background variations, thus creating unique challenges for robust SSL and feature extraction. Additionally, physical (e.g., cutting) and chemical (e.g., fixing and staining) processes used to prepare the samples can influence image appearance dramatically.

Taking these important aspects into consideration, pathology-specific SSL optimizations included helping the model learn stain-agnostic features, generalizing the model to patches from multiple magnifications, augmenting the data to mimic scanning and image post processing, and custom data balancing to improve input heterogeneity for SSL training. These approaches were extensively evaluated using a broad set of benchmark tasks involving 17 different tissue types over 12 different tasks.

Utilizing the vision transformer (ViT-S/16) architecture, Path Foundation was selected as the best performing model from the optimization and evaluation process described above (and illustrated in the figure below). This model thus provides an important balance between performance and model size to enable valuable and scalable use in generating embeddings over the many individual image patches of large pathology WSIs.

SSL training with pathology-specific optimizations for Path Foundation.

The value of domain-specific image representations can also be seen in the figure below, which shows the linear probing performance improvement of Path Foundation (as measured by AUROC) compared to traditional pre-training on natural images (ImageNet-21k). This includes evaluation for tasks such as metastatic breast cancer detection in lymph nodes, prostate cancer grading, and breast cancer grading, among others.

Path Foundation embeddings significantly outperform traditional ImageNet embeddings as evaluated by linear probing across multiple evaluation tasks in histopathology.

Derm Foundation

Derm Foundation is an embedding tool derived from our research in applying DL to interpret images of dermatology conditions and includes our recent work that adds improvements to generalize better to new datasets. Due to its dermatology-specific pre-training it has a latent understanding of features present in images of skin conditions and can be used to quickly develop models to classify skin conditions. The model underlying the API is a BiT ResNet-101×3 trained in two stages. The first pre-training stage uses contrastive learning, similar to ConVIRT, to train on a large number of image-text pairs from the internet. In the second stage, the image component of this pre-trained model is then fine-tuned for condition classification using clinical datasets, such as those from teledermatology services.

Unlike histopathology images, dermatology images more closely resemble the real-world images used to train many of today’s computer vision models. However, for specialized dermatology tasks, creating a high-quality model may still require a large dataset. With Derm Foundation, researchers can use their own smaller dataset to retrieve domain-specific embeddings, and use those to build smaller models (e.g., linear classifiers or other small non-linear models) that enable them to validate their research or product ideas. To evaluate this approach, we trained models on a downstream task using teledermatology data. Model training involved varying dataset sizes (12.5%, 25%, 50%, 100%) to compare embedding-based linear classifiers against fine-tuning.

The modeling variants considered were:

  • A linear classifier on frozen embeddings from BiT-M (a standard pre-trained image model)
  • Fine-tuned version of BiT-M with an extra dense layer for the downstream task
  • A linear classifier on frozen embeddings from the Derm Foundation API
  • Fine-tuned version of the model underlying the Derm Foundation API with an extra layer for the downstream task

We found that models built on top of the Derm Foundation embeddings for dermatology-related tasks achieved significantly higher quality than those built solely on embeddings or fine tuned from BiT-M. This advantage was found to be most pronounced for smaller training dataset sizes.

These results demonstrate that the Derm Foundation tooI can serve as a useful starting point to accelerate skin-related modeling tasks. We aim to enable other researchers to build on the underlying features and representations of dermatology that the model has learned.

However, there are limitations with this analysis. We’re still exploring how well these embeddings generalize across task types, patient populations, and image settings. Downstream models built using Derm Foundation still require careful evaluation to understand their expected performance in the intended setting.

Access Path and Derm Foundation

We envision that the Derm Foundation and Path Foundation embedding tools will enable a range of use cases, including efficient development of models for diagnostic tasks, quality assurance and pre-analytical workflow improvements, image indexing and curation, and biomarker discovery and validation. We are releasing both tools to the research community so they can explore the utility of the embeddings for their own dermatology and pathology data.

To get access, please sign up to each tool’s terms of service using the following Google Forms.

After gaining access to each tool, you can use the API to retrieve embeddings from dermatology images or digital pathology images stored in Google Cloud. Approved users who are just curious to see the model and embeddings in action can use the provided example Colab notebooks to train models using public data for classifying six common skin conditions or identifying tumors in histopathology patches. We look forward to seeing the range of use-cases these tools can unlock.

Acknowledgements

We would like to thank the many collaborators who helped make this work possible including Yun Liu, Can Kirmizi, Fereshteh Mahvar, Bram Sterling, Arman Tajback, Kenneth Philbrik, Arnav Agharwal, Aurora Cheung, Andrew Sellergren, Boris Babenko, Basil Mustafa, Jan Freyberg, Terry Spitz, Yuan Liu, Pinal Bavishi, Ayush Jain, Amit Talreja, Rajeev Rikhye, Abbi Ward, Jeremy Lai, Faruk Ahmed, Supriya Vijay,Tiam Jaroensri, Jessica Loo, Saurabh Vyawahare, Saloni Agarwal, Ellery Wulczyn, Jonathan Krause, Fayaz Jamil, Tom Small, Annisah Um’rani, Lauren Winer, Sami Lachgar, Yossi Matias, Greg Corrado, and Dale Webster.

Categories
Offsites

Chain-of-table: Evolving tables in the reasoning chain for table understanding

People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in natural language processing (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize.

Recently, large language models (LLMs) have achieved outstanding performance across diverse natural language understanding (NLU) tasks by generating reliable reasoning chains, as shown in works like Chain-of-Thought and Least-to-Most. However, the most suitable way for LLMs to reason over tabular data remains an open question.

In “Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the WikiTQ, TabFact, and FeTaQA benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods.

Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question.

Chain-of-Table

In Chain-of-Table, we guide LLMs using in-context learning to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM.

For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question.

These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding.

Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably.

Chain-of-Table consists of three main stages. In the first stage, it instructs the LLM to dynamically plan the next operation by in-context learning. Specifically, the prompt involves three components as shown in the following figure:

  1. The question Q: “Which country had the most cyclists finish in the top 3?”
  2. The operation history chain: f_add_col(Country) and f_select_row(1, 2, 3).
  3. The latest intermediate table T: the transformed intermediate table.

By providing the triplet (T, Q, chain) in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step.

Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments.

After the next operation f is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table.

For instance, when the operation f_group_by is selected, it requires a header name as its argument.

The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning.

Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer.

Experimental setup

We use PaLM 2-S and GPT 3.5 as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: WikiTQ, TabFact, and FeTaQA. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and Chain-of-Thought) and the program-aided methods (e.g., Text-to-SQL, Binder, and Dater).

More accurate answers

Compared to the generic reasoning methods and program-aided reasoning methods, Chain-of-Table achieves better performance across PaLM 2 and GPT 3.5. This is attributed to the dynamically sampled operations and the informative intermediate tables.

Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models.

Better robustness on harder questions

In Chain-of-Table, longer operation chains indicate the higher difficulty and complexity of the questions and their corresponding tables. We categorize the test samples according to their operation lengths in Chain-of-Table. We compare Chain-of-Table with Chain-of-Thought and Dater, as representative generic and program-aided reasoning methods. We illustrate this using results from PaLM 2 on WikiTQ.

Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts.

Notably, Chain-of-Table consistently surpasses both baseline methods across all operation chain lengths, with a significant margin up to 11.6% compared with Chain-of-Thought, and up to 7.9% compared with Dater. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five.

Better robustness with larger tables

We categorize the tables from WikiTQ into three groups based on token number: small (<2000 tokens), medium (2000 to 4000 tokens) and large (>4000 tokens). We then compare Chain-of-Table with Dater and Binder, the two latest and strongest baselines.

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs.

Conclusion

Our proposed Chain-of-Table method enhances the reasoning capability of LLMs by leveraging the tabular structure to express intermediate steps for table-based reasoning. It instructs LLMs to dynamically plan an operation chain according to the input table and its associated question. This evolving table design sheds new light on the understanding of prompting LLMs for table understanding.

Acknowledgements

This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.

Categories
Offsites

Talk like a graph: Encoding graphs for large language models

Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term graph is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.

Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses.

Since graphs are everywhere and LLM technology is on the rise, in “Talk like a Graph: Encoding Graphs for Large Language Models”, presented at ICLR 2024, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called GraphQA to study different approaches on different graph reasoning problems and show how to phrase a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!

Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.

Graphs as text

To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called GraphQA. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.

Overview of our framework for reasoning with graphs using LLMs.

GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like Erdős-Rényi, scale-free networks, Barabasi-Albert model, and stochastic block model, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.

When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. Prompting heuristics are different strategies for doing this. Let’s break down the common ones:

  • Zero-shot: simply describe the task (“Is there a cycle in this graph?”) and tell the LLM to go for it. No examples provided.
  • Few-shot: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.
  • Chain-of-Thought: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own “thought process” when faced with new graphs.
  • Zero-CoT: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like “Let’s think step-by-step,” to trigger its own problem-solving breakdown.
  • BAG (build a graph): This is specifically for graph tasks. We add the phrase “Let’s build a graph…” to the description, helping the LLM focus on the graph structure.

We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:

  • Node encoding: How do we represent individual nodes? Options tested include simple integers, common names (people, characters), and letters.
  • Edge encoding: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like “are friends”, and symbolic representations like arrows.

Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:

Examples of graph encoding functions used to encode graphs via text.

Analysis and results

We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA.

How LLMs handle graph tasks

In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:

  • LLMs struggle: On most of these basic tasks, LLMs did not do much better than a random guess.
  • Encoding matters significantly: How we represent the graph as text has a great effect on LLM performance. The “incident” encoding excelled for most of the tasks in general.

Our results are summarized in the following chart.

Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.

Bigger is (usually) better

In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of PaLM 2. Here is a summary of our findings:

  • In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.
  • Oddly, size didn’t matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).
  • Even the biggest LLM couldn’t consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.

Do different graph shapes confuse LLMs

We wondered if the “shape” of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.

Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.

Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

Conclusion

In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:

  • How to translate the graph to text: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..
  • Task type: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.
  • Graph structure: Surprisingly, the “shape” of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.

This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM’s accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.

Acknowledgements

We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.

Categories
Offsites

Cappy: Outperforming and boosting large multi-task language models with a small scorer

Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as T0, FLAN, and OPT-IML. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., Put the concepts together to form a sentence: ski, mountain, skier) paired with a corresponding response (e.g., Skier skis down the mountain). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions.

The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks.

Due to the complexity of understanding and solving various tasks solely using instructions, the size of multi-task LLMs typically spans from several billion parameters to hundreds of billions (e.g., FLAN-11B, T0-11B and OPT-IML-175B). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues.

Certain parameter-efficient tuning strategies, including prompt tuning and adapters, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some in-context learning techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model’s maximum input length, which permits only a few samples to guide task resolution.

In “Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer”, presented at NeurIPS 2023, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of RoBERTa with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs.

Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction.

Pre-training

We begin with the same dataset collection, which includes 39 diverse datasets from PromptSource that were used to train T0. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response.

Cappy’s regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ Rouge-L, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision.

As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the RoBERTa model. The pre-training of Cappy is conducted on Google’s TPU-v4, with RedCoast, a lightweight toolkit for automating distributed training.

Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning.

Applying Cappy

Cappy solves practical tasks within a candidate-selection mechanism. More specifically, given an instruction and a set of candidate responses, Cappy produces a score for each candidate response. This is achieved by inputting the instruction alongside each individual response, and then assigning the response with the highest score as its prediction. In classification tasks, all candidate responses are inherently predefined. For example, for an instruction of a sentiment classification task (e.g., “Based on this review, would the user recommend this product?: ‘Stunning even for the non-gamer.’”), the candidate responses are “Yes” or “No”. In such scenarios, Cappy functions independently. On the other hand, in generation tasks, candidate responses are not pre-defined, requiring an existing multi-task LLM to yield the candidate responses. In this case, Cappy serves as an auxiliary component of the multi-task LLM, enhancing its decoding.

Adapting multi-task LLMs with Cappy

When there is available downstream training data, Cappy enables effective and efficient adaptation of multi-task LLMs on downstream tasks. Specifically, we fine-tune Cappy to integrate downstream task information into LLM predictions. This process involves creating a separate regression dataset specific to the downstream training data with the same data annotation process used to construct the pre-training data. As a result, the fine-tuned Cappy collaborates with a multi-task LLM, boosting the LLM’s performance on the downstream task.

In contrast to other LLM tuning strategies, adapting LLMs with Cappy significantly reduces the high demand for device memory as it avoids the need for back-propagation through LLM parameters for downstream tasks. Moreover, Cappy adaptation does not rely on the access to LLM parameters, making it compatible with closed-source multi-task LLMs, such as the ones only accessible via WebAPIs. Compared with in-context learning approaches, which circumvent model tuning by attaching training examples to the instruction prefix, Cappy is not restricted by the LLM’s maximum input length. Thus, Cappy can incorporate an unlimited number of downstream training examples. Cappy can also be applied with other adaptation methods, such as fine-tuning and in-context learning, further boosting their overall performance.

Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs.

Results

We assess Cappy’s performance across eleven held-out language understanding classification tasks from PromptSource. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on teacher-forcing training that utilizes only the ground truth responses.

The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a pre-trained RLHF reward model. Cappy matches the best ones among existing multi-task LLMs.

We also examine the adaptation of multi-task LLMs with Cappy on complex tasks from BIG-Bench, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself.

The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin.

Conclusion

We introduce Cappy, a novel approach that enhances the performance and efficiency of multi-task LLMs. In our experiments, we adapt a single LLM to several domains with Cappy. In the future, Cappy as a pre-trained model can potentially be used in other creative ways beyond on single LLMs.

Acknowledgments

Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions.