Skip to contents

Introduction

CellODE implements a deep generative model that combines Variational Autoencoders (VAE) with Neural Ordinary Differential Equations (Neural ODE) to infer continuous cellular dynamics from single-cell RNA sequencing data.

This vignette provides a detailed explanation of the mathematical foundations and algorithmic principles underlying CellODE.

Mathematical Framework

Problem Formulation

Given a single-cell gene expression matrix π—βˆˆβ„NΓ—G\mathbf{X} \in \mathbb{R}^{N \times G} where NN is the number of cells and GG is the number of genes, we aim to:

  1. Infer a pseudotime ti∈[0,1]t_i \in [0, 1] for each cell ii
  2. Learn a latent representation 𝐳iβˆˆβ„d\mathbf{z}_i \in \mathbb{R}^d capturing cellular state
  3. Model the continuous dynamics d𝐳dt\frac{d\mathbf{z}}{dt} in latent space

Model Architecture

The TNODE (Time Neural ODE) model consists of three main components:

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                        CellODE Model                         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                             β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                          β”‚
β”‚   β”‚   Input     β”‚  X ∈ ℝ^(NΓ—G)                             β”‚
β”‚   β”‚ Expression  β”‚                                          β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜                                          β”‚
β”‚          β”‚                                                  β”‚
β”‚          β–Ό                                                  β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚   β”‚              Encoder Network             β”‚              β”‚
β”‚   β”‚  q(t, z | x) = q(t|x) Β· q(z|x)         β”‚              β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚                     β”‚                                       β”‚
β”‚          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                             β”‚
β”‚          β–Ό                   β–Ό                              β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                       β”‚
β”‚   β”‚ Time t     β”‚      β”‚ Latent z   β”‚                       β”‚
β”‚   β”‚ (sigmoid)  β”‚      β”‚ (ΞΌ, σ²)    β”‚                       β”‚
β”‚   β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜      β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜                       β”‚
β”‚         β”‚                    β”‚                              β”‚
β”‚         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜                             β”‚
β”‚                      β–Ό                                      β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚   β”‚            Neural ODE Solver             β”‚              β”‚
β”‚   β”‚       dz/dt = f_ΞΈ(z, t)                 β”‚              β”‚
β”‚   β”‚                                          β”‚              β”‚
β”‚   β”‚  z(tβ‚‚) = z(t₁) + ∫_{t₁}^{tβ‚‚} f_ΞΈ(z,t)dt β”‚              β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚                     β”‚                                       β”‚
β”‚                     β–Ό                                       β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚   β”‚            Decoder Network               β”‚              β”‚
β”‚   β”‚  p(x | z) ~ NB(ΞΌ(z), ΞΈ)                 β”‚              β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚                                                             β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Encoder Network

The encoder maps gene expression to time and latent space:

qΟ•(t,𝐳|𝐱)=qΟ•(t|𝐱)β‹…qΟ•(𝐳|𝐱)q_\phi(t, \mathbf{z} | \mathbf{x}) = q_\phi(t | \mathbf{x}) \cdot q_\phi(\mathbf{z} | \mathbf{x})

Time Inference

The time is modeled as a deterministic function with sigmoid activation to constrain values to [0,1][0, 1]:

t=Οƒ(ft(𝐱))t = \sigma(f_t(\mathbf{x}))

where Οƒ(β‹…)\sigma(\cdot) is the sigmoid function.

Latent Space Inference

The latent space follows a Gaussian distribution:

qΟ•(𝐳|𝐱)=𝒩(𝐳;𝛍ϕ(𝐱),diag(𝛔ϕ2(𝐱)))q_\phi(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}_\phi(\mathbf{x}), \text{diag}(\boldsymbol{\sigma}^2_\phi(\mathbf{x})))

We use the reparameterization trick for gradient computation:

𝐳=𝛍+π›”βŠ™π›œ,π›œβˆΌπ’©(0,𝐈)\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})

Neural ODE

Continuous Dynamics

The core innovation is modeling latent dynamics as a continuous-time process:

d𝐳(t)dt=fθ(𝐳(t),t)\frac{d\mathbf{z}(t)}{dt} = f_\theta(\mathbf{z}(t), t)

where fΞΈf_\theta is a neural network parameterized by ΞΈ\theta.

Integration

Given initial state 𝐳0\mathbf{z}_0 at time t0t_0, the state at any time tt is:

𝐳(t)=𝐳0+∫t0tfΞΈ(𝐳(Ο„),Ο„)dΟ„\mathbf{z}(t) = \mathbf{z}_0 + \int_{t_0}^{t} f_\theta(\mathbf{z}(\tau), \tau) d\tau

CellODE uses the Euler method for numerical integration:

𝐳n+1=𝐳n+Ξ”tβ‹…fΞΈ(𝐳n,tn)\mathbf{z}_{n+1} = \mathbf{z}_n + \Delta t \cdot f_\theta(\mathbf{z}_n, t_n)

ODE Function Architecture

The latent ODE function fΞΈf_\theta is implemented as a simple MLP:

Input: z ∈ ℝ^d
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
│ Linear(d→h) │
β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
      β”‚
      β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚    ELU      β”‚
β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
      β”‚
      β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
│ Linear(h→d) │
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
      β”‚
      β–Ό
Output: dz/dt ∈ ℝ^d

Decoder Network

Reconstruction Likelihood

CellODE supports three likelihood models:

1. Negative Binomial (NB)

For UMI count data, the negative binomial distribution is most appropriate:

p(xig|𝐳i)=NB(xig;μig,θg)p(x_{ig} | \mathbf{z}_i) = \text{NB}(x_{ig}; \mu_{ig}, \theta_g)

where: - ΞΌig=siβ‹…softmax(fdec(𝐳i))g\mu_{ig} = s_i \cdot \text{softmax}(f_\text{dec}(\mathbf{z}_i))_g is the expected count - sis_i is the library size (total UMI count) - ΞΈg\theta_g is the dispersion parameter

The log-likelihood is:

log⁑p(x|ΞΌ,ΞΈ)=log⁑Γ(x+ΞΈ)βˆ’log⁑Γ(ΞΈ)βˆ’log⁑Γ(x+1)+ΞΈlog⁑θθ+ΞΌ+xlog⁑μθ+ΞΌ\log p(x | \mu, \theta) = \log\Gamma(x + \theta) - \log\Gamma(\theta) - \log\Gamma(x+1) + \theta\log\frac{\theta}{\theta+\mu} + x\log\frac{\mu}{\theta+\mu}

2. Zero-Inflated Negative Binomial (ZINB)

For data with excess zeros:

p(x|ΞΌ,ΞΈ,Ο€)=Ο€β‹…πŸx=0+(1βˆ’Ο€)β‹…NB(x;ΞΌ,ΞΈ)p(x | \mu, \theta, \pi) = \pi \cdot \mathbf{1}_{x=0} + (1-\pi) \cdot \text{NB}(x; \mu, \theta)

3. Mean Squared Error (MSE)

For log-normalized data:

β„’MSE=1Nβˆ‘i=1N||𝐱iβˆ’π±Μ‚i||2\mathcal{L}_\text{MSE} = \frac{1}{N}\sum_{i=1}^N ||\mathbf{x}_i - \hat{\mathbf{x}}_i||^2

Loss Function

The total loss combines multiple components:

β„’=Ξ±1β„’reconenc+Ξ±2β„’reconode+β„’z-div+Ξ±klβ„’KL\mathcal{L} = \alpha_1 \mathcal{L}_\text{recon}^\text{enc} + \alpha_2 \mathcal{L}_\text{recon}^\text{ode} + \mathcal{L}_\text{z-div} + \alpha_\text{kl} \mathcal{L}_\text{KL}

Components

  1. Encoder Reconstruction Loss (β„’reconenc\mathcal{L}_\text{recon}^\text{enc}): Reconstruction from encoder-derived latent space

  2. ODE Reconstruction Loss (β„’reconode\mathcal{L}_\text{recon}^\text{ode}): Reconstruction from ODE-integrated latent space

  3. Latent Divergence (β„’z-div\mathcal{L}_\text{z-div}): β„’z-div=||𝐳encβˆ’π³ode||2\mathcal{L}_\text{z-div} = ||\mathbf{z}_\text{enc} - \mathbf{z}_\text{ode}||^2

  4. KL Divergence (β„’KL\mathcal{L}_\text{KL}): β„’KL=KL[q(𝐳|𝐱)||p(𝐳)]\mathcal{L}_\text{KL} = \text{KL}[q(\mathbf{z}|\mathbf{x}) || p(\mathbf{z})]

    For Gaussian distributions: KL=12βˆ‘j=1d(Οƒj2+ΞΌj2βˆ’1βˆ’logΟƒj2)\text{KL} = \frac{1}{2}\sum_{j=1}^d \left(\sigma_j^2 + \mu_j^2 - 1 - \log\sigma_j^2\right)

Time Direction Determination

The model may learn pseudotime in either direction. CellODE automatically determines the correct direction using the correlation between inferred time and number of detected genes:

β=cov(t,log⁑(ngenes))\beta = \text{cov}(t, \log(n_\text{genes}))

If Ξ²>0\beta > 0, the time is reversed since more mature cells typically have fewer detected genes.

Demonstration

Let’s visualize how the Neural ODE models dynamics:

library(torch)

# Create a simple ODE function
ode_func <- torch::nn_module(
  initialize = function() {
    self$fc1 <- torch::nn_linear(2, 32)
    self$fc2 <- torch::nn_linear(32, 2)
  },
  forward = function(t, z) {
    out <- torch::nnf_elu(self$fc1(z))
    self$fc2(out)
  }
)

# Initialize
func <- ode_func()
torch::torch_manual_seed(42)

# Create initial points on a circle
n_points <- 20
theta <- seq(0, 2*pi, length.out = n_points + 1)[1:n_points]
z0 <- torch::torch_stack(list(
  torch::torch_tensor(0.5 * cos(theta)),
  torch::torch_tensor(0.5 * sin(theta))
), dim = 2)

# Time points
t <- torch::torch_linspace(0, 1, 20)

# Simple Euler integration
func$eval()
trajectories <- list()

torch::with_no_grad({
  for (i in 1:n_points) {
    z <- z0[i, ]
    traj <- matrix(0, nrow = 20, ncol = 2)
    traj[1, ] <- as.numeric(z)
    
    for (j in 2:20) {
      dt <- (t[j] - t[j-1])$item()
      dz <- func(t[j-1], z)
      z <- z + dt * dz
      traj[j, ] <- as.numeric(z)
    }
    trajectories[[i]] <- traj
  }
})

# Plot trajectories
plot(NULL, xlim = c(-1.5, 1.5), ylim = c(-1.5, 1.5),
     xlab = "z1", ylab = "z2", main = "Neural ODE Trajectories")

colors <- rainbow(n_points)
for (i in 1:n_points) {
  lines(trajectories[[i]], col = colors[i], lwd = 1.5)
  points(trajectories[[i]][1, 1], trajectories[[i]][1, 2], 
         pch = 19, col = colors[i], cex = 1)
  points(trajectories[[i]][20, 1], trajectories[[i]][20, 2], 
         pch = 17, col = colors[i], cex = 1)
}
legend("topright", legend = c("Start", "End"), pch = c(19, 17), bty = "n")

Key Innovations

1. Automatic Time Inference

Unlike methods requiring specification of root cells, CellODE infers pseudotime directly from the data through the encoder network.

2. Continuous Dynamics

The Neural ODE framework provides: - Continuous (not discrete) state transitions - Physically interpretable dynamics (velocity field) - Memory-efficient training via adjoint method

3. Unified Framework

CellODE jointly learns: - Temporal ordering (pseudotime) - Low-dimensional representation (latent space) - Dynamical model (vector field)

References

  1. Li, S. et al.Β (2023). scTour: a deep learning architecture for robust inference and accurate prediction of cellular dynamics. Genome Biology, 24, 149.

  2. Chen, R.T.Q. et al.Β (2018). Neural Ordinary Differential Equations. NeurIPS.

  3. Kingma, D.P. & Welling, M. (2014). Auto-Encoding Variational Bayes. ICLR.

  4. Lopez, R. et al.Β (2018). Deep generative modeling for single-cell transcriptomics. Nature Methods, 15, 1053-1058.

Session Info