sakura.model_controllers.extractor_controller.ExtractorController.train
- ExtractorController.train(batch, backward_reconstruction_loss=True, backward_main_latent_regularization=True, backward_pheno_loss=True, selected_pheno: dict | None = None, backward_signature_loss=True, selected_signature: dict | None = None, suppress_backward=False, detach=False, detach_from='', save_raw_loss=False)
Train the model using the spectrain(ified batch of data.
This function performs a training step by computing the losses associated with reconstruction, main latent regularization, phenotype, and signature components. It then performs backpropagation to update the model weights based on the computed total loss.
- Parameters:
batch (dict) – A dictionary containing a batched data for training, typically obtained from rna_count dataset
backward_reconstruction_loss (bool, optional) – Whether to optimize and backward reconstruction loss, defaults to True
backward_main_latent_regularization (bool, optional) – Whether to optimize and backward regularization of main latent space, defaults to True
backward_pheno_loss (bool, optional) – Whether to optimize and backward phenotype-related loss, defaults to True
selected_pheno (dict, optional) – Phenotype selection for backpropagation optimization, should be None (selecting all phenotypes, and related losses and regularizations), or a dictionary formulated as: {‘pheno_name’: {‘loss’: [list of loss names] or ‘*’ (selecting all), ‘regularization’: [list of regularization keys, could be Null] or ‘*’ or None (no regularization)}}
backward_signature_loss (bool, optional) – Whether to optimize and backward signature-related loss, defaults to True
selected_signature (dict, optional) – Signature selection for backpropagation optimization, should be None (selecting all signatures, and related losses and regularizations), or a dictionary formulated similar to selected_pheno
suppress_backward (bool, optional) – Whether to suppress backward of sum of losses (useful when external training agent override the control)
detach (bool, optional) – Should losses be detached as specified in <detach_from> from the computation graph, defaults to False
detach_from (str, optional) – Starting point in the model from which the loss should be detached
save_raw_loss (bool, optional) – Whether to record unweighted, raw losses apart from the weighted losses, defaults to False
- Returns:
A dictionary containing the losses computed during this training tick, including total loss and individual component losses
- Return type: