diabayes.typedefs.BayesianSolution#

class diabayes.typedefs.BayesianSolution(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#

The result of a Bayesian inversion. This data class stores the particle trajectories (equivalent of MCMC chains), the final equilibrum state, and various statistcs, as well as helper routines to diagnose or visualise the results.

Examples

>>> result = solver.bayesian_inversion(...)
>>> print(result.nan_count)  # Check for NaN values
>>> print(result.statistics.Dc.median)  # Get the median of Dc
>>> result.plot_convergence()  # Visually inspect convergence
>>> result.cornerplot(nbins=15)  # Create a cornerplot
__init__(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#

Methods

__init__(log_params, log_likelihood, nan_count)

cornerplot([nbins])

Create a corner plot from the particle positions.

plot_convergence()

Helper routine to visualise the convergence of the inversion.

sample(solver, y0, t, friction_constants, ...)

Draw random samples from the posterior distribution.

Attributes

statistics

Pre-computed statistics of the parameters

log_params

The particles representing the (log-valued) parameters, including all the iteration steps

chains

The particle trajectories (to inspect convergence)

log_likelihood

The log-likelihood evolution

nan_count

The number of NaNs encountered during each iteration step

final_state

The final state of the particle swarm

cornerplot(nbins=20)[source]#

Create a corner plot from the particle positions. For N parameters, a corner plot is the lower half of an N x N grid of subplots. On the diagonal, the histogram of the marginal posterior over the i-th parameter is plotted. Below the diagonal, the (i,j)-th panel is a scatter plot of parameter i against parameter j, which shows the co-variance between the two quantities.

Parameters:

nbins (int) – The number of bins to use for the histograms on the diagonal of the corner plot

Returns:

fig – The figure object that can be manipulated after creation (e.g. to save to disk)

Return type:

matplotlib.pyplot.figure

plot_convergence()[source]#

Helper routine to visualise the convergence of the inversion. Convergence is achieved when the particles settle in an equilibrium position (i.e., they stop moving).

Returns:

fig – The figure object that can be manipulated after creation (e.g. to save to disk)

Return type:

matplotlib.pyplot.figure

sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#

Draw random samples from the posterior distribution. Each particle should represent one realisation of a plausible friction curve. For nsamples realisations requested by a user, it is more beneficial to vmap the forward simulations rather than compute them one by one.

The resulting friction curves can be used for validation purposes (does the posterior match the data?).

Examples

>>> result = solver.bayesian_inversion(...)
>>> samples, sample_results = result.sample(
...    solver=solver, y0=y0, t=t,
...    friction_constants=constants,
...    block_constants=block_constants,
...    nsamples=20, rng=42)
... )
Parameters:
  • solver (ODESolver) – The ODESolver class instantiated with the Forward model (typically the same one as used for the inversion)

  • y0 (Variables) – The initial values of friction and state

  • t (Array) – The time samples at which a solution is requested. This does not need to be the same as the time samples recorded in the experiment (from which the observed friction curve is obtained)

  • friction_constants (_Constants) – The friction and state model constants

  • block_constants (_BlockConstants) – The stress transfer constants

  • nsamples (int) – The number of random samples for which a solution should be computed

  • rng (None, int, jax.random.PRNGKey) – The random seed used to sample from the particle distribution. If None, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a new jax.random.PRNGKey is generated.

Returns:

  • samples (_Params) – The samples drawn from the posterior distribution

  • sample_results (Variables) – The forward simulation results corresponding with samples

chains: Float[Array, 'Nsteps Nparticles Nparams']#

The particle trajectories (to inspect convergence)

final_state: Float[Array, 'Nparticles Nparams']#

The final state of the particle swarm

log_likelihood: Float[Array, 'Nsteps']#

The log-likelihood evolution

log_params: RSFParams | CNSParams#

The particles representing the (log-valued) parameters, including all the iteration steps

nan_count: Float[Array, 'Nsteps']#

The number of NaNs encountered during each iteration step

statistics: RSFStatistics | CNSStatistics | None = None#

Pre-computed statistics of the parameters