You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The weights are natively bfloat16. Rather than convert them into float, you could just keep them as bfloat16 and convert between float and bfloat16 on the fly using a union type and a bitshift. This should double performance in the forward() function since it is memory bandwidth bound. The only caveat is that you would need to handle subnormal numbers when converting from float to bfloat16.
There are two ways of doing this:
Check for subnormal numbers via issubnormal() and zero them when converting from float to BF16.
Set bit 15 of the MXCSR on amd64 CPUs (non-portable)
Presumably, both could be used via a CPP check. The issubnormal() check could be done on non-amd64 processors while bit 15 of the MXCSR could be set on amd64 processors.
The text was updated successfully, but these errors were encountered:
The weights are natively bfloat16. Rather than convert them into float, you could just keep them as bfloat16 and convert between float and bfloat16 on the fly using a union type and a bitshift. This should double performance in the forward() function since it is memory bandwidth bound. The only caveat is that you would need to handle subnormal numbers when converting from float to bfloat16.
There are two ways of doing this:
Presumably, both could be used via a CPP check. The issubnormal() check could be done on non-amd64 processors while bit 15 of the MXCSR could be set on amd64 processors.
The text was updated successfully, but these errors were encountered: