A simple method to estimate uncertainty in Machine Learning
When generating predictions about an output, it is sometimes useful to get a confidence score or, similarly, a range of values around this expected value in which the actual value might be found. Practical examples include estimating an upper and lower bound when predicting an ETA or stock price since you not only care about the average outcome but are also very interested in the best-case and worst-case scenarios in when trying to minimize risk e.g. avoid getting late or not loosing money.
While most Machine Learning techniques do not provide a natural way of doing this, in this article, we will be exploring Quantile Regression as a means of doing so. This technique will allow us to learn some critical statistical properties of our data: the quantiles.
# uncomment to install dependencies
# ! curl -Ls https://raw.githubusercontent.com/cgarciae/quantile-regression/master/requirements.txt > requirements.txt
# ! pip install -qr requirements.txt
# ! pip install -U matplotlib
To begin our journey into quantile regression, we will first get a hold on some data:
import numpy as np
import matplotlib.pyplot as plt
import os
plt.rcParams["figure.dpi"] = int(os.environ.get("FIGURE_DPI", 150))
plt.rcParams["figure.facecolor"] = os.environ.get("FIGURE_FACECOLOR", "white")
np.random.seed(69)
def create_data(multimodal: bool):
x = np.random.uniform(0.3, 10, 1000)
y = np.log(x) + np.random.exponential(0.1 + x / 20.0)
if multimodal:
x = np.concatenate([x, np.random.uniform(5, 10, 500)])
y = np.concatenate([y, np.random.normal(6.0, 0.3, 500)])
return x[..., None], y[..., None]
multimodal: bool = False
x, y = create_data(multimodal)
fig = plt.figure()
plt.scatter(x[..., 0], y[..., 0], s=20, facecolors="none", edgecolors="k")
plt.close()
fig
Here we have a simple 2D dataset; however, notice that y
has some very peculiar statistical properties:
When making predictions for this kind of data, we might be very interested in knowing what range of values our data revolves around such that we can judge if a specific outcome is expected or not, what are the best and worst-case scenarios, and so on.
The only thing special about quantile regression is its loss function. Instead of the usual MAE or MSE losses for quantile regression, we use the following function:
Here
First lets notice that this formula can be rewritten as follows:
Using
import jax
import jax.numpy as jnp
def quantile_loss(q, y_true, y_pred):
e = y_true - y_pred
return jnp.maximum(q * e, (q - 1.0) * e)
Now that we have this function let us explore the error landscape for a particular set of predictions. Here we will generate values for y_true
in the range y_pred
could take. Ideally, we want to find the value of y_pred
where the error is the smallest.
def calculate_error(q):
y_true = np.linspace(10, 20, 100)
y_pred = np.linspace(10, 20, 200)
loss = jax.vmap(quantile_loss, in_axes=(None, None, 0))(q, y_true, y_pred)
loss = loss.mean(axis=1)
return y_true, y_pred, loss
q = 0.8
y_true, y_pred, loss = calculate_error(q)
q_true = np.quantile(y_true, q)
fig = plt.figure()
plt.plot(y_pred, loss)
plt.vlines(q_true, 0, loss.max(), linestyles="dashed", colors="k")
plt.gca().set_xlabel("y_pred")
plt.gca().set_ylabel("loss")
plt.title(f"Q({q:.2f}) = {q_true:.1f}")
plt.close()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
fig
If we plot the error, the quantile loss's minimum value is strictly at the value of the 0.5
it penalizes positive errors stronger than negative errors, and the opposite is true for quantiles below 0.5
. In particular, quantile 0.5
is the median, and its formula is equivalent to the MAE.
Generally, we would need to create to create a model per quantile. However, if we use a neural network, we can output the predictions for all the quantiles simultaneously. Here will use elegy
to create a neural network with two hidden layers with relu
activations and linear layers with n_quantiles
output units.
import elegy
class QuantileRegression(elegy.Module):
def __init__(self, n_quantiles: int):
super().__init__()
self.n_quantiles = n_quantiles
def call(self, x):
x = elegy.nn.Linear(128)(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(64)(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(self.n_quantiles)(x)
return x
Now we will adequately define a QuantileLoss
class that is parameterized by
a set of user-defined quantiles
.
class QuantileLoss(elegy.Loss):
def __init__(self, quantiles):
super().__init__()
self.quantiles = np.array(quantiles)
def call(self, y_true, y_pred):
loss = jax.vmap(quantile_loss, in_axes=(0, None, -1), out_axes=1)(
self.quantiles, y_true[:, 0], y_pred
)
return jnp.sum(loss, axis=-1)
Notice that we use the same quantile_loss
that we created previously, along with some jax.vmap
magic to properly vectorize the function. Finally, we will create a simple function that creates and trains our model for a set of quantiles using elegy
.
import optax
def train_model(quantiles, epochs: int, lr: float, eager: bool):
model = elegy.Model(
QuantileRegression(n_quantiles=len(quantiles)),
loss=QuantileLoss(quantiles),
optimizer=optax.adamw(lr),
run_eagerly=eager,
)
model.fit(x, y, epochs=epochs, batch_size=64, verbose=0)
return model
if not multimodal:
quantiles = (0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95)
else:
quantiles = np.linspace(0.05, 0.95, 9)
model = train_model(quantiles=quantiles, epochs=3001, lr=1e-4, eager=False)
model.summary(x)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer ┃ Outputs Shape ┃ Trainable ┃ Non-trainable ┃ ┃ ┃ ┃ Parameters ┃ Parameters ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ Inputs │ (1000, 1) float64 │ │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear Linear │ (1000, 128) float32 │ 256 1.0 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear_1 Linear │ (1000, 64) float32 │ 8,256 33.0 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ linear_2 Linear │ (1000, 7) float32 │ 455 1.8 KB │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ * QuantileRegression │ (1000, 7) float32 │ │ │ ├──────────────────────────────┼──────────────────────┼──────────────────┼───────────────┤ │ │ Total │ 8,967 35.9 KB │ │ └──────────────────────────────┴──────────────────────┴──────────────────┴───────────────┘ Total Parameters: 8,967 35.9 KB
Now that we have a model let us generate some test data that spans the entire domain and compute the predicted quantiles.
x_test = np.linspace(x.min(), x.max(), 100)
y_pred = model.predict(x_test[..., None])
fig = plt.figure()
plt.scatter(x, y, s=20, facecolors="none", edgecolors="k")
for i, q_values in enumerate(np.split(y_pred, len(quantiles), axis=-1)):
plt.plot(x_test, q_values[:, 0], linewidth=2, label=f"Q({quantiles[i]:.2f})")
plt.legend()
plt.close()
fig
Amazing! Notice how the first few quantiles are tightly packed together while the last ones spread out, capturing the behavior of the exponential distribution. We can also visualize the region between the highest and lowest quantiles, and this gives us some bounds on our predictions.
median_idx = np.where(np.isclose(quantiles, 0.5))[0]
fig = plt.figure()
plt.fill_between(x_test, y_pred[:, -1], y_pred[:, 0], alpha=0.5, color="b")
plt.scatter(x, y, s=20, facecolors="none", edgecolors="k")
plt.plot(
x_test,
y_pred[:, median_idx],
color="r",
linestyle="dashed",
label="Q(0.5)",
)
plt.legend()
plt.close()
fig
On the other hand, having multiple quantile values allows us to estimate the density of the data. Since the difference between two adjacent quantiles represent the probability that a point lies between them, we can construct a piecewise function that approximates the density of the data.
def get_pdf(quantiles, q_values):
densities = []
for i in range(len(quantiles) - 1):
area = quantiles[i + 1] - quantiles[i]
b = q_values[i + 1] - q_values[i]
a = area / b
densities.append(a)
return densities
def piecewise(xs):
return [xs[i + j] for i in range(len(xs) - 1) for j in range(2)]
def doubled(xs):
return [np.clip(xs[i], 0, 3) for i in range(len(xs)) for _ in range(2)]
For a given x
, we can compute the quantile values and then use these to compute the conditional piecewise density function of y
given x
.
xi = 7.0
q_values = model.predict(np.array([[xi]]))[0].tolist()
densities = get_pdf(quantiles, q_values)
fig = plt.figure()
plt.title(f"x = {xi}")
plt.fill_between(piecewise(q_values), 0, doubled(densities))
# plt.fill_between(q_values, 0, densities + [0])
# plt.plot(q_values, densities + [0], color="k")
plt.xlim(0, y.max())
plt.gca().set_xlabel("y")
plt.gca().set_ylabel("p(y)")
plt.close()
fig
One of the exciting properties of Quantile Regression is that we did not need to know a priori the output distribution, and training is easy compared to other methods.
multimodal = True
.Many thanks to David Cardozo for his proofreading and getting the notebook to run in colab.