Introducing LoCalPFN to Improve Tabular Foundation Models

November 13, 2024

Author: Anthony Caterini    Editor: George Stein

Tabular data is essential across numerous industries: it is the primary modality we leverage here at the bank, and otherwise spans diverse domains such as healthcare, logistics, science, engineering, and more. Yet despite significant advancements in the application of deep learning to images and text, effectively adapting similar techniques to tabular data remains a challenge. Our recent work, Retrieval & Fine-Tuning for In-Context Tabular Models1, addresses this gap. In it we demonstrate that retrieval and fine-tuning – incredibly popular techniques for leveraging the power of foundation models in other domains – can also push the frontier of tabular foundation models. Our paper will appear in the proceedings of the 2024 Neural Information Systems Processing (NeurIPS) conference, the premier venue for cutting-edge machine learning techniques, and we describe it below. We are continuing work in this space, so please check back again soon for a follow-up!


The Challenges of Tabular Data

Tabular data can vary greatly, ranging from basic spreadsheets to complex databases. We might encounter something as simple as the classic Iris dataset comprising just 150 observations of 4 floral characteristics, or as complicated as enterprise-level databases with millions of customers, thousands of features, and dozens of possible response variables to predict. This high degree of variability makes it hard to create a single neural network that can easily understand and adapt to the structure of a given tabular dataset. Not only are the sizes of the data different, but the relationship between columns is unknown: continuing the example above, the relationship between features of a flower in the Iris dataset likely does not provide any insight into the structure of banking data!

We can contrast this with two areas where neural networks reign supreme: text has clear connections between subsequent words in a sentence, and images have spatial coherence amongst nearby pixels. In both of these cases, large foundation models with neural network backbones – single models applicable to a wide range of downstream tasks, such as ChatGPT and CLIP2 – have pushed the state of the art, revolutionizing how we interact with AI. However, in the tabular space, methods based on decision trees – such as XGBoost3 and Random Forest4 – have shown more robustness to varied inputs, and thus are typically more reliable than deep learning for working with tabular data.

Yet the potential of tree-based methods is limited in comparison to neural networks: it is unlikely that a single large tree-based model will revolutionize how we interact with tabular data, as such techniques have not demonstrated the ability to transfer learnings from one dataset to another. Furthermore, neural networks can close the gap on tree-based methods when working with larger datasets, suggesting that approaches based on deep learning might simply need more data to learn something useful. Therefore, to unlock the full power of neural networks for tabular data, we should construct a large model that can adapt to the unique structure of the data, and then give it a lot of data from which to learn.


TabPFN and In-Context Learning

Fortunately, in recent years, progress has been made specifically along this direction. The most important breakthrough has been the Tabular Prior-Fitted Network5 – also known as TabPFN – which can be seen as a foundation model for tabular data. TabPFN shares a lot of similarities with foundation models for text and images:

  1. It is pre-trained on huge amounts of data, which is why “Prior-Fitted” is in the name.
  2. The underlying predictive model is transformer-based.

However, the training procedure for TabPFN is unique in that it uses only synthetically-generated data, whereas other foundation models are trained using data scraped from all over the internet. TabPFN’s synthetic pre-training data is designed to be incredibly diverse, meant to capture the inherent heterogeneity of tabular data as seen in the wild.

In-Context Learning

Another useful quirk of the TabPFN training procedure is that it ingests entire datasets at once, providing predictions on held-out data by dynamically adapting to the relationships between columns in the training dataset. TabPFN is therefore performing in-context learning, wherein the structure of training data passed to the transformer as context is automatically inferred and used when providing predictions. In-context learning was first discovered as an emergent phenomenon in Large Language Models (LLMs) such as ChatGPT: quite remarkably, it is possible to provide some training examples for a prediction task in the form \((x, y)\) to the prompt / context of the LLM, and then get the LLM to provide a prediction \(y^*\) for a new test example \(x^*\), despite neither being trained directly for this particular task nor for in-context learning more generally. On the other hand, TabPFN is directly trained to perform in-context learning specifically for tabular datasets, although still not trained for any particular classification task. The TabPFN training procedure backed by in-context learning results in a model that can automatically understand new datasets.

Figure 1 – Depiction of tabular in-context learning as performed by TabPFN. In this case, we want to predict which player \(y^*\) produced the statistics \(x^*\) in the Query, using the history of (statistics, player) tuples \((x, y)\) provided as Context. Based on the knowledge acquired through pre-training, plus the new dataset provided as context, the model is able to predict that it is most likely Player A that produced these statistics.

Context Length Limitation

TabPFN has been competitive with tree-based techniques in modelling completely unseen tabular data, especially when the dataset size is small: training a tree-based model from scratch on a small amount of data is challenging. TabPFN also greatly outperforms LLMs adapted to tabular in-context learning; LLMs, designed specifically to understand the relationships between words, are unable to understand the content of numerical columns in a dataset.  However, TabPFN-style in-context learning is quite limited by context length, i.e., the number of examples provided as context, since TabPFN was originally limited to 1000 training examples. This is problematic as many applications – including banking – can have tabular datasets which are orders of magnitude larger than that. We expect future foundation models for tabular data powered by in-context learning to suffer from the same limitations, and thus we set out to lift these restrictions and unleash the power of transformer-based architectures on larger tabular datasets.


LoCalPFN: A New Approach to Tabular Data

Since context size is the bottleneck in existing approaches on larger datasets, one way to improve performance is to be more efficient about what goes into the context. This leads us to our first contribution: retrieve only the most relevant points for the context, as opposed to just randomly selecting contexts. To illustrate this point, polling 1000 Torontonians about the best place to grab lunch in Toronto will get you much better results than polling 1000 random people spread across the globe! We show that this intuition holds when applying retrieval on top of TabPFN in Figure 2.

Figure 2 – Displaying the benefits of using retrieval to enhance TabPFN using a concentric circle dataset with two classes. On the left, we see that TabPFN is unable to properly classify the dataset, drawing a very simple decision boundary that places everything within one large circle in the orange class, and everything else in the blue class. On the right, we see that using retrieval on top of TabPFN with 100 nearby context points can properly capture the structure of the dataset, providing better predictions.

Then, with a careful choice of retrieval mechanism, we can fine-tune the underlying foundation model (here, TabPFN) end-to-end to further improve performance on downstream tasks. This approach helps the model adapt better to the specific characteristics of a dataset, creating what we call a “locally-calibrated PFN”, or LoCalPFN. LoCalPFN represents a clear improvement on TabPFN for large tabular datasets, and we describe it in more detail now.

The Need for Approximate Retrieval to Facilitate Fine-Tuning

LoCalPFN uses two fundamental techniques to improve the performance of TabPFN: retrieval to select relevant points for the context and form a “local neighbourhood” of data that shares similar characteristics, and fine-tuning to adjust the model’s parameters to enhance its performance for the specific classification task at hand. However, naively fine-tuning a tabular transformer architecture with retrieved samples in-context is quite inefficient: supposing we would like to predict and pass gradients on \(N_\text{qy}\) samples, the naive approach would require performing forward and backward passes through the network on \(N_\text{qy}\) different neighbourhoods each of size \(L_\text{ctx}\) (context length). The memory requirements to do this are quite large, considering typical values such as \(L_\text{ctx} = 1000\) and \(N_\text{qy} = 100\). We therefore require an approximate retrieval scheme to facilitate our fine-tuning.

Details of our Novel Retrieval Scheme

The core idea behind our approximate retrieval approach is to greatly reduce the number of unique context neighbourhoods that we are providing to the transformer. Assuming again that we would like to use \(N_\text{qy}\) examples for each training step, we propose to use only \(B \ll N_\text{qy}\) different contexts of size \(L_\text{ctx}\), with each context classifying \(L_\text{qy} = N_\text{qy} / B\) query points. To accomplish this, we first sample \(B\) unique points from the dataset. We next form a neighbourhood of size \(k = L_\text{ctx} + L_\text{qy}\) around each of the \(B\) points using their \(k\)-nearest neighbours in Euclidean (L2) distance. Finally, we randomly shuffle each of these size-\(k\) neighbourhoods into a context set of size \(L_\text{ctx}\), and a query size of \(L_\text{qy}\), and stack these \(B\) size-\(k\) neighbourhoods into a tensor of size \((B, k, d)\) – with \(d\) as the number of features in the data – to input to the transformer.

The parameter \(B\) controls the trade-off between neighbour accuracy and computationally efficiency: lower \(B\) means we share contexts between many points, but this comes at the cost of an approximation in the neighbour search as the notion of neighbourhood is not transitive, i.e., the neighbour of your neighbour might not be your neighbour. 

We also point out that this approximate retrieval scheme is only performed during training. At test time, since we are not performing any backward passes of the network, we can afford to have distinct contexts and neighbourhoods for each point in the test set.

Approximate Retrieval Algorithm and Figure

We illustrate the approximate retrieval technique in Figure 3, for one sample among the set of \(B\) chosen at each training step:

  1. We sample one of the \(B\) points from the training set; let’s call it \(x_b\).
  2. We compute the \(k\)-nearest neighbours \(x_b\) using Euclidean distance in the training set, represented by the points inside the black line. Note that the \(k\)-nearest neighbours of a given point will typically contain points from several classes, which is represented by the different colours of the points in the figure.
  3. We randomly shuffle the neighbourhood to split into context points and query points. Note that, once the neighbourhood is selected, there is no longer a distinction between the original point \(x_b\) and any of the other points in the neighbourhood, and so \(x_b\) could appear in either the context set or the query set.
Figure 3 – Depiction of the approximate retrieval scheme for one sampled point \(x_b\) highlighted in red.

In-Context Learners Beat Tree-Based Models

To test the quality of LoCalPFN for modelling tabular data, we evaluated its classification performance on 95 diverse datasets curated from the TabZilla6 suite. These datasets are pulled from domains such as finance, biology, games, banking, industrial applications, or natural signals such as vision or sound, and range in size from roughly 100 to 200,000 instances, and 2 to 100 features, We compared LoCalPFN to a variety of baselines, including TabPFN itself, TabPFN with a retrieval component shown in the concentric circle figure above (TabPFN-kNN), popular tree-based techniques such as XGBoost, Random Forest, and LightGBM, and some deep learning approaches. The full set of benchmarks and comparisons can be found in the paper.

Scaling Benefits of LoCalPFN

We would mainly like to highlight the scaling of LoCalPFN with both dataset size and complexity. In both cases we bin the datasets into groups based on either their size or complexity (the complexity measure is also detailed in the paper), and calculate the average Area Under the ROC Curve (AUC) on that particular bin relative to the average AUC of Random Forest to measure the classification performance; we choose relative AUC for clarity as there is no clear correlation between the maximum AUC attainable on a dataset, and either its size or complexity. In Figure 4, we can see that TabPFN really starts to struggle as we scale: with both increasing dataset size and complexity, the performance of TabPFN degrades. Meanwhile, LoCalPFN is improving relative to Random Forest as both dataset size and complexity increase, demonstrating that we have unlocked the capability to scale tabular in-context learners.

TabPFN-kNN – our retrieval-aided TabPFN variant – performs better than TabPFN, but not LoCalPFN, showing the importance of fine-tuning on specific tasks. TabPFN-kNN however does represent a lower-cost option to improve TabPFN and therefore remains interesting in its own right.

Figure 4 – Performance of LoCalPFN, TabPFN with retrieval (TabPFN-kNN), TabPFN, and tree-based methods. LoCalPFN has the best scaling both as a function of dataset size and complexity.

Looking Ahead

LoCalPFN opens new possibilities for improving predictive performance on tabular data problems across industries, driving practical value to its users. We anticipate that as tabular foundation models become more mainstream, the core techniques of retrieval and fine-tuning described here will become even more relevant, much like how retrieval-augmented generation7 is now crucial to many industrial applications of LLMs despite being tested originally on only one language model.

The story is not yet finished though! LoCalPFN is not perfect, and we are working to further improve the state of deep learning in tabular data by addressing the limitations described below.

  1. LoCalPFN & TabPFN cannot do regression. The tasks on which we tested LoCalPFN were all classification-based because TabPFN cannot natively perform regression. We can consider some workarounds such as “regression-as-classification” – i.e., discretize the continuous target and predict the target “bin” instead – but such an approach is itself limited by the restriction on number of classes imposed by TabPFN (a maximum of 10). Some initial experiments did show promise using LoCalPFN like this, but this is more of a hack than a fundamental solution. 
  2. Restrictions on number of features and classes. We have already mentioned that TabPFN is restricted to only predicting up to 10 classes, but it is also worth noting that TabPFN can currently only process up to 100 features. LoCalPFN inherits this limitation from TabPFN too.
  3. Dependence on TabPFN as a base model. The limitations above are a direct consequence of the limitations of TabPFN. These are microcosms of the overall limitation imposed by the dependence of LoCalPFN on TabPFN. At this point, however, TabPFN is the only tabular foundation model that provides competitive in-context learning capabilities.

The main approach we are using to address these limitations is to train our own tabular foundation model, using many of the ideas of TabPFN but with training and architecture improvements to lift the limitations described above. At Layer 6, we remain quite interested in solving the idiosyncrasies of tabular data and continue to push the boundaries of what is possible with tabular foundation models. LoCalPFN thus represents but one step on our overall journey. Stay tuned for more!


References

  1. Thomas, V., Ma, J., Hosseinzadeh, R., Golestan, K., Yu, G., Volkovs, M. and Caterini, A., 2024. Retrieval & Fine-Tuning for In-Context Tabular Models. In Advances in Neural Information Processing Systems (NeurIPS). ↩︎
  2. Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J. and Krueger, G., 2021. Learning Transferable Visual Models from Natural Language Supervision. In International Conference on Machine Learning (ICML). ↩︎
  3. Chen, T. and Guestrin, C., 2016. XGBoost: A Scalable Tree Boosting System. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. ↩︎
  4. Breiman, L., 2001. Random Forests. Machine Learning45, pp.5-32. ↩︎
  5. Hollmann, N., Müller, S., Eggensperger, K. and Hutter, F., 2023. TabPFN: A Transformer that Solves Small Tabular Classification Problems in a Second. In International Conference on Learning Representations (ICLR). ↩︎
  6. McElfresh, D., Khandagale, S., Valverde, J., Prasad C, V., Ramakrishnan, G., Goldblum, M. and White, C., 2023. When do neural nets outperform boosted trees on tabular data? In Advances in Neural Information Processing Systems (NeurIPS) Datasets and Benchmarks Track. ↩︎
  7. Lewis, P., Perez, E., Piktus, A., Petroni, F., Karpukhin, V., Goyal, N., Küttler, H., Lewis, M., Yih, W.T., Rocktäschel, T. and Riedel, S., 2020. Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks. In Advances in Neural Information Processing Systems (NeurIPS). ↩︎