Categories
Offsites

Health-specific embedding tools for dermatology and pathology

There’s a worldwide shortage of access to medical imaging expert interpretation across specialties including radiology, dermatology and pathology. Machine learning (ML) technology can help ease this burden by powering tools that enable doctors to interpret these images more accurately and efficiently. However, the development and implementation of such ML tools are often limited by the availability of high-quality data, ML expertise, and computational resources.

One way to catalyze the use of ML for medical imaging is via domain-specific models that utilize deep learning (DL) to capture the information in medical images as compressed numerical vectors (called embeddings). These embeddings represent a type of pre-learned understanding of the important features in an image. Identifying patterns in the embeddings reduces the amount of data, expertise, and compute needed to train performant models as compared to working with high-dimensional data, such as images, directly. Indeed, these embeddings can be used to perform a variety of downstream tasks within the specialized domain (see animated graphic below). This framework of leveraging pre-learned understanding to solve related tasks is similar to that of a seasoned guitar player quickly learning a new song by ear. Because the guitar player has already built up a foundation of skill and understanding, they can quickly pick up the patterns and groove of a new song.

Path Foundation is used to convert a small dataset of (image, label) pairs into (embedding, label) pairs. These pairs can then be used to train a task-specific classifier using a linear probe, (i.e., a lightweight linear classifier) as represented in this graphic, or other types of models using the embeddings as input.

Once the linear probe is trained, it can be used to make predictions on embeddings from new images. These predictions can be compared to ground truth information in order to evaluate the linear probe’s performance.

In order to make this type of embedding model available and drive further development of ML tools in medical imaging, we are excited to release two domain-specific tools for research use: Derm Foundation and Path Foundation. This follows on the strong response we’ve already received from researchers using the CXR Foundation embedding tool for chest radiographs and represents a portion of our expanding research offerings across multiple medical-specialized modalities. These embedding tools take an image as input and produce a numerical vector (the embedding) that is specialized to the domains of dermatology and digital pathology images, respectively. By running a dataset of chest X-ray, dermatology, or pathology images through the respective embedding tool, researchers can obtain embeddings for their own images, and use these embeddings to quickly develop new models for their applications.

Path Foundation

In “Domain-specific optimization and diverse evaluation of self-supervised models for histopathology”, we showed that self-supervised learning (SSL) models for pathology images outperform traditional pre-training approaches and enable efficient training of classifiers for downstream tasks. This effort focused on hematoxylin and eosin (H&E) stained slides, the principal tissue stain in diagnostic pathology that enables pathologists to visualize cellular features under a microscope. The performance of linear classifiers trained using the output of the SSL models matched that of prior DL models trained on orders of magnitude more labeled data.

Due to substantial differences between digital pathology images and “natural image” photos, this work involved several pathology-specific optimizations during model training. One key element is that whole-slide images (WSIs) in pathology can be 100,000 pixels across (thousands of times larger than typical smartphone photos) and are analyzed by experts at multiple magnifications (zoom levels). As such, the WSIs are typically broken down into smaller tiles or patches for computer vision and DL applications. The resulting images are information dense with cells or tissue structures distributed throughout the frame instead of having distinct semantic objects or foreground vs. background variations, thus creating unique challenges for robust SSL and feature extraction. Additionally, physical (e.g., cutting) and chemical (e.g., fixing and staining) processes used to prepare the samples can influence image appearance dramatically.

Taking these important aspects into consideration, pathology-specific SSL optimizations included helping the model learn stain-agnostic features, generalizing the model to patches from multiple magnifications, augmenting the data to mimic scanning and image post processing, and custom data balancing to improve input heterogeneity for SSL training. These approaches were extensively evaluated using a broad set of benchmark tasks involving 17 different tissue types over 12 different tasks.

Utilizing the vision transformer (ViT-S/16) architecture, Path Foundation was selected as the best performing model from the optimization and evaluation process described above (and illustrated in the figure below). This model thus provides an important balance between performance and model size to enable valuable and scalable use in generating embeddings over the many individual image patches of large pathology WSIs.

SSL training with pathology-specific optimizations for Path Foundation.

The value of domain-specific image representations can also be seen in the figure below, which shows the linear probing performance improvement of Path Foundation (as measured by AUROC) compared to traditional pre-training on natural images (ImageNet-21k). This includes evaluation for tasks such as metastatic breast cancer detection in lymph nodes, prostate cancer grading, and breast cancer grading, among others.

Path Foundation embeddings significantly outperform traditional ImageNet embeddings as evaluated by linear probing across multiple evaluation tasks in histopathology.

Derm Foundation

Derm Foundation is an embedding tool derived from our research in applying DL to interpret images of dermatology conditions and includes our recent work that adds improvements to generalize better to new datasets. Due to its dermatology-specific pre-training it has a latent understanding of features present in images of skin conditions and can be used to quickly develop models to classify skin conditions. The model underlying the API is a BiT ResNet-101×3 trained in two stages. The first pre-training stage uses contrastive learning, similar to ConVIRT, to train on a large number of image-text pairs from the internet. In the second stage, the image component of this pre-trained model is then fine-tuned for condition classification using clinical datasets, such as those from teledermatology services.

Unlike histopathology images, dermatology images more closely resemble the real-world images used to train many of today’s computer vision models. However, for specialized dermatology tasks, creating a high-quality model may still require a large dataset. With Derm Foundation, researchers can use their own smaller dataset to retrieve domain-specific embeddings, and use those to build smaller models (e.g., linear classifiers or other small non-linear models) that enable them to validate their research or product ideas. To evaluate this approach, we trained models on a downstream task using teledermatology data. Model training involved varying dataset sizes (12.5%, 25%, 50%, 100%) to compare embedding-based linear classifiers against fine-tuning.

The modeling variants considered were:

  • A linear classifier on frozen embeddings from BiT-M (a standard pre-trained image model)
  • Fine-tuned version of BiT-M with an extra dense layer for the downstream task
  • A linear classifier on frozen embeddings from the Derm Foundation API
  • Fine-tuned version of the model underlying the Derm Foundation API with an extra layer for the downstream task

We found that models built on top of the Derm Foundation embeddings for dermatology-related tasks achieved significantly higher quality than those built solely on embeddings or fine tuned from BiT-M. This advantage was found to be most pronounced for smaller training dataset sizes.

These results demonstrate that the Derm Foundation tooI can serve as a useful starting point to accelerate skin-related modeling tasks. We aim to enable other researchers to build on the underlying features and representations of dermatology that the model has learned.

However, there are limitations with this analysis. We’re still exploring how well these embeddings generalize across task types, patient populations, and image settings. Downstream models built using Derm Foundation still require careful evaluation to understand their expected performance in the intended setting.

Access Path and Derm Foundation

We envision that the Derm Foundation and Path Foundation embedding tools will enable a range of use cases, including efficient development of models for diagnostic tasks, quality assurance and pre-analytical workflow improvements, image indexing and curation, and biomarker discovery and validation. We are releasing both tools to the research community so they can explore the utility of the embeddings for their own dermatology and pathology data.

To get access, please sign up to each tool’s terms of service using the following Google Forms.

After gaining access to each tool, you can use the API to retrieve embeddings from dermatology images or digital pathology images stored in Google Cloud. Approved users who are just curious to see the model and embeddings in action can use the provided example Colab notebooks to train models using public data for classifying six common skin conditions or identifying tumors in histopathology patches. We look forward to seeing the range of use-cases these tools can unlock.

Acknowledgements

We would like to thank the many collaborators who helped make this work possible including Yun Liu, Can Kirmizi, Fereshteh Mahvar, Bram Sterling, Arman Tajback, Kenneth Philbrik, Arnav Agharwal, Aurora Cheung, Andrew Sellergren, Boris Babenko, Basil Mustafa, Jan Freyberg, Terry Spitz, Yuan Liu, Pinal Bavishi, Ayush Jain, Amit Talreja, Rajeev Rikhye, Abbi Ward, Jeremy Lai, Faruk Ahmed, Supriya Vijay,Tiam Jaroensri, Jessica Loo, Saurabh Vyawahare, Saloni Agarwal, Ellery Wulczyn, Jonathan Krause, Fayaz Jamil, Tom Small, Annisah Um’rani, Lauren Winer, Sami Lachgar, Yossi Matias, Greg Corrado, and Dale Webster.

Categories
Offsites

Simulating the electric field and a moving charge

Categories
Offsites

How the Mandelbrot set is defined

Categories
Offsites

A challenging puzzle about subset sums

Categories
Offsites

SCIN: A new resource for representative dermatology images

Health datasets play a crucial role in research and medical education, but it can be challenging to create a dataset that represents the real world. For example, dermatology conditions are diverse in their appearance and severity and manifest differently across skin tones. Yet, existing dermatology image datasets often lack representation of everyday conditions (like rashes, allergies and infections) and skew towards lighter skin tones. Furthermore, race and ethnicity information is frequently missing, hindering our ability to assess disparities or create solutions.

To address these limitations, we are releasing the Skin Condition Image Network (SCIN) dataset in collaboration with physicians at Stanford Medicine. We designed SCIN to reflect the broad range of concerns that people search for online, supplementing the types of conditions typically found in clinical datasets. It contains images across various skin tones and body parts, helping to ensure that future AI tools work effectively for all. We’ve made the SCIN dataset freely available as an open-access resource for researchers, educators, and developers, and have taken careful steps to protect contributor privacy.

Example set of images and metadata from the SCIN dataset.

Dataset composition

The SCIN dataset currently contains over 10,000 images of skin, nail, or hair conditions, directly contributed by individuals experiencing them. All contributions were made voluntarily with informed consent by individuals in the US, under an institutional-review board approved study. To provide context for retrospective dermatologist labeling, contributors were asked to take images both close-up and from slightly further away. They were given the option to self-report demographic information and tanning propensity (self-reported Fitzpatrick Skin Type, i.e., sFST), and to describe the texture, duration and symptoms related to their concern.

One to three dermatologists labeled each contribution with up to five dermatology conditions, along with a confidence score for each label. The SCIN dataset contains these individual labels, as well as an aggregated and weighted differential diagnosis derived from them that could be useful for model testing or training. These labels were assigned retrospectively and are not equivalent to a clinical diagnosis, but they allow us to compare the distribution of dermatology conditions in the SCIN dataset with existing datasets.

The SCIN dataset contains largely allergic, inflammatory and infectious conditions while datasets from clinical sources focus on benign and malignant neoplasms.

While many existing dermatology datasets focus on malignant and benign tumors and are intended to assist with skin cancer diagnosis, the SCIN dataset consists largely of common allergic, inflammatory, and infectious conditions. The majority of images in the SCIN dataset show early-stage concerns — more than half arose less than a week before the photo, and 30% arose less than a day before the image was taken. Conditions within this time window are seldom seen within the health system and therefore are underrepresented in existing dermatology datasets.

We also obtained dermatologist estimates of Fitzpatrick Skin Type (estimated FST or eFST) and layperson labeler estimates of Monk Skin Tone (eMST) for the images. This allowed comparison of the skin condition and skin type distributions to those in existing dermatology datasets. Although we did not selectively target any skin types or skin tones, the SCIN dataset has a balanced Fitzpatrick skin type distribution (with more of Types 3, 4, 5, and 6) compared to similar datasets from clinical sources.

Self-reported and dermatologist-estimated Fitzpatrick Skin Type distribution in the SCIN dataset compared with existing un-enriched dermatology datasets (Fitzpatrick17k, PH², SKINL2, and PAD-UFES-20).

The Fitzpatrick Skin Type scale was originally developed as a photo-typing scale to measure the response of skin types to UV radiation, and it is widely used in dermatology research. The Monk Skin Tone scale is a newer 10-shade scale that measures skin tone rather than skin phototype, capturing more nuanced differences between the darker skin tones. While neither scale was intended for retrospective estimation using images, the inclusion of these labels is intended to enable future research into skin type and tone representation in dermatology. For example, the SCIN dataset provides an initial benchmark for the distribution of these skin types and tones in the US population.

The SCIN dataset has a high representation of women and younger individuals, likely reflecting a combination of factors. These could include differences in skin condition incidence, propensity to seek health information online, and variations in willingness to contribute to research across demographics.

Crowdsourcing method

To create the SCIN dataset, we used a novel crowdsourcing method, which we describe in the accompanying research paper co-authored with investigators at Stanford Medicine. This approach empowers individuals to play an active role in healthcare research. It allows us to reach people at earlier stages of their health concerns, potentially before they seek formal care. Crucially, this method uses advertisements on web search result pages — the starting point for many people’s health journey — to connect with participants.

Our results demonstrate that crowdsourcing can yield a high-quality dataset with a low spam rate. Over 97.5% of contributions were genuine images of skin conditions. After performing further filtering steps to exclude images that were out of scope for the SCIN dataset and to remove duplicates, we were able to release nearly 90% of the contributions received over the 8-month study period. Most images were sharp and well-exposed. Approximately half of the contributions include self-reported demographics, and 80% contain self-reported information relating to the skin condition, such as texture, duration, or other symptoms. We found that dermatologists’ ability to retrospectively assign a differential diagnosis depended more on the availability of self-reported information than on image quality.

Dermatologist confidence in their labels (scale from 1-5) depended on the availability of self-reported demographic and symptom information.

While perfect image de-identification can never be guaranteed, protecting the privacy of individuals who contributed their images was a top priority when creating the SCIN dataset. Through informed consent, contributors were made aware of potential re-identification risks and advised to avoid uploading images with identifying features. Post-submission privacy protection measures included manual redaction or cropping to exclude potentially identifying areas, reverse image searches to exclude publicly available copies and metadata removal or aggregation. The SCIN Data Use License prohibits attempts to re-identify contributors.

We hope the SCIN dataset will be a helpful resource for those working to advance inclusive dermatology research, education, and AI tool development. By demonstrating an alternative to traditional dataset creation methods, SCIN paves the way for more representative datasets in areas where self-reported data or retrospective labeling is feasible.

Acknowledgements

We are grateful to all our co-authors Abbi Ward, Jimmy Li, Julie Wang, Sriram Lakshminarasimhan, Ashley Carrick, Bilson Campana, Jay Hartford, Pradeep Kumar S, Tiya Tiyasirisokchai, Sunny Virmani, Renee Wong, Yossi Matias, Greg S. Corrado, Dale R. Webster, Dawn Siegel (Stanford Medicine), Steven Lin (Stanford Medicine), Justin Ko (Stanford Medicine), Alan Karthikesalingam and Christopher Semturs. We also thank Yetunde Ibitoye, Sami Lachgar, Lisa Lehmann, Javier Perez, Margaret Ann Smith (Stanford Medicine), Rachelle Sico, Amit Talreja, Annisah Um’rani and Wayne Westerlind for their essential contributions to this work. Finally, we are grateful to Heather Cole-Lewis, Naama Hammel, Ivor Horn, Michael Howell, Yun Liu, and Eric Teasley for their insightful comments on the study design and manuscript.

Categories
Offsites

MELON: Reconstructing 3D objects from images with unknown poses

A person’s prior experience and understanding of the world generally enables them to easily infer what an object looks like in whole, even if only looking at a few 2D pictures of it. Yet the capacity for a computer to reconstruct the shape of an object in 3D given only a few images has remained a difficult algorithmic problem for years. This fundamental computer vision task has applications ranging from the creation of e-commerce 3D models to autonomous vehicle navigation.

A key part of the problem is how to determine the exact positions from which images were taken, known as pose inference. If camera poses are known, a range of successful techniques — such as neural radiance fields (NeRF) or 3D Gaussian Splatting — can reconstruct an object in 3D. But if these poses are not available, then we face a difficult “chicken and egg” problem where we could determine the poses if we knew the 3D object, but we can’t reconstruct the 3D object until we know the camera poses. The problem is made harder by pseudo-symmetries — i.e., many objects look similar when viewed from different angles. For example, square objects like a chair tend to look similar every 90° rotation. Pseudo-symmetries of an object can be revealed by rendering it on a turntable from various angles and plotting its photometric self-similarity map.

Self-Similarity map of a toy truck model. Left: The model is rendered on a turntable from various azimuthal angles, θ. Right: The average L2 RGB similarity of a rendering from θ with that of θ*. The pseudo-similarities are indicated by the dashed red lines.

The diagram above only visualizes one dimension of rotation. It becomes even more complex (and difficult to visualize) when introducing more degrees of freedom. Pseudo-symmetries make the problem ill-posed, with naïve approaches often converging to local minima. In practice, such an approach might mistake the back view as the front view of an object, because they share a similar silhouette. Previous techniques (such as BARF or SAMURAI) side-step this problem by relying on an initial pose estimate that starts close to the global minima. But how can we approach this if those aren’t available?

Methods, such as GNeRF and VMRF leverage generative adversarial networks (GANs) to overcome the problem. These techniques have the ability to artificially “amplify” a limited number of training views, aiding reconstruction. GAN techniques, however, often have complex, sometimes unstable, training processes, making robust and reliable convergence difficult to achieve in practice. A range of other successful methods, such as SparsePose or RUST, can infer poses from a limited number views, but require pre-training on a large dataset of posed images, which aren’t always available, and can suffer from “domain-gap” issues when inferring poses for different types of images.

In “MELON: NeRF with Unposed Images in SO(3)”, spotlighted at 3DV 2024, we present a technique that can determine object-centric camera poses entirely from scratch while reconstructing the object in 3D. MELON (Modulo Equivalent Latent Optimization of NeRF) is one of the first techniques that can do this without initial pose camera estimates, complex training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. We demonstrate that MELON can reconstruct a NeRF from unposed images with state-of-the-art accuracy while requiring as few as 4–6 images of an object.

MELON

We leverage two key techniques to aid convergence of this ill-posed problem. The first is a very lightweight, dynamically trained convolutional neural network (CNN) encoder that regresses camera poses from training images. We pass a downscaled training image to a four layer CNN that infers the camera pose. This CNN is initialized from noise and requires no pre-training. Its capacity is so small that it forces similar looking images to similar poses, providing an implicit regularization greatly aiding convergence.

The second technique is a modulo loss that simultaneously considers pseudo symmetries of an object. We render the object from a fixed set of viewpoints for each training image, backpropagating the loss only through the view that best fits the training image. This effectively considers the plausibility of multiple views for each image. In practice, we find N=2 views (viewing an object from the other side) is all that’s required in most cases, but sometimes get better results with N=4 for square objects.

These two techniques are integrated into standard NeRF training, except that instead of fixed camera poses, poses are inferred by the CNN and duplicated by the modulo loss. Photometric gradients back-propagate through the best-fitting cameras into the CNN. We observe that cameras generally converge quickly to globally optimal poses (see animation below). After training of the neural field, MELON can synthesize novel views using standard NeRF rendering methods.

We simplify the problem by using the NeRF-Synthetic dataset, a popular benchmark for NeRF research and common in the pose-inference literature. This synthetic dataset has cameras at precisely fixed distances and a consistent “up” orientation, requiring us to infer only the polar coordinates of the camera. This is the same as an object at the center of a globe with a camera always pointing at it, moving along the surface. We then only need the latitude and longitude (2 degrees of freedom) to specify the camera pose.

MELON uses a dynamically trained lightweight CNN encoder that predicts a pose for each image. Predicted poses are replicated by the modulo loss, which only penalizes the smallest L2 distance from the ground truth color. At evaluation time, the neural field can be used to generate novel views.

Results

We compute two key metrics to evaluate MELON’s performance on the NeRF Synthetic dataset. The error in orientation between the ground truth and inferred poses can be quantified as a single angular error that we average across all training images, the pose error. We then test the accuracy of MELON’s rendered objects from novel views by measuring the peak signal-to-noise ratio (PSNR) against held out test views. We see that MELON quickly converges to the approximate poses of most cameras within the first 1,000 steps of training, and achieves a competitive PSNR of 27.5 dB after 50k steps.

Convergence of MELON on a toy truck model during optimization. Left: Rendering of the NeRF. Right: Polar plot of predicted (blue x), and ground truth (red dot) cameras.

MELON achieves similar results for other scenes in the NeRF Synthetic dataset.

Reconstruction quality comparison between ground-truth (GT) and MELON on NeRF-Synthetic scenes after 100k training steps.

Noisy images

MELON also works well when performing novel view synthesis from extremely noisy, unposed images. We add varying amounts, σ, of white Gaussian noise to the training images. For example, the object in σ=1.0 below is impossible to make out, yet MELON can determine the pose and generate novel views of the object.

Novel view synthesis from noisy unposed 128×128 images. Top: Example of noise level present in training views. Bottom: Reconstructed model from noisy training views and mean angular pose error.

This perhaps shouldn’t be too surprising, given that techniques like RawNeRF have demonstrated NeRF’s excellent de-noising capabilities with known camera poses. The fact that MELON works for noisy images of unknown camera poses so robustly was unexpected.

Conclusion

We present MELON, a technique that can determine object-centric camera poses to reconstruct objects in 3D without the need for approximate pose initializations, complex GAN training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. Though we only demonstrated MELON on synthetic images we are adapting our technique to work in real world conditions. See the paper and MELON site to learn more.

Acknowledgements

We would like to thank our paper co-authors Axel Levy, Matan Sela, and Gordon Wetzstein, as well as Florian Schroff and Hartwig Adam for continuous help in building this technology. We also thank Matthew Brown, Ricardo Martin-Brualla and Frederic Poitevin for their helpful feedback on the paper draft. We also acknowledge the use of the computational resources at the SLAC Shared Scientific Data Facility (SDF).

Categories
Offsites

HEAL: A framework for health equity assessment of machine learning performance

Health equity is a major societal concern worldwide with disparities having many causes. These sources include limitations in access to healthcare, differences in clinical treatment, and even fundamental differences in the diagnostic technology. In dermatology for example, skin cancer outcomes are worse for populations such as minorities, those with lower socioeconomic status, or individuals with limited healthcare access. While there is great promise in recent advances in machine learning (ML) and artificial intelligence (AI) to help improve healthcare, this transition from research to bedside must be accompanied by a careful understanding of whether and how they impact health equity.

Health equity is defined by public health organizations as fairness of opportunity for everyone to be as healthy as possible. Importantly, equity may be different from equality. For example, people with greater barriers to improving their health may require more or different effort to experience this fair opportunity. Similarly, equity is not fairness as defined in the AI for healthcare literature. Whereas AI fairness often strives for equal performance of the AI technology across different patient populations, this does not center the goal of prioritizing performance with respect to pre-existing health disparities.

Health equity considerations. An intervention (e.g., an ML-based tool, indicated in dark blue) promotes health equity if it helps reduce existing disparities in health outcomes (indicated in lighter blue).

In “Health Equity Assessment of machine Learning performance (HEAL): a framework and dermatology AI model case study”, published in The Lancet eClinicalMedicine, we propose a methodology to quantitatively assess whether ML-based health technologies perform equitably. In other words, does the ML model perform well for those with the worst health outcomes for the condition(s) the model is meant to address? This goal anchors on the principle that health equity should prioritize and measure model performance with respect to disparate health outcomes, which may be due to a number of factors that include structural inequities (e.g., demographic, social, cultural, political, economic, environmental and geographic).

The health equity framework (HEAL)

The HEAL framework proposes a 4-step process to estimate the likelihood that an ML-based health technology performs equitably:

  1. Identify factors associated with health inequities and define tool performance metrics,
  2. Identify and quantify pre-existing health disparities,
  3. Measure the performance of the tool for each subpopulation,
  4. Measure the likelihood that the tool prioritizes performance with respect to health disparities.

The final step’s output is termed the HEAL metric, which quantifies how anticorrelated the ML model’s performance is with health disparities. In other words, does the model perform better with populations that have the worse health outcomes?

This 4-step process is designed to inform improvements for making ML model performance more equitable, and is meant to be iterative and re-evaluated on a regular basis. For example, the availability of health outcomes data in step (2) can inform the choice of demographic factors and brackets in step (1), and the framework can be applied again with new datasets, models and populations.

Framework for Health Equity Assessment of machine Learning performance (HEAL). Our guiding principle is to avoid exacerbating health inequities, and these steps help us identify disparities and assess for inequitable model performance to move towards better outcomes for all.

With this work, we take a step towards encouraging explicit assessment of the health equity considerations of AI technologies, and encourage prioritization of efforts during model development to reduce health inequities for subpopulations exposed to structural inequities that can precipitate disparate outcomes. We should note that the present framework does not model causal relationships and, therefore, cannot quantify the actual impact a new technology will have on reducing health outcome disparities. However, the HEAL metric may help identify opportunities for improvement, where the current performance is not prioritized with respect to pre-existing health disparities.

Case study on a dermatology model

As an illustrative case study, we applied the framework to a dermatology model, which utilizes a convolutional neural network similar to that described in prior work. This example dermatology model was trained to classify 288 skin conditions using a development dataset of 29k cases. The input to the model consists of three photos of a skin concern along with demographic information and a brief structured medical history. The output consists of a ranked list of possible matching skin conditions.

Using the HEAL framework, we evaluated this model by assessing whether it prioritized performance with respect to pre-existing health outcomes. The model was designed to predict possible dermatologic conditions (from a list of hundreds) based on photos of a skin concern and patient metadata. Evaluation of the model is done using a top-3 agreement metric, which quantifies how often the top 3 output conditions match the most likely condition as suggested by a dermatologist panel. The HEAL metric is computed via the anticorrelation of this top-3 agreement with health outcome rankings.

We used a dataset of 5,420 teledermatology cases, enriched for diversity in age, sex and race/ethnicity, to retrospectively evaluate the model’s HEAL metric. The dataset consisted of “store-and-forward” cases from patients of 20 years or older from primary care providers in the USA and skin cancer clinics in Australia. Based on a review of the literature, we decided to explore race/ethnicity, sex and age as potential factors of inequity, and used sampling techniques to ensure that our evaluation dataset had sufficient representation of all race/ethnicity, sex and age groups. To quantify pre-existing health outcomes for each subgroup we relied on measurements from public databases endorsed by the World Health Organization, such as Years of Life Lost (YLLs) and Disability-Adjusted Life Years (DALYs; years of life lost plus years lived with disability).

HEAL metric for all dermatologic conditions across race/ethnicity subpopulations, including health outcomes (YLLs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance.
(* Higher is better; measures the likelihood the model performs equitably with respect to the axes in this table.)
HEAL metric for all dermatologic conditions across sexes, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)

Our analysis estimated that the model was 80.5% likely to perform equitably across race/ethnicity subgroups and 92.1% likely to perform equitably across sexes.

However, while the model was likely to perform equitably across age groups for cancer conditions specifically, we discovered that it had room for improvement across age groups for non-cancer conditions. For example, those 70+ have the poorest health outcomes related to non-cancer skin conditions, yet the model didn’t prioritize performance for this subgroup.

HEAL metrics for all cancer and non-cancer dermatologic conditions across age groups, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.)

Putting things in context

For holistic evaluation, the HEAL metric cannot be employed in isolation. Instead this metric should be contextualized alongside many other factors ranging from computational efficiency and data privacy to ethical values, and aspects that may influence the results (e.g., selection bias or differences in representativeness of the evaluation data across demographic groups).

As an adversarial example, the HEAL metric can be artificially improved by deliberately reducing model performance for the most advantaged subpopulation until performance for that subpopulation is worse than all others. For illustrative purposes, given subpopulations A and B where A has worse health outcomes than B, consider the choice between two models: Model 1 (M1) performs 5% better for subpopulation A than for subpopulation B. Model 2 (M2) performs 5% worse on subpopulation A than B. The HEAL metric would be higher for M1 because it prioritizes performance on a subpopulation with worse outcomes. However, M1 may have absolute performances of just 75% and 70% for subpopulations A and B respectively, while M2 has absolute performances of 75% and 80% for subpopulations A and B respectively. Choosing M1 over M2 would lead to worse overall performance for all subpopulations because some subpopulations are worse-off while no subpopulation is better-off.

Accordingly, the HEAL metric should be used alongside a Pareto condition (discussed further in the paper), which restricts model changes so that outcomes for each subpopulation are either unchanged or improved compared to the status quo, and performance does not worsen for any subpopulation.

The HEAL framework, in its current form, assesses the likelihood that an ML-based model prioritizes performance for subpopulations with respect to pre-existing health disparities for specific subpopulations. This differs from the goal of understanding whether ML will reduce disparities in outcomes across subpopulations in reality. Specifically, modeling improvements in outcomes requires a causal understanding of steps in the care journey that happen both before and after use of any given model. Future research is needed to address this gap.

Conclusion

The HEAL framework enables a quantitative assessment of the likelihood that health AI technologies prioritize performance with respect to health disparities. The case study demonstrates how to apply the framework in the dermatological domain, indicating a high likelihood that model performance is prioritized with respect to health disparities across sex and race/ethnicity, but also revealing the potential for improvements for non-cancer conditions across age. The case study also illustrates limitations in the ability to apply all recommended aspects of the framework (e.g., mapping societal context, availability of data), thus highlighting the complexity of health equity considerations of ML-based tools.

This work is a proposed approach to address a grand challenge for AI and health equity, and may provide a useful evaluation framework not only during model development, but during pre-implementation and real-world monitoring stages, e.g., in the form of health equity dashboards. We hold that the strength of the HEAL framework is in its future application to various AI tools and use cases and its refinement in the process. Finally, we acknowledge that a successful approach towards understanding the impact of AI technologies on health equity needs to be more than a set of metrics. It will require a set of goals agreed upon by a community that represents those who will be most impacted by a model.

Acknowledgements

The research described here is joint work across many teams at Google. We are grateful to all our co-authors: Terry Spitz, Malcolm Pyles, Heather Cole-Lewis, Ellery Wulczyn, Stephen R. Pfohl, Donald Martin, Jr., Ronnachai Jaroensri, Geoff Keeling, Yuan Liu, Stephanie Farquhar, Qinghan Xue, Jenna Lester, Cían Hughes, Patricia Strachan, Fraser Tan, Peggy Bui, Craig H. Mermel, Lily H. Peng, Yossi Matias, Greg S. Corrado, Dale R. Webster, Sunny Virmani, Christopher Semturs, Yun Liu, and Po-Hsuan Cameron Chen. We also thank Lauren Winer, Sami Lachgar, Ting-An Lin, Aaron Loh, Morgan Du, Jenny Rizk, Renee Wong, Ashley Carrick, Preeti Singh, Annisah Um’rani, Jessica Schrouff, Alexander Brown, and Anna Iurchenko for their support of this project.

Categories
Offsites

Cappy: Outperforming and boosting large multi-task language models with a small scorer

Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as T0, FLAN, and OPT-IML. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., Put the concepts together to form a sentence: ski, mountain, skier) paired with a corresponding response (e.g., Skier skis down the mountain). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions.

The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks.

Due to the complexity of understanding and solving various tasks solely using instructions, the size of multi-task LLMs typically spans from several billion parameters to hundreds of billions (e.g., FLAN-11B, T0-11B and OPT-IML-175B). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues.

Certain parameter-efficient tuning strategies, including prompt tuning and adapters, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some in-context learning techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model’s maximum input length, which permits only a few samples to guide task resolution.

In “Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer”, presented at NeurIPS 2023, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of RoBERTa with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs.

Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction.

Pre-training

We begin with the same dataset collection, which includes 39 diverse datasets from PromptSource that were used to train T0. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response.

Cappy’s regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ Rouge-L, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision.

As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the RoBERTa model. The pre-training of Cappy is conducted on Google’s TPU-v4, with RedCoast, a lightweight toolkit for automating distributed training.

Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning.

Applying Cappy

Cappy solves practical tasks within a candidate-selection mechanism. More specifically, given an instruction and a set of candidate responses, Cappy produces a score for each candidate response. This is achieved by inputting the instruction alongside each individual response, and then assigning the response with the highest score as its prediction. In classification tasks, all candidate responses are inherently predefined. For example, for an instruction of a sentiment classification task (e.g., “Based on this review, would the user recommend this product?: ‘Stunning even for the non-gamer.’”), the candidate responses are “Yes” or “No”. In such scenarios, Cappy functions independently. On the other hand, in generation tasks, candidate responses are not pre-defined, requiring an existing multi-task LLM to yield the candidate responses. In this case, Cappy serves as an auxiliary component of the multi-task LLM, enhancing its decoding.

Adapting multi-task LLMs with Cappy

When there is available downstream training data, Cappy enables effective and efficient adaptation of multi-task LLMs on downstream tasks. Specifically, we fine-tune Cappy to integrate downstream task information into LLM predictions. This process involves creating a separate regression dataset specific to the downstream training data with the same data annotation process used to construct the pre-training data. As a result, the fine-tuned Cappy collaborates with a multi-task LLM, boosting the LLM’s performance on the downstream task.

In contrast to other LLM tuning strategies, adapting LLMs with Cappy significantly reduces the high demand for device memory as it avoids the need for back-propagation through LLM parameters for downstream tasks. Moreover, Cappy adaptation does not rely on the access to LLM parameters, making it compatible with closed-source multi-task LLMs, such as the ones only accessible via WebAPIs. Compared with in-context learning approaches, which circumvent model tuning by attaching training examples to the instruction prefix, Cappy is not restricted by the LLM’s maximum input length. Thus, Cappy can incorporate an unlimited number of downstream training examples. Cappy can also be applied with other adaptation methods, such as fine-tuning and in-context learning, further boosting their overall performance.

Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs.

Results

We assess Cappy’s performance across eleven held-out language understanding classification tasks from PromptSource. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on teacher-forcing training that utilizes only the ground truth responses.

The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a pre-trained RLHF reward model. Cappy matches the best ones among existing multi-task LLMs.

We also examine the adaptation of multi-task LLMs with Cappy on complex tasks from BIG-Bench, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself.

The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin.

Conclusion

We introduce Cappy, a novel approach that enhances the performance and efficiency of multi-task LLMs. In our experiments, we adapt a single LLM to several domains with Cappy. In the future, Cappy as a pre-trained model can potentially be used in other creative ways beyond on single LLMs.

Acknowledgments

Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions.

Categories
Offsites

Talk like a graph: Encoding graphs for large language models

Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term graph is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way.

Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses.

Since graphs are everywhere and LLM technology is on the rise, in “Talk like a Graph: Encoding Graphs for Large Language Models”, presented at ICLR 2024, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called GraphQA to study different approaches on different graph reasoning problems and show how to phrase a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks!

Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM.

Graphs as text

To be able to systematically find out what is the best way to translate a graph to text, we first design a benchmark called GraphQA. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world.

Overview of our framework for reasoning with graphs using LLMs.

GraphQA focuses on simple tasks related to graphs, like checking if an edge exists, calculating the number of nodes or edges, finding nodes that are connected to a specific node, and checking for cycles in a graph. These tasks might seem basic, but they require understanding the relationships between nodes and edges. By covering different types of challenges, from identifying patterns to creating new connections, GraphQA helps models learn how to analyze graphs effectively. These basic tasks are crucial for more complex reasoning on graphs, like finding the shortest path between nodes, detecting communities, or identifying influential nodes. Additionally, GraphQA includes generating random graphs using various algorithms like Erdős-Rényi, scale-free networks, Barabasi-Albert model, and stochastic block model, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training.

When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. Prompting heuristics are different strategies for doing this. Let’s break down the common ones:

  • Zero-shot: simply describe the task (“Is there a cycle in this graph?”) and tell the LLM to go for it. No examples provided.
  • Few-shot: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers.
  • Chain-of-Thought: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own “thought process” when faced with new graphs.
  • Zero-CoT: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like “Let’s think step-by-step,” to trigger its own problem-solving breakdown.
  • BAG (build a graph): This is specifically for graph tasks. We add the phrase “Let’s build a graph…” to the description, helping the LLM focus on the graph structure.

We explored different ways to translate graphs into text that LLMs can work with. Our key questions were:

  • Node encoding: How do we represent individual nodes? Options tested include simple integers, common names (people, characters), and letters.
  • Edge encoding: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like “are friends”, and symbolic representations like arrows.

Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure:

Examples of graph encoding functions used to encode graphs via text.

Analysis and results

We carried out three key experiments: one to test how LLMs handle graph tasks, and two to understand how the size of the LLM and different graph shapes affected performance. We run all our experiments on GraphQA.

How LLMs handle graph tasks

In this experiment, we tested how well pre-trained LLMs tackle graph problems like identifying connections, cycles, and node degrees. Here is what we learned:

  • LLMs struggle: On most of these basic tasks, LLMs did not do much better than a random guess.
  • Encoding matters significantly: How we represent the graph as text has a great effect on LLM performance. The “incident” encoding excelled for most of the tasks in general.

Our results are summarized in the following chart.

Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly.

Bigger is (usually) better

In this experiment, we wanted to see if the size of the LLM (in terms of the number of parameters) affects how well they can handle graph problems. For that, we tested the same graph tasks on the XXS, XS, S, and L sizes of PaLM 2. Here is a summary of our findings:

  • In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns.
  • Oddly, size didn’t matter as much for the “edge existence” task (finding out if two nodes in a graph are connected).
  • Even the biggest LLM couldn’t consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks.
Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L.

Do different graph shapes confuse LLMs

We wondered if the “shape” of a graph (how nodes are connected) influences how well LLMs can solve problems on it. Think of the following figure as different examples of graph shapes.

Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

We found that graph structure has a big impact on LLM performance. For example, in a task asking if a cycle exists, LLMs did great on tightly interconnected graphs (cycles are common there) but struggled on path graphs (where cycles never happen). Interestingly, providing some mixed examples helped it adapt. For instance, for cycle check, we added some examples containing a cycle and some examples with no cycles as few-shot examples in our prompt. Similar patterns occurred with other tasks.

Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively.

Conclusion

In short, we dug deep into how to best represent graphs as text so LLMs can understand them. We found three major factors that make a difference:

  • How to translate the graph to text: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general..
  • Task type: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text.
  • Graph structure: Surprisingly, the “shape” of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does.

This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM’s accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area.

Acknowledgements

We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.

Categories
Offsites

Chain-of-table: Evolving tables in the reasoning chain for table understanding

People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in natural language processing (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize.

Recently, large language models (LLMs) have achieved outstanding performance across diverse natural language understanding (NLU) tasks by generating reliable reasoning chains, as shown in works like Chain-of-Thought and Least-to-Most. However, the most suitable way for LLMs to reason over tabular data remains an open question.

In “Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the WikiTQ, TabFact, and FeTaQA benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods.

Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question.

Chain-of-Table

In Chain-of-Table, we guide LLMs using in-context learning to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM.

For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question.

These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding.

Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably.

Chain-of-Table consists of three main stages. In the first stage, it instructs the LLM to dynamically plan the next operation by in-context learning. Specifically, the prompt involves three components as shown in the following figure:

  1. The question Q: “Which country had the most cyclists finish in the top 3?”
  2. The operation history chain: f_add_col(Country) and f_select_row(1, 2, 3).
  3. The latest intermediate table T: the transformed intermediate table.

By providing the triplet (T, Q, chain) in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step.

Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments.

After the next operation f is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table.

For instance, when the operation f_group_by is selected, it requires a header name as its argument.

The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning.

Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer.

Experimental setup

We use PaLM 2-S and GPT 3.5 as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: WikiTQ, TabFact, and FeTaQA. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and Chain-of-Thought) and the program-aided methods (e.g., Text-to-SQL, Binder, and Dater).

More accurate answers

Compared to the generic reasoning methods and program-aided reasoning methods, Chain-of-Table achieves better performance across PaLM 2 and GPT 3.5. This is attributed to the dynamically sampled operations and the informative intermediate tables.

Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models.

Better robustness on harder questions

In Chain-of-Table, longer operation chains indicate the higher difficulty and complexity of the questions and their corresponding tables. We categorize the test samples according to their operation lengths in Chain-of-Table. We compare Chain-of-Table with Chain-of-Thought and Dater, as representative generic and program-aided reasoning methods. We illustrate this using results from PaLM 2 on WikiTQ.

Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts.

Notably, Chain-of-Table consistently surpasses both baseline methods across all operation chain lengths, with a significant margin up to 11.6% compared with Chain-of-Thought, and up to 7.9% compared with Dater. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five.

Better robustness with larger tables

We categorize the tables from WikiTQ into three groups based on token number: small (<2000 tokens), medium (2000 to 4000 tokens) and large (>4000 tokens). We then compare Chain-of-Table with Dater and Binder, the two latest and strongest baselines.

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

Performance of Binder, Dater, and the proposed Chain-of-Table on small (<2000 tokens), medium (2000 to 4000 tokens), and large (>4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.)

As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs.

Conclusion

Our proposed Chain-of-Table method enhances the reasoning capability of LLMs by leveraging the tabular structure to express intermediate steps for table-based reasoning. It instructs LLMs to dynamically plan an operation chain according to the input table and its associated question. This evolving table design sheds new light on the understanding of prompting LLMs for table understanding.

Acknowledgements

This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.