Calculate KL divergence between two normal distributions.
This is the standard formulation from torchdiffeq.
Usage
normal_kl(mu1, lv1, mu2, lv2)
Arguments
- mu1
Mean of first distribution (posterior)
- lv1
Log variance of first distribution
- mu2
Mean of second distribution (prior)
- lv2
Log variance of second distribution
Value
Tensor of KL divergence values
References
RTQichen. torchdiffeq. https://github.com/rtqichen/torchdiffeq