https://github.com/cloneofsimo/karras-power-ema-tutorial

https://github.com/cloneofsimo/karras-power-ema-tutorial

Science Score: 23.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (8.6%) to scientific vocabulary
Last synced: 10 months ago · JSON representation

Repository

Basic Info
  • Host: GitHub
  • Owner: cloneofsimo
  • Language: Python
  • Default Branch: master
  • Size: 825 KB
Statistics
  • Stars: 51
  • Watchers: 2
  • Forks: 1
  • Open Issues: 1
  • Releases: 0
Created over 2 years ago · Last pushed over 2 years ago
Metadata Files
Readme

README.md

Karras Power Function EMA (Post-training EMA synthesis)

This tutorial-repo implements the Karras's Power function EMA, quite incredible trick introduced in the paper Analyzing and Improving the Training Dynamics of Diffusion Models by Tero Karras, Miika Aittala, Jaakko Lehtinen, Janne Hellsten, Timo Aila, Samuli Laine.

So What is Karras's Power function EMA?

I recommend you to read the paper for full detail, but here is the big picture.

Recall that EMA'ing checkpoint is about keeping track of smooth-version of model parameters, $\theta\beta$, where $\theta\beta(t) = \beta \theta_\beta(t-1) + (1-\beta) \theta(t)$

, where $\beta$ is the decaying factor close to 1. Using EMA typically makes the model more robust, and it is a common practice in training deep neural networks.

You want to use EMA, but...

  1. You don't want the ema to be too slow, because it will make random initialization's contribution to the final model too big.
  2. You definitely want the decaying factor to be self-similar, because you should be able to increase-time of the training.
  3. You want to set decaying factor post-hoc, because you don't want to retrain the model from scratch with different decaying factor.

Karras's Power function EMA is the answer to all of these problems. He first uses power-function version of EMA where instead of keeping beta constant, he uses $\beta(t) = (1 - 1/t)^{1 + \gamma}$, where $\gamma$ is the hyperparameter. This makes the contribution of historical parameters self-similar, and you can increase the time of the training and it will not change how you expect the EMA to behave (i.e., if you want first 10% of the training to contribute x% of the final model, increasing/decreasing training time will not change that).

Overall Algorithm and Implementation

So there is two main part of the algorithm.

  1. Saving two copies of the EMA-model, each with different width.
  2. Recovering arbitrary-width EMA

Think of width as decaying factor. Larger width means it will be smoother.

Alt text

First, save two copies of the EMA, with different width

This is the easy part. You just need to save two copies of the EMA, each with different width (different $\gamma$).

```python gamma1 = 5 gamma2 = 10 model = Model() modelema1 = copy.deepcopy(model).cpu() modelema2 = copy.deepcopy(model).cpu()

for i, batch in enumerate(dataloader): beta1 = (1 - 1/(i+1)) ** (1 + gamma1) beta2 = (1 - 1/(i+1)) ** (1 + gamma2) # train model loss.backward() optimizer.step() for p, pema1, pema2 in zip(model.parameters(), modelema1.parameters(), modelema2.parameters()): pema1.data = pema1.data * beta1 + p.data * (1 - beta1) pema2.data = pema2.data * beta2 + p.data * (1 - beta_2)

if i % save_freq == 0:
    torch.save(model_ema_1.state_dict(), f'./model_ema_1_{i}.pth')
    torch.save(model_ema_2.state_dict(), f'./model_ema_2_{i}.pth')

```

Second, recover arbitrary-decay EMA after training.

Now what if you want to recover EMA with $\gamma_3$? Incredibly, you can do this with all the saved checkpoints. The math behind this in the paper is bit... not straightforward so here is my version of the explanation.

EMA, by definition, can be considered as integral of trajectory of the model parameters. So if you have some weighting function $w(t)$, such that

$$\thetae(T) = \int0^t w(t) \theta(t) dt$$

For a fixed training runs, $t \in [0, T]$, because we saved two copies of EMA for every, say, total of $n$ checkpoints for different $\gamma$ this means we know the integral value of the trajectory of the

$$\theta{i,j} = \int0^T w_{i, j}(t) \theta(t) dt$$

for $i = 1, 2$ and $j = 1, 2, \cdots, n$. $i$ correponds to different width, $j$ corresponds to $j$ th checkpoint. Notice how

$$ w{i, j}(t) = \begin{cases} t^{\gammai} / g_{i,j} & \text{if } t < j \ 0 & \text{otherwise} \end{cases} $$

where $g{i,j}$ is simply the normalization constant to make $\int0^T w_{i, j}(t) dt = 1$.

Our goal is then to

  1. find a approximate $\hat{w}3(t)$ that will give us the EMA that corresponds with $\gamma3$.

  2. find the correpsonding $\theta_{3,T}$

See where this is going? Our goal is to approximate $w3(t)$ as linear combination of $w{1, j}(t)$ and $w_{2, j}(t)$, i.e.,

$$w3(t) = \sum{j=1}^n \alphaj w{1, j}(t) + \betaj w{2, j}(t)$$

where $\alphaj$ and $\betaj$ are the coefficients we need to find. This way,

$$\theta{3,T} = \int0^T w3(t) \theta(t) dt = \sum{j=1}^n \alphaj \theta{1, j} + \betaj \theta{2, j}$$

Aha! Now we can find $\alphaj$ and $\betaj$ by solving the linear system of equations. Let's just take this one step further.

Goal let us project $w3(t)$ onto the subspace spanned by $w{1, j}(t)$ and $w_{2, j}(t)$

We have $K$ functions $fk(t)$, we have a target $g(t)$ and we want to find $k$ coefficients $\alphai$ such that

$$\min \int0^T \left( g(t) - \sum{i=1}^K \alphai fi(t) \right)^2 dt$$

How would you solve this?

Define inner product as

$$\langle f, g \rangle = \int_0^T f(t) g(t) dt$$

Then we can rewrite the problem as

$$\min \left | g - \sum{i=1}^K \alphai fi \right |2^2$$

if we define $| f |_2 = \sqrt{\langle f, f \rangle}$, expanding the norm, we get

math \min {\left \| g \right \|_2}^2 - 2 \sum_{i=1}^K \alpha_i \langle g, f_i \rangle + \sum_{i=1}^K \sum_{j=1}^K \alpha_i \alpha_j \langle f_i, f_j \rangle

Ha, so substituting $A{i,j} = \langle fi, fj \rangle$ and $bi = \langle g, f_i \rangle$, we actually just had linear least square problem!

$$\min \left | g \right |_2^2 - 2 \alpha^T b + \alpha^T A \alpha$$

where $\alpha = (\alpha1, \cdots, \alphaK)^T$.

So the solution is simply

$$\alpha = A^{+} b$$

where $A^{+}$ is the pseudo-inverse of $A$. We are left to just use the $\alpha$ to get $\theta_{3,T}$.

Note : Well if you ever studied functional analysis, you realize hey, there exists unique solution to this problem, via Hilbert's Projection Theorem. The above is simply finding the projection of $g$ onto the subspace spanned by $f_i$, in $L^2$ space.

So thing you learned:

  1. The level of approximation is determined by the number of checkpoints you saved. More checkpoints, better approximation.
  2. This doesn't have to be power-function EMA. You can use any weighting function $w(t)$, as long as you can compute the integral of the trajectory of the model parameters.

I don't care about the math just give me the code?

Ok, but reminder this is just for a power-function EMA. You can use this for any weighting function $w(t)$.

In the above code, you saved two copies of EMA, each with different $\gamma$. Now you want to recover EMA with $\gamma3$. Suppose you saved $n$ checkpoints, at iteration $i1, i2, \cdots, in$. Then you can do the following.

```python tcheckpoint = t[checkpointindex]

ts = np.concatenate((tcheckpoint, tcheckpoint)) gammas = np.concatenate( ( np.oneslike(checkpointindex) * gamma1, np.oneslike(checkpointindex) * gamma2, ) )

x = solveweights(ts, gammas, lastindex, gamma3) emapoints = np.concatenate((ytema1[checkpointindex], ytema2[checkpoint_index]))

ytema3 = np.dot(x, emapoints) ```

where solve_weights is the function that solves the linear least square problem. You can find the implementation in ema_eq.py.

The result is the EMA with $\gamma_3$.

Alt text

Owner

  • Name: Simo Ryu
  • Login: cloneofsimo
  • Kind: user
  • Company: Corca AI

Cats are Turing machines cloneofsimo@gmail.com

GitHub Events

Total
  • Watch event: 2
Last Year
  • Watch event: 2

Issues and Pull Requests

Last synced: about 1 year ago

All Time
  • Total issues: 1
  • Total pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Total issue authors: 1
  • Total pull request authors: 0
  • Average comments per issue: 0.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 1
  • Pull requests: 0
  • Average time to close issues: N/A
  • Average time to close pull requests: N/A
  • Issue authors: 1
  • Pull request authors: 0
  • Average comments per issue: 0.0
  • Average comments per pull request: 0
  • Merged pull requests: 0
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • XinYu-Andy (1)
Pull Request Authors
Top Labels
Issue Labels
Pull Request Labels