JaxPlan
pyRDDLGym-jax (or JaxPlan as it is referred to in the literature) is an extension of the pyRDDLGym eco-system, leveraging the JAX library to automatically build differentiable simulators for any RDDL problem and do gradient-based planning. It also supports various planning methods, including deep reactive policy networks and straight-line planning, risk-aware planning, and provides tools for automatic hyper-parameter tuning.
Purpose
The open-loop planning problem can be succinctly described mathematically as
where
which, for
The actions can then be updated by backpropagating through the return as follows:
for some learning rate
JaxPlan leverages JAX auto-differentiation to automatically compute the above gradients for any problem described in the RDDL Language, and state-of-the-art gradient descent algorithms such as ADAM to automatically compute the optimal sequence of actions for any problem. The planner is versatile and performs model relaxations when dealing with discrete domains, where the exact gradient would otherwise be impossible to compute (Gimelfarb et al., 2024).
Below are some successful (and cool!) examples where JaxPlan was able to solve the problem, or make substantial progress where some other planners failed:





JaxPlan was part of the official evaluation system in the 2023 International Planning Competition (Taitler et al., 2024)
Examples
JaxPlan can be easily set up on any Python environment that has the pyRDDLGym and JAX frameworks preinstalled. Simply create a config file to store hyper-parameters for the planner as described here, then run the following code:
import pyRDDLGym
from pyRDDLGym_jax.core.planner import JaxStraightLinePlan, JaxBackpropPlanner, JaxOfflineController, load_config
# set up the environment (note the vectorized option must be True)
env = pyRDDLGym.make("domain", "instance", vectorized=True)
# load the config file for the problem with hyper-parameters and set up the planner
planner_args, plan_args, train_args = load_config("/path/to/config.cfg")
plan = JaxStraightLinePlan(**plan_args)
planner = JaxBackpropPlanner(rddl=env.model, plan=plan, **planner_args)
controller = JaxOfflineController(planner, **train_args)
# evaluate the planner
controller.evaluate(env, episodes=1, verbose=True, render=True)
JaxPlan is highly configurable and scalable, and is capable of optimizing problems efficiently with dozens or even hundreds of observation or action variables.