10. Linear regression with multiple predictors¶

When moving from a simple model $y = a + bx + \epsilon$ to a multiple regression model $y = \beta_0 + beta_1 x_1 + beta_2 x_2 + ... + beta_p x_p + \epsilon$, it becomes more complex:

  1. What predictors should be included in the model?
  2. How do we interpret the coefficients $\beta_1, \beta_2, ...,
  3. Interactions between predictors
  4. Construction of new predictors (e.g., polynomial terms, transformations)

10.1 Adding predictors to a model¶

Regression coefficients are more complicated with multiple predictors because they are in part contingent on other variables in the model.

This attempts to say whilst holding other predictors constant, the expected change in the response variable for a one unit change in the predictor of interest.

Starting with a binary predictor¶

Model childrens test scores given an indicator of whether mother graduated from high school or not (0 = no, 1 = yes).

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import arviz as az
import pymc as pm 
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.robust.scale import mad
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/arviz/__init__.py:50: FutureWarning: 
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(
In [2]:
def fit_and_plot_lm(data, predictors, outcome, add_constant=True, show_plot=True, scatter_kws=None, line_kws=None):
    """
    Fit a linear model using statsmodels, print summary, plot, and show formula.
    Args:
        data: pandas DataFrame
        predictors: list of predictor column names (str)
        outcome: outcome column name (str)
        add_constant: whether to add intercept (default True)
        show_plot: whether to plot (default True)
        scatter_kws: dict, kwargs for scatterplot
        line_kws: dict, kwargs for regression line
    """
    X = data[predictors].copy()
    if add_constant:
        X = sm.add_constant(X, prepend=False)
    y = data[outcome]
    model = sm.OLS(y, X)
    results = model.fit()
    print(results.summary())
    # Print formula
    params = results.params
    formula = f"{outcome} = " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    if add_constant:
        formula = f"{outcome} = {params['const']:.2f} + " + " + ".join([f"{params[name]:.2f}*{name}" for name in predictors])
    print("Formula:", formula)
    # Print residual standard deviation and its uncertainty
    sigma = np.sqrt(results.mse_resid)
    sigma_se = sigma / np.sqrt(2 * results.df_resid)
    print(f"Residual std dev (σ): {sigma:.2f} ± {sigma_se:.2f}")
    # Print median absolute deviation of residuals
    print("MAD of residuals:", round(mad(results.resid), 2))
    # Plot if only one predictor
    if show_plot and len(predictors) == 1:
        x_name = predictors[0]
        ax = sns.scatterplot(data=data, x=x_name, y=outcome, **(scatter_kws or {}))
        x_vals = np.linspace(data[x_name].min(), data[x_name].max(), 100)
        y_vals = params['const'] + params[x_name] * x_vals if add_constant else params[x_name] * x_vals
        ax.plot(x_vals, y_vals, color='red', **(line_kws or {}))
        ax.set_title('Linear Regression Fit')
        plt.show()
In [3]:
def fit_and_plot_bayes(data, *args,
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=True,
                       show_posterior=True, show_regression=True,
                       n_regression_lines=100):
    """
    Fit a Bayesian linear regression using PyMC and optionally plot diagnostics.
    Supports single or multiple predictors.
    Args:
        data: pandas DataFrame
        *args: predictor column name(s) followed by the outcome column name (last arg)
        intercept_mu, intercept_sigma: prior mean and std for intercept ~ Normal
        slope_mu, slope_sigma: prior mean and std for slope ~ Normal
        sigma_sigma: prior std for residual noise ~ HalfNormal
        samples: number of posterior draws
        tune: number of tuning steps
        hdi_prob: HDI probability for summaries and plots
        show_trace: plot trace and posterior density per parameter
        show_forest: plot forest (posterior means + HDI)
        show_posterior: plot posterior densities
        show_regression: plot data with posterior regression lines
        n_regression_lines: number of posterior draws to overlay on regression plot
    Returns:
        trace: PyMC InferenceData object
    """
    predictors = list(args[:-1])
    outcome = args[-1]
    y = data[outcome].values

    with pm.Model() as model:
        intercept = pm.Normal("intercept", mu=intercept_mu, sigma=intercept_sigma)
        slopes = []
        mu = intercept
        for pred in predictors:
            s = pm.Normal(f"slope_{pred}", mu=slope_mu, sigma=slope_sigma)
            slopes.append(s)
            mu = mu + s * data[pred].values
        sigma = pm.HalfNormal("sigma", sigma=sigma_sigma)
        likelihood = pm.Normal("y", mu=mu, sigma=sigma, observed=y)
        trace = pm.sample(samples, tune=tune)

    summary = pm.summary(trace, hdi_prob=hdi_prob)
    print(summary)

    # Print regression formula
    posterior = trace.posterior
    intercept_mean = posterior["intercept"].values.flatten().mean()
    formula = f"{outcome} = {intercept_mean:.2f}"
    for pred in predictors:
        slope_mean = posterior[f"slope_{pred}"].values.flatten().mean()
        formula += f" + {slope_mean:.2f}*{pred}"
    print(f"\nRegression formula: {formula}")

    if show_trace:
        az.plot_trace(trace)
        plt.tight_layout()
        plt.show()

    if show_forest:
        az.plot_forest(trace, hdi_prob=hdi_prob)
        plt.show()

    if show_posterior:
        az.plot_posterior(trace, hdi_prob=hdi_prob)
        plt.show()

    if show_regression:
        a_samples = posterior["intercept"].values.flatten()
        slope_samples = {pred: posterior[f"slope_{pred}"].values.flatten() for pred in predictors}
        idx = np.random.choice(len(a_samples), n_regression_lines, replace=False)

        fig, axes = plt.subplots(1, len(predictors), figsize=(6 * len(predictors), 5))
        if len(predictors) == 1:
            axes = [axes]

        for ax, pred in zip(axes, predictors):
            x = data[pred].values
            ax.scatter(x, y, alpha=0.5)
            x_grid = np.linspace(x.min(), x.max(), 100)

            # For each posterior draw, compute the line for this predictor
            # holding other predictors at their mean
            other_contribution = np.zeros(len(a_samples))
            for other_pred in predictors:
                if other_pred != pred:
                    other_contribution += slope_samples[other_pred] * data[other_pred].mean()

            for i in idx:
                y_line = a_samples[i] + other_contribution[i] + slope_samples[pred][i] * x_grid
                ax.plot(x_grid, y_line, alpha=0.05, color="gray")

            # Mean regression line
            mean_other = sum(slope_samples[op].mean() * data[op].mean() for op in predictors if op != pred)
            y_mean = a_samples.mean() + mean_other + slope_samples[pred].mean() * x_grid
            ax.plot(x_grid, y_mean, color="red")
            ax.set_xlabel(pred)
            ax.set_ylabel(outcome)
            ax.set_title(f"{outcome} vs {pred} (others at mean)")

        plt.tight_layout()
        plt.show()

    return trace
In [4]:
kidiq = pd.read_csv('../ros_data/kidiq.csv', skiprows=0)
display(kidiq.head())

fit_and_plot_lm(kidiq, ['mom_hs'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

fit_and_plot_bayes(kidiq, 'mom_hs', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)

fit_and_plot_bayes(kidiq, 'mom_iq', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
kid_score mom_hs mom_iq mom_work mom_age
0 65 1 121.117529 4 27
1 98 1 89.361882 4 25
2 85 1 115.443165 4 27
3 83 1 99.449639 3 25
4 115 1 92.745710 4 27
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.056
Model:                            OLS   Adj. R-squared:                  0.054
Method:                 Least Squares   F-statistic:                     25.69
Date:                Thu, 02 Apr 2026   Prob (F-statistic):           5.96e-07
Time:                        07:54:29   Log-Likelihood:                -1911.8
No. Observations:                 434   AIC:                             3828.
Df Residuals:                     432   BIC:                             3836.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_hs        11.7713      2.322      5.069      0.000       7.207      16.336
const         77.5484      2.059     37.670      0.000      73.502      81.595
==============================================================================
Omnibus:                       11.077   Durbin-Watson:                   1.464
Prob(Omnibus):                  0.004   Jarque-Bera (JB):               11.316
Skew:                          -0.373   Prob(JB):                      0.00349
Kurtosis:                       2.738   Cond. No.                         4.11
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = 77.55 + 11.77*mom_hs
Residual std dev (σ): 19.85 ± 0.68
MAD of residuals: 19.27
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     77.437  2.067    73.272     81.386      0.036    0.025   
slope_mom_hs  11.872  2.323     7.420     16.505      0.039    0.028   
sigma         19.918  0.689    18.633     21.319      0.011    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       3368.0    3999.0    1.0  
slope_mom_hs    3463.0    4342.0    1.0  
sigma           4309.0    4029.0    1.0  

Regression formula: kid_score = 77.44 + 11.87*mom_hs
No description has been provided for this image
No description has been provided for this image
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 2 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     25.447  5.870    13.940     36.734      0.113    0.094   
slope_mom_iq   0.614  0.058     0.501      0.725      0.001    0.001   
sigma         18.311  0.618    17.103     19.495      0.010    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       2690.0    2720.0    1.0  
slope_mom_iq    2699.0    2804.0    1.0  
sigma           4030.0    3707.0    1.0  

Regression formula: kid_score = 25.45 + 0.61*mom_iq
No description has been provided for this image
No description has been provided for this image
Out[4]:
arviz.InferenceData
    • <xarray.Dataset> Size: 208kB
      Dimensions:       (chain: 4, draw: 2000)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          intercept     (chain, draw) float64 64kB 29.68 30.12 31.05 ... 26.0 29.4
          slope_mom_iq  (chain, draw) float64 64kB 0.5599 0.5666 ... 0.6214 0.5693
          sigma         (chain, draw) float64 64kB 18.95 18.95 18.67 ... 17.55 17.28
      Attributes:
          created_at:                 2026-04-02T06:54:34.516781+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              1.9824352264404297
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          29.68 30.12 31.05 ... 26.0 29.4
          array([[29.67530362, 30.11792695, 31.05120458, ..., 35.2925018 ,
                  35.66797408, 35.66339677],
                 [27.51237146, 28.5607453 , 26.6491285 , ..., 23.78678544,
                  24.93098182, 20.3563031 ],
                 [34.43289888, 28.14548513, 21.0547988 , ..., 22.04647521,
                  23.1621027 , 24.34784931],
                 [17.63426   , 33.28013574, 34.94318507, ..., 23.69595448,
                  26.00331995, 29.39650265]], shape=(4, 2000))
        • slope_mom_iq
          (chain, draw)
          float64
          0.5599 0.5666 ... 0.6214 0.5693
          array([[0.55992915, 0.566583  , 0.5545006 , ..., 0.50824826, 0.51650715,
                  0.51019273],
                 [0.59351372, 0.58770175, 0.60070348, ..., 0.6178625 , 0.61082498,
                  0.67062168],
                 [0.52960389, 0.57823041, 0.6496279 , ..., 0.6532362 , 0.63136065,
                  0.62806065],
                 [0.69398433, 0.53299806, 0.52443291, ..., 0.62613978, 0.62139532,
                  0.56927683]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          18.95 18.95 18.67 ... 17.55 17.28
          array([[18.95406443, 18.94964066, 18.66678006, ..., 17.83043885,
                  17.78857344, 18.84821789],
                 [17.68523852, 17.50376174, 17.86583035, ..., 19.52524884,
                  19.51001698, 18.5225863 ],
                 [18.12649144, 18.18917647, 18.59395755, ..., 18.09868114,
                  18.48597758, 18.44244899],
                 [18.01251505, 17.17134888, 17.06714302, ..., 18.26132919,
                  17.54524429, 17.28100989]], shape=(4, 2000))
      • created_at :
        2026-04-02T06:54:34.516781+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        1.9824352264404297
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          tree_depth             (chain, draw) int64 64kB 5 1 5 4 4 5 ... 5 5 4 5 4 4
          energy                 (chain, draw) float64 64kB 1.89e+03 ... 1.89e+03
          perf_counter_diff      (chain, draw) float64 64kB 0.0007811 ... 0.0003747
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          energy_error           (chain, draw) float64 64kB 0.3427 -0.4307 ... -0.7372
          lp                     (chain, draw) float64 64kB -1.888e+03 ... -1.888e+03
          ...                     ...
          process_time_diff      (chain, draw) float64 64kB 0.000781 ... 0.000375
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          step_size              (chain, draw) float64 64kB 0.1494 0.1494 ... 0.1421
          diverging              (chain, draw) bool 8kB False False ... False False
          acceptance_rate        (chain, draw) float64 64kB 0.7402 1.0 ... 0.8973
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
      Attributes:
          created_at:                 2026-04-02T06:54:34.527742+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              1.9824352264404297
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • tree_depth
          (chain, draw)
          int64
          5 1 5 4 4 5 4 3 ... 5 4 5 5 4 5 4 4
          array([[5, 1, 5, ..., 4, 2, 5],
                 [5, 3, 5, ..., 2, 2, 5],
                 [4, 4, 5, ..., 5, 4, 4],
                 [5, 5, 3, ..., 5, 4, 4]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          1.89e+03 1.888e+03 ... 1.89e+03
          array([[1890.18385197, 1888.05714758, 1888.58599139, ..., 1889.72080115,
                  1889.09126634, 1889.94507283],
                 [1888.58712757, 1887.93457449, 1889.06651779, ..., 1891.6237337 ,
                  1892.5305134 , 1890.35481546],
                 [1888.40918417, 1888.48530755, 1888.21103138, ..., 1888.21885592,
                  1887.44485223, 1887.95703469],
                 [1889.64361494, 1890.80777391, 1890.60594953, ..., 1887.53598277,
                  1890.55889609, 1889.59444349]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0007811 5.138e-05 ... 0.0003747
          array([[7.81124982e-04, 5.13750128e-05, 7.61250005e-04, ...,
                  3.91792011e-04, 1.01333018e-04, 7.73167005e-04],
                 [7.71000021e-04, 1.94207998e-04, 7.75958004e-04, ...,
                  1.02584017e-04, 9.91249981e-05, 7.98749999e-04],
                 [3.99125012e-04, 3.98875010e-04, 7.90749997e-04, ...,
                  8.93208024e-04, 4.17624979e-04, 3.13083001e-04],
                 [7.70292012e-04, 7.70500017e-04, 1.92958018e-04, ...,
                  4.60457988e-04, 3.61457991e-04, 3.74666997e-04]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • energy_error
          (chain, draw)
          float64
          0.3427 -0.4307 ... 0.5532 -0.7372
          array([[ 3.42742846e-01, -4.30658268e-01,  1.31410246e-02, ...,
                   1.39384174e-01, -1.79101443e-01, -3.10835933e-02],
                 [-9.49662193e-02,  7.56893720e-02, -7.67649652e-02, ...,
                   2.24797662e-01, -1.64960558e-01, -2.52900005e-02],
                 [-1.72819257e-02,  1.53974330e-01, -5.42478407e-02, ...,
                   4.32643234e-04, -4.75414699e-02, -2.75812653e-02],
                 [-2.92660864e-02,  6.18881776e-03,  8.48885895e-02, ...,
                  -2.34773710e-02,  5.53226896e-01, -7.37190249e-01]],
                shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -1.888e+03 ... -1.888e+03
          array([[-1888.47561537, -1887.61962639, -1887.47441278, ...,
                  -1888.95046393, -1888.63613707, -1888.65136295],
                 [-1887.19294299, -1887.77532512, -1886.92119003, ...,
                  -1889.45913552, -1888.91682287, -1887.51995844],
                 [-1888.01798451, -1887.33187999, -1887.41873137, ...,
                  -1887.19108914, -1886.97729114, -1886.86880664],
                 [-1887.80004127, -1889.45768497, -1890.30644028, ...,
                  -1886.87643634, -1888.67812829, -1888.49311767]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          13 1 4 -3 4 4 ... -20 11 2 -3 -5 -6
          array([[ 13,   1,   4, ...,  -7,  -2,  -7],
                 [ 15,  -4, -30, ...,   2,  -2, -13],
                 [ -8, -13, -11, ...,   4,  14,  -2],
                 [ 19,  28,   2, ...,  -3,  -5,  -6]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          1.538e+05 1.538e+05 ... 1.538e+05
          array([[153794.13920317, 153794.14002337, 153794.1401115 , ...,
                  153795.27070246, 153795.27113083, 153795.27126617],
                 [153794.13342196, 153794.13423096, 153794.1344685 , ...,
                  153795.27502492, 153795.27516638, 153795.27529971],
                 [153794.07875842, 153794.07920217, 153794.0796385 , ...,
                  153795.18446375, 153795.18542167, 153795.18588421],
                 [153794.13730667, 153794.13811579, 153794.13892225, ...,
                  153795.30124163, 153795.30173571, 153795.30213021]],
                shape=(4, 2000))
        • step_size_bar
          (chain, draw)
          float64
          0.1695 0.1695 ... 0.1393 0.1393
          array([[0.16950244, 0.16950244, 0.16950244, ..., 0.16950244, 0.16950244,
                  0.16950244],
                 [0.16269063, 0.16269063, 0.16269063, ..., 0.16269063, 0.16269063,
                  0.16269063],
                 [0.15556697, 0.15556697, 0.15556697, ..., 0.15556697, 0.15556697,
                  0.15556697],
                 [0.13931307, 0.13931307, 0.13931307, ..., 0.13931307, 0.13931307,
                  0.13931307]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          0.7724 -0.4307 ... 2.354 -0.7563
          array([[ 0.77242911, -0.43065827,  0.8618348 , ...,  0.25060918,
                   0.34364908,  0.20926229],
                 [ 0.09951233,  0.1449465 , -0.07676497, ...,  1.26533649,
                   1.42262485, -0.14068214],
                 [-0.08472532,  0.15412474,  0.28706018, ..., -0.06873776,
                   0.14803108,  0.76682325],
                 [ 0.70476421, -0.04070219,  0.08488859, ...,  0.2636057 ,
                   2.35442167, -0.75627505]], shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          31.0 1.0 31.0 ... 19.0 15.0 15.0
          array([[31.,  1., 31., ..., 15.,  3., 31.],
                 [31.,  7., 31., ...,  3.,  3., 31.],
                 [15., 15., 31., ..., 31., 15., 11.],
                 [31., 31.,  7., ..., 19., 15., 15.]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000781 5.1e-05 ... 0.000375
          array([[7.81e-04, 5.10e-05, 7.62e-04, ..., 3.92e-04, 1.01e-04, 7.73e-04],
                 [7.71e-04, 1.95e-04, 7.77e-04, ..., 1.04e-04, 9.80e-05, 7.98e-04],
                 [3.99e-04, 4.00e-04, 7.91e-04, ..., 8.74e-04, 4.18e-04, 3.10e-04],
                 [7.70e-04, 7.70e-04, 1.93e-04, ..., 4.60e-04, 3.61e-04, 3.75e-04]],
                shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.1494 0.1494 ... 0.1421 0.1421
          array([[0.14938558, 0.14938558, 0.14938558, ..., 0.14938558, 0.14938558,
                  0.14938558],
                 [0.18512052, 0.18512052, 0.18512052, ..., 0.18512052, 0.18512052,
                  0.18512052],
                 [0.17080272, 0.17080272, 0.17080272, ..., 0.17080272, 0.17080272,
                  0.17080272],
                 [0.14209196, 0.14209196, 0.14209196, ..., 0.14209196, 0.14209196,
                  0.14209196]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.7402 1.0 0.711 ... 0.3757 0.8973
          array([[0.74024694, 1.        , 0.71097429, ..., 0.91146299, 0.84016103,
                  0.92680183],
                 [0.97232202, 0.92599108, 1.        , ..., 0.54659911, 0.59121159,
                  0.99970575],
                 [0.99970147, 0.93317593, 0.91019138, ..., 0.99524194, 0.94050033,
                  0.70632565],
                 [0.75872368, 0.99608733, 0.96216481, ..., 0.90586743, 0.37568737,
                  0.89730794]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
      • created_at :
        2026-04-02T06:54:34.527742+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        1.9824352264404297
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 7kB
      Dimensions:  (y_dim_0: 434)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 3kB 0 1 2 3 4 5 6 7 ... 427 428 429 430 431 432 433
      Data variables:
          y        (y_dim_0) float64 3kB 65.0 98.0 85.0 83.0 ... 76.0 50.0 88.0 70.0
      Attributes:
          created_at:                 2026-04-02T06:54:34.532435+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 434
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 ... 429 430 431 432 433
          array([  0,   1,   2, ..., 431, 432, 433], shape=(434,))
        • y
          (y_dim_0)
          float64
          65.0 98.0 85.0 ... 50.0 88.0 70.0
          array([ 65.,  98.,  85.,  83., 115.,  98.,  69., 106., 102.,  95.,  91.,
                  58.,  84.,  78., 102., 110., 102.,  99., 105., 101., 102., 115.,
                 100.,  87.,  99.,  96.,  72.,  78.,  77.,  98.,  69., 130., 109.,
                 106.,  92., 100., 107.,  86.,  90., 110., 107., 113.,  65., 102.,
                 103., 111.,  42., 100.,  67.,  92., 100., 110.,  56., 107.,  97.,
                  56.,  95.,  78.,  76.,  86.,  79.,  81.,  79.,  79.,  56.,  52.,
                  63.,  80.,  87.,  88.,  92., 100.,  94., 117., 102., 107.,  99.,
                  73.,  56.,  78.,  94., 110., 109.,  86.,  92.,  91., 123., 102.,
                 105., 114.,  96.,  66., 104., 108.,  84.,  83.,  83.,  92., 109.,
                  95.,  93., 114., 106.,  87.,  65.,  95.,  61.,  73., 112., 113.,
                  49., 105., 122.,  96.,  97.,  94., 117., 136.,  85., 116., 106.,
                  99.,  94.,  89., 119., 112., 104.,  92.,  86.,  69.,  45.,  57.,
                  94., 104.,  89., 144.,  52., 102., 106.,  98.,  97.,  94., 111.,
                 100., 105.,  90.,  98., 121., 106., 121., 102.,  64.,  99.,  81.,
                  69.,  84., 104., 104., 107.,  88.,  67., 103.,  94., 109.,  94.,
                  98., 102., 104., 114.,  87., 102.,  77., 109.,  94.,  93.,  86.,
                  97.,  97.,  88., 103.,  87.,  87.,  90.,  65., 111., 109.,  87.,
                  58.,  87., 113.,  64.,  78.,  97.,  95.,  75.,  91.,  99., 108.,
                  95., 100.,  85.,  97., 108.,  90., 100.,  82.,  94.,  95., 119.,
                  98., 100., 112., 136., 122., 126., 116.,  98.,  94.,  93.,  90.,
                  70., 110., 104.,  83.,  99.,  81., 104., 109., 113.,  95.,  74.,
                  81.,  89.,  93., 102.,  95.,  85.,  97.,  92.,  78., 104., 120.,
                  83., 105.,  68., 104.,  80., 120.,  94.,  81., 101.,  61.,  68.,
                 110.,  89.,  98., 113.,  50.,  57.,  86.,  83., 106., 106., 104.,
                  78.,  99.,  91.,  40.,  42.,  69.,  84.,  58.,  42.,  72.,  80.,
                  58.,  52., 101.,  63.,  73.,  68.,  60.,  69.,  73.,  75.,  20.,
                  56.,  49.,  71.,  46.,  54.,  54.,  44.,  74.,  58.,  46.,  76.,
                  43.,  60.,  58.,  89.,  43.,  94.,  88.,  79.,  87.,  46.,  95.,
                  92.,  42.,  62.,  52., 101.,  97.,  85.,  98.,  94.,  90.,  72.,
                  92.,  75.,  83.,  64., 101.,  82.,  77., 101.,  50.,  90., 103.,
                  96.,  50.,  47.,  73.,  62.,  77.,  64.,  52.,  61.,  86.,  41.,
                  83.,  64.,  83., 116., 100.,  42.,  74.,  76.,  92.,  98.,  96.,
                  67.,  84., 111.,  41.,  68., 107.,  82.,  89.,  83.,  73.,  74.,
                  94.,  58.,  76.,  61.,  38., 100.,  84.,  99.,  86.,  94.,  90.,
                  50., 112.,  58.,  87.,  76.,  68., 110.,  88.,  87.,  54.,  49.,
                  56.,  79.,  82.,  80.,  60., 102.,  87.,  42., 119.,  84.,  86.,
                 113.,  72., 104.,  94.,  78.,  80.,  67., 104.,  96.,  65.,  64.,
                  95.,  56.,  75.,  91., 106.,  76.,  90., 108.,  86.,  85., 104.,
                  87.,  41., 106.,  76., 100.,  89.,  42., 102., 104.,  59.,  93.,
                  94.,  76.,  50.,  88.,  70.])
      • created_at :
        2026-04-02T06:54:34.532435+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

kid_score = 77.55 + 11.77*mom_hs

Intercept 78 is score for children whose mothers did not graduate from high school. The slope of 11.77 means that children whose mothers graduated from high school are expected to score 11.77 points higher than those whose mothers did not graduate, holding all else constant.

A single continuous predictor¶

kid_score = 25.31 + 0.62*mom_iq

Intercept 25.31 is score for children whose mothers have an IQ of 0. The slope of 0.62 means that for each unit increase in mother's IQ, the child's score is expected to increase by 0.62 points, holding all else constant.

Include both predictors in the model¶

In [5]:
fit_and_plot_bayes(kidiq, 'mom_hs', 'mom_iq', 'kid_score',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_mom_hs, slope_mom_iq, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.
                mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     25.477  5.942    12.910     36.462      0.096    0.080   
slope_mom_hs   5.957  2.242     1.493     10.293      0.031    0.029   
slope_mom_iq   0.566  0.061     0.447      0.690      0.001    0.001   
sigma         18.181  0.611    17.039     19.407      0.009    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       3856.0    3739.0    1.0  
slope_mom_hs    5325.0    3939.0    1.0  
slope_mom_iq    3773.0    3815.0    1.0  
sigma           4799.0    4201.0    1.0  

Regression formula: kid_score = 25.48 + 5.96*mom_hs + 0.57*mom_iq
No description has been provided for this image
No description has been provided for this image
Out[5]:
arviz.InferenceData
    • <xarray.Dataset> Size: 272kB
      Dimensions:       (chain: 4, draw: 2000)
      Coordinates:
        * chain         (chain) int64 32B 0 1 2 3
        * draw          (draw) int64 16kB 0 1 2 3 4 5 ... 1995 1996 1997 1998 1999
      Data variables:
          intercept     (chain, draw) float64 64kB 26.93 20.18 18.18 ... 32.16 18.43
          slope_mom_hs  (chain, draw) float64 64kB 6.317 13.28 12.63 ... 5.579 7.524
          slope_mom_iq  (chain, draw) float64 64kB 0.5387 0.5796 ... 0.5068 0.6261
          sigma         (chain, draw) float64 64kB 17.88 18.41 18.25 ... 17.84 17.9
      Attributes:
          created_at:                 2026-04-02T06:54:38.077635+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.714233636856079
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • intercept
          (chain, draw)
          float64
          26.93 20.18 18.18 ... 32.16 18.43
          array([[26.93278457, 20.18369871, 18.18070481, ..., 18.46326035,
                  32.97338251, 16.31168974],
                 [19.17817927, 26.71898122, 28.19351202, ..., 23.93072487,
                  14.75473931, 22.99988998],
                 [30.89573177, 26.28302684, 27.15098064, ..., 19.08481006,
                  19.04914107, 25.60797288],
                 [24.5983557 , 24.90417698, 24.19007836, ..., 23.93315643,
                  32.1623599 , 18.4325289 ]], shape=(4, 2000))
        • slope_mom_hs
          (chain, draw)
          float64
          6.317 13.28 12.63 ... 5.579 7.524
          array([[ 6.31679527, 13.28043083, 12.6284207 , ..., 12.83335851,
                   4.25127099,  3.417648  ],
                 [ 9.80105514,  7.3138477 ,  6.44044438, ...,  3.38496655,
                   5.75336018,  6.30545468],
                 [ 6.5514141 ,  6.1180437 ,  6.78457788, ...,  5.17947777,
                   6.05923646,  4.29659279],
                 [10.82597005, 11.06915593,  9.97378338, ...,  6.10327214,
                   5.57910576,  7.52386086]], shape=(4, 2000))
        • slope_mom_iq
          (chain, draw)
          float64
          0.5387 0.5796 ... 0.5068 0.6261
          array([[0.5387238 , 0.57962849, 0.5692898 , ..., 0.58014256, 0.51086164,
                  0.68507832],
                 [0.59107671, 0.55073246, 0.54795791, ..., 0.62552649, 0.66982699,
                  0.59439775],
                 [0.51436707, 0.55505116, 0.54192108, ..., 0.62673597, 0.64423262,
                  0.56483565],
                 [0.54025086, 0.52638207, 0.55146205, ..., 0.56768668, 0.50679571,
                  0.62608425]], shape=(4, 2000))
        • sigma
          (chain, draw)
          float64
          17.88 18.41 18.25 ... 17.84 17.9
          array([[17.8796105 , 18.40546308, 18.24548094, ..., 18.04932901,
                  18.98739939, 17.2599867 ],
                 [17.9437347 , 17.92628333, 17.67062512, ..., 19.45982863,
                  17.72606804, 17.65691546],
                 [19.30497624, 17.04562482, 19.3857538 , ..., 17.33220609,
                  17.09238536, 19.94446946],
                 [18.39623603, 18.67066605, 18.57320929, ..., 17.95756191,
                  17.83946541, 17.89836245]], shape=(4, 2000))
      • created_at :
        2026-04-02T06:54:38.077635+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.714233636856079
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 1MB
      Dimensions:                (chain: 4, draw: 2000)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 16kB 0 1 2 3 4 ... 1996 1997 1998 1999
      Data variables: (12/18)
          tree_depth             (chain, draw) int64 64kB 4 5 2 4 5 5 ... 5 5 3 4 5 5
          energy                 (chain, draw) float64 64kB 1.89e+03 ... 1.89e+03
          perf_counter_diff      (chain, draw) float64 64kB 0.0004297 ... 0.0009698
          divergences            (chain, draw) int64 64kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          energy_error           (chain, draw) float64 64kB -0.1328 0.8701 ... -0.0208
          lp                     (chain, draw) float64 64kB -1.889e+03 ... -1.889e+03
          ...                     ...
          process_time_diff      (chain, draw) float64 64kB 0.000429 ... 0.000941
          reached_max_treedepth  (chain, draw) bool 8kB False False ... False False
          step_size              (chain, draw) float64 64kB 0.1631 0.1631 ... 0.2034
          diverging              (chain, draw) bool 8kB False False ... False False
          acceptance_rate        (chain, draw) float64 64kB 0.9573 0.7371 ... 0.9924
          smallest_eigval        (chain, draw) float64 64kB nan nan nan ... nan nan
      Attributes:
          created_at:                 2026-04-02T06:54:38.090470+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
          sampling_time:              2.714233636856079
          tuning_steps:               1000
      xarray.Dataset
        • chain: 4
        • draw: 2000
        • chain
          (chain)
          int64
          0 1 2 3
          array([0, 1, 2, 3])
        • draw
          (draw)
          int64
          0 1 2 3 4 ... 1996 1997 1998 1999
          array([   0,    1,    2, ..., 1997, 1998, 1999], shape=(2000,))
        • tree_depth
          (chain, draw)
          int64
          4 5 2 4 5 5 4 5 ... 4 4 5 5 3 4 5 5
          array([[4, 5, 2, ..., 5, 4, 5],
                 [5, 4, 5, ..., 5, 5, 5],
                 [3, 5, 5, ..., 5, 4, 5],
                 [4, 5, 4, ..., 4, 5, 5]], shape=(4, 2000))
        • energy
          (chain, draw)
          float64
          1.89e+03 1.898e+03 ... 1.89e+03
          array([[1889.94079937, 1898.27614958, 1898.55335503, ..., 1894.65610319,
                  1896.19537114, 1894.79292712],
                 [1891.84819245, 1891.55921874, 1889.5573314 , ..., 1895.46297886,
                  1894.98670285, 1890.51460621],
                 [1891.44398683, 1891.72034265, 1891.1357217 , ..., 1898.05594817,
                  1893.80416296, 1893.79404174],
                 [1890.75262185, 1893.37576163, 1891.31799667, ..., 1892.31006191,
                  1891.54560776, 1890.36518649]], shape=(4, 2000))
        • perf_counter_diff
          (chain, draw)
          float64
          0.0004297 0.0006474 ... 0.0009698
          array([[0.00042975, 0.00064738, 0.00011242, ..., 0.00087492, 0.00044042,
                  0.00087163],
                 [0.00083233, 0.00042454, 0.00083167, ..., 0.00085633, 0.00085525,
                  0.00086312],
                 [0.000223  , 0.000993  , 0.00087263, ..., 0.00086192, 0.00042213,
                  0.00083142],
                 [0.00044304, 0.00086513, 0.00045383, ..., 0.00040733, 0.00081937,
                  0.00096979]], shape=(4, 2000))
        • divergences
          (chain, draw)
          int64
          0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0
          array([[0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0],
                 [0, 0, 0, ..., 0, 0, 0]], shape=(4, 2000))
        • energy_error
          (chain, draw)
          float64
          -0.1328 0.8701 ... -0.563 -0.0208
          array([[-0.13280684,  0.87009401, -0.3839084 , ...,  0.05496261,
                   0.05364032,  0.12675196],
                 [-0.29328212,  0.01634378,  0.35242726, ...,  1.64753006,
                  -1.90611607,  0.08354871],
                 [-0.02466938, -0.13879677,  0.08554437, ..., -1.21674237,
                   0.71225334, -0.38418552],
                 [-0.01359497,  0.05645521, -0.02075026, ...,  0.52766973,
                  -0.56303387, -0.02079985]], shape=(4, 2000))
        • lp
          (chain, draw)
          float64
          -1.889e+03 ... -1.889e+03
          array([[-1888.81789848, -1896.16563872, -1894.93661925, ...,
                  -1893.54808268, -1890.07797048, -1891.61522768],
                 [-1890.40000634, -1888.55263776, -1889.3452942 , ...,
                  -1894.07102398, -1889.92297761, -1888.60337765],
                 [-1890.31357857, -1889.58240241, -1890.03270531, ...,
                  -1889.96202139, -1891.9586151 , -1893.01012234],
                 [-1890.52102817, -1891.08749073, -1889.94309706, ...,
                  -1889.08499741, -1888.78490719, -1889.04213572]], shape=(4, 2000))
        • index_in_trajectory
          (chain, draw)
          int64
          -1 -7 -2 11 -7 ... -6 -3 -10 -5 -15
          array([[ -1,  -7,  -2, ...,  10, -10,  12],
                 [  8,  -6,  -7, ...,  20, -14,  10],
                 [ -2,  24,  24, ..., -14,  -2, -17],
                 [  8,  -2,   5, ..., -10,  -5, -15]], shape=(4, 2000))
        • perf_counter_start
          (chain, draw)
          float64
          1.538e+05 1.538e+05 ... 1.538e+05
          array([[153797.31025675, 153797.31072929, 153797.31141775, ...,
                  153798.75212842, 153798.75304537, 153798.75352387],
                 [153797.36425446, 153797.36512758, 153797.36559083, ...,
                  153798.84464629, 153798.8455435 , 153798.84643729],
                 [153797.34760433, 153797.34790629, 153797.34895117, ...,
                  153798.86020213, 153798.86111154, 153798.86157525],
                 [153797.39621975, 153797.39670408, 153797.39760767, ...,
                  153798.82802371, 153798.82846663, 153798.82932163]],
                shape=(4, 2000))
        • step_size_bar
          (chain, draw)
          float64
          0.1529 0.1529 ... 0.1452 0.1452
          array([[0.15292657, 0.15292657, 0.15292657, ..., 0.15292657, 0.15292657,
                  0.15292657],
                 [0.13786508, 0.13786508, 0.13786508, ..., 0.13786508, 0.13786508,
                  0.13786508],
                 [0.14301343, 0.14301343, 0.14301343, ..., 0.14301343, 0.14301343,
                  0.14301343],
                 [0.14515776, 0.14515776, 0.14515776, ..., 0.14515776, 0.14515776,
                  0.14515776]], shape=(4, 2000))
        • largest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
        • max_energy_error
          (chain, draw)
          float64
          -0.3527 0.8701 ... -0.6078 -0.04628
          array([[-0.35265746,  0.87009401,  1.02585247, ...,  0.51793491,
                   0.11995729,  0.21427745],
                 [ 0.43872169, -0.1656617 ,  0.35242726, ...,  2.04971051,
                  -1.90611607,  0.08619967],
                 [ 0.38041373, -0.15141595,  0.4860857 , ...,  2.5844526 ,
                   1.1524983 , -1.08931589],
                 [-0.07754396,  1.27460437, -0.04255349, ...,  0.78580698,
                  -0.60779376, -0.0462807 ]], shape=(4, 2000))
        • n_steps
          (chain, draw)
          float64
          15.0 23.0 3.0 ... 15.0 31.0 31.0
          array([[15., 23.,  3., ..., 31., 15., 31.],
                 [31., 15., 31., ..., 31., 31., 31.],
                 [ 7., 31., 31., ..., 31., 15., 31.],
                 [15., 31., 15., ..., 15., 31., 31.]], shape=(4, 2000))
        • process_time_diff
          (chain, draw)
          float64
          0.000429 0.000647 ... 0.000941
          array([[0.000429, 0.000647, 0.000112, ..., 0.000875, 0.00044 , 0.000873],
                 [0.000832, 0.000424, 0.000831, ..., 0.000856, 0.000855, 0.000863],
                 [0.000222, 0.000958, 0.000872, ..., 0.000862, 0.000423, 0.000831],
                 [0.000443, 0.000865, 0.000454, ..., 0.000407, 0.000819, 0.000941]],
                shape=(4, 2000))
        • reached_max_treedepth
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • step_size
          (chain, draw)
          float64
          0.1631 0.1631 ... 0.2034 0.2034
          array([[0.16313402, 0.16313402, 0.16313402, ..., 0.16313402, 0.16313402,
                  0.16313402],
                 [0.15979053, 0.15979053, 0.15979053, ..., 0.15979053, 0.15979053,
                  0.15979053],
                 [0.14048484, 0.14048484, 0.14048484, ..., 0.14048484, 0.14048484,
                  0.14048484],
                 [0.20341195, 0.20341195, 0.20341195, ..., 0.20341195, 0.20341195,
                  0.20341195]], shape=(4, 2000))
        • diverging
          (chain, draw)
          bool
          False False False ... False False
          array([[False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False],
                 [False, False, False, ..., False, False, False]], shape=(4, 2000))
        • acceptance_rate
          (chain, draw)
          float64
          0.9573 0.7371 ... 0.9981 0.9924
          array([[0.9572621 , 0.73707609, 0.63741458, ..., 0.77828478, 0.95282079,
                  0.92735951],
                 [0.87386175, 0.99627602, 0.88520222, ..., 0.46312264, 0.70431541,
                  0.96602283],
                 [0.82333528, 0.99518805, 0.80230766, ..., 0.65116058, 0.67562921,
                  0.99483196],
                 [0.99702955, 0.65584201, 1.        , ..., 0.71137465, 0.99813834,
                  0.9923535 ]], shape=(4, 2000))
        • smallest_eigval
          (chain, draw)
          float64
          nan nan nan nan ... nan nan nan nan
          array([[nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan],
                 [nan, nan, nan, ..., nan, nan, nan]], shape=(4, 2000))
      • created_at :
        2026-04-02T06:54:38.090470+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2
        sampling_time :
        2.714233636856079
        tuning_steps :
        1000

    • <xarray.Dataset> Size: 7kB
      Dimensions:  (y_dim_0: 434)
      Coordinates:
        * y_dim_0  (y_dim_0) int64 3kB 0 1 2 3 4 5 6 7 ... 427 428 429 430 431 432 433
      Data variables:
          y        (y_dim_0) float64 3kB 65.0 98.0 85.0 83.0 ... 76.0 50.0 88.0 70.0
      Attributes:
          created_at:                 2026-04-02T06:54:38.093357+00:00
          arviz_version:              0.23.4
          inference_library:          pymc
          inference_library_version:  5.28.2
      xarray.Dataset
        • y_dim_0: 434
        • y_dim_0
          (y_dim_0)
          int64
          0 1 2 3 4 5 ... 429 430 431 432 433
          array([  0,   1,   2, ..., 431, 432, 433], shape=(434,))
        • y
          (y_dim_0)
          float64
          65.0 98.0 85.0 ... 50.0 88.0 70.0
          array([ 65.,  98.,  85.,  83., 115.,  98.,  69., 106., 102.,  95.,  91.,
                  58.,  84.,  78., 102., 110., 102.,  99., 105., 101., 102., 115.,
                 100.,  87.,  99.,  96.,  72.,  78.,  77.,  98.,  69., 130., 109.,
                 106.,  92., 100., 107.,  86.,  90., 110., 107., 113.,  65., 102.,
                 103., 111.,  42., 100.,  67.,  92., 100., 110.,  56., 107.,  97.,
                  56.,  95.,  78.,  76.,  86.,  79.,  81.,  79.,  79.,  56.,  52.,
                  63.,  80.,  87.,  88.,  92., 100.,  94., 117., 102., 107.,  99.,
                  73.,  56.,  78.,  94., 110., 109.,  86.,  92.,  91., 123., 102.,
                 105., 114.,  96.,  66., 104., 108.,  84.,  83.,  83.,  92., 109.,
                  95.,  93., 114., 106.,  87.,  65.,  95.,  61.,  73., 112., 113.,
                  49., 105., 122.,  96.,  97.,  94., 117., 136.,  85., 116., 106.,
                  99.,  94.,  89., 119., 112., 104.,  92.,  86.,  69.,  45.,  57.,
                  94., 104.,  89., 144.,  52., 102., 106.,  98.,  97.,  94., 111.,
                 100., 105.,  90.,  98., 121., 106., 121., 102.,  64.,  99.,  81.,
                  69.,  84., 104., 104., 107.,  88.,  67., 103.,  94., 109.,  94.,
                  98., 102., 104., 114.,  87., 102.,  77., 109.,  94.,  93.,  86.,
                  97.,  97.,  88., 103.,  87.,  87.,  90.,  65., 111., 109.,  87.,
                  58.,  87., 113.,  64.,  78.,  97.,  95.,  75.,  91.,  99., 108.,
                  95., 100.,  85.,  97., 108.,  90., 100.,  82.,  94.,  95., 119.,
                  98., 100., 112., 136., 122., 126., 116.,  98.,  94.,  93.,  90.,
                  70., 110., 104.,  83.,  99.,  81., 104., 109., 113.,  95.,  74.,
                  81.,  89.,  93., 102.,  95.,  85.,  97.,  92.,  78., 104., 120.,
                  83., 105.,  68., 104.,  80., 120.,  94.,  81., 101.,  61.,  68.,
                 110.,  89.,  98., 113.,  50.,  57.,  86.,  83., 106., 106., 104.,
                  78.,  99.,  91.,  40.,  42.,  69.,  84.,  58.,  42.,  72.,  80.,
                  58.,  52., 101.,  63.,  73.,  68.,  60.,  69.,  73.,  75.,  20.,
                  56.,  49.,  71.,  46.,  54.,  54.,  44.,  74.,  58.,  46.,  76.,
                  43.,  60.,  58.,  89.,  43.,  94.,  88.,  79.,  87.,  46.,  95.,
                  92.,  42.,  62.,  52., 101.,  97.,  85.,  98.,  94.,  90.,  72.,
                  92.,  75.,  83.,  64., 101.,  82.,  77., 101.,  50.,  90., 103.,
                  96.,  50.,  47.,  73.,  62.,  77.,  64.,  52.,  61.,  86.,  41.,
                  83.,  64.,  83., 116., 100.,  42.,  74.,  76.,  92.,  98.,  96.,
                  67.,  84., 111.,  41.,  68., 107.,  82.,  89.,  83.,  73.,  74.,
                  94.,  58.,  76.,  61.,  38., 100.,  84.,  99.,  86.,  94.,  90.,
                  50., 112.,  58.,  87.,  76.,  68., 110.,  88.,  87.,  54.,  49.,
                  56.,  79.,  82.,  80.,  60., 102.,  87.,  42., 119.,  84.,  86.,
                 113.,  72., 104.,  94.,  78.,  80.,  67., 104.,  96.,  65.,  64.,
                  95.,  56.,  75.,  91., 106.,  76.,  90., 108.,  86.,  85., 104.,
                  87.,  41., 106.,  76., 100.,  89.,  42., 102., 104.,  59.,  93.,
                  94.,  76.,  50.,  88.,  70.])
      • created_at :
        2026-04-02T06:54:38.093357+00:00
        arviz_version :
        0.23.4
        inference_library :
        pymc
        inference_library_version :
        5.28.2

In [6]:
# plot mother iq vs kid score colored by whether mother graduated high school and with regression lines for each group
sns.lmplot(data=kidiq, x='mom_iq', y='kid_score', hue='mom_hs', ci=None)
plt.title('Kid Score vs Mom IQ by High School Graduation')
plt.show()

# lm for mothers who graduated high school - filter data to only include those with mom_hs == 1
kidiq_graduated = kidiq[kidiq['mom_hs'] == 1]
kidiq_not_graduated = kidiq[kidiq['mom_hs'] == 0]

fit_and_plot_lm(kidiq_graduated, ['mom_iq'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

fit_and_plot_lm(kidiq_not_graduated, ['mom_iq'], 'kid_score', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)
No description has been provided for this image
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.143
Model:                            OLS   Adj. R-squared:                  0.140
Method:                 Least Squares   F-statistic:                     56.42
Date:                Thu, 02 Apr 2026   Prob (F-statistic):           5.24e-13
Time:                        07:54:38   Log-Likelihood:                -1462.0
No. Observations:                 341   AIC:                             2928.
Df Residuals:                     339   BIC:                             2936.
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_iq         0.4846      0.065      7.511      0.000       0.358       0.612
const         39.7862      6.663      5.971      0.000      26.679      52.893
==============================================================================
Omnibus:                        5.765   Durbin-Watson:                   1.612
Prob(Omnibus):                  0.056   Jarque-Bera (JB):                5.908
Skew:                          -0.314   Prob(JB):                       0.0521
Kurtosis:                       2.851   Cond. No.                         720.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = 39.79 + 0.48*mom_iq
Residual std dev (σ): 17.66 ± 0.68
MAD of residuals: 15.55
No description has been provided for this image
                            OLS Regression Results                            
==============================================================================
Dep. Variable:              kid_score   R-squared:                       0.294
Model:                            OLS   Adj. R-squared:                  0.286
Method:                 Least Squares   F-statistic:                     37.87
Date:                Thu, 02 Apr 2026   Prob (F-statistic):           2.00e-08
Time:                        07:54:38   Log-Likelihood:                -405.14
No. Observations:                  93   AIC:                             814.3
Df Residuals:                      91   BIC:                             819.3
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
mom_iq         0.9689      0.157      6.154      0.000       0.656       1.282
const        -11.4820     14.601     -0.786      0.434     -40.485      17.521
==============================================================================
Omnibus:                        2.584   Durbin-Watson:                   1.924
Prob(Omnibus):                  0.275   Jarque-Bera (JB):                2.360
Skew:                          -0.389   Prob(JB):                        0.307
Kurtosis:                       2.950   Cond. No.                         685.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Formula: kid_score = -11.48 + 0.97*mom_iq
Residual std dev (σ): 19.07 ± 1.41
MAD of residuals: 18.01
No description has been provided for this image
Understanding the fitted model¶

kid_score = 25.43 + 5.96mom_hs + 0.57mom_iq + error

Intercept = when mom didnt complete high school and has an IQ of 0, the expected score is 25.43. Not meaningful as no one has an IQ of 0.

Coefficient for mom_hs = 5.96 means comparing children whose mothers have the same IQ, the children whose mothers graduated from high school are expected to score 5.96 points higher than those whose mothers did not graduate.

Coefficient for mom_iq = 0.57 means comparing children whose mothers have the same high school graduation status, for each unit increase in mother's IQ, the child's score is expected to increase by 0.57 points.

We can also look at the separate regressions for moms who did and did not graduate from high school to see how the relationship between mom_iq and kid_score differs by mom_hs status.

mom_hs = 0: kid_score = -11.48 + 0.97mom_iq mom_hs = 1: kid_score = 39.79 + 0.48mom_iq

When should we look for interactions?¶

Interactions can be important. Typically look for them when predictors have large coefficients when not interacted.

For example, smoking strongly associated with cancer. Crucial to adjust for other factors, for example radon exposure. Those who smoke and are exposed to radon may have a much higher risk of cancer than those who only smoke or are only exposed to radon. This would be an interaction between smoking and radon exposure.

We can fit models separately for smokers and non-smokers to see if the relationship between radon exposure and cancer risk differs by smoking status. We can also include an interaction term in a single model to formally test for the presence of an interaction between smoking and radon exposure.

Interpreting regression coefficients in the presence of interactions¶

We can more easily interpret models with interactions by centering the predictors. Typically about the mean.

10.2 Interpreting regression coefficients¶

Dummy variables are used to represent categorical predictors in regression models.

In [14]:
earnings = pd.read_csv('../ros_data/earnings.csv', skiprows=0)
# earnings['earn_k'] = earnings['earn'] / 1000
# earnings['c_height'] = earnings['height'] - 66 # Center height around 66 inches for better interpretability of the intercept

display(earnings.head())

earnings_clean = earnings.dropna(subset=['height', 'weight'])

# Predict weight in pounds from height
# fit_and_plot_lm(earnings_clean, ['height'], 'weight', add_constant=True, show_plot=True, scatter_kws=None, line_kws=None)

earnings_model = fit_and_plot_bayes(earnings_clean, 'height', 'weight',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)

print(earnings_model)
height weight male earn earnk ethnicity education mother_education father_education walk exercise smokenow tense angry age
0 74 210.0 1 50000.0 50.0 White 16.0 16.0 16.0 3 3 2.0 0.0 0.0 45
1 66 125.0 0 60000.0 60.0 White 16.0 16.0 16.0 6 5 1.0 0.0 0.0 58
2 64 126.0 0 30000.0 30.0 White 16.0 16.0 16.0 8 1 2.0 1.0 1.0 29
3 65 200.0 0 25000.0 25.0 White 17.0 17.0 NaN 8 1 2.0 0.0 0.0 57
4 63 110.0 0 50000.0 50.0 Other 16.0 16.0 16.0 5 6 2.0 0.0 0.0 91
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_height, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 5 seconds.
                 mean      sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept    -163.790  11.428  -185.933   -141.570      0.231    0.170   
slope_height    4.807   0.171     4.470      5.135      0.003    0.003   
sigma          28.978   0.487    28.075     29.950      0.008    0.008   

              ess_bulk  ess_tail  r_hat  
intercept       2447.0    2858.0    1.0  
slope_height    2461.0    2942.0    1.0  
sigma           3339.0    3024.0    1.0  

Regression formula: weight = -163.79 + 4.81*height
No description has been provided for this image
No description has been provided for this image
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
In [15]:
# Posterior prediction for a new observation (PyMC equivalent of R's posterior_predict)
# Predict weight for a person who is 66 inches tall
new_height = 66

posterior = earnings_model.posterior
intercept_samples = posterior["intercept"].values.flatten()  # all posterior draws for intercept (e.g. 4000 samples)
slope_samples = posterior["slope_height"].values.flatten()  # all posterior draws for slope
sigma_samples = posterior["sigma"].values.flatten()  # all posterior draws for residual std dev

# Point prediction for each posterior draw: E[weight | height=66]
mu_samples = intercept_samples + slope_samples * new_height  # linear predictor evaluated at new_height for each draw

# Full posterior predictive: sample a new observation from Normal(mu, sigma) for each draw
# This adds residual noise on top of the mean prediction, capturing both:
#   1. Parameter uncertainty (from the spread of intercept/slope draws)
#   2. Individual-level variation (from sigma)
# This is what makes it a *prediction* interval rather than just a *confidence* interval
pred_samples = np.random.normal(mu_samples, sigma_samples)

print(f"Predicted weight for height={new_height} inches:")
print(f"  Mean: {pred_samples.mean():.1f} lbs")
print(f"  Median: {np.median(pred_samples):.1f} lbs")
print(f"  50% interval: [{np.percentile(pred_samples, 25):.1f}, {np.percentile(pred_samples, 75):.1f}]")
print(f"  95% interval: [{np.percentile(pred_samples, 2.5):.1f}, {np.percentile(pred_samples, 97.5):.1f}]")

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(pred_samples, bins=50, density=True, alpha=0.7)
ax.axvline(pred_samples.mean(), color='red', linestyle='--', label=f'Mean: {pred_samples.mean():.1f}')
ax.set_xlabel('Predicted weight (lbs)')
ax.set_ylabel('Density')
ax.set_title(f'Posterior predictive distribution for height = {new_height} inches')
ax.legend()
plt.show()
Predicted weight for height=66 inches:
  Mean: 154.5 lbs
  Median: 154.5 lbs
  50% interval: [134.7, 174.6]
  95% interval: [96.6, 211.3]
No description has been provided for this image
In [18]:
earnings['c_height'] = earnings['height'] - 66 # Center height around 66 inches for better interpretability of the intercept

display(earnings.head())

earnings_clean = earnings.dropna(subset=['c_height', 'weight', 'male'])

# Predict weight in pounds from height and male

earnings_model = fit_and_plot_bayes(earnings_clean, 'male', 'c_height', 'weight',
                       intercept_mu=0, intercept_sigma=50,
                       slope_mu=0, slope_sigma=50,
                       sigma_sigma=50,
                       samples=2000, tune=1000, hdi_prob=0.95,
                       show_trace=True, show_forest=False,
                       show_posterior=False, show_regression=True,
                       n_regression_lines=100)
height weight male earn earnk ethnicity education mother_education father_education walk exercise smokenow tense angry age c_height
0 74 210.0 1 50000.0 50.0 White 16.0 16.0 16.0 3 3 2.0 0.0 0.0 45 8
1 66 125.0 0 60000.0 60.0 White 16.0 16.0 16.0 6 5 1.0 0.0 0.0 58 0
2 64 126.0 0 30000.0 30.0 White 16.0 16.0 16.0 8 1 2.0 1.0 1.0 29 -2
3 65 200.0 0 25000.0 25.0 White 17.0 17.0 NaN 8 1 2.0 0.0 0.0 57 -1
4 63 110.0 0 50000.0 50.0 Other 16.0 16.0 16.0 5 6 2.0 0.0 0.0 91 -3
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, slope_male, slope_c_height, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 1 seconds.
                   mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept       149.499  0.931   147.664    151.292      0.013    0.009   
slope_male       11.881  1.962     8.199     15.786      0.029    0.022   
slope_c_height    3.883  0.248     3.431      4.394      0.004    0.003   
sigma            28.696  0.473    27.790     29.653      0.006    0.006   

                ess_bulk  ess_tail  r_hat  
intercept         5292.0    5856.0    1.0  
slope_male        4571.0    5203.0    1.0  
slope_c_height    4731.0    5896.0    1.0  
sigma             5858.0    5555.0    1.0  

Regression formula: weight = 149.50 + 11.88*male + 3.88*c_height
No description has been provided for this image
No description has been provided for this image

Coefficient of 12 for man tells us that when comparing a man to woman of the same height, the man will be 12 pounds more on average.

In [22]:
# Predict weight of a 70 inch woman:

# Posterior prediction for a new observation (PyMC equivalent of R's posterior_predict)
# Predict weight for a person who is 66 inches tall
new_height = 70

posterior = earnings_model.posterior
intercept_samples = posterior["intercept"].values.flatten()  # all posterior draws for intercept (e.g. 4000 samples)
slope_height_samples = posterior["slope_c_height"].values.flatten()  # all posterior draws for slope
slope_male_samples = posterior["slope_male"].values.flatten()  # all posterior draws for slope
sigma_samples = posterior["sigma"].values.flatten()  # all posterior draws for residual std dev

# Point prediction for each posterior draw: E[weight | height=70]
mu_samples = intercept_samples + slope_height_samples * (70-66) + slope_male_samples * 0

# Full posterior predictive: sample a new observation from Normal(mu, sigma) for each draw
# This adds residual noise on top of the mean prediction, capturing both:
#   1. Parameter uncertainty (from the spread of intercept/slope draws)
#   2. Individual-level variation (from sigma)
# This is what makes it a *prediction* interval rather than just a *confidence* interval
pred_samples = np.random.normal(mu_samples, sigma_samples)

print(f"Predicted weight for height={new_height} inches:")
print(f"  Mean: {pred_samples.mean():.1f} lbs")
print(f"  Median: {np.median(pred_samples):.1f} lbs")
print(f"  50% interval: [{np.percentile(pred_samples, 25):.1f}, {np.percentile(pred_samples, 75):.1f}]")
print(f"  95% interval: [{np.percentile(pred_samples, 2.5):.1f}, {np.percentile(pred_samples, 97.5):.1f}]")

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(pred_samples, bins=50, density=True, alpha=0.7)
ax.axvline(pred_samples.mean(), color='red', linestyle='--', label=f'Mean: {pred_samples.mean():.1f}')
ax.set_xlabel('Predicted weight (lbs)')
ax.set_ylabel('Density')
ax.set_title(f'Posterior predictive distribution for height = {new_height} inches')
ax.legend()
plt.show()
Predicted weight for height=70 inches:
  Mean: 164.6 lbs
  Median: 164.7 lbs
  50% interval: [145.3, 184.0]
  95% interval: [107.6, 221.2]
No description has been provided for this image
Using indicator variables for multiple levels of a categorical predictor¶

Add ethnicity

In [29]:
# Create dummy variables for ethnicity
# Drop any existing eth_ columns first to avoid duplicates on re-run
earnings_clean = earnings_clean[[c for c in earnings_clean.columns if not c.startswith('eth_')]]
eth_dummies = pd.get_dummies(earnings_clean['ethnicity'], prefix='eth', dtype=int)
eth_dummies = eth_dummies.drop(columns='eth_Black')  # Black is reference group
earnings_clean = pd.concat([earnings_clean, eth_dummies], axis=1)
print(eth_dummies.columns.tolist())

with pm.Model() as earnings_eth_model:
    intercept = pm.Normal("intercept", mu=0, sigma=50)
    male = pm.Normal("male", mu=0, sigma=50)
    c_height = pm.Normal("c_height", mu=0, sigma=50)
    eth_Hispanic = pm.Normal("eth_Hispanic", mu=0, sigma=50)
    eth_Other = pm.Normal("eth_Other", mu=0, sigma=50)
    eth_White = pm.Normal("eth_White", mu=0, sigma=50)
    sigma = pm.HalfNormal("sigma", sigma=50)

    # intercept = expected weight for a Black female of average height
    # each eth_ coefficient = difference from Black reference group
    mu = (intercept
          + male * earnings_clean['male'].values
          + c_height * earnings_clean['c_height'].values
          + eth_Hispanic * earnings_clean['eth_Hispanic'].values
          + eth_Other * earnings_clean['eth_Other'].values
          + eth_White * earnings_clean['eth_White'].values)

    y = pm.Normal("y", mu=mu, sigma=sigma, observed=earnings_clean['weight'].values)
    trace_eth = pm.sample(2000, tune=1000)

print(pm.summary(trace_eth, hdi_prob=0.95))
az.plot_trace(trace_eth)
plt.tight_layout()
plt.show()
Initializing NUTS using jitter+adapt_diag...
['eth_Hispanic', 'eth_Other', 'eth_White']
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, male, c_height, eth_Hispanic, eth_Other, eth_White, sigma]
/opt/anaconda3/envs/ros_pymc/lib/python3.12/site-packages/rich/live.py:260: UserWarning: install "ipywidgets" for 
Jupyter support
  warnings.warn('install "ipywidgets" for Jupyter support')

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 3 seconds.
                 mean     sd  hdi_2.5%  hdi_97.5%  mcse_mean  mcse_sd  \
intercept     153.991  2.253   149.426    158.296      0.035    0.027   
male           12.173  2.004     8.175     16.079      0.028    0.022   
c_height        3.845  0.254     3.354      4.354      0.004    0.003   
eth_Hispanic   -5.830  3.529   -12.911      0.919      0.047    0.036   
eth_Other     -11.851  5.186   -21.855     -1.608      0.069    0.055   
eth_White      -4.864  2.281    -9.438     -0.381      0.035    0.026   
sigma          28.644  0.473    27.715     29.552      0.006    0.005   

              ess_bulk  ess_tail  r_hat  
intercept       4101.0    4227.0    1.0  
male            5147.0    5286.0    1.0  
c_height        5274.0    5557.0    1.0  
eth_Hispanic    5525.0    5765.0    1.0  
eth_Other       5699.0    5281.0    1.0  
eth_White       4231.0    4547.0    1.0  
sigma           6828.0    5222.0    1.0  
No description has been provided for this image

When comparing a black person to hispanic person of same sex and height, the hispanic person will be -5.83 pounds lighter on average.

Dummy variables act like switches that turn on or off the effect of a particular category. The coefficients for the dummy variables represent the difference in the response variable between the category represented by the dummy variable and the reference category, holding all else constant.

Reference is always when all dummy variables are 0. In this case, the reference category is white people

Using an index variable to access a group-level predictor¶

Sometimes we have predictors that are measured at a group level rather than an individual level. For example, we have data on 1000 students from 20 schools, and we want to include a predictor that is the average income of parents in each school. We can create an index variable that assigns a unique number to each school, and then use this index variable to merge the group-level predictor (average income) with the individual-level data on students. In simple terms this means that we allocate the same average income value to all students in the same school, and then include this variable in our regression model to see how it affects student outcomes.

10.3 Interactions¶

10.4 Indicator variables¶

10.5 Formulating paired or blocked designs as a regression problem¶

10.6 Example: uncertainty in predicting congressional elections¶

10.7 Mathematical notation and statistical inference¶

10.8 Weighted regression¶

10.9 Fitting the same model to many datasets¶

10.10 Bibliographic note¶

10.11 Exercises¶