This is a post about a new paper Online Multivalid Learning: Means, Moments, and Prediction Intervals, that is joint work with Varun Gupta, Christopher Jung, Georgy Noarov, and Mallesh Pai. It is cross-posted to the new TOC4Fairness blog. For those that prefer watching to reading, here is a recording of a talk I gave on this paper.
Suppose you go and train the latest, greatest machine learning architecture to predict something important. Say (to pick an example entirely out of thin air) you are in the midst of a pandemic, and want to predict the severity of patients' symptoms in 2 days time, so as to triage scarce medical resources. Since you will be using these predictions to make decisions, you would like them to be accurate in various ways: for example, at the very least, you will want your predictions to be calibrated, and you may also want to be able to accurately quantify the uncertainty of your predictions (say with 95% prediction intervals). It is a fast moving situation, and data is coming in dynamically --- and you need to make decisions as you go. What can you do?
The first thing you might do is ask on twitter! What you will find is that the standard tool for quantifying uncertainty in settings like this is conformal prediction. The conformal prediction literature has a number of elegant techniques for endowing arbitrary point prediction methods with marginal prediction intervals: i.e intervals $(\ell(x), u(x))$ such that over the randomness of some data distribution over labelled examples $(x,y)$: $\Pr_{(x,y)}\left[y \in [\ell(x), u(x)]\right] \approx 0.95$ These would be 95% marginal prediction intervals --- but in general you could pick your favorite coverage probability $1-\delta$.
Conformal prediction has a lot going for it --- its tools are very general and flexible, and lead to practical algorithms. But it also has two well known shortcomings:
- Strong Assumptions. Like many tools from statistics and machine learning, conformal prediction methods require that the future look like the past. In particular, they require that the data be drawn i.i.d. from some distribution --- or at least be exchangable (i.e. their distribution should be invariant to permutation). This is sometimes the case --- but it often is not. In our pandemic scenario, the distribution on patient features might quickly change in unexpected ways as the disease moves between different populations, as might the relationship between features and outcomes, as treatments advance. In other settings in which consequential decisions are being made about people --- like lending and hiring decisions --- people might intentionally manipulate their features in response to the predictive algorithms you deploy, in an attempt to get the outcome they want. Or you might be trying to predict outcomes in time series data, in which there are explicit dependencies across time. In all of these scenarios, exchangeability is violated.
- Weak Guarantees. Marginal coverage guarantees are averages over people. 95% marginal coverage means that the true label falls within the predicted interval for 95% of people. It need not mean anything for people like you. For example, if you are part of a demographic group that makes up less than 5% of the population, it is entirely consistent with the guarantees of a 95% marginal prediction interval that labels for people from your demographic group fall outside of their intervals 100% of the time. This can be both an accuracy and a fairness concern --- marginal prediction works well for "typical" members of a population, but not necessarily for everyone else.
What kinds of improvements might we hope for? Lets start with how to strengthen the guarantee:
Multivalidity Ideally, we would want conditional guarantees --- i.e. the promise that for every $x$, that we would have $\Pr_{y}\left[y \in [\ell(x), u(x)] | x \right] \approx 0.95$. In other words, that somehow for each individual, the prediction interval was valid for them specifically, over the "unrealized" (or unmeasured) randomness of the world. Of course this is too much to hope for. In a rich feature space, we have likely never seen anyone exactly like you before (i.e. with your feature vector $x$). So strictly speaking, we have no information at all about your conditional label distribution. We still have to average over people. But we don't have to average over everybody. An important idea that has been investigated in several different contexts in recent years in the theory literature on fairness is that we might articulate a very rich collection of (generally intersecting) demographic groups $G$ corresponding to relevant subsets of the data domain, and ask for things that we care about to hold true as averaged over any group $S \in G$ in the collection. In the case of prediction intervals, this would correspond to asking for something like that simultaneously for every demographic group $S \in G$, $\Pr_{(x,y)}\left[y \in [\ell(x), u(x)] | x \in S \right] \approx 0.95$. Note here that an individual might be a member of many different demographic groups, and can interpret the guarantees of their prediction interval as averages over any of those demographic groups, at their option. This is what we can achieve --- at least for any such group that isn't too small.
And what kinds of assumptions do we need?
Adversarial Data Actually, its not clear that we need any! Many learning problems which initially appear to require distributional assumptions turn out to be solvable even in the worst case over data sequences --- i.e. even if a clever adversary, with full knowledge of your algorithm, and with the intent only to sabotage your learning guarantees, is allowed to adaptively choose data to present to your algorithm. This is the case for calibrated weather prediction, as well as general contextual prediction. It turns out to be the case for us as well. Instead of promising coverage probabilities of $1-\delta + O(1/T)$ after $T$ rounds on the underlying distribution, as conformal prediction is able to, (for us there is no underlying distribution) we offer empirical coverage rates of $1-\delta \pm O(1/\sqrt{T})$. This kind of guarantee is quite similar to what conformal prediction guarantees about empirical coverage.
More Generally Our techniques are not specific to prediction intervals. We can do the same thing for predicting label means, and predicting variances of the residuals of arbitrary prediction methods. For mean prediction, this corresponds to an algorithm for providing multi-calibrated predictions in the sense of Hebert-Johnson et al, in an online adversarial environment. For variances and other higher moments, it corresponds to an online algorithm for making mean-conditioned moment multicalibrated predictions in the sense of Jung et al.
Techniques At the risk of boring my one stubbornly remaining reader, let me say a few words about how we do it. We generalize an idea that dates back to an argument that Fudenberg and Levine first made in 1995 --- and is closely related to an earlier, beautiful argument by Sergiu Hart --- but that I just learned about this summer, and thought was just amazing. It applies broadly to solving any prediction task that would be easy, if only you were facing a known data distribution. This is the case for us. If, for each arriving patient at our hospital, a wizard told us their "true" distribution over outcome severity, we could easily make calibrated predictions by always predicting the mean of this distribution --- and we could similarly read off correct 95% coverage intervals from the CDF of the distribution. So what? That's not the situation we are in, of course. Absent a wizard, we first need to commit to some learning algorithm, and only then will the adversary decide what data to show us.
But lets put our game theory hats on. Suppose we've been making predictions for awhile. We can write down some measure of our error so far --- say the maximum, over all demographic groups in $G$, of the deviation of our empirical coverage so far from our 95% coverage target. For the next round, define a zero sum game, in which we (the learner) want to minimize the increase in this measure of error, and the adversary wants to maximize it. The defining feature of zero-sum games is that how well you can do in them is independent of which player has to announce their distribution on play first --- this is the celebrated Minimax Theorem. So to evaluate how well the learner could do in this game, we can think about the situation involving a Wizard above, in which for each arriving person, before we have to make a prediction for them, we get to observe their true label distribution. Of course in this scenario we can do well, because for all of our goals, our measure of success is based on how well our predictions match observed properties of these distributions. The Minimax theorem tells us that (at least in principle --- it doesn't give us the algorithm), there must therefore also be a learning algorithm that can do just as well, but against an adversary.
The minimax argument is slick, but non-constructive. To actually pin down a concrete algorithm, we need to solve for the equilibrium in the corresponding game. That's what we spend much of the paper doing, for each of the prediction tasks that we study. For multicalibration, we get a simple, elementary algorithm --- but for the prediction interval problem, although we get a polynomial time algorithm, it involves solving a linear program with a separation oracle at each round. Finding more efficient and practical ways to do this strikes me as an important problem.
Finally, I had more fun writing this paper --- learning about old techniques from the game theoretic calibration literature --- than I've had in awhile. I hope a few people enjoy reading it!
No comments:
Post a Comment