Refactoring Policy for Compositional Generalizability using Self-Supervised Object Proposals


Tongzhou Mu* 1
Jiayuan Gu* 1
Zhiwei Jia1
Hao Tang2
Hao Su1

1University of California San Diego
2Shanghai Jiao Tong University
*Equal contribution

NeurIPS 2020





We study how to learn a policy with compositional generalizability. We propose a two-stage framework, which refactorizes a high-reward teacher policy into a generalizable student policy with strong inductive bias. Particularly, we implement an object-centric GNN-based student policy, whose input objects are learned from images through self-supervised learning. Empirically, we evaluate our approach on four difficult tasks that require compositional generalizability, and achieve superior performance compared to baselines.



Paper and Other Materials

Tongzhou Mu*, Jiayuan Gu*, Zhiwei Jia, Hao Tang, Hao Su

Refactoring Policy for Compositional Generalizability using
Self-Supervised Object Proposals

NeurIPS 2020

[Paper]
[Supplement]
[Bibtex]
[Arxiv]
[Poster]
[Slides]
[Code]
[Additional Results]



Results


Multi-MNIST


Generalize to more objects





In Multi-MNIST, the task is to to calculate the summation of the digits in the image with complicated backgrounds (from CIFAR or ImageNet). Our model (GNN+SPACE) trained on images with 1 to 3 digits can generalize well to images with 4 digits. We show that object-centric graph can be a strong inductive bias for compositional generalizability.



Task-relevant Knowledge Discovery




Object attributes emerge when the learned object features are clustered. This figure shows the t-SNE visualization of the learned object features by the self-supervised object detector (CIFAR/ImageNet-Recon) and our policy GNN (CIFAR/ImageNet-Task). It is observed that task-driven object features are more distinguishable compared to reconstruction-driven ones.



FallingDigit


Visualization of generalization






Train on 3 digits
CNN-based teacher policy fails to generalize to 9 digits
GNN-based student policy generalizes to 9 digits

FallingDigit is a Tetris-like game. A digit is falling from the top and the player needs to control it to hit the digit with the closest value lying on the bottom. We show that the student network with object-centric graph inductive bias can refactorize the teacher policy into a compositional generalizable policy.



Quantitative results of generalization




The mean rewards got by different methods in the FallingDigit environments with different target digits. Our refactorized GNN-based policy is trained on the environment with 3 target digits, and can generalize well to the environment with 9 target digits.




The website template is adapted from NVIDIA's DefTet project, and we thank the template author David Acuna.