sakura.sakuraAE.sakuraAE.train
- sakuraAE.train(split_id, train_main=True, train_pheno=True, selected_pheno=None, train_signature=True, selected_signature=None, epoch=50, batch_size=100, tick_controller_epoch=True, make_logs=True, log_prefix='train', log_loss_groups=['loss', 'regularization'], save_raw_loss=False, test_every_epoch=False, test_on_segment=False, test_segment=2000, tests=None, checkpoint_on_segment=False, checkpoint_segment=2000, checkpoint_prefix='', checkpoint_save_arch=False, resume=False, resume_dict=None, detach=False, detach_from='')
Batch train model for at least one epoch.
- Parameters:
split_id (str) – Split id to be used in this train
train_main (bool, optional) – Whether to forward the main latent space part during training, defaults to True
train_pheno (bool, optional) – Whether to forward phenotype side task(s) during training, defaults to True
selected_pheno* – Phenotype id(s) used for phenotype side tasks during training, selected phenotype(s)
train_signature (bool, optional) – Whether to forward gene signature side tasks during training, defaults to True
selected_signature* – Similar to selected_pheno, but for signature side tasks during training
epoch (int, optional) – Number of epochs to be trained in this round of training, defaults to 50
batch_size (int, optional) – Batch size to be used in this round of training, defaults to 100
tick_controller_epoch (bool, optional) – Should controller epoch be ticked, defaults to True
make_logs (bool, optional) – Should information, including losses be logged, defaults to True
log_prefix (str, optional) – Prefix of training log (for losses, this prefix will be added first to the item name in tensorboard and filename of latent embeddings), defaults to ‘train’
log_loss_groups (list[str], optional) – Selected loss(es) group to be logged, defaults to [‘loss’, ‘regularization’]
save_raw_loss (bool, optional) – Whether to record raw losses, defaults to False
test_every_epoch (bool, optional) – Should test/evaluation be performed after finishing each epoch, defaults to False
test_on_segment (bool, optional) – Whether to do segmental testing, defaults to False
test_segment (int, optional) – Tick interval of test segment, defaults to 2000
tests (list[dict[str,Any]], optional) – A list of test configuration dictionaries, where each dictionary should contain keys: ‘on_split’, ‘make_logs’, ‘dump_latent’ and ‘latent_prefix’
checkpoint_on_segment (bool, optional) – Should model be checkpointed after a certain tick interval, defaults to False
checkpoint_segment (int, optional) – Tick interval of checkpoint segment, defaults to 2000
checkpoint_prefix (str, optional) – Prefix of checkpoint files
checkpoint_save_arch (bool) – Should model architecture be checkpointed, defaults to False
resume (bool, optional) – Whether to resume from saved training session, defaults to False
resume_dict (dict[str, Any], optional) – Session state dictionary used for resuming previous training
detach (bool, optional) – Should losses be detached as specified in <detach_from> from the computation graph, defaults to False
detach_from (str, optional) – Starting point in the model from which the loss should be detached
Note
The selected_pheno (selected signature) should be configured and stored in self.selected_pheno (self.selected_signature). If it is set to None, self.selected_pheno (self.selected_signature) will act as the default, which means that all selected phenotypes (or signatures) will be trained. This feature is designed for complex training scenarios where the neural network (NN) is partially forwarded.
- Returns:
None