Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run export script on CPU #78

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 60 additions & 40 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
Place it into the root directory of:
https://github.com/facebookresearch/llama

And then run it similar to their other examples, via torchrun sadly:
torchrun --nproc_per_node 1 export_meta_llama_bin.py
And then run:
python export_meta_llama_bin.py
"""

from llama import Llama
import json
from pathlib import Path

import torch

# -----------------------------------------------------------------------------
def export(self, filepath='model.bin'):
def export(checkpoint, params, filepath='model.bin'):
"""export the model weights in fp32 into .bin file to be read from C"""

f = open(filepath, 'wb')
Expand All @@ -24,68 +27,85 @@ def serialize(t):
f.write(b)

# first write out the header
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
p = self.params
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
n_kv_heads, -p.vocab_size, p.max_seq_len)
hidden_dim = checkpoint['layers.0.feed_forward.w1.weight'].shape[0] #self.layers[0].feed_forward.w1.weight.shape[0]
p = params
p['max_seq_len'] = 4096
p['vocab_size'] = 32000
n_layers = p['n_layers']
n_kv_heads = p.get('n_kv_heads', p['n_heads'])
header = struct.pack('iiiiiii', p['dim'], hidden_dim, n_layers, p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len'])
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
f.write(header)

# next write out the embedding weights
print("writing tok_embeddings...")
serialize(self.tok_embeddings.weight)
serialize(checkpoint['tok_embeddings.weight'].type(torch.HalfTensor))

# now all the layers
# attention weights
for i, layer in enumerate(self.layers):
for i in range(n_layers):
print(f"writing attention_norm layer {i}...")
serialize(layer.attention_norm.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.attention_norm.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wq layer {i}...")
serialize(layer.attention.wq.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.attention.wq.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wk layer {i}...")
serialize(layer.attention.wk.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.attention.wk.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wv layer {i}...")
serialize(layer.attention.wv.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.attention.wv.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing attention.wo layer {i}...")
serialize(layer.attention.wo.weight)
serialize(checkpoint['layers.'+str(i)+'.attention.wo.weight'].type(torch.HalfTensor))
# ffn weights
for i, layer in enumerate(self.layers):
for i in range(n_layers):
print(f"writing ffn_norm layer {i}...")
serialize(layer.ffn_norm.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.ffn_norm.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w1 layer {i}...")
serialize(layer.feed_forward.w1.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w1.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w2 layer {i}...")
serialize(layer.feed_forward.w2.weight)
for i, layer in enumerate(self.layers):
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w2.weight'].type(torch.HalfTensor))
for i in range(n_layers):
print(f"writing feed_forward.w3 layer {i}...")
serialize(layer.feed_forward.w3.weight)
serialize(checkpoint['layers.'+str(i)+'.feed_forward.w3.weight'].type(torch.HalfTensor))


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

freqs_cis = precompute_freqs_cis(
p['dim'] // p['n_heads'], p['max_seq_len'] * 2
)

# final rmsnorm
print("writing final rmsnorm, classifier and freq_cis...")
serialize(self.norm.weight)
serialize(checkpoint['norm.weight'].type(torch.HalfTensor))
# freqs_cis
serialize(self.freqs_cis.real[:p.max_seq_len])
serialize(self.freqs_cis.imag[:p.max_seq_len])
serialize(freqs_cis.real[:p['max_seq_len']].type(torch.HalfTensor))
serialize(freqs_cis.imag[:p['max_seq_len']].type(torch.HalfTensor))
# finally write the output weights
serialize(self.output.weight)
serialize(checkpoint['output.weight'].type(torch.HalfTensor))

# write to binary file
f.close()
print(f"wrote {filepath}")
# -----------------------------------------------------------------------------

# init Llama as normal
generator = Llama.build(
ckpt_dir="llama-2-7b",
tokenizer_path="tokenizer.model",
max_seq_len=4096,
max_batch_size=1,
)
export(generator.model, "llama2_7b.bin")
if __name__ == '__main__':
ckpt_dir = "llama-2-7b"
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

export(checkpoint, params, "llama2_7b.bin")