diabayes package#

Submodules#

diabayes.SVI module#

diabayes.SVI.compute_phi(x: Float[Array, 'N M'], gradp: Float[Array, 'N M'], gradq: Float[Array, 'N M']) Float[Array, 'N M'][source]#

Compute the Stein variational gradients for a set of particles (x) and the gradients of the log-likelihood function (gradp) and the log-prior (gradq).

Parameters:
  • x (ParticleArray) – The set of invertible parameters (“particles”) of shape (Nparticles, Ndimensions)

  • gradp (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)

  • gradq (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)

Returns:

grad_x – The directional gradients for the particle updates, e.g. new_x = x + step_size * grad_x. Has the same shape as x and gradp, gradq.

Return type:

Gradients

diabayes.SVI.mapped_log_likelihood(log_params: Float[Array, '...'], mu_obs: Float[Array, 'Nt'], noise_std: Float, v0: Float, forward_fn: Callable[[Float[Array, '2'], Float[Array, '...']], Variables]) Float#

Vectorized version of _log_likelihood. Takes similar arguments as _log_likelihood but with additional array axes over which _log_likelihood is mapped.

diabayes.forward_models module#

class diabayes.forward_models.Forward(friction_model: FrictionModel, state_evolution: StateEvolution, block_type: Callable[[Float, Variables, SpringBlockConstants], Float])[source]#

Bases: object

The Forward class assembles the various components that comprise a forward model, such that Forward.__call__ takes some variables and returns the rate of change of these variables, i.e.:

\[\frac{\mathrm{d} \vec{X}}{\mathrm{d}t} = f \left( \vec{X} \right)\]

This forward ODE can then be solved by any ODE solver.

The Forward class is instantiated by providing a friction model (of type FrictionModel), a “state” evolution law (of type StateEvolution), and a stress transfer model.

Examples

>>> from diabayes.forward_models import ageing_law, rsf, springblock, Forward
>>> foward_model = Forward(rsf, ageing_law, springblock)
>>> X_dot = forward_model(variables=..., params=..., friction_constants=..., block_constants=...)
diabayes.forward_models.ageing_law(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float[source]#

The conventional ageing law state evolution formulation:

\[\frac{\mathrm{d}\theta}{\mathrm{d}t} = 1 - \frac{v \theta}{D_c}\]
Parameters:
  • v (Float) – Instantaneous fault slip rate [m/s].

  • variables (Variables) – The friction coefficient mu and state parameter theta

  • params (RSFParams) – The rate-and-state parameters a, b, and D_c

  • constants (RSFConstants) – The constant parameters mu0 and v0

Returns:

dtheta – The rate of change of the state variable theta [s/s]

Return type:

Float

diabayes.forward_models.rsf(variables: Variables, params: RSFParams, constants: RSFConstants) Float[source]#

The classical rate-and-state friction law, given by:

\[v(\mu, \theta) = v_0 \exp \left( \frac{1}{a} \left[ \mu - \mu_0 - b \log \left( \frac{v_0 \theta}{D_c} \right) \right] \right)\]
Parameters:
  • variables (Variables) – The friction coefficient mu and state parameter theta

  • params (RSFParams) – The rate-and-state parameters a, b, and D_c

  • constants (RSFConstants) – The constant parameters mu0 and v0

Returns:

v – The instantaneous slip rate in the same units as v0

Return type:

Float

diabayes.forward_models.springblock(v: Float, variables: Variables, constants: SpringBlockConstants) Float[source]#

A conventional (non-inertial) spring-block loading formulation:

\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = k \left ( v_{lp} - v(t) \right)\]
Parameters:
  • v (Float) – Instantaneous fault slip rate [m/s].

  • variables (Variables) – The friction coefficient mu and state parameter theta. This argument is not used, but included for call signature consistency.

  • constants (SpringBlockConstants) – The constant parameters stiffness k (units of “friction per metre”) and load-point velocity v_lp (same units as v).

Returns:

dmu – The rate of change of the friction coefficient mu [1/s]

Return type:

Float

diabayes.solver module#

class diabayes.solver.ODESolver(forward_model: Forward, rtol: float = 1e-10, atol: float = 1e-12, checkpoints: int = 100)[source]#

Bases: object

atol: float#
bayesian_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], noise_std: Float, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants, Nparticles: int = 1000, Nsteps: int = 150, rng: None | int | Array = None)[source]#
checkpoints: int#
forward_model: Forward#
learning_rate: float = 0.01#
max_likelihood_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants, verbose: bool = False) Solution[source]#
rtol: float#
solve_forward(t: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants) Any[source]#

Solve a forward problem using SciPy’s solve_ivp routine. While this routine doesn’t propagate any gradients, it is much faster to initialise and to perform a single forward run. Hence for playing around with different parameters, it is preferred over a JITed JAX implementation.

Parameters:
  • t (Float[Array, "Nt"]) – A vector of time samples where a solution is requested

  • y0 (Variables) – The initial values (fricton and state) wrapped in a Variables container.

  • params (_Params) – The (invertible) parameters that govern the dynamics, wrapped in a Params container.

  • friction_constants (_Constants) – A container object containing the friction constants

  • block_constants (_BlockConstants) – A container object containing the block constants

Returns:

result – Solution time series of friction and state

Return type:

Variables

diabayes.typedefs module#

class diabayes.typedefs.BayesianSolution(chains: jaxtyping.Float[Array, 'Nsteps Nparticles Nparams'], log_likelihood: jaxtyping.Float[Array, 'Nsteps'], nan_count: jaxtyping.Float[Array, 'Nsteps'])[source]#

Bases: object

chains: Float[Array, 'Nsteps Nparticles Nparams']#
cornerplot(nbins=20)[source]#
log_likelihood: Float[Array, 'Nsteps']#
nan_count: Float[Array, 'Nsteps']#
plot_convergence()[source]#
sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#
statistics: RSFStatistics | CNSStatistics | None = None#
class diabayes.typedefs.CNSConstants(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#

Bases: object

Z: Float#
phi_c: Float#
class diabayes.typedefs.CNSParams(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#

Bases: Container

Z: Float#
phi_c: Float#
tree_flatten()[source]#
class diabayes.typedefs.CNSStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics)[source]#

Bases: ParamStatistics

Dc: Statistics#
a: Statistics#
b: Statistics#
class diabayes.typedefs.Container[source]#

Bases: object

classmethod from_array(x: Iterable)[source]#
to_array()[source]#
tree_flatten()[source]#
classmethod tree_unflatten(aux_data, children)[source]#
class diabayes.typedefs.FrictionModel(*args, **kwargs)[source]#

Bases: Protocol

diabayes.typedefs.Gradients#

Inversion results#

class diabayes.typedefs.ParamStatistics[source]#

Bases: Module

classmethod from_state(state)[source]#
get(x: str) Statistics[source]#
get_param_names() Tuple[source]#
class diabayes.typedefs.Particles(_particles: Tuple[diabayes.typedefs.RSFParams, ...] | Tuple[diabayes.typedefs.CNSParams, ...])[source]#

Bases: object

classmethod from_array(x: Iterable)[source]#
classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)[source]#
to_array()[source]#
tree_flatten()[source]#
classmethod tree_unflatten(aux_data, children)[source]#
class diabayes.typedefs.RSFConstants(v0: jaxtyping.Float, mu0: jaxtyping.Float)[source]#

Bases: object

mu0: Float#
v0: Float#
class diabayes.typedefs.RSFParams(a: jaxtyping.Float, b: jaxtyping.Float, Dc: jaxtyping.Float)[source]#

Bases: Container

Dc: Float#
a: Float#
b: Float#
tree_flatten()[source]#
class diabayes.typedefs.RSFParticles(_particles: Tuple[diabayes.typedefs.RSFParams, ...] | Tuple[diabayes.typedefs.CNSParams, ...])[source]#

Bases: Particles

classmethod tree_unflatten(aux_data, children)[source]#
class diabayes.typedefs.RSFStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics, cov: jaxtyping.Float[Array, 'N N'])[source]#

Bases: ParamStatistics

Dc: Statistics#
a: Statistics#
b: Statistics#
cov: Float[Array, 'N N']#
class diabayes.typedefs.SpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float)[source]#

Bases: object

k: Float#
v_lp: Float#
class diabayes.typedefs.StateEvolution(*args, **kwargs)[source]#

Bases: Protocol

class diabayes.typedefs.Statistics(mean, median, std, q5, q95)[source]#

Bases: NamedTuple

classmethod from_array(x: Float[Array, '...'])[source]#
get(x: str) Float[source]#
mean: Float#

Alias for field number 0

median: Float#

Alias for field number 1

q5: Float#

Alias for field number 3

q95: Float#

Alias for field number 4

std: Float#

Alias for field number 2

class diabayes.typedefs.Variables(mu: jaxtyping.Float, state: jaxtyping.Float)[source]#

Bases: Container

mu: Float#
state: Float#
tree_flatten()[source]#

Module contents#