sakura.models.extractor.Extractor.forward
- Extractor.forward(batch, forward_signature=True, selected_signature=None, forward_pheno=True, selected_pheno=None, forward_main_latent=True, forward_reconstruction=True, detach=False, detach_from='')
Forward extractor framework with control over computation branches
Orchestrates data flow through the assembled modular architecture, enabling selective activation of task branches and gradient flow control.
- Parameters:
batch (torch.Tensor) – Gene expression tensors, shape should be (N,M), where N is number of cell, M is number of gene
forward_signature (bool, optional) – Whether to forward signature supervision part, defaults to True
selected_signature (list[str], optional) – A list of selected signatures to be forwarded, None to forward all signatures, optional
forward_pheno (bool, optional) – Whether to forward phenotype supervision part, defaults to True
selected_pheno (list[str], optional) – A list of selected phenotypes to be forwarded, None to forward all phenotypes
forward_main_latent (bool, optional) – Whether to forward main latent part, defaults to True
forward_reconstruction* – Whether to forward decoder reconstruction part, defaults to True
detach (bool, optional) – Should the gradient be blocked from midway of the network as specified in <detach_from>, defaults to False
detach_from (Literal['pre_encoder', 'encoder'] or str, optional) – Specific component from which the gradient should be blocked if <detach> is True
Note
<forward_reconstruction>: The decoder reconstruction part could only be forwarded when all latent dimensions are forwarded.
- <detach_from> options:
‘pre_encoder’ (lat_pre will be detached, pre_encoder will not be trained);
‘encoder’ (main_lat, pheno_lat, signature_lat will be detached, neither pre-encoder nor encoder will be trained).
Gradient reverse layer and gradient neutralize layer related computations are done in
model_controllers.extractor_controller.