Skip to contents

R6 class for implementing the CellODE training process.

Public fields

seurat_obj

Seurat object for training

model

TNODE model

optimizer

Adam optimizer

device

Computation device

log

Training log

time_reverse

Whether to reverse time

model_kwargs

Model configuration

Methods


Method new()

Initialize Trainer

Usage

Trainer$new(
  seurat_obj,
  assay = "RNA",
  slot = NULL,
  percent = NULL,
  n_latent = 5L,
  n_ode_hidden = 25L,
  n_vae_hidden = 128L,
  batch_norm = FALSE,
  ode_method = "euler",
  step_size = NULL,
  alpha_recon_lec = 0.5,
  alpha_recon_lode = 0.5,
  alpha_kl = 1,
  loss_mode = "nb",
  nepoch = NULL,
  batch_size = 1024L,
  drop_last = FALSE,
  lr = 0.001,
  wt_decay = 1e-06,
  eps = 0.01,
  random_state = 0L,
  val_frac = 0.1,
  use_gpu = TRUE
)

Arguments

seurat_obj

Seurat object with expression data

assay

Assay to use (default: "RNA")

slot

Slot to use (default: "counts" for nb/zinb, "data" for mse)

percent

Percentage of cells for training (default: auto)

n_latent

Latent space dimensions (default: 5)

n_ode_hidden

ODE hidden layer size (default: 25)

n_vae_hidden

VAE hidden layer size (default: 128)

batch_norm

Use batch normalization (default: FALSE)

ode_method

ODE solver (default: "euler")

step_size

Step size multiplier (default: NULL)

alpha_recon_lec

Encoder reconstruction weight (default: 0.5)

alpha_recon_lode

ODE reconstruction weight (default: 0.5)

alpha_kl

KL divergence weight (default: 1.0)

loss_mode

Loss mode: "mse", "nb", "zinb" (default: "nb")

nepoch

Number of epochs (default: auto)

batch_size

Batch size (default: 1024)

drop_last

Drop last incomplete batch (default: FALSE)

lr

Learning rate (default: 1e-3)

wt_decay

Weight decay (default: 1e-6)

eps

Adam epsilon (default: 0.01)

random_state

Random seed (default: 0)

val_frac

Validation fraction (default: 0.1)

use_gpu

Use GPU if available (default: TRUE)


Method train()

Train the model

Usage

Trainer$train()


Method get_time()

Get pseudotime for all cells

Usage

Trainer$get_time()

Returns

Numeric vector of pseudotime values


Method get_vector_field()

Get vector field

Usage

Trainer$get_vector_field(t, z)

Arguments

t

Pseudotime vector

z

Latent space matrix

Returns

Matrix of vector field


Method get_latentsp()

Get latent space representation

Usage

Trainer$get_latentsp(
  alpha_z = 0.5,
  alpha_predz = 0.5,
  step_size = NULL,
  step_wise = FALSE,
  batch_size = NULL
)

Arguments

alpha_z

Weight for encoder-derived latent (default: 0.5)

alpha_predz

Weight for ODE-derived latent (default: 0.5)

step_size

Step size for integration (default: NULL)

step_wise

Use step-wise integration (default: FALSE)

batch_size

Batch size (default: NULL for all)

Returns

List with mix_zs, zs, pred_zs matrices


Method save_model()

Save trained model

Usage

Trainer$save_model(path)

Arguments

path

File path (without extension)


Method load_model()

Load trained model

Usage

Trainer$load_model(path)

Arguments

path

File path (without extension)


Method clone()

The objects of this class are cloneable with this method.

Usage

Trainer$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.