diabayes.typedefs.StateDict#
- class diabayes.typedefs.StateDict(keys: tuple[str, ...], vals: Array)[source]#
A container class for state variables. A user would typically not interact with this class directly (only through
Variables
andForward.set_initial_values
).Examples
>>> from diabayes.typedefs import StateDict >>> import jax.numpy as jnp >>> keys = ("x", "y") >>> vals = jnp.array([1.0, -5.1]) >>> state_dict = StateDict(keys=keys, vals=vals)
- __init__(keys: tuple[str, ...], vals: Array) None #
Methods
__init__
(keys, vals)replace_values
(**kwargs)Update the values of the state variables.
Attributes
A tuple of key names (strings) corresponding with the state variable names
An array of values corresponding (in order) with the variables defined by
keys
- replace_values(**kwargs) StateDict [source]#
Update the values of the state variables. Since immutability is required for JIT caching, this method will return a copy of the class with the updated values.
Examples
>>> from diabayes.typedefs import StateDict >>> import jax.numpy as jnp >>> state_dict = StateDict(keys=("x",), vals=jnp.asarray(1.0)) >>> state_dict = state_dict.replace_values({"x": jnp.asarray(2.0)})
- Parameters:
**kwargs (dict) – Key-value pairs to update. These can overwrite existing values or create entirely new ones.
- Returns:
state_dict – A copy of the current StateDict with updated values
- Return type:
- keys: tuple[str, ...]#
A tuple of key names (strings) corresponding with the state variable names
- vals: Array#
An array of values corresponding (in order) with the variables defined by
keys