R6 class for implementing the CellODE training process.
Public fields
seurat_objSeurat object for training
modelTNODE model
optimizerAdam optimizer
deviceComputation device
logTraining log
time_reverseWhether to reverse time
model_kwargsModel configuration
Methods
Method new()
Initialize Trainer
Usage
Trainer$new(
seurat_obj,
assay = "RNA",
slot = NULL,
percent = NULL,
n_latent = 5L,
n_ode_hidden = 25L,
n_vae_hidden = 128L,
batch_norm = FALSE,
ode_method = "euler",
step_size = NULL,
alpha_recon_lec = 0.5,
alpha_recon_lode = 0.5,
alpha_kl = 1,
loss_mode = "nb",
nepoch = NULL,
batch_size = 1024L,
drop_last = FALSE,
lr = 0.001,
wt_decay = 1e-06,
eps = 0.01,
random_state = 0L,
val_frac = 0.1,
use_gpu = TRUE
)Arguments
seurat_objSeurat object with expression data
assayAssay to use (default: "RNA")
slotSlot to use (default: "counts" for nb/zinb, "data" for mse)
percentPercentage of cells for training (default: auto)
n_latentLatent space dimensions (default: 5)
n_ode_hiddenODE hidden layer size (default: 25)
n_vae_hiddenVAE hidden layer size (default: 128)
batch_normUse batch normalization (default: FALSE)
ode_methodODE solver (default: "euler")
step_sizeStep size multiplier (default: NULL)
alpha_recon_lecEncoder reconstruction weight (default: 0.5)
alpha_recon_lodeODE reconstruction weight (default: 0.5)
alpha_klKL divergence weight (default: 1.0)
loss_modeLoss mode: "mse", "nb", "zinb" (default: "nb")
nepochNumber of epochs (default: auto)
batch_sizeBatch size (default: 1024)
drop_lastDrop last incomplete batch (default: FALSE)
lrLearning rate (default: 1e-3)
wt_decayWeight decay (default: 1e-6)
epsAdam epsilon (default: 0.01)
random_stateRandom seed (default: 0)
val_fracValidation fraction (default: 0.1)
use_gpuUse GPU if available (default: TRUE)
Method get_latentsp()
Get latent space representation
Usage
Trainer$get_latentsp(
alpha_z = 0.5,
alpha_predz = 0.5,
step_size = NULL,
step_wise = FALSE,
batch_size = NULL
)Method load_model()
Load trained model