35 lines
1.2 KiB
Python
Executable File
35 lines
1.2 KiB
Python
Executable File
import torch
|
|
|
|
|
|
class TorchAutocast:
|
|
"""TorchAutocast utility class.
|
|
Allows you to enable and disable autocast. This is specially useful
|
|
when dealing with different architectures and clusters with different
|
|
levels of support.
|
|
|
|
Args:
|
|
enabled (bool): Whether to enable torch.autocast or not.
|
|
args: Additional args for torch.autocast.
|
|
kwargs: Additional kwargs for torch.autocast
|
|
"""
|
|
def __init__(self, enabled: bool, *args, **kwargs):
|
|
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
|
|
|
def __enter__(self):
|
|
if self.autocast is None:
|
|
return
|
|
try:
|
|
self.autocast.__enter__()
|
|
except RuntimeError:
|
|
device = self.autocast.device
|
|
dtype = self.autocast.fast_dtype
|
|
raise RuntimeError(
|
|
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
|
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
|
)
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
if self.autocast is None:
|
|
return
|
|
self.autocast.__exit__(*args, **kwargs)
|