nanoml.dtype module

nanoml.dtype.get_half_dtype() dtype[source]

Get the half dtype for the current device.

Returns:

torch.dtype: The half dtype (torch.float16 or torch.bfloat16)

nanoml.dtype.get_half_dtype_string() str[source]

Get the half dtype string for the current device.

Returns:

str: The half dtype as a string (bfloat16 or float16)

nanoml.dtype.is_bf16_supported() bool[source]

Check if bfloat16 is supported on the current device.

Returns:

bool: True if bfloat16 is supported, False otherwise.