sakura.utils.size_estimator.SizeEstimator
- class sakura.utils.size_estimator.SizeEstimator(model, input_size=(1, 1, 32, 32), bits=32)
Bases:
objectEstimates memory consumption of PyTorch models
- Calculates:
Parameter storage requirements
Activation memory for forward pass
Gradient memory for backward pass
Input tensor memory
- Parameters:
Methods
Calculate bits needed for activation storage during forward/backward passes.
Calculate bits required for input tensor storage.
Calculate total bits required for parameter storage.
Calculate total memory requirements.
Determine output dimensions for each layer by running a sample input through the model.
Collect dimensions of all parameters in the model.