III: Scaling to Deep Learning
In the previous chapter, we saw that in classical regimes, there are reliable methods for (a close statistical analog of) the predictive data attribution problem. In this section, we’re going to explore predictive data attribution in modernlarge scale settings—it turns out that the landscape here is quite a bit more complex.
A first attempt: directly applying the influence function
Recall from Part II the leaveoneout problem:
\[LOO(j) = \left(\arg\min_\theta \sum_{i=1}^n \ell_i(\theta)\right)  \left(\arg\min_\theta \sum_{i\neq j} \ell_i(\theta)\right),\]and its corresponding “influence function estimate”
\[\widehat{LOO}_{\text{IF}}(j) = \left(\sum_{i=1}^n \nabla^2 \ell_i(\theta)\right)^{1} \cdot \nabla \ell_j(\theta).\]We found that this estimator was remarkably good and can even come with provable guarantees when the loss functions $\ell_i(\cdot)$ are strongly convex. In light of this, a very natural first question to ask is whether we can simply apply the exact same estimator in the context of deep learning.
On one hand, many of the assumptions we made in Part II are notably violated in deep learning—for example, neural networks are nonconvex problems; we typically do not train them to convergence; and the parameters, as far as we can tell, do not seem to be inherently meaningful. These violations pose both technical problems (e.g., how do we invert a possibly noninvertible Hessian?) as well as conceptual ones (e.g., if the global solution is not unique, what is the influence function even computing [BNL+22]?
On the other hand, many of the most successful techniques in deep learning use intuition from classical problems, and work in practice even when major assumptions are violated (see, e.g., the use of momentum in convex optimization). Inspired by success stories like these, we might want to try ignoring these issues for now, and applying the same Taylorbased approximations from Part II (e.g., the influence function)It turns out that, assumption violations aside, even just computing the IF estimate is highly nontrivial. This is, in part, due to highdimensionality and noninvertibility of the Hessian. We’ll sweep these issues under the rug for now, and just assume we are able to compute the relevant quantities. .
A pioneering paper by Koh and Liang [KL17] takes exactly this approach. They develop clever techniques to compute influence function (IF) estimates (adapting the estimator above to estimate changes in model predictions, rather than model parameters), and show that the corresponding estimates work well for small MLPs and for linear models trained on the last layer representations of an Inceptionv3 model. In particular, the leaveoneout estimates they get correlate remarkably well with true leaveoneout effects (computed by retraining the model):
(Figure 2 from [KL17]). Right shows fidelity of IF approximation for a small CNN trained on MNIST
In an accompanying Jupyter notebook, we present a simple (~200 lines of code) implementation of this method, highlighting the main technical ideas from [KL17].
[KL17] and subsequent works leverage the influence function for a variety of applications, such as identifying mislabeled or poisoned data, debugging model behavior, or cleaning datasets, suggesting their potential utility. Often, the attributions appear to pass “sanity checks” like visual similarity—the most “influential” examples bare some resemblance to the target example.
However, other subsequent works have found that under closer scrutiny, variants of these approaches are not all that reliable for DNNs:
 [BPF’21] find that the IF estimates are extremely fragile, depending on hyperparameters (regularization), and things generally get worse with wider or deeper networks:
reproduced from Basu et al.
 [HL’22] find that existing methods fail at even simple sanity checks: if you compute the influence scores on dataset A and B (e.g., CIFAR10 and MNIST) to a target example from A, often the most influential examples according to existing attributions are from B rather than A!
 If we think a bit more carefully, it’s not even clear what these influence estimates are even predicting, as so many assumptions are violated. Bae et al. [BNL+22] carefully analyze these issues, and suggest that in the context of deep neural networks, influence functions actually estimate a different quantity entirely than $LOO$.
More broadly, these findings highlight that we cannot use heuristics like human inspection or visual similarity to judge the quality of predictive attributions. After all, the reason we are trying to attribute model behavior to data in the first place is precisely because models often behave in strange ways!
What do we do next? On the one hand, we could fiddle around more with the influence function estimates. Perhaps there’s a better way to apply it to ML contexts. Or perhaps the approach is doomed and we need to use other MLspecific techniques to estimate LOO. These are all reasonable approaches and we’ll actually talk about some of them later, but right now we have a more fundamental problem.
It’s clear that what we have right now isn’t reliable—we have failed sanity checks and theoretical analysis to show this. But in coming up with new methods, how will we know when we’re making progress?
Quantifying predictive data attribution
Before we can blaze ahead designing better methods for predictive data attribution, we need a good way of knowing when we’re actually making progress. There are many things we could do here: we could reuse some of the sanity checks from above (as in HL20); we could try to measure paramterspace differences (as in BPF20); or we could look at leaveoneout scores (as in [KL17; BNL+22]). Recall from Part I, however, that our goal at the end of the day is to predict how training data influences final model behavior.
Inspired by this goal, we formalize an idea that’s been floating around in the literature on training data attribution: the most direct thing we can do to evaluate a predictive data attribution metric is measure its prediction error relative to model retraining. The result is what we call the Linear Datamodeling Score [IPE+’22; PGI+23] (LDS). First, recall from Part I the definition of a predictive data attribution method:
Definition [Predictive data attribution method]. Consider a universe of data $\,\mathcal{U}$, a training algorithm $\smash{\theta(\cdot): 2^\mathcal{U} \to \mathcal{M}}$ mapping training sets $S$ to machine learning models, and a measurement function $\ell: \mathcal{M} \to \mathbb{R}^k$ recording some property of a machine learning model (e.g., its loss on a set of test examples). A predictive data attribution is a function $\hat{f}: 2^\mathcal{U} \to \mathbb{R}^k$ that directly approximates the result of training a model and evaluating $\ell$, i.e., $\hat{f}(S) \approx \ell(\theta(S))$ for all $S \subset \mathcal{U}$.
Note that the definition above is given in terms of all subsets $S \subset \mathcal{U}$. While this is indeed the end goal, in practice predictive data methods aren’t quite there yet—so instead, we’re going to measure the average performance of data attribution methods across a distribution over training sets $S$:
We’ll often be interested in the “perexample” LDS, which is when the measurement function $\ell$ measures the loss of a model on a specific test example $z$. We generally denote the resulting examplespecific data attribution as $\smash{\hat{f}(z; S)}$.
Let’s digest this definition with some code. First, we create subsets $\lbrace S_1,\ldots,S_m\rbrace $:
# 1. create 10 random subsets of the training set
# (each containing a half of the samples)
import numpy as np
np.random.seed(42) # fix random seed for reproducibility
train_set_subsets = []
for i in range(10):
subset = np.random.choice(range(50_000), 25_000, replace=False)
train_set_subsets.append(subset)
Then, to obtain the true model output $\ell(\theta(S))$, we need to first train a model on the subset $S_j$, and then evaluate it on a target example $z$:
# 2. train a model on each subset and record
# its output on a target example of choice
from utils import train_on_subset, record_output, get_loader
val_loader = get_loader(split="val")
z = val_loader.dataset[0][0] # target example
outputs_per_subset = []
for subset in train_set_subsets:
model = train_on_subset(subset)
out = record_output(model, target_example)
outputs_per_subset.append(out)
Finally, we use our data attribution method of choice to get predicted model output $\smash{\hat{f}(z, S_j)}$ for each subset $S_j$, and report the rank correlation between the two arrays:
# 3. evaluate the rankcorrelation between the true model
# outputs and the predictions from our attribution method
from scipy.stats import spearmanr
predictions_per_subset = get_attribution_scores(model, S_j, z)
LDS = spearmanr(outputs_per_subset, predictions_per_subset)
print(f'LDS: {LDS.correlation:.3f} (p value {LDS.pvalue:.6f})')
For convenience, we have put together the above snippets in this Jupyter notebook.
With this formalization, our goal is to design data attribution methods $\hat{f}$ that achieve a high LDS.In practice, evaluating LDS requires training many models (~100) on different training datasets. Though expensive, note that this cost is upfront; it is independent of the attribution method, so we can train these models and record their outputs in advance, and then evaluate any subsequently attributions. The LDS has been used recently to evaluate credibility of attributions across various domains [PGI+23; GVS+23; DZM23; ZPD+23; BLLG24; DLZM24 ].
Revisiting our initial estimator
Now that we have a clearer picture of what to evaluate, we can revisit and evaluate some of the initially proposed estimators. Below, we evaluate the influence function estimator [KL17] computed for ResNet9 classifiers trained on CIFAR10:
We observe that the LDS is very low. In other words, this means that the influence function on its own cannot reliably predict what happens to predictions when we train models on random subsets of a larger training pool. (Note that this doesn’t preclude this estimator from being useful—LDS only evaluates methods in their capacity as predictive data attribution techniques.)
Is this even possible? Direct estimation
Note that the influence function estimator we derived in Part II is linear in that we can decompose it into the following form:
\[\hat{f}(S) = \sum\limits_{i\in S} \tau_i\]But large model training (particularly for nonconvex models like neural networks) is an incredibly complex process—is it possible that we just can’t predict with a simple linear model?
To understand this better, we’re going to directly try to learn the datatomodeloutput mapping. That is, we treat our task as a statistical learning problem of $S’\mapsto \ell(\theta(S’))$.:
 Repeatedly train models $\theta(S_i)$ on random samples $S_i \subset S$, and collect the corresponding model outputs $\lbrace \ell(\theta(S_i))\rbrace $.

Now we have a simple regression problem from $2^{\vert S\vert}$ to $\mathbb{R}$: $\lbrace (S_i, f(x,S_i))\rbrace $. While we can consider any family of functions, [IPE+22] finds that perhaps the simplest instantiation of using linear models suffices, i.e., we can solve:
\[\min_{\tau \in \mathbb{R}^n} \frac{1}{m} \sum_{i=1}^m (\tau^\top \mathbb{1}_{S_i}  \ell(\theta(S_i)))^2 + \lambda \cdot \text{reg}(\tau)\]where $\mathbb{1}_{S_i} \in \lbrace 0,1\rbrace^N$ is the indicator vector encoding $S_i$ and $\text{reg}$ is a regularization term (e.g., $\ell_1$ norm to promote sparsity).
Interestingly, the resulting estimator ends up being very related to (i.e., a slight refinement of) previouslysuggested data attribution methods, both predictive and gametheoretic [GZ19, FZ20; IPE+’22; LZL+’22]!
Note that we estimate such a linear model for each choice of target input $z$. Here, we overload the notation so that $\tau(x)$ is the estimated datamodel for $x$ computed using the attribution method $\tau$.
The resulting datamodels achieve very high LDS (and generally perform better with more samples):
Before we go further, this establishes that our goal of predictive attribution is actually attainable in modern settings! This is a crucial point: a priori, it’s not clear we can model the complicated function $f$ with any fidelity.
Now, in some sense this direct approach should also give optimal estimates with infinite compute. The downside is that this procedure is often prohibitively expensive, as this requires many model retrainings (e.g., one needs ~10,000 models for the CIFAR10 dataset to get a decent LDS).
Predictive data attribution in modern settings
We saw that it is possible to find predictive linear models even in the DNN setting. Influence functions worked great for classical problems like logistic regression—so what changed?
There are a number of difficulties, including some we alluded to earlier:
 Nonconvexity: The Hessian of a DNN is in general no longer going to be positive definite, so its inverse will not be well defined. One can try accounting for this by adding regularization or other tricks, but it’s not clear if these posthoc fixes are the right approach.
 Randomness: DNNs are stochastic in nature; retraining a model with a different random seed will lead to a completely different set of parameters. This means that we have to more careful about reasoning in parameter space directly.
 Largescale/highdimensionality: Due to the high dimensionality of modern overparameterized models, exact secondorder computations are infeasible (i.e., it’s prohibitevly expensive to compute the exact Hessian, even if our objective were convex)
Various methods have been proposed to deal with some of the problems above. Next, we’ll give a broad overview of the main approaches.
A birdseye view of modern predictive data attribution
Broadly, there are a few lines of approach that people have come up with. Namely:
 Better IF/Hessian approximations
 Approximating training dynamics (“unrolling”)
 Surrogate models
While the direct estimator earlier treated the DNN as a blackbox, each of these approaches try to approximate the blackbox in some ways. We’ll now give intuition for each of these approaches, focusing on one or two successful methods for each approach.
I. Approach: Better IF approximations. Some methods directly tackle the original IF approximation. There are many different variations on this theme, but a key trends that have emerged are:

Using the “GaussNewton Hessian” (GNH) approximation to the true Hessian,
\[G := \mathbb{E}[J^T H_{\hat{y}} J],\]where $J = d\hat{y}/d\theta$ is the parameteroutput Jacobian and $H_{\hat{y}}$ is the Hessian of the training loss w.r.t. to the network’s outputs. This can be interpreted as computing the Hessian of the linearized network. This change addresses nonconvexity, as this new Hessian is guaranteed to be positivesemidefinite.

Further structural approximations that make above more tractable: For example, the KFAC [Martens Grosse ’15] based estimation assumes that the gradients across different layers, and further between the activations and the inputs to linear layers, are independent. This introduces a “blockdiagonal” structure that allows one to compute the inverse of the Hessian by inverting each block separately, significantly speeding up the inverseHessianvector products needed to compute the IF estimates. Furthermore, some type of spectral regularization is also employed (approximating each block by its top eigenspace).
There is a rich line of work on Hessianfree optimization that provides some intuition for why the above approximations are okay—and perhaps even desirable over the original exact Hessian [Martens ’14].
In summary, these methods tackle the challenges by finding better and more efficient structural approximations to the Hessian, the key object in the IF estimate.
For an implementation of this method, check out the Kronfluence repository.
II. Approach: Approximating training dynamics. The methods thus far only look at the final trained model, but doesn’t account for the trajectory of how we got there. Given the unsuitability of the IF approximation and our understanding of the importance of implicit biases, it makes sense to look at the entire trajectory. At a high level, “unrolling” based methods try to “trace” the training dynamics across multiple checkpoints to predict the counterfactual.
To begin, recall that the IF approximation is actually computing this derivative (see Chapter 2 for details): \(\frac{d\theta^{(T)}}{ {dw_j}_{\phantom{t}}}\) the infinitesimal effect of upweighting example $j$ by $\epsilon$ on final parameters $\theta^{(T)}$. But this required a lot of assumptions on $\theta^{(T)}$ and the loss function. So instead, we’re going to try to compute the above quantity by exactly tracing the entire trajectory of parameters $\lbrace \theta^{(t)}\rbrace _{t=1,…,T}$.
Let’s look at an example of analyzing fullbatch gradient descent, where we perform the updates:
\[\theta^{(t+1)} = \theta^{(t)}  \eta_t \cdot \sum_{i=1}^n \nabla \ell_i(\theta^{(t)})\]
First, we start by looking at something simpler, the effect of upweighting the example $j$ by $\epsilon$ only at iteration $t$ on final model parameters $\theta_T$:
\[\frac{d\theta^{(T)}}{dw_j^{(t)}} := \frac{d\theta^{(T)}}{d\theta^{(T1)}}\cdots \frac{d\theta^{(t+2)}}{d\theta^{(t+1)}} \cdot \frac{d\theta^{(t)}}{dw_j^{(t)}}\]which is simply via the chain rule from calculus.
Now, given the gradient update rule, we can compute the quantities in the above expression in closed form. For the rightmost term:
\[\frac{d\theta^{(t)}}{dw_j^{(t)}} =\eta_t \nabla_\theta \ell(z_j, \theta_k)\]For all intermediate terms, we can derive by simply differentiating the gradient update rule w.r.t. $\theta^{(t)}$:
\[\frac{d\theta^{(t+1)}}{d\theta^{(t)}} = (I  \eta_{t1} H_{t1})\]where $H_t$ is the Hessian at time step $t$.
Substituting these back into original expression, we have:
\[\frac{d\theta^{(T)}}{dw_j^{(t)}} = (I  \eta_{T1} H_{T1}) \cdots (I  \eta_{t} H_t) \cdot \eta_t \nabla_\theta \ell(z_j, \theta_k)\]where $H_t$ is the Hessian at time step $t$.

Next, by summing over the above expression for all iterations $t$, we can the total effect of the entire training process w.r.t. $\epsilon$:
\[\frac{d\theta_T}{d\epsilon} = \sum\limits_{t=0}^T \frac{d\theta_T}{d\epsilon_t}\]which is just due to the total derivative rule.
At the end, this gives an exact closed form expression! However, computing it (naively) would require storing checkpoints at every iteration and storing many Hessian matrices (or at the least, computing many Hessianvector products), so it is not really tractable to compute. So different methods aim at approximating the above quantity:
A firstorder approximation: One approximation is a popular method called TracIn [PLSK20], which ignores all the secondorder effects (the $\frac{d\theta^{(t1)}}{d\theta^{(t)}}$ terms). This simplification removes all the Hessians and requires you to keep only track of the gradients. Then, the overall influence on final model parameters can be approximated as the sum: \(\sum_{t=0}^T \nabla_\theta\ell(x_j,\theta^{(t)})\) However, this introduces significant bias into the estimates; indeed, if we consider just a linear model, we can find that TracIn already gives a very skewed estimate (an exercise for the reader :).
A better approximation. A more recent method SOURCE [BLLG24] does a much more faithful approximation. At a highlevel, the method breaks the time steps into a few segments, and assumes that the Hessian is approximately constant within each segment. This allows you to get by only storing a few intermediate checkpoints. In the end, this approximation yields very good estimates (e.g., outperforms other approaches on recent LDS benchmarks).
Other works such has Simfluence [GWP+23] take a more modelingbased approach and model the entire trajectory with a Markov process.
To summarize, the main idea behind these methods address nonconvexity by more directly modeling the sequential training process of DNNs. These methods also implicitly deal with randomness by averaging over multiple checkpoints. However, keeping track of all the information about optimization (gradients, Hessians) gets quickly unwieldy, so gradients, so different methods leverage various approximations.
III. Approach: Surrogate models. For simple models (such as the linear models we saw in Part II, or other simple classes like kNN), we actually do know how to do predictive data attribution. To that end, the idea behind this final approach is to find a proxy or surrogate model that approximates the neural network, and compute data attributions on the proxy model.
For example, a line of work uses kNN classifiers as a proxy for the original model [JDW+19]. Applying influence functions to the last layer embeddings (used in a few works, e.g., Koh Liang ’17) can also be viewed as using a linear proxy model on those representations.
Here, we briefly describe one particularly successful approach called TRAK. At a high level, TRAK reduces to a generalized linear model (GLM)—e.g., logistic regression—and applies the influence function approximation to the GLM. We can break this down into three steps:
 Linearize the original DNN: \(h(x,\theta) \approx h(x,\theta^\star) + \nabla_\theta h(x;\theta^\star)\cdot (\theta\theta^\star)\) This allows us to reason about a linear model whose features are given by $\phi(x) := \nabla_\theta h(x;\theta^\star)$. The motivation for such can approximation is loosely based on recent works on the eNTK approximation [L21; WHS22; MWY+23].
 We can further reduce this to a more tractable lowdimensional model by randomly projecting the features: $P^T \phi(x) \mapsto \tilde{\phi}(x)$ using some projection matrix $P$. Intuitively, as long as we approximately preserve the geometry, the proxy model won’t change too much from the true linearized model.
 We can then apply the standard IF approximation (from Part II) to this corresponding linear model to perform predictive data attribution.
 Lastly, we can ensemble the earlier steps over multiple checkpoints and projections.
Despite the number of approximations involved, the resulting estimates are surprisingly good (in terms of LDS), and allow us to approximate direct estimators while being orders of magnitude more efficient.It turns out that one can also view TRAK as an (efficient) approximation of the GGN Hessian; in particular, one can view TRAK as an “agnostic” approximation leveraging randomness rather than the structural approximations made by EKFAC and related methods. As shown in the plots below, these have different tradeoffs.
For a fast and flexible implementation of TRAK, check out the code on GitHub. For this tutorial, we also provide a simple & hackable implemenation in this Jupyter notebook.
Evaluating the landscape
To get a sense of how these different methods perform on different target tasks, we evaluate a representative set of these methods across three settings: ResNet9 models trained on CIFAR10, BERT models finetuned on QNLI, and GPT2 models finetuned on WikiText. On the yaxis, we evaluate the effectiveness of the attribution method using the LDS, and on the xaxis we quantify its efficiency based on GPU runtime.
(Note: SOURCE is missing from the figure above; see their paper for comparisons.)
We observe the following trends:
 Popular baselines commonly used (original adaptation of influence functions to DL, representation similarity, and TracIn) not very predictive, regardless of compute.
 Direct estimators (e.g., regressionbased datamodels) perform best with enough compute.
 Recent efficient methods (e.g., TRAK, EKFAC) manage to approach the performance of direct estimators in some cases.
 The best method (in terms of LDS vs time tradeoffs) depends on the target task: e.g., TRAK is better on vision, EKFAC better for language modeling (we suspect this is due to the difference in assumptions used by TRAK vs EKFAC, and differences in CNNs and Transformers).
Overall, we do have fast, predictive, and reliable methods that seem to adapt well to different modalities!
Takeaways & (lots of) future work
We covered a lot of ground in this chapter, so let’s briefly recap the main takeaways:
 Evaluate countefactuals: If your goal is predictive attribution, something like the LDS is a good metric to go by. Practically, using the LDS can be a good way of choosing the best method for your given application (e.g., as different methods have higher LDS on different domains).
 Use good attribution methods: Use recent methods that are predictive, not ones that we know no longer works.
 Choose method appropriate to modality: Beyond the original settings in which these methods were developed, newer techniques now adapt them pretty reliably across different modalities (e.g., language modeling [PGI+23; GBA+23], diffusion, etc.) Nice thing is that latest methods are pretty general and adapt readily to different settings outofthebox.
 Sanity check attribution method in simpler settings. Often, analyzing whether your new attribution method works in a simpler setting like a linear model can help us rule out some approaches easily (e.g., similarity).
While the field has matured a lot over the past several years, there’s still many interesting open problems. We end with some ideas for future research on data attribution:
 More predictive methods: While we have reasonably predictive methods, they are far from what “oracle” estimates such as the direct approach yields. Can we improve on above approaches or use new ideas to get even better estimates? For example, are there better proxy models for DNNs?
 Understanding beyond linearity: Most existing attribution methods use linearity (and they are surprisingly effective), but can we go beyond linear? Simple toy models suggest that modeling data interactions linearly cannot provide the full picture, so we must look for new methods that incorporate nonlinearity
 “Single model counterfactual”: For most of this tutorial, we focused on attributing the average output of models (across retraining with different random seeds). But often in practice, we have a particular model we care about (e.g., GPT4TurboApril03). How do we adapt methods and evaluations to such settings?
 Multiple training stages: Modern models are trained via multiple stages (e.g., LLMs go through pretraining, then supervised finetuning and RLHF). Can we find methods that can flexibly deal with such scenarios? Recent methods provide some ideas.
 More efficient evaluation: To even know whether an attribution is working at all requires some type of counterfactual retraining, which quickly gets expensive. Can we find more efficient proxies to gauge whether attributions are predictive?