diabayes.typedefs.Variables#

class diabayes.typedefs.Variables(mu: Array, state: StateDict)[source]#

The main container class for variables (friction and state). It contains additional convenience routines for mapping between this class object and JAX arrays.

Users would typically not instantiate a Variables object directly; instead, it is created through the Forward.set_initial_values method and retreived as Forward.variables.

__init__(mu: Array, state: StateDict) None#

Methods

__init__(mu, state)

from_array(x, keys)

Import a JAX array to instantiate the class.

set_values(**kwargs)

Set the values of friction and the state variables through key-value pair arguments.

to_array()

Convert the container values to a JAX array.

Attributes

mu

The friction coefficient

state

A container with the various state variables

classmethod from_array(x: Float[Array, '...'], keys: tuple[str, ...]) Variables[source]#

Import a JAX array to instantiate the class. The array x can either be an array of scalars (size n for n variables), or an array of time series of shape (n, t). The first element in the array (i.e., x[0]) is assumed to be the friction coefficient mu. The remaining elements are the state variables, matching the number and order of the keys tuple.

set_values(**kwargs) Variables[source]#

Set the values of friction and the state variables through key-value pair arguments. The state parameters are provided individually as key-value pairs as well. Because this class is immutable, set_values returns a copy of the class with updated values.

Examples

>>> state_dict = StateDict(keys=("x", "y"), vals=jnp.array([0.1, 0.2]))
>>> variables = Variables(mu=jnp.asarray(0.6), state=state_dict)
>>> variables = variables.set_values(mu=..., x=..., y=...)
to_array() Float[Array, '...'][source]#

Convert the container values to a JAX array. The order of the output follows the order of StateDict.keys, with the first item being the friction coefficient. For n state variables, the output is an array of shape (1+n,) for scalars, and (t, 1+n) for time series.

Examples

>>> scalars = Variables(...)
>>> mu = scalars.to_array()[0]  # friction is first element
>>> timeseries = Variables(...)
>>> mu = timeseries.to_array()[:, 0]
>>> all_final_values = timeseries.to_array()[-1, :]

Of course, in this example one could simply do scalars.mu and timeseries.mu to extract mu directly.

mu: Array#

The friction coefficient

state: StateDict#

A container with the various state variables