Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 #545

Open
ryao opened this issue Dec 19, 2024 · 0 comments
Open

bfloat16 #545

ryao opened this issue Dec 19, 2024 · 0 comments

Comments

@ryao
Copy link

ryao commented Dec 19, 2024

The weights are natively bfloat16. Rather than convert them into float, you could just keep them as bfloat16 and convert between float and bfloat16 on the fly using a union type and a bitshift. This should double performance in the forward() function since it is memory bandwidth bound. The only caveat is that you would need to handle subnormal numbers when converting from float to bfloat16.

There are two ways of doing this:

  1. Check for subnormal numbers via issubnormal() and zero them when converting from float to BF16.
  2. Set bit 15 of the MXCSR on amd64 CPUs (non-portable)

Presumably, both could be used via a CPP check. The issubnormal() check could be done on non-amd64 processors while bit 15 of the MXCSR could be set on amd64 processors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant