Comment on “World Models”

Over the last few days, this work has been getting a lot of attention:

At the core, World Models is about building environment models from raw pixel observations and then using those models during training (in imagination). The motivation is to improve sample efficiency of reinforcement learning methods, but also enable new avenues of active exploration (e.g. artificial curiosity).

It is a great piece of research on many fronts. The field of model-based reinforcement learning is extensively reviewed. The implementation relies on a relatively small neural model (<5mln parameters). You will find in-depth discussion of strengths, weaknesses and potential extensions of the method. All of this accompanied by neat diagrams and even a real-time running simulation. Go and read the interactive article or the pdf if you have not yet.

In this post, I share a comment on the training scheme used.

Training Vision and Memory modules separately

As outlined in the training procedure, parts of the network are trained separately. First, a VAE is optimised to find visual representation of frames from environment. Then an MDN-RNN is trained to make predictions in this latent space. Finally, a controller uses both visual representation and memory state of the RNN to select actions.

Splitting the training in such manner means that the visual representation is not optimised for making predictions in this environment. The authors are aware of this (see footnote 5). Nonetheless, let me spell dangers of this approach.

VAE optimisation leads to a representation that performs well on image reconstruction criterion. Under such loss, little focus is put on reconstructing objects that are small in size or low contrast. This might not hurt the agent on the particular problems presented, but in general it is a strong drawback.

In many environments, e.g. Space Invaders, relatively small objects, say projectiles shot by enemies, can be visually small. If an aspect of an environment is not visible in the reconstruction, there is no reason to believe the latent representation contains information about it. Thus, an agent which cannot reconstruct bullets, cannot reliably predict their consequences. Or avoid getting hit.

Training vision and prediction models together in an end-to-end manner provides an additional incentive to represent information about small objects that strongly affect the future state. Not realising that a projectile is about to hit the player will result in high penalty on the prediction loss.

Somewhat similar argument could be made for training all of the three models together, but: (a) the author is aware of this and mentions it already in the Discussion, (b) it is not as important, as prediction of future frames probably requires the same or more information than action selection (though the reward could emphasise different parts of the state or guide exploration better).