Skip to contents

Introduction

This vignette covers advanced features and best practices for using CellODE effectively in your single-cell analysis workflows. We’ll discuss:

  • Model hyperparameter tuning
  • Handling complex trajectories
  • Prediction on query datasets
  • Integration with Seurat workflows
  • Performance optimization

Hyperparameter Tuning

Latent Space Dimensionality

The n_latent parameter controls the dimensionality of the learned latent space:

# For simple linear trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 3)

# For complex branching trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 10)

# For very complex data with multiple lineages
trainer <- Trainer$new(seurat_obj, n_latent = 20)

Guidelines:

Data Complexity Recommended n_latent
Simple linear 3-5
Single branch 5-10
Multiple branches 10-20
Very complex 20-50

ODE Network Architecture

# Default: lightweight
trainer <- Trainer$new(
  seurat_obj,
  n_ode_hidden = 25  # Hidden layer size in ODE function
)

# For more complex dynamics
trainer <- Trainer$new(
  seurat_obj,
  n_ode_hidden = 100
)

VAE Network Architecture

# Default architecture
trainer <- Trainer$new(
  seurat_obj,
  n_vae_hidden = 128
)

# For large gene sets or complex data
trainer <- Trainer$new(
  seurat_obj,
  n_vae_hidden = 256,
  batch_norm = TRUE  # Can help with large networks
)

Loss Function Selection

# Negative Binomial (default) - Best for UMI counts
trainer <- Trainer$new(seurat_obj, loss_mode = "nb")

# Zero-Inflated NB - For data with many zeros (e.g., SMART-seq)
trainer <- Trainer$new(seurat_obj, loss_mode = "zinb")

# MSE - For already normalized data
trainer <- Trainer$new(seurat_obj, loss_mode = "mse", slot = "data")

Loss Weight Balancing

The loss function combines multiple terms with adjustable weights:

trainer <- Trainer$new(
  seurat_obj,
  alpha_recon_lec = 0.5,   # Encoder reconstruction weight
  alpha_recon_lode = 0.5,  # ODE reconstruction weight
  alpha_kl = 1.0           # KL divergence weight
)

# If latent space is too smooth, increase KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 2.0)

# If reconstruction is poor, decrease KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 0.5)

Training Configuration

Learning Rate Scheduling

# Standard training
trainer <- Trainer$new(
  seurat_obj,
  lr = 1e-3,         # Learning rate
  wt_decay = 1e-6,   # L2 regularization
  eps = 0.01         # Adam optimizer epsilon
)

# For unstable training, use smaller learning rate
trainer <- Trainer$new(seurat_obj, lr = 1e-4)

# For faster convergence on simple data
trainer <- Trainer$new(seurat_obj, lr = 5e-3)

Batch Size and Epochs

# Auto-determined (recommended)
trainer <- Trainer$new(seurat_obj)

# Manual specification
trainer <- Trainer$new(
  seurat_obj,
  batch_size = 512,   # Smaller batch = more noise, larger = more memory
  nepoch = 200        # More epochs for complex data
)

Data Subsampling

For large datasets, train on a subset:

# Auto-determined based on dataset size
# > 10,000 cells: 20%
# <= 10,000 cells: 90%

# Manual override
trainer <- Trainer$new(seurat_obj, percent = 0.3)  # Use 30% of cells

Handling Complex Trajectories

Branching Trajectories

For data with multiple differentiation branches:

trainer <- Trainer$new(
  seurat_obj,
  n_latent = 15,       # More dimensions to capture branches
  n_ode_hidden = 50,   # Larger ODE network
  nepoch = 300         # More training epochs
)

# After training, use step-wise integration for better accuracy
latent <- trainer$get_latentsp(step_wise = TRUE)

Cyclic Trajectories

For cell cycle data:

# Cyclic data may have pseudotime wrapping issues
# Consider using larger latent space
trainer <- Trainer$new(seurat_obj, n_latent = 10)

# Manual time reversal may be needed
pseudotime <- trainer$get_time()
# If time direction seems wrong:
pseudotime <- reverse_time(pseudotime)

Prediction on Query Data

Coarse vs Fine Mode

# Coarse mode: Fast, independent of training data
query_latent <- predict_latentsp(trainer, query_seurat, mode = "coarse")

# Fine mode: More accurate, uses training data as reference
query_latent <- predict_latentsp(trainer, query_seurat, mode = "fine")

Predicting Future States

Interpolate latent space at unobserved time points:

# Define query time points
query_times <- seq(0.5, 1.0, by = 0.1)

# Predict latent representations
future_latent <- predict_ltsp_from_time(
  trainer,
  t = query_times,
  k = 20,           # Number of neighbors for interpolation
  step_wise = TRUE  # More accurate integration
)

Integration with Seurat Workflows

Adding Results to Seurat Object

# Pseudotime
seurat_obj$cellode_time <- trainer$get_time()

# Latent space as dimensional reduction
latent <- trainer$get_latentsp()
seurat_obj[["cellode"]] <- Seurat::CreateDimReducObject(
  embeddings = latent$mix_zs,
  key = "CELLODE_",
  assay = "RNA"
)

# Vector field
vf <- trainer$get_vector_field(seurat_obj$cellode_time, latent$mix_zs)
seurat_obj@misc$X_VF <- vf
seurat_obj@misc$X_zs <- latent$mix_zs

Downstream Analysis

# Find genes correlated with pseudotime
library(Seurat)

# Pseudotime regression using Seurat's AddModuleScore equivalent
gene_time_cor <- cor(
  as.matrix(t(seurat_obj[["RNA"]]@data)),
  seurat_obj$cellode_time
)

# Identify trajectory-associated genes
top_genes <- names(sort(abs(gene_time_cor[,1]), decreasing = TRUE)[1:100])

Performance Optimization

GPU Acceleration

# Auto-detect (default)
trainer <- Trainer$new(seurat_obj, use_gpu = TRUE)

# Force CPU (for debugging or memory issues)
trainer <- Trainer$new(seurat_obj, use_gpu = FALSE)

Memory Management

# For large datasets, use smaller batch size
trainer <- Trainer$new(
  seurat_obj,
  batch_size = 256,  # Reduce memory usage
  percent = 0.2      # Train on subset
)

# Clear GPU memory after training
gc()
torch::cuda_empty_cache()  # If using CUDA

Batched Inference

# For large datasets, use batched latent space computation
latent <- trainer$get_latentsp(batch_size = 1000)

Model Persistence

Saving and Loading

# Save model
trainer$save_model("path/to/model")

# Load model for prediction
loaded_trainer <- load_model("path/to/model", seurat_obj)

# Continue training (if needed)
loaded_trainer$train()

Model Inspection

# View model architecture
print(trainer$model)

# Check training history
plot_training_history(trainer)

# Access model parameters
trainer$model_kwargs

Troubleshooting

Training Issues

Problem Solution
Loss not decreasing Reduce learning rate
Loss oscillating Increase batch size
NaN/Inf loss Check input data normalization
Slow training Use GPU, reduce n_latent

Quality Issues

Problem Solution
Poor pseudotime ordering Check time direction, increase epochs
Noisy vector field Increase n_neigh in cosine_similarity
Discontinuous trajectory Increase n_latent

Memory Issues

Problem Solution
GPU out of memory Reduce batch_size
CPU out of memory Reduce percent, use sparse matrices

Reproducibility

# Set seeds for reproducibility
trainer <- Trainer$new(
  seurat_obj,
  random_state = 42  # Fixed seed
)

# Full reproducibility requires:
# 1. Same random_state
# 2. Same data ordering
# 3. Same hardware (GPU results may vary)

Best Practices Summary

  1. Start simple: Begin with default parameters, adjust as needed

  2. Monitor training: Use plot_loss() to check convergence

  3. Validate results: Compare with known markers or annotations

  4. Document settings: Save model parameters for reproducibility

  5. Use appropriate loss: NB for UMI counts, MSE for normalized data

  6. Consider data size: Use percent for large datasets, more epochs for small

Session Info

sessionInfo()
#> R version 4.4.0 (2024-04-24)
#> Platform: aarch64-apple-darwin20
#> Running under: macOS 15.6.1
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib 
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
#> 
#> locale:
#> [1] C
#> 
#> time zone: Asia/Shanghai
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> loaded via a namespace (and not attached):
#>  [1] digest_0.6.39     desc_1.4.3        R6_2.6.1          fastmap_1.2.0    
#>  [5] xfun_0.56         cachem_1.1.0      knitr_1.51        htmltools_0.5.9  
#>  [9] rmarkdown_2.30    lifecycle_1.0.5   cli_3.6.5         sass_0.4.10      
#> [13] pkgdown_2.1.3     textshaping_1.0.4 jquerylib_0.1.4   systemfonts_1.3.1
#> [17] compiler_4.4.0    tools_4.4.0       ragg_1.5.0        bslib_0.9.0      
#> [21] evaluate_1.0.5    yaml_2.3.12       otel_0.2.0        jsonlite_2.0.0   
#> [25] rlang_1.1.7       fs_1.6.6          htmlwidgets_1.6.4