diabayes package#

Submodules#

diabayes.SVI module#

diabayes.SVI.compute_phi(x: Float[Array, 'N M'], gradp: Float[Array, 'N M'], gradq: Float[Array, 'N M']) Float[Array, 'N M'][source]#

Compute the Stein variational gradients for a set of particles (x) and the gradients of the log-likelihood function (gradp) and the log-prior (gradq).

Parameters:
  • x (ParticleArray) – The set of invertible parameters (“particles”) of shape (Nparticles, Ndimensions)

  • gradp (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)

  • gradq (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)

Returns:

grad_x – The directional gradients for the particle updates, e.g. new_x = x + step_size * grad_x. Has the same shape as x and gradp, gradq.

Return type:

Gradients

diabayes.SVI.mapped_log_likelihood(log_params: RSFParams | CNSParams, mu_obs: Float[Array, 'Nt'], noise_std: Float, forward_fn: Callable[[RSFParams | CNSParams], Variables]) Float#

Vectorized version of _log_likelihood. Takes similar arguments as _log_likelihood but with additional array axes over which _log_likelihood is mapped.

diabayes.forward_models module#

class diabayes.forward_models.Forward(friction_model: FrictionModel, state_evolution: Dict[str, Callable], stress_transfer: StressTransfer[BC])[source]#

Bases: Generic[BC]

The Forward class assembles the various components that comprise a forward model, such that Forward.__call__ takes some variables and returns the rate of change of these variables

\[\frac{\mathrm{d} \vec{X}}{\mathrm{d}t} = f \left( \vec{X} \right)\]

This forward ODE can then be solved by any ODE solver.

The Forward class is instantiated by providing a friction model, a “state” evolution law, and a stress transfer model.

Examples

>>> from diabayes.forward_models import ageing_law, rsf, springblock, Forward
>>> foward_model = Forward(rsf, {"theta": ageing_law}, springblock)
>>> X_dot = forward_model(variables=..., params=..., friction_constants=..., block_constants=...)
call(t: Float, variables: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: BC) Variables#

Calculate the rate of change of the variables, defining the ODE to be solved.

Parameters:
  • t (Float) – Current value of time [s]

  • variables (diabayes.Variables) – The instantaneous values of the variables for which the time derivative will be computed

  • params (_Params) – The forward model (invertible) parameters

  • friction_constants (_Constants) – The forward model constants

  • block_constants (_BlockConstants) – The constants associated with the stress transfer

Returns:

dvars – The instantaneous rate of change of the variables

Return type:

Variables

set_initial_values(**kwargs) None[source]#

Set the initial values of the variables. This sets the values as self.variables, which is a Variables object.

Parameters:

**kwargs (dict) – Key-value pairs of variable names and corresponding values

diabayes.forward_models.ageing_law(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float[source]#

The conventional ageing law state evolution formulation

\[\frac{\mathrm{d}\theta}{\mathrm{d}t} = 1 - \frac{v \theta}{D_c}\]
Parameters:
  • v (Float) – Instantaneous fault slip rate [m/s].

  • variables (Variables) – The friction coefficient mu and state parameter theta

  • params (RSFParams) – The rate-and-state parameters, including D_c

  • constants (RSFConstants) – The constant parameters mu0 and v0 (not used)

Returns:

dtheta – The rate of change of the state variable [s/s]

Return type:

Float

diabayes.forward_models.aging_law(*args, **kwargs)[source]#

Alias for the ageing_law function

diabayes.forward_models.inertial_springblock(t: Float, v: Float, v_partials: Variables, variables: Variables, dstate: Float[Array, '...'], constants: InertialSpringBlockConstants) Float[source]#

An inertial spring-block loading formulation

\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = \left[ \frac{\partial v}{\partial \mu} \right]^{-1} \left( \frac{1}{M} \left[ k \left( v_{lp} t - x \right) - \mu \right] - \frac{\partial v}{\partial \theta} \frac{\mathrm{d} \theta}{\mathrm{d} t} - \dots \right)\]

The acceleration term in the classical inertial spring-block formulation is decomposed into its partial derivatives, avoiding the need for solving a second-order ODE. These partial derivatives (v_partials) are computed using the JAX autodiff framework.

Notes

This formulation is rather stiff, and for certain parameter values could lead to extremely small time steps necessary to maintain numerical accuracy. It is recommended to use a conventional (non-inenrtial) springblock formulation for basic velocity-steps and slide-hold-slide simuilations. Inertia is only really needed for stick-slip simulations.

Parameters:
  • t (Float) – Current value of time [s]

  • v (Float) – Instantaneous fault slip rate [m/s]

  • v_partials (Variables) – The partial derivatives of slip rate to the relevant variables (friction and state variables)

  • variables (Variables) – The instantaneous values of the variables: friction (mu), slip (slip), and other state variables (not used)

  • dstate (Array) – The time derivatives of the state variables. The radiation term is v_partials @ dstate (excluding mu)

  • constants (InertialSpringBlockConstants) – The spring-block constants containing the mass term M, the stiffness k, and the load-point velocity v_lp

Returns:

dmu – The rate of change of the friction coefficient [1/s]

Return type:

Float

diabayes.forward_models.rsf(variables: Variables, params: RSFParams, constants: RSFConstants) Float[source]#

The classical rate-and-state friction law

\[v(\mu, \theta) = v_0 \exp \left( \frac{1}{a} \left[ \mu - \mu_0 - b \log \left( \frac{v_0 \theta}{D_c} \right) \right] \right)\]
Parameters:
  • variables (Variables) – The friction coefficient mu and state parameter theta

  • params (RSFParams) – The rate-and-state parameters a, b, and D_c

  • constants (RSFConstants) – The constant parameters mu0 and v0

Returns:

v – The instantaneous slip rate in the same units as v0

Return type:

Float

diabayes.forward_models.slip_rate(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float[source]#

Evolve slip from slip rate

\[\frac{\mathrm{d} x}{\mathrm{d} t} = v\]
Parameters:
  • v (Float) – Instantaneous fault slip rate [m/s]

  • variables (Variables) – The friction coefficient mu and state parameter theta (not used)

  • params (RSFParams) – The rate-and-state parameters (not used)

  • constants (RSFConstants) – The constant parameters mu0 and v0 (not used)

Returns:

v – The rate of change of slip [m/s]

Return type:

Float

diabayes.forward_models.springblock(t: Float, v: Float, v_partials: Variables, variables: Variables, dstate: Float[Array, '...'], constants: SpringBlockConstants) Float[source]#

A conventional (non-inertial) spring-block loading formulation

\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = k \left ( v_{lp} - v(t) \right)\]
Parameters:
  • v (Float) – Instantaneous fault slip rate [m/s].

  • variables (Variables) – The instantaneous variables. This argument is not used, but included for call signature consistency.

  • constants (SpringBlockConstants) – The constant parameters stiffness k (units of “friction per metre”) and load-point velocity v_lp (same units as v).

Returns:

dmu – The rate of change of the friction coefficient [1/s]

Return type:

Float

diabayes.solver module#

class diabayes.solver.ODESolver(forward_model: Forward, rtol: float = 1e-08, atol: float = 1e-12, checkpoints: int = 100)[source]#

Bases: object

The main solver class that contains forward and inverse modelling methods.

bayesian_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], noise_std: Float, y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, Nparticles: int = 1000, Nsteps: int = 150, rng: None | int | Array = None) BayesianSolution[source]#

A Bayesian inversion routine using the Stein Variational Inference method.

Parameters:
  • t (Array) – A vector or time values (in units of seconds). The time steps do not need to be uniform

  • mu (Array) – The observed friction curve sampled at t

  • noise_std (Float) – An estimate of the standard deviation of the noise in the measured friction curve. A conservative value is recommended, i.e. if the noise has a standard deviation of $10^{-3}$, a good starting point would be to set noise_std = 0.5e-3

  • y0 (Variables) – The initial values for the modelled friction and any state variables

  • params (_Params) – The initial guess for the invertible parameters that characterise the forward problem. It is recommended to use the result from ODESolver.max_likelihood_inversion. The prior distribution will be centered around this initial guess

  • friction_constants (_Constants) – The non-invertible constants that characterise the forward problem

  • block_constants (_BlockConstants) – The stress transfer constants (e.g. stiffness and loading rate)

  • Nparticles (int) – The number of particles (= posterior samples) to include. A higher value gives a more accurate estimation of the posterior distribution, at a higher computational cost.

  • Nsteps (int) – The number of iterations before convergence is expected to be achieved. It is recommended to start with a value of 100 and then see if an equilibrium was actually achieved. Increasing this value beyond the point of equilibrium does not do anything.

  • rng (None, int, jax.random.PRNGKey) – The random seed used to initialise the particle swarm distribution. If None, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a new jax.random.PRNGKey is generated.

Returns:

result – The result of the inversion incapsulated in a BayesianSolution container, which provides access to the full convergence chains, as well as diagnostic and visualisation routines.

Return type:

diabayes.BayesianSolution

Notes

This method is currently only compatible with standard rate-and-state friction models.

max_likelihood_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, verbose: bool = False) Solution[source]#

Minimises the least-squares residuals between the observed friction curve and the parameterised one, using the Levenberg-Marquardt algorithm.

Parameters:
  • t (Array) – A vector or time values (in units of seconds). The time steps do not need to be uniform

  • mu (Array) – The observed friction curve sampled at t

  • y0 (Variables) – The initial values for the modelled friction and any state variables

  • params (_Params) – The initial guess for the invertible parameters that characterise the forward problem. These need to be sufficiently close to the “true” values for the algorithm to converge

  • friction_constants (_Constants) – The non-invertible constants that characterise the forward problem

  • block_constants (_BlockConstants) – The stress transfer constants (e.g. stiffness and loading rate)

  • verbose (bool) – Whether or not to output detailed progress of the inversion. Defaults to False

Returns:

sol – The inversion result, including various diagnostics. The inverted parameter values can be accessed as sol.values

Return type:

optimistix.Solution

Notes

If an error is produced in the first iteration step, it is quite possible that the initial guess parameters were too far off from the “true” values (i.e., the mismatch between the observed and modelled friction curves is too large), breaking the Gauss-Newton step of the Levenberg-Marquardt algorithm. Initial manual tuning is recommended.

solve_forward(t: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, method: str = 'RK45') Any[source]#

Solve a forward problem using SciPy’s solve_ivp routine. While this routine doesn’t propagate any gradients, it is much faster to initialise and to perform a single forward run. Hence for playing around with different parameters, it is preferred over a JITed JAX implementation.

Parameters:
  • t (Float[Array, "Nt"]) – A vector of time samples where a solution is requested

  • y0 (Variables) – The initial values (fricton and state) wrapped in a Variables container.

  • params (_Params) – The (invertible) parameters that govern the dynamics, wrapped in a Params container.

  • friction_constants (_Constants) – A container object containing the friction constants

  • block_constants (_BlockConstants) – A container object containing the block constants

Returns:

result – Solution time series of friction and state

Return type:

Variables

atol: float#

The absolute tolerance used by the ODE solver

checkpoints: int#

The number of checkpoints to use to compute the (adjoint) gradients through the ODE routine. A higher number increases stability and speed, at the expense of more GPU memory

forward_model: Forward#

An instantiated diabayes.Forward class, including friction law, state evolution equations, and stress transfer equation

learning_rate: float = 0.01#

The initial learning rate provided to the Adam algorithm for the Stein Variational Inference. The default value of 1e-2 seems like a sensible choice for most models.

rtol: float#

The absolute tolerance used by the ODE solver

diabayes.typedefs module#

class diabayes.typedefs.BC#

Function signatures#

alias of TypeVar(‘BC’, bound=SpringBlockConstants | InertialSpringBlockConstants, contravariant=True)

class diabayes.typedefs.BayesianSolution(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#

Bases: object

The result of a Bayesian inversion. This data class stores the particle trajectories (equivalent of MCMC chains), the final equilibrum state, and various statistcs, as well as helper routines to diagnose or visualise the results.

Examples

>>> result = solver.bayesian_inversion(...)
>>> print(result.nan_count)  # Check for NaN values
>>> print(result.statistics.Dc.median)  # Get the median of Dc
>>> result.plot_convergence()  # Visually inspect convergence
>>> result.cornerplot(nbins=15)  # Create a cornerplot
cornerplot(nbins=20)[source]#

Create a corner plot from the particle positions. For N parameters, a corner plot is the lower half of an N x N grid of subplots. On the diagonal, the histogram of the marginal posterior over the i-th parameter is plotted. Below the diagonal, the (i,j)-th panel is a scatter plot of parameter i against parameter j, which shows the co-variance between the two quantities.

Parameters:

nbins (int) – The number of bins to use for the histograms on the diagonal of the corner plot

Returns:

fig – The figure object that can be manipulated after creation (e.g. to save to disk)

Return type:

matplotlib.pyplot.figure

plot_convergence()[source]#

Helper routine to visualise the convergence of the inversion. Convergence is achieved when the particles settle in an equilibrium position (i.e., they stop moving).

Returns:

fig – The figure object that can be manipulated after creation (e.g. to save to disk)

Return type:

matplotlib.pyplot.figure

sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#

Draw random samples from the posterior distribution. Each particle should represent one realisation of a plausible friction curve. For nsamples realisations requested by a user, it is more beneficial to vmap the forward simulations rather than compute them one by one.

The resulting friction curves can be used for validation purposes (does the posterior match the data?).

Examples

>>> result = solver.bayesian_inversion(...)
>>> samples, sample_results = result.sample(
...    solver=solver, y0=y0, t=t,
...    friction_constants=constants,
...    block_constants=block_constants,
...    nsamples=20, rng=42)
... )
Parameters:
  • solver (ODESolver) – The ODESolver class instantiated with the Forward model (typically the same one as used for the inversion)

  • y0 (Variables) – The initial values of friction and state

  • t (Array) – The time samples at which a solution is requested. This does not need to be the same as the time samples recorded in the experiment (from which the observed friction curve is obtained)

  • friction_constants (_Constants) – The friction and state model constants

  • block_constants (_BlockConstants) – The stress transfer constants

  • nsamples (int) – The number of random samples for which a solution should be computed

  • rng (None, int, jax.random.PRNGKey) – The random seed used to sample from the particle distribution. If None, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a new jax.random.PRNGKey is generated.

Returns:

  • samples (_Params) – The samples drawn from the posterior distribution

  • sample_results (Variables) – The forward simulation results corresponding with samples

chains: Float[Array, 'Nsteps Nparticles Nparams']#

The particle trajectories (to inspect convergence)

final_state: Float[Array, 'Nparticles Nparams']#

The final state of the particle swarm

log_likelihood: Float[Array, 'Nsteps']#

The log-likelihood evolution

log_params: RSFParams | CNSParams#

The particles representing the (log-valued) parameters, including all the iteration steps

nan_count: Float[Array, 'Nsteps']#

The number of NaNs encountered during each iteration step

statistics: RSFStatistics | CNSStatistics | None = None#

Pre-computed statistics of the parameters

class diabayes.typedefs.CNSConstants(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#

Bases: object

Z: Float#
phi_c: Float#
class diabayes.typedefs.CNSParams(*args, **kwargs)[source]#

Bases: Params

classmethod from_array(x)#
classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)#
to_array()#
Z: Float#
phi_c: Float#
class diabayes.typedefs.CNSStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics)[source]#

Bases: ParamStatistics

classmethod from_state(state)#
get(x: str) Statistics#
get_param_names() Tuple#
Dc: Statistics#
a: Statistics#
b: Statistics#
class diabayes.typedefs.FrictionModel(*args, **kwargs)[source]#

Bases: Protocol

class diabayes.typedefs.InertialSpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float, M: jaxtyping.Float)[source]#

Bases: object

M: Float#
k: Float#
v_lp: Float#
class diabayes.typedefs.ParamStatistics[source]#

Bases: Module

classmethod from_state(state)[source]#
get(x: str) Statistics[source]#
get_param_names() Tuple[source]#
class diabayes.typedefs.Params(*args, **kwargs)[source]#

Bases: Module

classmethod from_array(x)[source]#
classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)[source]#
to_array()[source]#
class diabayes.typedefs.RSFConstants(v0: jaxtyping.Float, mu0: jaxtyping.Float)[source]#

Bases: object

mu0: Float#
v0: Float#
class diabayes.typedefs.RSFParams(*args, **kwargs)[source]#

Bases: Params

classmethod from_array(x)#
classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)#
to_array()#
Dc: Float | Float[Array, '...']#
a: Float | Float[Array, '...']#
b: Float | Float[Array, '...']#
class diabayes.typedefs.RSFStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics, cov: jaxtyping.Float[Array, 'N N'])[source]#

Bases: ParamStatistics

classmethod from_state(state)#
get(x: str) Statistics#
get_param_names() Tuple#
Dc: Statistics#
a: Statistics#
b: Statistics#
cov: Float[Array, 'N N']#
class diabayes.typedefs.SpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float)[source]#

Bases: object

k: Float#
v_lp: Float#
class diabayes.typedefs.StateDict(keys: tuple[str, ...], vals: Array)[source]#

Bases: Module

A container class for state variables. A user would typically not interact with this class directly (only through Variables and Forward.set_initial_values).

Examples

>>> from diabayes.typedefs import StateDict
>>> import jax.numpy as jnp
>>> keys = ("x", "y")
>>> vals = jnp.array([1.0, -5.1])
>>> state_dict = StateDict(keys=keys, vals=vals)
replace_values(**kwargs) StateDict[source]#

Update the values of the state variables. Since immutability is required for JIT caching, this method will return a copy of the class with the updated values.

Examples

>>> from diabayes.typedefs import StateDict
>>> import jax.numpy as jnp
>>> state_dict = StateDict(keys=("x",), vals=jnp.asarray(1.0))
>>> state_dict = state_dict.replace_values({"x": jnp.asarray(2.0)})
Parameters:

**kwargs (dict) – Key-value pairs to update. These can overwrite existing values or create entirely new ones.

Returns:

state_dict – A copy of the current StateDict with updated values

Return type:

StateDict

keys: tuple[str, ...]#

A tuple of key names (strings) corresponding with the state variable names

vals: Array#

An array of values corresponding (in order) with the variables defined by keys

class diabayes.typedefs.StateEvolution(*args, **kwargs)[source]#

Bases: Protocol

class diabayes.typedefs.Statistics(mean, median, std, q5, q95)[source]#

Bases: NamedTuple

count(value, /)#

Return number of occurrences of value.

classmethod from_array(x: Float[Array, '...'])[source]#
get(x: str) Float[source]#
index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

mean: Float#

Alias for field number 0

median: Float#

Alias for field number 1

q5: Float#

Alias for field number 3

q95: Float#

Alias for field number 4

std: Float#

Alias for field number 2

class diabayes.typedefs.StressTransfer(*args, **kwargs)[source]#

Bases: Protocol[BC]

class diabayes.typedefs.Variables(mu: Array, state: StateDict)[source]#

Bases: Module

The main container class for variables (friction and state). It contains additional convenience routines for mapping between this class object and JAX arrays.

Users would typically not instantiate a Variables object directly; instead, it is created through the Forward.set_initial_values method and retreived as Forward.variables.

classmethod from_array(x: Float[Array, '...'], keys: tuple[str, ...]) Variables[source]#

Import a JAX array to instantiate the class. The array x can either be an array of scalars (size n for n variables), or an array of time series of shape (n, t). The first element in the array (i.e., x[0]) is assumed to be the friction coefficient mu. The remaining elements are the state variables, matching the number and order of the keys tuple.

set_values(**kwargs) Variables[source]#

Set the values of friction and the state variables through key-value pair arguments. The state parameters are provided individually as key-value pairs as well. Because this class is immutable, set_values returns a copy of the class with updated values.

Examples

>>> state_dict = StateDict(keys=("x", "y"), vals=jnp.array([0.1, 0.2]))
>>> variables = Variables(mu=jnp.asarray(0.6), state=state_dict)
>>> variables = variables.set_values(mu=..., x=..., y=...)
to_array() Float[Array, '...'][source]#

Convert the container values to a JAX array. The order of the output follows the order of StateDict.keys, with the first item being the friction coefficient. For n state variables, the output is an array of shape (1+n,) for scalars, and (t, 1+n) for time series.

Examples

>>> scalars = Variables(...)
>>> mu = scalars.to_array()[0]  # friction is first element
>>> timeseries = Variables(...)
>>> mu = timeseries.to_array()[:, 0]
>>> all_final_values = timeseries.to_array()[-1, :]

Of course, in this example one could simply do scalars.mu and timeseries.mu to extract mu directly.

mu: Array#

The friction coefficient

state: StateDict#

A container with the various state variables

Module contents#