diabayes.typedefs.StateDict#

class diabayes.typedefs.StateDict(keys: tuple[str, ...], vals: Array)[source]#

A container class for state variables. A user would typically not interact with this class directly (only through Variables and Forward.set_initial_values).

Examples

>>> from diabayes.typedefs import StateDict
>>> import jax.numpy as jnp
>>> keys = ("x", "y")
>>> vals = jnp.array([1.0, -5.1])
>>> state_dict = StateDict(keys=keys, vals=vals)
__init__(keys: tuple[str, ...], vals: Array) None#

Methods

__init__(keys, vals)

replace_values(**kwargs)

Update the values of the state variables.

Attributes

keys

A tuple of key names (strings) corresponding with the state variable names

vals

An array of values corresponding (in order) with the variables defined by keys

replace_values(**kwargs) StateDict[source]#

Update the values of the state variables. Since immutability is required for JIT caching, this method will return a copy of the class with the updated values.

Examples

>>> from diabayes.typedefs import StateDict
>>> import jax.numpy as jnp
>>> state_dict = StateDict(keys=("x",), vals=jnp.asarray(1.0))
>>> state_dict = state_dict.replace_values({"x": jnp.asarray(2.0)})
Parameters:

**kwargs (dict) – Key-value pairs to update. These can overwrite existing values or create entirely new ones.

Returns:

state_dict – A copy of the current StateDict with updated values

Return type:

StateDict

keys: tuple[str, ...]#

A tuple of key names (strings) corresponding with the state variable names

vals: Array#

An array of values corresponding (in order) with the variables defined by keys