diabayes.typedefs.Variables#
- class diabayes.typedefs.Variables(mu: Array, state: StateDict)[source]#
The main container class for variables (friction and state). It contains additional convenience routines for mapping between this class object and JAX arrays.
Users would typically not instantiate a
Variables
object directly; instead, it is created through theForward.set_initial_values
method and retreived asForward.variables
.Methods
__init__
(mu, state)from_array
(x, keys)Import a JAX array to instantiate the class.
set_values
(**kwargs)Set the values of friction and the state variables through key-value pair arguments.
to_array
()Convert the container values to a JAX array.
Attributes
- classmethod from_array(x: Float[Array, '...'], keys: tuple[str, ...]) Variables [source]#
Import a JAX array to instantiate the class. The array x can either be an array of scalars (size n for n variables), or an array of time series of shape (n, t). The first element in the array (i.e., x[0]) is assumed to be the friction coefficient mu. The remaining elements are the state variables, matching the number and order of the keys tuple.
- set_values(**kwargs) Variables [source]#
Set the values of friction and the state variables through key-value pair arguments. The state parameters are provided individually as key-value pairs as well. Because this class is immutable,
set_values
returns a copy of the class with updated values.Examples
>>> state_dict = StateDict(keys=("x", "y"), vals=jnp.array([0.1, 0.2])) >>> variables = Variables(mu=jnp.asarray(0.6), state=state_dict) >>> variables = variables.set_values(mu=..., x=..., y=...)
- to_array() Float[Array, '...'] [source]#
Convert the container values to a JAX array. The order of the output follows the order of StateDict.keys, with the first item being the friction coefficient. For
n
state variables, the output is an array of shape(1+n,)
for scalars, and(t, 1+n)
for time series.Examples
>>> scalars = Variables(...) >>> mu = scalars.to_array()[0] # friction is first element >>> timeseries = Variables(...) >>> mu = timeseries.to_array()[:, 0] >>> all_final_values = timeseries.to_array()[-1, :]
Of course, in this example one could simply do
scalars.mu
andtimeseries.mu
to extractmu
directly.
- mu: Array#
The friction coefficient