Source code for nanoml.dtype

import torch


[docs] def is_bf16_supported() -> bool: """Check if bfloat16 is supported on the current device. Returns: bool: True if bfloat16 is supported, False otherwise. """ major_version, _ = torch.cuda.get_device_capability() return major_version >= 8
[docs] def get_half_dtype() -> torch.dtype: """Get the half dtype for the current device. Returns: torch.dtype: The half dtype (torch.float16 or torch.bfloat16) """ if is_bf16_supported(): return torch.bfloat16 return torch.float16
[docs] def get_half_dtype_string() -> str: """Get the half dtype string for the current device. Returns: str: The half dtype as a string (bfloat16 or float16) """ if is_bf16_supported(): return "bfloat16" return "float16"