diabayes package#
Submodules#
diabayes.SVI module#
- diabayes.SVI.compute_phi(x: Float[Array, 'N M'], gradp: Float[Array, 'N M'], gradq: Float[Array, 'N M']) Float[Array, 'N M'] [source]#
Compute the Stein variational gradients for a set of particles (x) and the gradients of the log-likelihood function (gradp) and the log-prior (gradq).
- Parameters:
x (ParticleArray) – The set of invertible parameters (“particles”) of shape (Nparticles, Ndimensions)
gradp (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)
gradq (GradientArray) – The gradients of the log-likelihood (gradp) and the log-prior (gradq) with respect to the invertible parameters. Has a shape (Nparticles, Ndimensions)
- Returns:
grad_x – The directional gradients for the particle updates, e.g. new_x = x + step_size * grad_x. Has the same shape as x and gradp, gradq.
- Return type:
Gradients
- diabayes.SVI.mapped_log_likelihood(log_params: Float[Array, '...'], mu_obs: Float[Array, 'Nt'], noise_std: Float, v0: Float, forward_fn: Callable[[Float[Array, '2'], Float[Array, '...']], Variables]) Float #
Vectorized version of _log_likelihood. Takes similar arguments as _log_likelihood but with additional array axes over which _log_likelihood is mapped.
diabayes.forward_models module#
- class diabayes.forward_models.Forward(friction_model: FrictionModel, state_evolution: StateEvolution, block_type: Callable[[Float, Variables, SpringBlockConstants], Float])[source]#
Bases:
object
The Forward class assembles the various components that comprise a forward model, such that Forward.__call__ takes some variables and returns the rate of change of these variables, i.e.:
\[\frac{\mathrm{d} \vec{X}}{\mathrm{d}t} = f \left( \vec{X} \right)\]This forward ODE can then be solved by any ODE solver.
The Forward class is instantiated by providing a friction model (of type FrictionModel), a “state” evolution law (of type StateEvolution), and a stress transfer model.
Examples
>>> from diabayes.forward_models import ageing_law, rsf, springblock, Forward >>> foward_model = Forward(rsf, ageing_law, springblock) >>> X_dot = forward_model(variables=..., params=..., friction_constants=..., block_constants=...)
- diabayes.forward_models.ageing_law(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float [source]#
The conventional ageing law state evolution formulation:
\[\frac{\mathrm{d}\theta}{\mathrm{d}t} = 1 - \frac{v \theta}{D_c}\]- Parameters:
v (Float) – Instantaneous fault slip rate [m/s].
variables (Variables) – The friction coefficient mu and state parameter theta
params (RSFParams) – The rate-and-state parameters a, b, and D_c
constants (RSFConstants) – The constant parameters mu0 and v0
- Returns:
dtheta – The rate of change of the state variable theta [s/s]
- Return type:
Float
- diabayes.forward_models.rsf(variables: Variables, params: RSFParams, constants: RSFConstants) Float [source]#
The classical rate-and-state friction law, given by:
\[v(\mu, \theta) = v_0 \exp \left( \frac{1}{a} \left[ \mu - \mu_0 - b \log \left( \frac{v_0 \theta}{D_c} \right) \right] \right)\]- Parameters:
variables (Variables) – The friction coefficient mu and state parameter theta
params (RSFParams) – The rate-and-state parameters a, b, and D_c
constants (RSFConstants) – The constant parameters mu0 and v0
- Returns:
v – The instantaneous slip rate in the same units as v0
- Return type:
Float
- diabayes.forward_models.springblock(v: Float, variables: Variables, constants: SpringBlockConstants) Float [source]#
A conventional (non-inertial) spring-block loading formulation:
\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = k \left ( v_{lp} - v(t) \right)\]- Parameters:
v (Float) – Instantaneous fault slip rate [m/s].
variables (Variables) – The friction coefficient mu and state parameter theta. This argument is not used, but included for call signature consistency.
constants (SpringBlockConstants) – The constant parameters stiffness k (units of “friction per metre”) and load-point velocity v_lp (same units as v).
- Returns:
dmu – The rate of change of the friction coefficient mu [1/s]
- Return type:
Float
diabayes.solver module#
- class diabayes.solver.ODESolver(forward_model: Forward, rtol: float = 1e-10, atol: float = 1e-12, checkpoints: int = 100)[source]#
Bases:
object
- atol: float#
- bayesian_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], noise_std: Float, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants, Nparticles: int = 1000, Nsteps: int = 150, rng: None | int | Array = None)[source]#
- checkpoints: int#
- learning_rate: float = 0.01#
- max_likelihood_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants, verbose: bool = False) Solution [source]#
- rtol: float#
- solve_forward(t: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants) Any [source]#
Solve a forward problem using SciPy’s solve_ivp routine. While this routine doesn’t propagate any gradients, it is much faster to initialise and to perform a single forward run. Hence for playing around with different parameters, it is preferred over a JITed JAX implementation.
- Parameters:
t (Float[Array, "Nt"]) – A vector of time samples where a solution is requested
y0 (Variables) – The initial values (fricton and state) wrapped in a Variables container.
params (_Params) – The (invertible) parameters that govern the dynamics, wrapped in a Params container.
friction_constants (_Constants) – A container object containing the friction constants
block_constants (_BlockConstants) – A container object containing the block constants
- Returns:
result – Solution time series of friction and state
- Return type:
diabayes.typedefs module#
- class diabayes.typedefs.BayesianSolution(chains: jaxtyping.Float[Array, 'Nsteps Nparticles Nparams'], log_likelihood: jaxtyping.Float[Array, 'Nsteps'], nan_count: jaxtyping.Float[Array, 'Nsteps'])[source]#
Bases:
object
- chains: Float[Array, 'Nsteps Nparticles Nparams']#
- log_likelihood: Float[Array, 'Nsteps']#
- nan_count: Float[Array, 'Nsteps']#
- sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#
- statistics: RSFStatistics | CNSStatistics | None = None#
- class diabayes.typedefs.CNSConstants(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#
Bases:
object
- Z: Float#
- phi_c: Float#
- class diabayes.typedefs.CNSParams(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#
Bases:
Container
- Z: Float#
- phi_c: Float#
- class diabayes.typedefs.CNSStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics)[source]#
Bases:
ParamStatistics
- Dc: Statistics#
- a: Statistics#
- b: Statistics#
- class diabayes.typedefs.ParamStatistics[source]#
Bases:
Module
- get(x: str) Statistics [source]#
- class diabayes.typedefs.Particles(_particles: Tuple[diabayes.typedefs.RSFParams, ...] | Tuple[diabayes.typedefs.CNSParams, ...])[source]#
Bases:
object
- class diabayes.typedefs.RSFConstants(v0: jaxtyping.Float, mu0: jaxtyping.Float)[source]#
Bases:
object
- mu0: Float#
- v0: Float#
- class diabayes.typedefs.RSFParams(a: jaxtyping.Float, b: jaxtyping.Float, Dc: jaxtyping.Float)[source]#
Bases:
Container
- Dc: Float#
- a: Float#
- b: Float#
- class diabayes.typedefs.RSFParticles(_particles: Tuple[diabayes.typedefs.RSFParams, ...] | Tuple[diabayes.typedefs.CNSParams, ...])[source]#
Bases:
Particles
- class diabayes.typedefs.RSFStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics, cov: jaxtyping.Float[Array, 'N N'])[source]#
Bases:
ParamStatistics
- Dc: Statistics#
- a: Statistics#
- b: Statistics#
- cov: Float[Array, 'N N']#
- class diabayes.typedefs.SpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float)[source]#
Bases:
object
- k: Float#
- v_lp: Float#