A simple method to estimate uncertainty in Machine Learning
When generating predictions about an output, it is sometimes useful to get a confidence score or, similarly, a range of values around this expected value in which the actual value might be found. Practical examples include estimating an upper and lower bound when predicting an ETA or stock price since you not only care about the average outcome but are also very interested in the best-case and worst-case scenarios in when trying to minimize risk e.g. avoid getting late or not loosing money.
While most Machine Learning techniques do not provide a natural way of doing this, in this article, we will be exploring Quantile Regression as a means of doing so. This technique will allow us to learn some critical statistical properties of our data: the quantiles.
To begin our journey into quantile regression, we will first get a hold on some data:
Here we have a simple 2D dataset; however, notice that y
has some very peculiar statistical properties:
When making predictions for this kind of data, we might be very interested in knowing what range of values our data revolves around such that we can judge if a specific outcome is expected or not, what are the best and worst-case scenarios, and so on.
The only thing special about quantile regression is its loss function. Instead of the usual MAE or MSE losses for quantile regression, we use the following function:
Here is the error term, and is the loss function for the quantile . So what do we mean by this? Concretely it means that will bias to output the value of the 'th quantile instead of the usual mean or median statistic. The big question is: how does it do it?
First lets notice that this formula can be rewritten as follows:
Using instead of a conditional statement will make it more straightforward to implement on tensor/array libraries. We will do this next in jax.
Now that we have this function let us explore the error landscape for a particular set of predictions. Here we will generate values for y_true
in the range , and for a particular value of (0.8 by default), we will compute the total error you would get for each value y_pred
could take. Ideally, we want to find the value of y_pred
where the error is the smallest.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
If we plot the error, the quantile loss's minimum value is strictly at the value of the th quantile. It achieves this because the quantile loss is not symmetrical; for quantiles above 0.5
it penalizes positive errors stronger than negative errors, and the opposite is true for quantiles below 0.5
. In particular, quantile 0.5
is the median, and its formula is equivalent to the MAE.
Generally, we would need to create to create a model per quantile. However, if we use a neural network, we can output the predictions for all the quantiles simultaneously. Here will use elegy
to create a neural network with two hidden layers with relu
activations and linear layers with n_quantiles
output units.
Now we will adequately define a QuantileLoss
class that is parameterized by
a set of user-defined quantiles
.
Notice that we use the same quantile_loss
that we created previously, along with some jax.vmap
magic to properly vectorize the function. Finally, we will create a simple function that creates and trains our model for a set of quantiles using elegy
.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer ┃ Outputs Shape ┃ Trainable ┃ Non-trainable ┃ ┃ ┃ ┃ Parameters ┃ Parameters ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ Inputs │ (1000, 1) float64 │ │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear Linear │ (1000, 128) float32 │ 256 1.0 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear_1 Linear │ (1000, 64) float32 │ 8,256 33.0 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear_2 Linear │ (1000, 7) float32 │ 455 1.8 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ * QuantileRegression │ (1000, 7) float32 │ │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ │ Total │ 8,967 35.9 KB │ │ └──────────────────────────────┴──────────────────────┴──────────────────┴───────────────┘ Total Parameters: 8,967 35.9 KB
Now that we have a model let us generate some test data that spans the entire domain and compute the predicted quantiles.
Amazing! Notice how the first few quantiles are tightly packed together while the last ones spread out, capturing the behavior of the exponential distribution. We can also visualize the region between the highest and lowest quantiles, and this gives us some bounds on our predictions.
On the other hand, having multiple quantile values allows us to estimate the density of the data. Since the difference between two adjacent quantiles represent the probability that a point lies between them, we can construct a piecewise function that approximates the density of the data.
For a given x
, we can compute the quantile values and then use these to compute the conditional piecewise density function of y
given x
.
One of the exciting properties of Quantile Regression is that we did not need to know a priori the output distribution, and training is easy compared to other methods.
multimodal = True
.Many thanks to David Cardozo for his proofreading and getting the notebook to run in colab.