SoftmaxWeightedSumFitter#

class causalpy.pymc_models.SoftmaxWeightedSumFitter[source]#

Weighted sum model with softmax-over-Normal-logits parameterization.

An alternative to WeightedSumFitter for synthetic control experiments. Instead of a Dirichlet prior on the simplex weights, this model places Normal priors on unconstrained logits and maps them to the simplex via the softmax transform. The first logit is pinned to zero to remove the softmax’s shift invariance.

Defines the PyMC model:

\[\begin{split}\tilde{\beta}_1 &= 0 \\ \tilde{\beta}_{j} &\sim \mathrm{Normal}(0, \sigma) \quad j = 2, \ldots, N \\ \beta &= \mathrm{softmax}(\tilde{\beta}) \\ \mu &= X \cdot \beta \\ y &\sim \mathrm{Normal}(\mu, \sigma_y) \\\end{split}\]

Notes

The softmax-Normal parameterization and the Dirichlet prior used by WeightedSumFitter both produce simplex-valued weights, but they encode different prior beliefs and regularization behavior:

  • Dirichlet (WeightedSumFitter): With concentration a=1 the prior is uniform on the simplex. Setting a < 1 encourages sparsity (weights concentrating on fewer donors), while a > 1 encourages uniformity. Regularization strength is controlled by the concentration parameter.

  • Softmax-Normal (this class): The prior scale sigma on the logits controls regularization. Small sigma shrinks logits toward zero, producing near-uniform weights (DiD-like behavior). Large sigma allows the data to concentrate weight on a few well-matching control units (SC-like behavior). The default sigma=1.0 provides moderate regularization.

This parameterization is motivated by the Bayesian Synthetic Difference-in-Differences (SDiD) formulation, where the prior scale plays the role of the \(\ell_2\) regularization parameter \(\zeta\) in the frequentist SDiD of Arkhangelsky et al. (2021).

Example

>>> import causalpy as cp
>>> import numpy as np
>>> import xarray as xr
>>> from causalpy.pymc_models import SoftmaxWeightedSumFitter
>>> sc = cp.load_data("sc")
>>> control_units = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
>>> X = xr.DataArray(
...     sc[control_units].values,
...     dims=["obs_ind", "coeffs"],
...     coords={"obs_ind": sc.index, "coeffs": control_units},
... )
>>> y = xr.DataArray(
...     sc['actual'].values.reshape((sc.shape[0], 1)),
...     dims=["obs_ind", "treated_units"],
...     coords={"obs_ind": sc.index, "treated_units": ["actual"]},
... )
>>> coords = {
...     "coeffs": control_units,
...     "treated_units": ["actual"],
...     "obs_ind": np.arange(sc.shape[0]),
... }
>>> wsf = SoftmaxWeightedSumFitter(sample_kwargs={"progressbar": False})
>>> wsf.fit(X, y, coords=coords)
Inference data...

Methods

SoftmaxWeightedSumFitter.__init__([...])

SoftmaxWeightedSumFitter.add_coord(name[, ...])

Register a dimension coordinate with the model.

SoftmaxWeightedSumFitter.add_coords(coords, *)

Vectorized version of Model.add_coord.

SoftmaxWeightedSumFitter.add_named_variable(var)

Add a random graph variable to the named variables of the model.

SoftmaxWeightedSumFitter.build_model(X, y, ...)

Build the PyMC model with softmax-parameterized simplex weights.

SoftmaxWeightedSumFitter.calculate_cumulative_impact(impact)

SoftmaxWeightedSumFitter.calculate_impact(...)

Calculate the causal impact as the difference between observed and predicted values.

SoftmaxWeightedSumFitter.check_start_vals(...)

Check that the logp is defined and finite at the starting point.

SoftmaxWeightedSumFitter.compile_d2logp([...])

Compiled log probability density hessian function.

SoftmaxWeightedSumFitter.compile_dlogp([...])

Compiled log probability density gradient function.

SoftmaxWeightedSumFitter.compile_fn(outs, *)

SoftmaxWeightedSumFitter.compile_logp([...])

Compiled log probability density function.

SoftmaxWeightedSumFitter.copy()

Clone the model.

SoftmaxWeightedSumFitter.create_value_var(...)

Create a TensorVariable that will be used as the random variable's "value" in log-likelihood graphs.

SoftmaxWeightedSumFitter.d2logp([vars, ...])

Hessian of the models log-probability w.r.t.

SoftmaxWeightedSumFitter.debug([point, fn, ...])

Debug model function at point.

SoftmaxWeightedSumFitter.dlogp([vars, jacobian])

Gradient of the models log-probability w.r.t.

SoftmaxWeightedSumFitter.eval_rv_shapes()

Evaluate shapes of untransformed AND transformed free variables.

SoftmaxWeightedSumFitter.fit(X, y[, coords])

Draw samples from posterior, prior predictive, and posterior predictive distributions.

SoftmaxWeightedSumFitter.get_context([...])

SoftmaxWeightedSumFitter.initial_point([...])

Compute the initial point of the model.

SoftmaxWeightedSumFitter.logp([vars, ...])

Elemwise log-probability of the model.

SoftmaxWeightedSumFitter.logp_dlogp_function([...])

Compile a PyTensor function that computes logp and gradient.

SoftmaxWeightedSumFitter.make_obs_var(...)

Create a TensorVariable for an observed random variable.

SoftmaxWeightedSumFitter.name_for(name)

Check if name has prefix and adds if needed.

SoftmaxWeightedSumFitter.name_of(name)

Check if name has prefix and deletes if needed.

SoftmaxWeightedSumFitter.point_logps([...])

Compute the log probability of point for all random variables in the model.

SoftmaxWeightedSumFitter.predict(X[, ...])

Predict data given input data X

SoftmaxWeightedSumFitter.print_coefficients(labels)

Print the model coefficients with their labels.

SoftmaxWeightedSumFitter.priors_from_data(X, y)

Set Normal prior for logit weights based on number of control units.

SoftmaxWeightedSumFitter.profile(outs, *[, ...])

Compile and profile a PyTensor function which returns outs and takes values of model vars as a dict as an argument.

SoftmaxWeightedSumFitter.register_data_var(data)

Register a data variable with the model.

SoftmaxWeightedSumFitter.register_rv(rv_var, ...)

Register an (un)observed random variable with the model.

SoftmaxWeightedSumFitter.replace_rvs_by_values(...)

Clone and replace random variables in graphs with their value variables.

SoftmaxWeightedSumFitter.score(X, y[, coords])

Score the Bayesian \(R^2\) given inputs X and outputs y.

SoftmaxWeightedSumFitter.set_data(name, values)

Change the values of a data variable in the model.

SoftmaxWeightedSumFitter.set_dim(name, ...)

Update a mutable dimension.

SoftmaxWeightedSumFitter.set_initval(rv_var, ...)

Set an initial value (strategy) for a random variable.

SoftmaxWeightedSumFitter.shape_from_dims(dims)

SoftmaxWeightedSumFitter.to_graphviz(*[, ...])

Produce a graphviz Digraph from a PyMC model.

Attributes

basic_RVs

List of random variables the model is defined in terms of.

continuous_value_vars

All the continuous value variables in the model.

coords

Coordinate values for model dimensions.

datalogp

PyTensor scalar of log-probability of the observed variables and potential terms.

default_priors

dim_lengths

The symbolic lengths of dimensions in the model.

discrete_value_vars

All the discrete value variables in the model.

isroot

observedlogp

PyTensor scalar of log-probability of the observed variables.

parent

potentiallogp

PyTensor scalar of log-probability of the Potential terms.

prefix

root

unobserved_RVs

List of all random variables, including deterministic ones.

unobserved_value_vars

List of all random variables (including untransformed projections), as well as deterministics used as inputs and outputs of the model's log-likelihood graph.

value_vars

List of unobserved random variables used as inputs to the model's log-likelihood (which excludes deterministics).

varlogp

PyTensor scalar of log-probability of the unobserved random variables (excluding deterministic).

varlogp_nojac

PyTensor scalar of log-probability of the unobserved random variables (excluding deterministic) without jacobian term.

__init__(sample_kwargs=None, priors=None)#
Parameters:
  • sample_kwargs (dict[str, Any] | None) – Dictionary of kwargs that get unpacked and passed to the pymc.sample() function. Defaults to an empty dictionary if None.

  • priors (dict[str, Any] | None) – Dictionary of priors for the model. Defaults to None, in which case default priors are used.

Return type:

None

classmethod __new__(*args, **kwargs)#