Skip to contents

Calculate KL divergence between two normal distributions. This is the standard formulation from torchdiffeq.

Usage

normal_kl(mu1, lv1, mu2, lv2)

Arguments

mu1

Mean of first distribution (posterior)

lv1

Log variance of first distribution

mu2

Mean of second distribution (prior)

lv2

Log variance of second distribution

Value

Tensor of KL divergence values

References

RTQichen. torchdiffeq. https://github.com/rtqichen/torchdiffeq