sakura.sakuraAE.sakuraAE
- class sakura.sakuraAE.sakuraAE(config_json_path, verbose=False, suppress_train=False, suppress_tensorboardX=False)
Bases:
objectA comprehensive class for SAKURA pipeline
This class manages overall workflow of SAKURA includeing model initialization, training, testing, and model inference or external model merging based on the configuration and argument settings.
- Parameters:
config_json_path (str) – Path to the configuration JSON file, which contains all the necessary settings for the class
verbose (bool, optional) – Whether to enable verbose console logging, defaults to False
suppress_train (bool, optional) – Whether to suppress model training, only setup dataset and model, defaults to False
suppress_tensorboardX (bool, optional) – Whether to suppress Logger to initiate tensorboardX (to prevent flushing logs), defaults to False
Methods
Perform inference on the given tasks represented as a list of stories.
Generate dataset split masks for model training and testing.
Insert an external module and merge it with SAKURA model.
Perform integrity check on selected phenotypes/signatures against the input dataset.
Load a checkpoint file and resume the model's state, including parameters, random states, training progress, etc.
Save the current state of the model and training process as a checkpoint.
Set up dataset for SAKURA model.
Test all latents of the model, with options to evaluate specific loss groups and dump selected latents.
Batch train model for at least one epoch.
Train the model in hybrid mode, where model module splits are trained with flexibility.
Implement the multithread dataloader version of hybrid mode training, where model module splits are trained with flexibility.
Train the model on the given sets of tasks represented as a list of storylines.