diabayes.typedefs.BayesianSolution#
- class diabayes.typedefs.BayesianSolution(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#
The result of a Bayesian inversion. This data class stores the particle trajectories (equivalent of MCMC chains), the final equilibrum state, and various statistcs, as well as helper routines to diagnose or visualise the results.
Examples
>>> result = solver.bayesian_inversion(...) >>> print(result.nan_count) # Check for NaN values >>> print(result.statistics.Dc.median) # Get the median of Dc >>> result.plot_convergence() # Visually inspect convergence >>> result.cornerplot(nbins=15) # Create a cornerplot
- __init__(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#
Methods
__init__
(log_params, log_likelihood, nan_count)cornerplot
([nbins])Create a corner plot from the particle positions.
Helper routine to visualise the convergence of the inversion.
sample
(solver, y0, t, friction_constants, ...)Draw random samples from the posterior distribution.
Attributes
Pre-computed statistics of the parameters
The particles representing the (log-valued) parameters, including all the iteration steps
The particle trajectories (to inspect convergence)
The log-likelihood evolution
The number of NaNs encountered during each iteration step
The final state of the particle swarm
- cornerplot(nbins=20)[source]#
Create a corner plot from the particle positions. For N parameters, a corner plot is the lower half of an N x N grid of subplots. On the diagonal, the histogram of the marginal posterior over the i-th parameter is plotted. Below the diagonal, the (i,j)-th panel is a scatter plot of parameter i against parameter j, which shows the co-variance between the two quantities.
- Parameters:
nbins (int) – The number of bins to use for the histograms on the diagonal of the corner plot
- Returns:
fig – The figure object that can be manipulated after creation (e.g. to save to disk)
- Return type:
matplotlib.pyplot.figure
- plot_convergence()[source]#
Helper routine to visualise the convergence of the inversion. Convergence is achieved when the particles settle in an equilibrium position (i.e., they stop moving).
- Returns:
fig – The figure object that can be manipulated after creation (e.g. to save to disk)
- Return type:
matplotlib.pyplot.figure
- sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#
Draw random samples from the posterior distribution. Each particle should represent one realisation of a plausible friction curve. For
nsamples
realisations requested by a user, it is more beneficial to vmap the forward simulations rather than compute them one by one.The resulting friction curves can be used for validation purposes (does the posterior match the data?).
Examples
>>> result = solver.bayesian_inversion(...) >>> samples, sample_results = result.sample( ... solver=solver, y0=y0, t=t, ... friction_constants=constants, ... block_constants=block_constants, ... nsamples=20, rng=42) ... )
- Parameters:
solver (ODESolver) – The
ODESolver
class instantiated with theForward
model (typically the same one as used for the inversion)y0 (Variables) – The initial values of friction and state
t (Array) – The time samples at which a solution is requested. This does not need to be the same as the time samples recorded in the experiment (from which the observed friction curve is obtained)
friction_constants (_Constants) – The friction and state model constants
block_constants (_BlockConstants) – The stress transfer constants
nsamples (int) – The number of random samples for which a solution should be computed
rng (None, int, jax.random.PRNGKey) – The random seed used to sample from the particle distribution. If
None
, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a newjax.random.PRNGKey
is generated.
- Returns:
samples (_Params) – The samples drawn from the posterior distribution
sample_results (Variables) – The forward simulation results corresponding with
samples
- chains: Float[Array, 'Nsteps Nparticles Nparams']#
The particle trajectories (to inspect convergence)
- final_state: Float[Array, 'Nparticles Nparams']#
The final state of the particle swarm
- log_likelihood: Float[Array, 'Nsteps']#
The log-likelihood evolution
- log_params: RSFParams | CNSParams#
The particles representing the (log-valued) parameters, including all the iteration steps
- nan_count: Float[Array, 'Nsteps']#
The number of NaNs encountered during each iteration step
- statistics: RSFStatistics | CNSStatistics | None = None#
Pre-computed statistics of the parameters