Advanced Usage and Best Practices
Zaoqu Liu
2026-01-26
Source:vignettes/advanced-usage.Rmd
advanced-usage.RmdIntroduction
This vignette covers advanced features and best practices for using CellODE effectively in your single-cell analysis workflows. We’ll discuss:
- Model hyperparameter tuning
- Handling complex trajectories
- Prediction on query datasets
- Integration with Seurat workflows
- Performance optimization
Hyperparameter Tuning
Latent Space Dimensionality
The n_latent parameter controls the dimensionality of
the learned latent space:
# For simple linear trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 3)
# For complex branching trajectories
trainer <- Trainer$new(seurat_obj, n_latent = 10)
# For very complex data with multiple lineages
trainer <- Trainer$new(seurat_obj, n_latent = 20)Guidelines:
| Data Complexity | Recommended n_latent
|
|---|---|
| Simple linear | 3-5 |
| Single branch | 5-10 |
| Multiple branches | 10-20 |
| Very complex | 20-50 |
Loss Function Selection
# Negative Binomial (default) - Best for UMI counts
trainer <- Trainer$new(seurat_obj, loss_mode = "nb")
# Zero-Inflated NB - For data with many zeros (e.g., SMART-seq)
trainer <- Trainer$new(seurat_obj, loss_mode = "zinb")
# MSE - For already normalized data
trainer <- Trainer$new(seurat_obj, loss_mode = "mse", slot = "data")Loss Weight Balancing
The loss function combines multiple terms with adjustable weights:
trainer <- Trainer$new(
seurat_obj,
alpha_recon_lec = 0.5, # Encoder reconstruction weight
alpha_recon_lode = 0.5, # ODE reconstruction weight
alpha_kl = 1.0 # KL divergence weight
)
# If latent space is too smooth, increase KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 2.0)
# If reconstruction is poor, decrease KL weight
trainer <- Trainer$new(seurat_obj, alpha_kl = 0.5)Training Configuration
Learning Rate Scheduling
# Standard training
trainer <- Trainer$new(
seurat_obj,
lr = 1e-3, # Learning rate
wt_decay = 1e-6, # L2 regularization
eps = 0.01 # Adam optimizer epsilon
)
# For unstable training, use smaller learning rate
trainer <- Trainer$new(seurat_obj, lr = 1e-4)
# For faster convergence on simple data
trainer <- Trainer$new(seurat_obj, lr = 5e-3)Data Subsampling
For large datasets, train on a subset:
# Auto-determined based on dataset size
# > 10,000 cells: 20%
# <= 10,000 cells: 90%
# Manual override
trainer <- Trainer$new(seurat_obj, percent = 0.3) # Use 30% of cellsHandling Complex Trajectories
Branching Trajectories
For data with multiple differentiation branches:
trainer <- Trainer$new(
seurat_obj,
n_latent = 15, # More dimensions to capture branches
n_ode_hidden = 50, # Larger ODE network
nepoch = 300 # More training epochs
)
# After training, use step-wise integration for better accuracy
latent <- trainer$get_latentsp(step_wise = TRUE)Cyclic Trajectories
For cell cycle data:
# Cyclic data may have pseudotime wrapping issues
# Consider using larger latent space
trainer <- Trainer$new(seurat_obj, n_latent = 10)
# Manual time reversal may be needed
pseudotime <- trainer$get_time()
# If time direction seems wrong:
pseudotime <- reverse_time(pseudotime)Prediction on Query Data
Coarse vs Fine Mode
# Coarse mode: Fast, independent of training data
query_latent <- predict_latentsp(trainer, query_seurat, mode = "coarse")
# Fine mode: More accurate, uses training data as reference
query_latent <- predict_latentsp(trainer, query_seurat, mode = "fine")Predicting Future States
Interpolate latent space at unobserved time points:
# Define query time points
query_times <- seq(0.5, 1.0, by = 0.1)
# Predict latent representations
future_latent <- predict_ltsp_from_time(
trainer,
t = query_times,
k = 20, # Number of neighbors for interpolation
step_wise = TRUE # More accurate integration
)Integration with Seurat Workflows
Adding Results to Seurat Object
# Pseudotime
seurat_obj$cellode_time <- trainer$get_time()
# Latent space as dimensional reduction
latent <- trainer$get_latentsp()
seurat_obj[["cellode"]] <- Seurat::CreateDimReducObject(
embeddings = latent$mix_zs,
key = "CELLODE_",
assay = "RNA"
)
# Vector field
vf <- trainer$get_vector_field(seurat_obj$cellode_time, latent$mix_zs)
seurat_obj@misc$X_VF <- vf
seurat_obj@misc$X_zs <- latent$mix_zsDownstream Analysis
# Find genes correlated with pseudotime
library(Seurat)
# Pseudotime regression using Seurat's AddModuleScore equivalent
gene_time_cor <- cor(
as.matrix(t(seurat_obj[["RNA"]]@data)),
seurat_obj$cellode_time
)
# Identify trajectory-associated genes
top_genes <- names(sort(abs(gene_time_cor[,1]), decreasing = TRUE)[1:100])Performance Optimization
Memory Management
# For large datasets, use smaller batch size
trainer <- Trainer$new(
seurat_obj,
batch_size = 256, # Reduce memory usage
percent = 0.2 # Train on subset
)
# Clear GPU memory after training
gc()
torch::cuda_empty_cache() # If using CUDAModel Persistence
Saving and Loading
# Save model
trainer$save_model("path/to/model")
# Load model for prediction
loaded_trainer <- load_model("path/to/model", seurat_obj)
# Continue training (if needed)
loaded_trainer$train()Model Inspection
# View model architecture
print(trainer$model)
# Check training history
plot_training_history(trainer)
# Access model parameters
trainer$model_kwargsTroubleshooting
Training Issues
| Problem | Solution |
|---|---|
| Loss not decreasing | Reduce learning rate |
| Loss oscillating | Increase batch size |
| NaN/Inf loss | Check input data normalization |
| Slow training | Use GPU, reduce n_latent |
Reproducibility
# Set seeds for reproducibility
trainer <- Trainer$new(
seurat_obj,
random_state = 42 # Fixed seed
)
# Full reproducibility requires:
# 1. Same random_state
# 2. Same data ordering
# 3. Same hardware (GPU results may vary)Best Practices Summary
Start simple: Begin with default parameters, adjust as needed
Monitor training: Use
plot_loss()to check convergenceValidate results: Compare with known markers or annotations
Document settings: Save model parameters for reproducibility
Use appropriate loss: NB for UMI counts, MSE for normalized data
Consider data size: Use
percentfor large datasets, more epochs for small
Session Info
sessionInfo()
#> R version 4.4.0 (2024-04-24)
#> Platform: aarch64-apple-darwin20
#> Running under: macOS 15.6.1
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
#>
#> locale:
#> [1] C
#>
#> time zone: Asia/Shanghai
#> tzcode source: internal
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> loaded via a namespace (and not attached):
#> [1] digest_0.6.39 desc_1.4.3 R6_2.6.1 fastmap_1.2.0
#> [5] xfun_0.56 cachem_1.1.0 knitr_1.51 htmltools_0.5.9
#> [9] rmarkdown_2.30 lifecycle_1.0.5 cli_3.6.5 sass_0.4.10
#> [13] pkgdown_2.1.3 textshaping_1.0.4 jquerylib_0.1.4 systemfonts_1.3.1
#> [17] compiler_4.4.0 tools_4.4.0 ragg_1.5.0 bslib_0.9.0
#> [21] evaluate_1.0.5 yaml_2.3.12 otel_0.2.0 jsonlite_2.0.0
#> [25] rlang_1.1.7 fs_1.6.6 htmlwidgets_1.6.4