Addressing Negative Transfer in Diffusion Models

NeurIPS 2023

Hyojun Go*, Jinyoung Kim*, Yunsung Lee*, Seunghyun Lee*, Shinhyeok Oh, Hyeongdon Moon, Seungtaek Choi

Twelve Labs    Wrtn    Riiid    EPFL   *Co-first author   Corresponding Author

ANT reconceptualizes the training process of diffusion models as multi-task learning problem. From this perspective, we can identify and mitigate negative transfer, leading to enhanced generation quality.

Summary

  • We rethink the training of diffusion models as multi-task learning in that each task corresponds to a denoising task of a specific timestep or noise level.
  • We analyze diffusion training from an multi-task learning (MTL) perspective, and present two key observations: (1) The task affinity between denoising tasks diminishes as the gap in noise levels increases, and (2) Negative transfer which causes performance degradation due to task conflicts, can arise even in diffusion models.
  • We propose leveraging existing multi-task learning (MTL) methods, such as PCgrad, UW, and NashMTL through interval clustering, which clusters denoising tasks into several pairwise disjoint timestep intervals.

Observational study

By denoting \(D_t\) as a denoising task of timestep \(t\) (learned by loss \(L_t = ||\epsilon - \epsilon_\theta(x_t, t)||_2^2\)), diffusion models can be formed as multi-task learning problem to denoising tasks \(D^{[0,T]}=\{D_t\}_{t=1,...,T}\). We analyze the task affinity between denoising tasks of different timesteps, and observe that negative transfer occurs in some denoising tasks.

1. Task affinity

We employ the gradient direction-based task affinity score. For two distinct tasks, \(D_{t_1}\) and \(D_{t_2}\), we compute the cosine similarity between their loss function gradients, then average the similarities across training iterations. We observe that the task affinity for denoising tasks decreases as the discrepancy between noise level and timestep increases.


2. Negative transfer

Negative transfer refers to deterioration in a multi-task learner's performance due to conflicts between tasks. It can be identified by observing the performance gap between a multi-task and specific-task learner. In this context, we define negative transfer gap (\(NTG\)), where \(NTG < 0 \) indicates that negative transfer occurs. We observe that negative transfer occurs in some denoising tasks in both ADM and LDM.

Method: ANT

1. Leveraging MTL methods

To address negative transfer, we propose leveraging existing MTL methods. We adopt three MTL methods: PCgrad, UW, and NashMTL.

  • PCgrad mitigate conflicting gradients between tasks by projecting conflicting parts of gradients.
  • NashMTL balances gradients between tasks by solving a bargaining game.
  • Uncertainty Weighting (UW) balances task losses by weighting each task loss with task-dependent uncertainty.

2. Interval clustering for efficient computation

MTL methods can require a large amount of computation, especially when the number of tasks is large. To address this, we leverage an interval clustering algorithm to group denoising tasks with interval clusters inspired from task affinity, then, we incorporate MTL methods by regarding each interval cluster as a single task. In our case, interval clustering assigns diffusion timesteps \(\mathcal{X} = \{1, \dots, T\}\) to \(k\) contiguous intervals \(I_{1}, \dots, I_{k}\), where \(I_{i} = [l_i, r_i]\) and \(l_i \leq r_i\).

For \(i = 1, \dots, k\) and \(l_{1} = 1\), and \(r_{i} = l_{i+1}-1\) (\(i< k\) and \(r_k=T\)), the interval clustering problem is defined as: $$ \min_{l_1=1 < l_2 < ... < l_k } \sum_{i=1}^k L_{cluster}(I_i \cap \mathcal{X}) $$

We present timestep, SNR and gradient-based clustering cost.

Experimental Results

1. Improved quality of generated images

Table: Leveraging MTL methods through interval clustering improves the quality of generated images in both ADM and LDM trained on CelebA-HQ and FFHQ datasets.


Figure: Leveraging MTL methods through interval clustering improves the quality of generated images in DiT-S trained on ImageNet dataset with classifier-free guidance.

2. Faster convergence

MTL methods mitigate negative transfer in training, achieving faster convergence than vanilla training.

Figure: Faster convergence in ADM and LDM trained on FFHQ dataset.


Figure: Faster convergence in DiT-S trained on ImageNet dataset with classifier-free guidance.

3. Negative transfer mitigated by MTL methods

MTL methods with interval clustering effectively mitigate negative transfer as indicated increased \(NTG\) compared to vanilla training.


4. Comparison to other weighting methods and computational costs

1. Our method, ANT-UW, that employ UW with interval clustering greatly outperforms MinSNR. 2. ANT-UW needs similar computation and memory cost to Vanilla training.

BibTeX

@article{go2023addressing,
      title={Addressing Negative Transfer in Diffusion Models},
      author={Go, Hyojun and Kim, JinYoung and Lee, Yunsung and Lee, Seunghyun and Oh, Shinhyeok and Moon, Hyeongdon and Choi, Seungtaek},
      journal={arXiv preprint arXiv:2306.00354},
      year={2023}
}

Acknowledgement

This website is adapted from Nerfies and LLaVA, licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.