Skip to contents

Train a TorchDecon model or ensemble on processed training data.

Usage

TrainModel(
  model,
  data,
  batch_size = 128L,
  learning_rate = 1e-04,
  num_steps = 1000L,
  validation_split = 0,
  early_stopping = FALSE,
  patience = 500L,
  checkpoint_dir = NULL,
  verbose = TRUE,
  seed = NULL
)

Arguments

model

A TorchDeconModel or TorchDeconEnsemble object.

data

A TorchDeconProcessed object from ProcessTrainingData, or a list containing X (features) and Y (labels) matrices.

batch_size

Integer. Batch size for training. Default is 128.

learning_rate

Numeric. Learning rate for Adam optimizer. Default is 0.0001.

num_steps

Integer. Number of training steps. Default is 5000.

validation_split

Numeric. Fraction of data to use for validation (0-1). Default is 0 (no validation).

early_stopping

Logical. Enable early stopping based on validation loss. Default is FALSE.

patience

Integer. Number of steps without improvement before stopping. Default is 500.

checkpoint_dir

Character. Directory to save model checkpoints. Default is NULL.

verbose

Logical. Print training progress. Default is TRUE.

seed

Integer. Random seed. Default is NULL.

Value

The trained model object (modified in place and returned).

Details

The training process uses:

  • Adam optimizer with configurable learning rate

  • Mean Squared Error (MSE) loss function

  • Mini-batch gradient descent

  • Optional validation and early stopping

For ensemble models, each sub-model is trained sequentially.

Examples

if (FALSE) { # \dontrun{
# Train a single model
model <- CreateTorchDecon(n_features = 5000, n_classes = 10)
model <- TrainModel(model, processed_data, num_steps = 5000)

# Train an ensemble
ensemble <- CreateTorchDeconEnsemble(n_features = 5000, n_classes = 10)
ensemble <- TrainModel(ensemble, processed_data, num_steps = 5000)
} # }