sakura.sakuraAE.sakuraAE.train_hybrid
- sakuraAE.train_hybrid(split_configs: dict, ticks=50000, hybrid_mode='interleave', prog_loss_weight_mode='epoch_end', make_logs=True, log_prefix='', log_loss_groups=['loss', 'regularization'], save_raw_loss=False, perform_test=False, test_segment=2000, tests: dict | None = None, perform_checkpoint=False, checkpoint_segment=2000, checkpoint_prefix='', checkpoint_save_arch=False, loss_prog_on_test: dict | None = None, resume=False, resume_dict=None)
Train the model in hybrid mode, where model module splits are trained with flexibility.
- Parameters:
split_configs (dict[str, str or int]) – A dictionary containing module split configurations used for training, should contain below keys for each module split: ‘use_split’,’batch_size’,’train_main_latent’,’train_pheno’,’train_signature’
ticks (int, optional) – The total number of training iterations, each tick corresponding to the training of one batch of data, defaults to 50000
hybrid_mode (Literal['interleave', 'pattern', 'sum'], optional) – hybrid mode defines how the module splits are trained, defaults to ‘interleave’ where each module split is trained in a round-robin fashion.
prog_loss_weight_mode (Literal['on_test', 'epoch_end'], optional) – The mode for progressive loss weighting. defaults to ‘epoch_end’ where loss weights progress at the end of each epoch.
make_logs (bool, optional) – Should information, including losses be logged, defaults to True
log_prefix (str, optional) – Prefix of training log (for losses, this prefix will be added first to the item name in tensorboard and filename of latent embeddings)
log_loss_groups (list[str], optional) – Selected loss(es) group to be logged, defaults to [‘loss’, ‘regularization’]
save_raw_loss (bool, optional) – Whether to record raw losses, defaults to False
perform_test (bool, optional) – Whether to perform testing during training at specified <test_segment> intervals, defaults to False
test_segment (int, optional) – Tick interval at which testing is performed, defaults to 2000
tests (list[dict[str,Any]], optional) – A list of test configuration dictionaries, where each dictionary should contain keys: ‘on_split’, ‘make_logs’, ‘dump_latent’ and ‘latent_prefix’
perform_checkpoint (bool, optional) – Whether to checkpoint the model a certain tick interval, defaults to False
checkpoint_segment (int, optional) – Tick interval of model checkpoint, defaults to 2000
checkpoint_prefix (str, optional) – Prefix of checkpoint files
checkpoint_save_arch (bool, optional) – Should model architecture be checkpointed, defaults to False
loss_prog_on_test (dict[str, Any], optional) – A dictionary specifying progressive loss weights to use during testing when prog_loss_weight_mode is ‘on_test’, should contain keys: ‘prog_main’, ‘train_pheno’,’selected_pheno’,’train_signature’ and ‘selected_signature’
resume (bool, optional) – Whether to resume from saved training session, defaults to False
resume_dict (dict[str, Any], optional) – Session state dictionary used for resuming previous training
Note
When epoch loss progressing is on, the progression will incur only for selected loss when an epoch ends (tick reach end).
- Returns:
None