diabayes package#
Submodules#
diabayes.SVI module#
- diabayes.SVI.compute_phi(x: Float[Array, 'N M'], gradp: Float[Array, 'N M'], gradq: Float[Array, 'N M']) Float[Array, 'N M'] [source]#
Compute the Stein variational gradients for a set of particles (
x
) and the gradients of the log-likelihood function (gradp
) and the log-prior (gradq
).- Parameters:
x (ParticleArray) – The set of invertible parameters (“particles”) of shape
(Nparticles, Ndimensions)
gradp (GradientArray) – The gradients of the log-likelihood (
gradp
) and the log-prior (gradq
) with respect to the invertible parameters. Has a shape(Nparticles, Ndimensions)
gradq (GradientArray) – The gradients of the log-likelihood (
gradp
) and the log-prior (gradq
) with respect to the invertible parameters. Has a shape(Nparticles, Ndimensions)
- Returns:
grad_x – The directional gradients for the particle updates, e.g.
new_x = x + step_size * grad_x
. Has the same shape asx
andgradp, gradq
.- Return type:
Gradients
- diabayes.SVI.mapped_log_likelihood(log_params: RSFParams | CNSParams, mu_obs: Float[Array, 'Nt'], noise_std: Float, forward_fn: Callable[[RSFParams | CNSParams], Variables]) Float #
Vectorized version of _log_likelihood. Takes similar arguments as _log_likelihood but with additional array axes over which _log_likelihood is mapped.
diabayes.forward_models module#
- class diabayes.forward_models.Forward(friction_model: FrictionModel, state_evolution: Dict[str, Callable], stress_transfer: StressTransfer[BC])[source]#
Bases:
Generic
[BC
]The
Forward
class assembles the various components that comprise a forward model, such thatForward.__call__
takes some variables and returns the rate of change of these variables\[\frac{\mathrm{d} \vec{X}}{\mathrm{d}t} = f \left( \vec{X} \right)\]This forward ODE can then be solved by any ODE solver.
The
Forward
class is instantiated by providing a friction model, a “state” evolution law, and a stress transfer model.Examples
>>> from diabayes.forward_models import ageing_law, rsf, springblock, Forward >>> foward_model = Forward(rsf, {"theta": ageing_law}, springblock) >>> X_dot = forward_model(variables=..., params=..., friction_constants=..., block_constants=...)
- call(t: Float, variables: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: BC) Variables #
Calculate the rate of change of the variables, defining the ODE to be solved.
- Parameters:
t (Float) – Current value of time [s]
variables (diabayes.Variables) – The instantaneous values of the variables for which the time derivative will be computed
params (_Params) – The forward model (invertible) parameters
friction_constants (_Constants) – The forward model constants
block_constants (_BlockConstants) – The constants associated with the stress transfer
- Returns:
dvars – The instantaneous rate of change of the
variables
- Return type:
- diabayes.forward_models.ageing_law(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float [source]#
The conventional ageing law state evolution formulation
\[\frac{\mathrm{d}\theta}{\mathrm{d}t} = 1 - \frac{v \theta}{D_c}\]- Parameters:
v (Float) – Instantaneous fault slip rate [m/s].
variables (Variables) – The friction coefficient
mu
and state parametertheta
params (RSFParams) – The rate-and-state parameters, including
D_c
constants (RSFConstants) – The constant parameters
mu0
andv0
(not used)
- Returns:
dtheta – The rate of change of the state variable [s/s]
- Return type:
Float
- diabayes.forward_models.inertial_springblock(t: Float, v: Float, v_partials: Variables, variables: Variables, dstate: Float[Array, '...'], constants: InertialSpringBlockConstants) Float [source]#
An inertial spring-block loading formulation
\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = \left[ \frac{\partial v}{\partial \mu} \right]^{-1} \left( \frac{1}{M} \left[ k \left( v_{lp} t - x \right) - \mu \right] - \frac{\partial v}{\partial \theta} \frac{\mathrm{d} \theta}{\mathrm{d} t} - \dots \right)\]The acceleration term in the classical inertial spring-block formulation is decomposed into its partial derivatives, avoiding the need for solving a second-order ODE. These partial derivatives (
v_partials
) are computed using the JAX autodiff framework.Notes
This formulation is rather stiff, and for certain parameter values could lead to extremely small time steps necessary to maintain numerical accuracy. It is recommended to use a conventional (non-inenrtial)
springblock
formulation for basic velocity-steps and slide-hold-slide simuilations. Inertia is only really needed for stick-slip simulations.- Parameters:
t (Float) – Current value of time [s]
v (Float) – Instantaneous fault slip rate [m/s]
v_partials (Variables) – The partial derivatives of slip rate to the relevant variables (friction and state variables)
variables (Variables) – The instantaneous values of the variables: friction (
mu
), slip (slip
), and other state variables (not used)dstate (Array) – The time derivatives of the state variables. The radiation term is
v_partials @ dstate
(excludingmu
)constants (InertialSpringBlockConstants) – The spring-block constants containing the mass term
M
, the stiffnessk
, and the load-point velocityv_lp
- Returns:
dmu – The rate of change of the friction coefficient [1/s]
- Return type:
Float
- diabayes.forward_models.rsf(variables: Variables, params: RSFParams, constants: RSFConstants) Float [source]#
The classical rate-and-state friction law
\[v(\mu, \theta) = v_0 \exp \left( \frac{1}{a} \left[ \mu - \mu_0 - b \log \left( \frac{v_0 \theta}{D_c} \right) \right] \right)\]- Parameters:
variables (Variables) – The friction coefficient
mu
and state parametertheta
params (RSFParams) – The rate-and-state parameters
a
,b
, andD_c
constants (RSFConstants) – The constant parameters
mu0
andv0
- Returns:
v – The instantaneous slip rate in the same units as
v0
- Return type:
Float
- diabayes.forward_models.slip_rate(v: Float, variables: Variables, params: RSFParams, constants: RSFConstants) Float [source]#
Evolve slip from slip rate
\[\frac{\mathrm{d} x}{\mathrm{d} t} = v\]- Parameters:
v (Float) – Instantaneous fault slip rate [m/s]
variables (Variables) – The friction coefficient
mu
and state parametertheta
(not used)params (RSFParams) – The rate-and-state parameters (not used)
constants (RSFConstants) – The constant parameters
mu0
andv0
(not used)
- Returns:
v – The rate of change of slip [m/s]
- Return type:
Float
- diabayes.forward_models.springblock(t: Float, v: Float, v_partials: Variables, variables: Variables, dstate: Float[Array, '...'], constants: SpringBlockConstants) Float [source]#
A conventional (non-inertial) spring-block loading formulation
\[\frac{\mathrm{d} \mu}{\mathrm{d} t} = k \left ( v_{lp} - v(t) \right)\]- Parameters:
v (Float) – Instantaneous fault slip rate [m/s].
variables (Variables) – The instantaneous variables. This argument is not used, but included for call signature consistency.
constants (SpringBlockConstants) – The constant parameters stiffness
k
(units of “friction per metre”) and load-point velocityv_lp
(same units asv
).
- Returns:
dmu – The rate of change of the friction coefficient [1/s]
- Return type:
Float
diabayes.solver module#
- class diabayes.solver.ODESolver(forward_model: Forward, rtol: float = 1e-08, atol: float = 1e-12, checkpoints: int = 100)[source]#
Bases:
object
The main solver class that contains forward and inverse modelling methods.
- bayesian_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], noise_std: Float, y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, Nparticles: int = 1000, Nsteps: int = 150, rng: None | int | Array = None) BayesianSolution [source]#
A Bayesian inversion routine using the Stein Variational Inference method.
- Parameters:
t (Array) – A vector or time values (in units of seconds). The time steps do not need to be uniform
mu (Array) – The observed friction curve sampled at
t
noise_std (Float) – An estimate of the standard deviation of the noise in the measured friction curve. A conservative value is recommended, i.e. if the noise has a standard deviation of $10^{-3}$, a good starting point would be to set
noise_std = 0.5e-3
y0 (Variables) – The initial values for the modelled friction and any state variables
params (_Params) – The initial guess for the invertible parameters that characterise the forward problem. It is recommended to use the result from
ODESolver.max_likelihood_inversion
. The prior distribution will be centered around this initial guessfriction_constants (_Constants) – The non-invertible constants that characterise the forward problem
block_constants (_BlockConstants) – The stress transfer constants (e.g. stiffness and loading rate)
Nparticles (int) – The number of particles (= posterior samples) to include. A higher value gives a more accurate estimation of the posterior distribution, at a higher computational cost.
Nsteps (int) – The number of iterations before convergence is expected to be achieved. It is recommended to start with a value of 100 and then see if an equilibrium was actually achieved. Increasing this value beyond the point of equilibrium does not do anything.
rng (None, int, jax.random.PRNGKey) – The random seed used to initialise the particle swarm distribution. If
None
, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a newjax.random.PRNGKey
is generated.
- Returns:
result – The result of the inversion incapsulated in a
BayesianSolution
container, which provides access to the full convergence chains, as well as diagnostic and visualisation routines.- Return type:
diabayes.BayesianSolution
Notes
This method is currently only compatible with standard rate-and-state friction models.
- max_likelihood_inversion(t: Float[Array, 'Nt'], mu: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, verbose: bool = False) Solution [source]#
Minimises the least-squares residuals between the observed friction curve and the parameterised one, using the Levenberg-Marquardt algorithm.
- Parameters:
t (Array) – A vector or time values (in units of seconds). The time steps do not need to be uniform
mu (Array) – The observed friction curve sampled at
t
y0 (Variables) – The initial values for the modelled friction and any state variables
params (_Params) – The initial guess for the invertible parameters that characterise the forward problem. These need to be sufficiently close to the “true” values for the algorithm to converge
friction_constants (_Constants) – The non-invertible constants that characterise the forward problem
block_constants (_BlockConstants) – The stress transfer constants (e.g. stiffness and loading rate)
verbose (bool) – Whether or not to output detailed progress of the inversion. Defaults to
False
- Returns:
sol – The inversion result, including various diagnostics. The inverted parameter values can be accessed as
sol.values
- Return type:
optimistix.Solution
Notes
If an error is produced in the first iteration step, it is quite possible that the initial guess parameters were too far off from the “true” values (i.e., the mismatch between the observed and modelled friction curves is too large), breaking the Gauss-Newton step of the Levenberg-Marquardt algorithm. Initial manual tuning is recommended.
- solve_forward(t: Float[Array, 'Nt'], y0: Variables, params: RSFParams | CNSParams, friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, method: str = 'RK45') Any [source]#
Solve a forward problem using SciPy’s
solve_ivp
routine. While this routine doesn’t propagate any gradients, it is much faster to initialise and to perform a single forward run. Hence for playing around with different parameters, it is preferred over a JITed JAX implementation.- Parameters:
t (Float[Array, "Nt"]) – A vector of time samples where a solution is requested
y0 (Variables) – The initial values (fricton and state) wrapped in a Variables container.
params (_Params) – The (invertible) parameters that govern the dynamics, wrapped in a Params container.
friction_constants (_Constants) – A container object containing the friction constants
block_constants (_BlockConstants) – A container object containing the block constants
- Returns:
result – Solution time series of friction and state
- Return type:
- atol: float#
The absolute tolerance used by the ODE solver
- checkpoints: int#
The number of checkpoints to use to compute the (adjoint) gradients through the ODE routine. A higher number increases stability and speed, at the expense of more GPU memory
- forward_model: Forward#
An instantiated
diabayes.Forward
class, including friction law, state evolution equations, and stress transfer equation
- learning_rate: float = 0.01#
The initial learning rate provided to the Adam algorithm for the Stein Variational Inference. The default value of
1e-2
seems like a sensible choice for most models.
- rtol: float#
The absolute tolerance used by the ODE solver
diabayes.typedefs module#
- class diabayes.typedefs.BC#
Function signatures#
alias of TypeVar(‘BC’, bound=
SpringBlockConstants
|InertialSpringBlockConstants
, contravariant=True)
- class diabayes.typedefs.BayesianSolution(log_params: RSFParams | CNSParams, log_likelihood: Float[Array, 'Nsteps'], nan_count: Float[Array, 'Nsteps'])[source]#
Bases:
object
The result of a Bayesian inversion. This data class stores the particle trajectories (equivalent of MCMC chains), the final equilibrum state, and various statistcs, as well as helper routines to diagnose or visualise the results.
Examples
>>> result = solver.bayesian_inversion(...) >>> print(result.nan_count) # Check for NaN values >>> print(result.statistics.Dc.median) # Get the median of Dc >>> result.plot_convergence() # Visually inspect convergence >>> result.cornerplot(nbins=15) # Create a cornerplot
- cornerplot(nbins=20)[source]#
Create a corner plot from the particle positions. For N parameters, a corner plot is the lower half of an N x N grid of subplots. On the diagonal, the histogram of the marginal posterior over the i-th parameter is plotted. Below the diagonal, the (i,j)-th panel is a scatter plot of parameter i against parameter j, which shows the co-variance between the two quantities.
- Parameters:
nbins (int) – The number of bins to use for the histograms on the diagonal of the corner plot
- Returns:
fig – The figure object that can be manipulated after creation (e.g. to save to disk)
- Return type:
matplotlib.pyplot.figure
- plot_convergence()[source]#
Helper routine to visualise the convergence of the inversion. Convergence is achieved when the particles settle in an equilibrium position (i.e., they stop moving).
- Returns:
fig – The figure object that can be manipulated after creation (e.g. to save to disk)
- Return type:
matplotlib.pyplot.figure
- sample(solver, y0: Variables, t: Float[Array, 'Nt'], friction_constants: RSFConstants, block_constants: SpringBlockConstants | InertialSpringBlockConstants, nsamples: int = 10, rng: None | int | Array = None)[source]#
Draw random samples from the posterior distribution. Each particle should represent one realisation of a plausible friction curve. For
nsamples
realisations requested by a user, it is more beneficial to vmap the forward simulations rather than compute them one by one.The resulting friction curves can be used for validation purposes (does the posterior match the data?).
Examples
>>> result = solver.bayesian_inversion(...) >>> samples, sample_results = result.sample( ... solver=solver, y0=y0, t=t, ... friction_constants=constants, ... block_constants=block_constants, ... nsamples=20, rng=42) ... )
- Parameters:
solver (ODESolver) – The
ODESolver
class instantiated with theForward
model (typically the same one as used for the inversion)y0 (Variables) – The initial values of friction and state
t (Array) – The time samples at which a solution is requested. This does not need to be the same as the time samples recorded in the experiment (from which the observed friction curve is obtained)
friction_constants (_Constants) – The friction and state model constants
block_constants (_BlockConstants) – The stress transfer constants
nsamples (int) – The number of random samples for which a solution should be computed
rng (None, int, jax.random.PRNGKey) – The random seed used to sample from the particle distribution. If
None
, the current time will be used as a seed, which leads to different result for each realisation. When an integer is provided, a newjax.random.PRNGKey
is generated.
- Returns:
samples (_Params) – The samples drawn from the posterior distribution
sample_results (Variables) – The forward simulation results corresponding with
samples
- chains: Float[Array, 'Nsteps Nparticles Nparams']#
The particle trajectories (to inspect convergence)
- final_state: Float[Array, 'Nparticles Nparams']#
The final state of the particle swarm
- log_likelihood: Float[Array, 'Nsteps']#
The log-likelihood evolution
- log_params: RSFParams | CNSParams#
The particles representing the (log-valued) parameters, including all the iteration steps
- nan_count: Float[Array, 'Nsteps']#
The number of NaNs encountered during each iteration step
- statistics: RSFStatistics | CNSStatistics | None = None#
Pre-computed statistics of the parameters
- class diabayes.typedefs.CNSConstants(phi_c: jaxtyping.Float, Z: jaxtyping.Float)[source]#
Bases:
object
- Z: Float#
- phi_c: Float#
- class diabayes.typedefs.CNSParams(*args, **kwargs)[source]#
Bases:
Params
- classmethod from_array(x)#
- classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)#
- to_array()#
- Z: Float#
- phi_c: Float#
- class diabayes.typedefs.CNSStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics)[source]#
Bases:
ParamStatistics
- classmethod from_state(state)#
- get(x: str) Statistics #
- get_param_names() Tuple #
- Dc: Statistics#
- a: Statistics#
- b: Statistics#
- class diabayes.typedefs.InertialSpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float, M: jaxtyping.Float)[source]#
Bases:
object
- M: Float#
- k: Float#
- v_lp: Float#
- class diabayes.typedefs.ParamStatistics[source]#
Bases:
Module
- get(x: str) Statistics [source]#
- class diabayes.typedefs.RSFConstants(v0: jaxtyping.Float, mu0: jaxtyping.Float)[source]#
Bases:
object
- mu0: Float#
- v0: Float#
- class diabayes.typedefs.RSFParams(*args, **kwargs)[source]#
Bases:
Params
- classmethod from_array(x)#
- classmethod generate(N: int, loc: Float[Array, 'M'], scale: Float[Array, 'M'], key: Array)#
- to_array()#
- Dc: Float | Float[Array, '...']#
- a: Float | Float[Array, '...']#
- b: Float | Float[Array, '...']#
- class diabayes.typedefs.RSFStatistics(a: diabayes.typedefs.Statistics, b: diabayes.typedefs.Statistics, Dc: diabayes.typedefs.Statistics, cov: jaxtyping.Float[Array, 'N N'])[source]#
Bases:
ParamStatistics
- classmethod from_state(state)#
- get(x: str) Statistics #
- get_param_names() Tuple #
- Dc: Statistics#
- a: Statistics#
- b: Statistics#
- cov: Float[Array, 'N N']#
- class diabayes.typedefs.SpringBlockConstants(k: jaxtyping.Float, v_lp: jaxtyping.Float)[source]#
Bases:
object
- k: Float#
- v_lp: Float#
- class diabayes.typedefs.StateDict(keys: tuple[str, ...], vals: Array)[source]#
Bases:
Module
A container class for state variables. A user would typically not interact with this class directly (only through
Variables
andForward.set_initial_values
).Examples
>>> from diabayes.typedefs import StateDict >>> import jax.numpy as jnp >>> keys = ("x", "y") >>> vals = jnp.array([1.0, -5.1]) >>> state_dict = StateDict(keys=keys, vals=vals)
- replace_values(**kwargs) StateDict [source]#
Update the values of the state variables. Since immutability is required for JIT caching, this method will return a copy of the class with the updated values.
Examples
>>> from diabayes.typedefs import StateDict >>> import jax.numpy as jnp >>> state_dict = StateDict(keys=("x",), vals=jnp.asarray(1.0)) >>> state_dict = state_dict.replace_values({"x": jnp.asarray(2.0)})
- Parameters:
**kwargs (dict) – Key-value pairs to update. These can overwrite existing values or create entirely new ones.
- Returns:
state_dict – A copy of the current StateDict with updated values
- Return type:
- keys: tuple[str, ...]#
A tuple of key names (strings) corresponding with the state variable names
- vals: Array#
An array of values corresponding (in order) with the variables defined by
keys
- class diabayes.typedefs.Statistics(mean, median, std, q5, q95)[source]#
Bases:
NamedTuple
- count(value, /)#
Return number of occurrences of value.
- index(value, start=0, stop=9223372036854775807, /)#
Return first index of value.
Raises ValueError if the value is not present.
- mean: Float#
Alias for field number 0
- median: Float#
Alias for field number 1
- q5: Float#
Alias for field number 3
- q95: Float#
Alias for field number 4
- std: Float#
Alias for field number 2
- class diabayes.typedefs.Variables(mu: Array, state: StateDict)[source]#
Bases:
Module
The main container class for variables (friction and state). It contains additional convenience routines for mapping between this class object and JAX arrays.
Users would typically not instantiate a
Variables
object directly; instead, it is created through theForward.set_initial_values
method and retreived asForward.variables
.- classmethod from_array(x: Float[Array, '...'], keys: tuple[str, ...]) Variables [source]#
Import a JAX array to instantiate the class. The array x can either be an array of scalars (size n for n variables), or an array of time series of shape (n, t). The first element in the array (i.e., x[0]) is assumed to be the friction coefficient mu. The remaining elements are the state variables, matching the number and order of the keys tuple.
- set_values(**kwargs) Variables [source]#
Set the values of friction and the state variables through key-value pair arguments. The state parameters are provided individually as key-value pairs as well. Because this class is immutable,
set_values
returns a copy of the class with updated values.Examples
>>> state_dict = StateDict(keys=("x", "y"), vals=jnp.array([0.1, 0.2])) >>> variables = Variables(mu=jnp.asarray(0.6), state=state_dict) >>> variables = variables.set_values(mu=..., x=..., y=...)
- to_array() Float[Array, '...'] [source]#
Convert the container values to a JAX array. The order of the output follows the order of StateDict.keys, with the first item being the friction coefficient. For
n
state variables, the output is an array of shape(1+n,)
for scalars, and(t, 1+n)
for time series.Examples
>>> scalars = Variables(...) >>> mu = scalars.to_array()[0] # friction is first element >>> timeseries = Variables(...) >>> mu = timeseries.to_array()[:, 0] >>> all_final_values = timeseries.to_array()[-1, :]
Of course, in this example one could simply do
scalars.mu
andtimeseries.mu
to extractmu
directly.
- mu: Array#
The friction coefficient