Wednesday, August 19, 2020

Moment Multicalibration for Uncertainty Estimation

This blog post is about a new paper that I'm excited about, which is joint work with Chris Jung, Changhwa Lee, Mallesh Pai, and Ricky Vohra. If you prefer watching talks, you can watch one I gave to the Wharton statistics department here.

Suppose you are diagnosed with hypertension, and your doctor recommends that you take a certain drug to lower your blood pressure. The latest research, she tells you, finds that the drug lowers diastolic blood pressure by an average of 10 mm Hg. You remember your statistics class from college, and so you ask about confidence intervals. She looks up the paper, and tells you that it reports a 95% confidence interval of [5, 15]. How should you interpret this? 

What you might naively hope is that [5, 15] represents a conditional prediction interval. If you have some set of observable features $x$, and a label $y$ (in this case corresponding to your decrease in diastolic blood pressure after taking the drug), a 95% conditional prediction interval would promise that:
$$\Pr_y [y \in [5, 15] | x] \geq 0.95$$

In other words, a conditional prediction interval would promise that given all of your observed features, over the unrealized/unmeasured randomness of the world, there is a 95% chance that your diastolic blood pressure will decrease by between 5 and 15 points. 

But if you think about it, coming up with a conditional prediction interval is essentially impossible in a rich feature space. If $x$ contains lots of information about you, then probably there was nobody in the original study population that exactly matched your set of features $x$, and so we have no information at all about the conditional distribution on $y$ given $x$ --- i.e. no samples at all from the distribution over which our coverage probability supposedly holds! So how can you expect any sort of promise at all? There are two typical ways around this difficulty. 

The first is to make heroic assumptions about the data generation process. For example, if we assume that the world looks like an ordinary least squares model, and that there is a linear relationship between $y$ and $x$, then we can form a confidence region around the parameters of the model, and from that derive prediction intervals. But these prediction intervals are not valid if the model fails to hold, which it inevitably will. 

The second is to give up on conditional prediction intervals, and instead give marginal prediction intervals. This is what the conformal prediction literature aims to do. A marginal prediction interval looks quite similar to a conditional prediction interval (at least syntactically), and promises:
$$\Pr_{(x,y)} [y \in [5, 15] ] \geq 0.95$$

Rather than conditioning on your features $x$, a marginal prediction interval averages over all people, and promises that 95% of people who take the drug have their diastolic blood pressure lowered by between 5 and 15 points. But the semantics of this promise are quite different than that of a conditional prediction interval. Because the average is now taken over a large, heterogeneous population, very little is promised to you. For example, it might be that for patients in your demographic group (e.g. middle aged women with Sephardic Jewish ancestry and a family history of diabetes) that the drug is actually expected to raise blood pressure rather than lower it. Because this subgroup represents less than 5% of the population, it is entirely consistent with the marginal prediction interval being correct. Of course, if you are lucky, then perhaps someone has conducted a study of people from this demographic group and has computed marginal prediction intervals over it! But what if there are multiple different groups that you are a member of, over which the results seem to conflict? For example, you might also have a low BMI value and have unusually good cholesterol readings --- features of a group for which the drug works unusually well. Which uncertainty estimate should you trust, if you are a member of both groups? 

These concerns actually arise already when we think about the semantics of mean estimations ("the expected drop in blood pressure amongst patients who take this drug is 10 mm Hg"). Ideally, if you were a patient with features $x$, then 10 would be an estimate of $\mathbb{E}[y | x]$. But just as with uncertainty estimation, in a large feature space, we typically have no information about the distribution on $y$ conditional on $x$ (because we have never met anyone exactly like you before), and so instead what we have is just an estimate of $\mathbb{E}[y]$ --- i.e. averaging over people. If you have a method of making predictions $f(x)$ as a function of features $x$, then a standard performance metric is calibration --- which informally asks that for every prediction $p$, amongst all people for whom we predicted $f(x) = p$, the average of the realized labels $y$ should be $p$. Again, estimates of this form promise little to individuals, because they are averages over a large and heterogeneous population.   

Several years ago, Hebert-Johnson et al. proposed a nice way to interpolate between the (impossible) ideal of offering conditional mean predictions  $f(x) = \mathbb{E}[y | x]$, and the weak guarantee of merely offering calibrated predictions $f$. Roughly speaking, they proposed to specify a very large collection of potentially intersecting groups $G$ (representing e.g. demographic groups like Sephardic Jewish women with a family history of diabetes, and hypertensive patients with low cholesterol and BMI values, etc) and to ask that a trained predictor be simultaniously calibrated on each sufficiently large group in $G$. They showed how to accomplish this using a polynomially sized sample from the underlying distribution, with polynomial running time overhead, on top of the cost of solving learning problems over $G$. 

In our paper, we --- roughly speaking --- show how to accomplish the same thing, but for variances and other higher moments, in addition to just means. And our "multicalibrated moment estimates" can be used to construct prediction intervals in exactly the same way that real moments of the conditional label distribution could be used. If you used the real (unknown) label distribution moments, you would have gotten conditional prediction intervals. If you use our multi-calibrated moments, you get marginal prediction intervals that are simultaneously valid as averaged over each of the groups in $G$. So, for example, our hypertensive patient above could interpret her prediction interval --- if it was constructed from multicalibrated moment estimates computed from her features --- as an average over each of the demographic groups that she is a member of (so long as they are contained within $G$), and all of those interpretations would be simultaneously valid. 

I'll leave the details to the paper --- including what exactly we mean by "moment multicalibration". I'll just note that a major difficulty is that variances and higher moments --- unlike expectations --- do not combine linearly, so it is no longer sensible to ask that "amongst all people for whom we predicted variance v, the true variance should be v" --- because even the true conditional label variances do not satisfy this property. But it is sensible to ask that a pair of mean and moment predictions be calibrated in this way: "amongst all people for whom we predicted mean $\mu$ and variance v, the true mean should be $\mu$ and the true variance should be $v$." This is what we call "mean-conditioned moment calibration", and it is satisfied by the true distributional moments. 


No comments: