It was in 2018, when the idea of reinforcement learning in the context of a neural network world model was first introduced, and soon, this fundamental principle was applied on world models. Some of the prominent models that implement reinforcement learning were the Dreamer framework, which introduced reinforcement learning from the latent space of a recurrent state space model. The DreamerV2 demonstrated that the use of discrete latents might result in reduced compounding errors, and the DreamerV3 framework was able to achieve human-like performance on a series of tasks across different domains with fixed hyperparameters.
Furthermore, parallels can be drawn between image generation models and world models indicating that the progress made in generative vision models could be replicated to benefit the world models. Ever since the use of transformers in natural language processing frameworks gained popularity, DALL-E and VQGAN frameworks emerged. The frameworks implemented discrete autoencoders to convert images into discrete tokens, and were able to build highly powerful and efficient text to image generative models by leveraging the sequence modeling abilities of the autoregressive transformers. At the same time, diffusion models gained traction, and today, diffusion models have established themselves as a dominant paradigm for high-resolution image generation. Owing to the capabilities offered by diffusion models and reinforcement learning, attempts are being made to combine the two approaches, with the aim to take advantage of the flexibility of diffusion models as trajectory models, reward models, planners, and as policy for data augmentation in offline reinforcement learning.
World models offer a promising method for training reinforcement learning agents safely and efficiently. Traditionally, these models use sequences of discrete latent variables to simulate environment dynamics. However, this compression can overlook visual details crucial for reinforcement learning. At the same time, diffusion models have risen in popularity for image generation, challenging traditional methods that use discrete latents. Inspired by this shift, in this article, we will talk about DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained within a diffusion world model. We will explore the necessary design choices to make diffusion suitable for world modeling and show that enhanced visual details lead to better agent performance. DIAMOND sets a new benchmark on the competitive Atari 100k test, achieving a mean human normalized score of 1.46, the highest for agents trained entirely within a world model.
World models or Generative models of environments are emerging as one of the more important components for generative agents to plan and reason about their environments. Although the use of reinforcement learning has achieved considerable success in recent years, models implementing reinforcement learning are known for being sample inefficient, which significantly limits their real world applications. On the other hand, world models have demonstrated their ability to efficiently train reinforcement learning agents across diverse environments with a significantly improved sample efficiency, allowing the model to learn from real world experiences. Recent world modeling frameworks usually model environment dynamics as a sequence of discrete latent variables, with the model discretizing the latent space to avoid compounding errors over multi-step time horizons. Although the approach might deliver substantial results, it is also associated with a loss of information, leading to loss of reconstruction quality and loss of generality. The loss of information might become a significant roadblock for real-world scenarios that require the information to be well-defined, like training autonomous vehicles. In such tasks, small changes or details in the visual input like the color of the traffic light, or the turn indicator of the vehicle in front can change the policy of an agent. Although increasing the number of discrete latents can help avoid information loss, it shoots the computation costs significantly.
Furthermore, in the recent years, diffusion models have emerged as the dominant approach for high-quality image generation frameworks since frameworks built on diffusion models learn to reverse a noising process, and directly competes with some of the more well-established approaches modeling discrete tokens, and therefore offers a promising alternative to eliminate the need for discretization in world modeling. Diffusion models are known for their ability to be easily conditioned and to flexibly model complex, multi-modal distributions without mode collapse. These attributes are crucial for world modeling, as conditioning enables a world model to accurately reflect an agent’s actions, leading to more reliable credit assignment. Moreover, modeling multimodal distributions offers a greater diversity of training scenarios for the agent, enhancing its overall performance.
Building upon these characteristics, DIAMOND, (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning agent trained within a diffusion world model. The DIAMOND framework makes careful design choices to ensure its diffusion world model remains efficient and stable over long time horizons. The framework provides a qualitative analysis to demonstrate the importance of these design choices. DIAMOND sets a new state-of-the-art with a mean human normalized score of 1.46 on the well-established Atari 100k benchmark, the highest for agents trained entirely within a world model. Operating in image space allows DIAMOND’s diffusion world model to seamlessly substitute the environment, offering greater insights into world model and agent behaviors. Notably, the improved performance in certain games is attributed to better modeling of critical visual details. The DIAMOND framework models the environment as a standard POMDP or Partially Observable Markov Decision Process with a set of states, a set of discrete actions, and a set of image observations. The transition functions describe the environment dynamics, and the reward function maps the transitions to scalar rewards. The observation function describes the observation probabilities, and emits image observations, that are then used by the agents to see the environments, since they cannot directly access the states. The primary aim of the approach was to obtain a policy that maps observations to actions with the attempt to maximize the expected discount return with a discount factor. World models are generative models of the environment, and world models can be used to create simulated environments to train reinforcement learning agents in the real environment, and train reinforcement learning agents in the world model environment. Figure 1 demonstrates the unrolling imagination of the DIAMOND framework over time.
DIAMOND : Methodology and Architecture
At its core, diffusion models are a class of generative models that generate a sample by reversing the noising process, and draw heavy inspiration from non-equilibrium thermodynamics. The DIAMOND framework considers a diffusion process indexed by a continuous time variable with corresponding marginals and boundary conditions with a tractable unstructured prior distribution. Furthermore, to obtain a generative model, which maps from noise to data, the DIAMOND framework must reverse the process, with the reversion process also being a diffusion process, running backwards in time. Furthermore, at any given point in time, it is not trivial to estimate the score function since the DIAMOND framework does not access to the true score function, and the model overcomes this hurdle by implementing score matching objective, an approach that facilitates a framework to train a score model without knowing the underlying score function. The score-based diffusion model provides an unconditional generative model. However, a conditional generative model of environment dynamics is required to serve as a world model, and to serve this purpose, the DIAMOND framework looks at the general case of the POMDP approach, in which the framework can make use of past observations and actions to approximate the unknown Markovian state. As demonstrated in Figure 1., the DIAMOND framework makes use of this history to condition a diffusion model, to estimate and generate the next observation directly. Although the DIAMOND framework can resort to any SDE or ODE solver in theory, there is a trade-off between NFE or Number of Function Evaluations, and sample quality that impacts the inference cost of diffusion models significantly.
Building on the above learnings, let us now look at the practical realization of the DIAMOND framework of a diffusion-based world model including the drift and diffusion coefficients corresponding to a particular choice of diffusion approach. Instead of opting for DDPM, a naturally suitable candidate for the task, the DIAMOND framework builds on the EDM formulation, and considers a perturbation kernel with a real-valued function of diffusion time called the noise schedule. The framework selects the preconditioners to keep the input and output variance for any voice level. The network training mixes signal and noise adaptively depending on the degradation level, and when the noise is low, and the target becomes the difference between the clean and the perturbed signal, i.e. the added Gaussian noise. Intuitively, this prevents the training objective from becoming trivial in the low-noise regime. In practice, this objective is high variance at the extremes of the noise schedule, so the model samples the noise level from a log-normal distribution chosen empirically in order to concatenate the training around the medium noise regions. The DIAMOND framework makes use of a standard U-Net 2D component for the vector field, and keeps a buffer of past observations and actions that the framework uses to condition itself. The DIAMOND framework then concatenates these past observations to the next noisy observation, and input actions through adaptive group normalization layers in the residual blocks of the U-Net.
DIAMOND: Experiments and Results
For comprehensive evaluation, the DIAMOND framework opts for the Atari 100k benchmark. The Atari 100k benchmark consists of 26 games designed to test a wide range of agent capabilities. In each game, an agent is limited to 100k actions in the environment, which is roughly equivalent to 2 hours of human gameplay, to learn the game before evaluation. For comparison, unconstrained Atari agents typically train for 50 million steps, representing a 500-fold increase in experience. We trained DIAMOND from scratch using 5 random seeds for each game. Each training run required around 12GB of VRAM and took approximately 2.9 days on a single Nvidia RTX 4090, amounting to 1.03 GPU years in total. The following table provides the score for all games, the mean, and the IQM or interquartile mean of human-normalized scores.
Following the limitations of point estimates, the DIAMOND framework provides stratified bootstrap confidence in the mean, and the IQM or interquartile mean of human-normalized scores along with performance profiles and additional metrics, as summed up in the following figure.
The results show that DIAMOND performs exceptionally well across the benchmark, surpassing human players in 11 games and achieving a superhuman mean HNS of 1.46, setting a new record for agents trained entirely within a world model. Additionally, DIAMOND’s IQM is comparable to STORM and exceeds all other baselines. DIAMOND excels in environments where capturing small details is crucial, such as Asterix, Breakout, and RoadRunner. Furthermore, as discussed earlier, the DIAMOND framework has the flexibility of implementing any diffusion model in its pipeline, although it opts for the EDM approach, it would have been a natural choice to opt for the DDPM model since it is already being implemented in numerous image generative applications. To compare the EDM approach against DDPM implementation, the DIAMOND framework trains both the variants with the same network architecture on the same shared static dataset with over 100k frames collected with an expert policy. The number of denoising steps is directly related to the inference cost of the world model, and so fewer steps will reduce the cost of training an agent on imagined trajectories. To ensure our world model remains computationally comparable with other baselines, such as IRIS which requires 16 NFE per timestep, we aim to use no more than tens of denoising steps, preferably fewer. However, setting the number of denoising steps too low can degrade visual quality, leading to compounding errors. To assess the stability of different diffusion variants, we display imagined trajectories generated autoregressively up to t = 1000 timesteps in the following figure, using different numbers of denoising steps n ≤ 10.
We observe that using DDPM (a), in this regime results in severe compounding errors, causing the world model to quickly drift out of distribution. In contrast, the EDM-based diffusion world model (b) remains much more stable over long time horizons, even with a single denoising step. Imagined trajectories with diffusion world models based on DDPM (left) and EDM (right) are shown. The initial observation at t = 0 is the same for both, and each row corresponds to a decreasing number of denoising steps n. We observe that DDPM-based generation suffers from compounding errors, with smaller numbers of denoising steps leading to faster error accumulation. In contrast, DIAMOND’s EDM-based world model remains much more stable, even for n = 1. The optimal single-step prediction is the expectation over possible reconstructions for a given noisy input, which can be out of distribution if the posterior distribution is multimodal. While some games, like Breakout, have deterministic transitions that can be accurately modeled with a single denoising step, other games exhibit partial observability, resulting in multimodal observation distributions. In these cases, an iterative solver is necessary to guide the sampling procedure towards a specific mode, as illustrated in the game Boxing in the following figure. Consequently, The DIAMOND framework set n = 3 in all of our experiments.
The above figure compares single-step (top row) and multi-step (bottom row) sampling in Boxing. The movements of the black player are unpredictable, causing single-step denoising to interpolate between possible outcomes, resulting in blurry predictions. In contrast, multi-step sampling produces a clear image by guiding the generation towards a specific mode. Interestingly, since the policy controls the white player, his actions are known to the world model, eliminating ambiguity. Thus, both single-step and multi-step sampling correctly predict the white player’s position.
In the above figure, the trajectories imagined by DIAMOND generally exhibit higher visual quality and are more faithful to the true environment compared to those imagined by IRIS. The trajectories generated by IRIS contain visual inconsistencies between frames (highlighted by white boxes), such as enemies being displayed as rewards and vice-versa. Although these inconsistencies may only affect a few pixels, they can significantly impact reinforcement learning. For instance, an agent typically aims to target rewards and avoid enemies, so these small visual discrepancies can make it more challenging to learn an optimal policy. The figure shows consecutive frames imagined with IRIS (left) and DIAMOND (right). The white boxes highlight inconsistencies between frames, which only appear in trajectories generated with IRIS. In Asterix (top row), an enemy (orange) becomes a reward (red) in the second frame, then reverts to an enemy in the third, and again to a reward in the fourth. In Breakout (middle row), the bricks and score are inconsistent between frames. In Road Runner (bottom row), the rewards (small blue dots on the road) are inconsistently rendered between frames. These inconsistencies do not occur with DIAMOND. In Breakout, the score is reliably updated by +7 when a red brick is broken.
Conclusion
In this article, we have talked about DIAMOND, a reinforcement learning agent trained within a diffusion world model. The DIAMOND framework makes careful design choices to ensure its diffusion world model remains efficient and stable over long time horizons. The framework provides a qualitative analysis to demonstrate the importance of these design choices. DIAMOND sets a new state-of-the-art with a mean human normalized score of 1.46 on the well-established Atari 100k benchmark, the highest for agents trained entirely within a world model. Operating in image space allows DIAMOND’s diffusion world model to seamlessly substitute the environment, offering greater insights into world model and agent behaviors. Notably, the improved performance in certain games is attributed to better modeling of critical visual details. The DIAMOND framework models the environment as a standard POMDP or Partially Observable Markov Decision Process with a set of states, a set of discrete actions, and a set of image observations. The transition functions describe the environment dynamics, and the reward function maps the transitions to scalar rewards.