Many people who work in statistics-adjacent occupations - looking at you BI analyst, DS, or ML eng - were visited early on in their careers by Laurence Fishburne offering them two options. Choose the frequentist pill and all your parameters will have fixed quantities, your p-values will give binary answers, and well supported libraries will abound with computationally cheap optimizations. Choose the probabilistic pill, and you’ll see just how much you’ve swept under the “assume everything is IID” rug.
Naturally we almost all picked the blue pill, which has left our braver coworkers with worse documentation and less seamless support for their libraries, including PyMC on Databricks. So today I’m going to offer up some time out of my happy little sklearn Bob Ross painting of a career to help our Bayesian friends get MCMC sampling working on Databricks.
If you try to %pip
install PyMC right now on Databricks Runtime 15.4 LTS, you’ll get an error, but if you %pip install 'miniKanren<1.0.4'
first, then it seems to work fine. Great, shortest blog ever.
Unfortunately, most commands will still not work depending on your compiler. For example, Databricks MLR has its own rabbit hole of compatibility issues with the Numba backend, which is needed to speed up the gradient computations and MCMC bottlenecks by JIT compiling the python into machine code. We can get around this by using a JAX backend, which proves much more reliable and performant in Databricks environments due to its native replacement of NumPy (jax.numpy
), automatic gradients (just wrap jax.grad()
around a function), and built in GPU support. It’s not perfect, but it’s serviceable.
Let’s look at an example. Imagine we own a beautiful forest full of tall trees and woodland creatures. Then we cut it down because trees only have value as timber or toilet paper and we’ve ground all the woodland interlopers into hotdogs. Much better, but we need to repeat this cycle as fast as possible to keep churning out the maximum profitability on our land. So we’ve hired some scientists to experiment with soil amendments that maximize tree growth and we observe the effects of the various treatments.
Our dataset looks something like this:
And we also have control_2, _3, and _4 for months 1, 2, and 3 and we have the same for the organic fertilizer group and the synthetic fertilizer group. So only 36 records in total, after all this is a toy dataset, not the LHC’s particle physics dataset.
Let’s read in that data now with the appropriate imports and then we’ll talk a little more about the model and sampling:
%pip install 'miniKanren<1.0.4'
%pip install pymc
%pip install jax jaxlib blackjax
## Uncomment if you want to try this on GPU
# %pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
%pip install nutpie
%pip install numpyro
%restart_python
We’re already at first potential hang-up; be sure to set PYTENSOR_FLAGS
via environment variables BEFORE any other imports, as attempting to modify pytensor config post-initialization throws exceptions. Threading configuration through OMP_NUM_THREADS
and OPENBLAS_NUM_THREADS
also helps optimize BLAS (Basic Linear Algebra Subprograms) operations later on, so we’ll set those now too.
import os
# Set before any other imports
os.environ['OMP_NUM_THREADS'] = '4' # this cluster is 4 core
os.environ['OPENBLAS_NUM_THREADS'] = '4'
os.environ['PYTENSOR_FLAGS'] = 'mode=JAX,device=cpu,floatX=float32,openmp=True'
import pymc as pm
import numpy as np
import pandas as pd
import jax, jaxlib, blackjax
import nutpie
import pytensor
We can quickly validate that PyMC is working and verify the BLAS are working well by benchmarking them:
import pathlib, pytensor, sys
!{sys.executable} {pathlib.Path(pytensor.__file__).parent / 'misc/check_blas.py'}
Great, I’m getting pretty good results on an r6id.xlarge instance. Let’s also test the pytensor backend configuration:
pytensor.config.mode
This should print ‘JAX’.
We can now read in and view our sample data:
df_to_analyze = spark.sql(f"SELECT * FROM <catalog>.<schema>.<table>")
df_to_analyze = df_to_analyze.toPandas()
display(df_to_analyze)
Great, we’re now able to turn our attention to the actual model definition. I’m assuming you already know what model you want to use, because you’re the expert on your data and your use case, so going to describe the model below only briefly.
# Prepare data
df_to_analyze["treatment_lot_factor"], treat_lot_vals = df_to_analyze.treatment_lot.factorize()
df_to_analyze["treatment_factor"], treatment_vals = df_to_analyze.treatment.factorize()
coords = {
"treat_lots": treat_lot_vals,
"treatments": treatment_vals,
"param": ["alpha", "beta"],
"obs_id": range(len(df_to_analyze))
}
with pm.Model(coords=coords) as simple_hierarchical_model:
# Data containers
month_idx = pm.Data("month_idx", df_to_analyze.month, dims="obs_id")
treat = pm.Data("treat", df_to_analyze.treatment_factor, dims="obs_id")
treat_lot_idx = pm.Data("treat_lot_idx", df_to_analyze.treatment_lot_factor, dims="obs_id")
# Treatment-level priors (population means)
alpha_mu = pm.Normal("alpha_mu", mu=6, sigma=0.5, dims="treatments")
beta_mu = pm.Normal("beta_mu", mu=-0.1, sigma=0.02, dims="treatments")
# LKJ correlation structure for lot-level random effects
sd_lot = pm.Exponential.dist(4)
chol, corr, stds = pm.LKJCholeskyCov("chol_lot", n=2, eta=2, sd_dist=sd_lot)
# Lot-level random effects (centered parameterization)
z = pm.Normal("z", 0.0, 1, dims=("param", "treat_lots"))
lot_effects = pm.Deterministic(
"lot_effects",
pt.dot(chol, z).T,
dims=("treat_lots", "param")
)
# Simple homoscedastic error
sigma = pm.HalfNormal("sigma", sigma=0.25)
# Expected value: population mean + lot-specific deviations
y_hat = (
alpha_mu[treat] + lot_effects[treat_lot_idx, 0] +
(beta_mu[treat] + lot_effects[treat_lot_idx, 1]) * month_idx
)
# Likelihood, tree_growth_cm generated from normal distribution, with means and sigma specified above
tree_growth = pm.Normal(
"tree_growth",
mu=y_hat,
sigma=sigma,
observed=df_to_analyze.tree_size_cm,
dims="obs_id"
)
This hierarchical Bayesian model helps us process more trees and fury freeloaders by analyzing tree growth over time across different treatment groups and treatment lots, capturing the nested structure of individual lots grouped within broader categories. The model estimates population-level growth patterns (intercept and slope over time) for each treatment, while allowing individual lots to deviate from these population means through correlated random effects. Lot-specific intercepts and slopes are modeled as correlated using an LKJ prior on the Cholesky decomposition - meaning lots that start larger than average for their treatment tend to also grow at a different rate. This approach provides uncertainty quantification at multiple levels: we get credible intervals for treatment effects, lot-specific deviations, and can make predictions that properly account for both treatment-level and lot-level variability. We’ll assume homoskedastic variance to simplify the code, but this can be configured to work for heteroskedasticity very easily (remember what I said about the IID rug sweep).
Now we get to the actual sampling code. The way people may default to running this is something like:
# DON'T DO THIS
with simple_hierarchical_model:
trace = pm.sample(2000, tune=1000, chains=4, random_seed=42)
But don’t do this on Databricks! At the very least, Databricks needs us to specify single core for hierarchical models:
# Option 1, takes 3 minutes with JAX pytensor config
with simple_hierarchical_model:
trace = pm.sample(
1000, tune=1000,
cores=1, chains=2, # Sequential chains
return_inferencedata=True,
progressbar=True,
random_seed=42
)
Why single core? While simple PyMC models can natively leverage multicore sampling without issues, complex hierarchical models like those with correlation structures like LKJCholeskyCov encounter JAX multiprocessing deadlocks when using multiple cores on Databricks. However, we can get around this!
# Option 2 - 45 seconds, but progressbar doesn't work ¯\_(ツ)_/¯
with simple_hierarchical_model:
trace = pm.sample(
1000, tune=1000,
nuts_sampler="nutpie",
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "pytensor"},
return_inferencedata=True,
progressbar=True,
random_seed=42
)
Specifying the nuts_sampler
(No U-Turn Sampler) along with its requisite kwargs allows us to cut the duration of this same model’s sampling dramatically thanks to nutpie’s Rust compiled sampling loop and highly optimized automatic differentiation. Unlike the natively implemented NUTS sampler, the nutpie one can run our chains in parallel without the JAX deadlock, resulting in a 4x reduction in duration!
Hopefully that helps, happy coding.