186 lines
7.0 KiB
Python
186 lines
7.0 KiB
Python
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
Convert llama weight.
|
|
Support huggingface format and Meta format.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import mindspore as ms
|
|
from mindspore import ops
|
|
|
|
|
|
def convert_meta_torch_ckpt(ckpt_dir, output_name, dtype=ms.float16):
|
|
"""Support convert meta weight splited."""
|
|
print(f"Trying to convert pytorch checkpoint in '{ckpt_dir}'.", flush=True)
|
|
try:
|
|
from torch import load
|
|
except:
|
|
raise ImportError(f"Failed to load pytorch checkpoint. Please make sure pytorch is available.")
|
|
dic = {
|
|
'tok_embeddings.weight': 1,
|
|
'norm.weight': None,
|
|
'output.weight': 0,
|
|
'attention.wq.weight': 0,
|
|
'attention.wk.weight': 0,
|
|
'attention.wv.weight': 0,
|
|
'attention.wo.weight': 1,
|
|
'feed_forward.w1.weight': 0,
|
|
'feed_forward.w2.weight': 1,
|
|
'feed_forward.w3.weight': 0,
|
|
'attention_norm.weight': None,
|
|
'ffn_norm.weight': None,
|
|
'rope.freqs': None,
|
|
}
|
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
if not ckpt_paths:
|
|
print(f"Do not find pytorch checkpoint in '{ckpt_dir}'.", flush=True)
|
|
return False
|
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
|
model_args = json.loads(f.read())
|
|
n_heads = model_args["n_heads"]
|
|
dim = model_args["dim"]
|
|
|
|
def permute(w):
|
|
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
|
|
|
checkpoints = []
|
|
for i in range(len(ckpt_paths)):
|
|
checkpoints.append(load(ckpt_paths[i], map_location="cpu"))
|
|
ckpt_list = []
|
|
for name in checkpoints[0].keys():
|
|
for k, v in dic.items():
|
|
if k in name:
|
|
if v is not None:
|
|
value = np.concatenate(
|
|
[checkpoints[i][name].numpy() for i in range(len(checkpoints))], v)
|
|
else:
|
|
value = checkpoints[0][name].numpy()
|
|
if name == 'norm.weight':
|
|
name = 'norm_out.weight'
|
|
if name == 'output.weight':
|
|
name = 'lm_head.weight'
|
|
else:
|
|
name = 'model.' + name
|
|
if 'rope.freqs' in name:
|
|
continue
|
|
|
|
if 'wq' in name or 'wk' in name:
|
|
value = permute(value)
|
|
print(f'\rprocessing parameter: {name} {value.shape} ', end='', flush=True)
|
|
ckpt_list.append({'name': name, 'data': ms.Tensor(value, dtype=dtype)})
|
|
|
|
ckpt_file = os.path.join(ckpt_dir, output_name)
|
|
ms.save_checkpoint(ckpt_list, ckpt_file)
|
|
print(f"\rConvert pytorch checkpoint finished, the mindspore checkpoint is saved in '{ckpt_file}'.", flush=True)
|
|
return True
|
|
|
|
|
|
def read_json(path):
|
|
with open(path, "r") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def name_replace(name: str):
|
|
"""replace hf param name to ms."""
|
|
name = name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')
|
|
name = name.replace('.self_attn.q_proj.', '.attention.wq.')
|
|
name = name.replace('.self_attn.k_proj.', '.attention.wk.')
|
|
name = name.replace('.self_attn.v_proj.', '.attention.wv.')
|
|
name = name.replace('.self_attn.o_proj.', '.attention.wo.')
|
|
name = name.replace('.mlp.gate_proj.', '.feed_forward.w1.')
|
|
name = name.replace('.mlp.down_proj.', '.feed_forward.w2.')
|
|
name = name.replace('.mlp.up_proj.', '.feed_forward.w3.')
|
|
name = name.replace('.input_layernorm.', '.attention_norm.')
|
|
name = name.replace('.post_attention_layernorm.', '.ffn_norm.')
|
|
name = name.replace('.norm.', '.norm_out.')
|
|
return name
|
|
|
|
|
|
def convert_hf_ckpt(ckpt_dir, output_name, dtype=ms.float16):
|
|
"""convert hf weight to ms."""
|
|
print(f"Trying to convert huggingface checkpoint in '{ckpt_dir}'.", flush=True)
|
|
try:
|
|
from transformers import LlamaForCausalLM
|
|
except:
|
|
raise ImportError(f"Failed to load huggingface checkpoint. Please make sure transformers is available.")
|
|
|
|
try:
|
|
model_hf = LlamaForCausalLM.from_pretrained(ckpt_dir)
|
|
# pylint: disable=W0703
|
|
except Exception as e:
|
|
print(f"Do not find huggingface checkpoint in '{ckpt_dir}', Error {e.message}.", flush=True)
|
|
return False
|
|
ckpt_list = []
|
|
hf_list = []
|
|
mp = model_hf.named_parameters()
|
|
for name, value in model_hf.named_parameters():
|
|
nn = name
|
|
ss = value.shape
|
|
hf_list.append([name, value])
|
|
name = name_replace(name)
|
|
if name == 'norm.weight':
|
|
name = 'norm_out.weight'
|
|
if name[:7] == 'layers.':
|
|
name = name[7:]
|
|
value = value.detach().numpy()
|
|
print(f'\rprocessing parameter: {name} {value.shape} ', end='', flush=True)
|
|
ckpt_list.append({'name': name, 'data': ms.Tensor(value, dtype=dtype)})
|
|
|
|
ckpt_file = os.path.join(ckpt_dir, output_name)
|
|
ms.save_checkpoint(ckpt_list, os.path.join(ckpt_file))
|
|
print(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{ckpt_file}'.", flush=True)
|
|
return True
|
|
|
|
|
|
def convert_to_new_ckpt(ckpt_path, config_path):
|
|
"""convert previous ckpt to new ckpt"""
|
|
load_path = ckpt_path.split('.ckpt')[0]
|
|
save_path = load_path + "_hf"
|
|
params = ms.load_checkpoint(load_path.split('.ckpt')[0] + '.ckpt')
|
|
with open(config_path, "r") as f:
|
|
model_args = json.loads(f.read())
|
|
n_heads = model_args["n_heads"]
|
|
dim = model_args["dim"]
|
|
def permute(w):
|
|
return ops.transpose(w.reshape(n_heads, dim // n_heads // 2, 2, dim), (0, 2, 1, 3)).reshape(dim, dim)
|
|
|
|
ckpt_list = []
|
|
for name in params.keys():
|
|
value = params[name].value()
|
|
if '.wq' in name or '.wk' in name:
|
|
value = permute(value)
|
|
ckpt_list.append({'name': name, 'data': value})
|
|
print("\r", name, value.shape, end=" ")
|
|
|
|
ms.save_checkpoint(ckpt_list, save_path)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--torch_ckpt_dir', default='./llama_model/llama-13b-hf/')
|
|
parser.add_argument('--mindspore_ckpt_path', default='transform.ckpt')
|
|
parser.add_argument('--pre_ckpt_path', default=None)
|
|
parser.add_argument('--config_path', default=None)
|
|
args = parser.parse_args()
|
|
if args.pre_ckpt_path is not None and args.config_path is not None:
|
|
convert_to_new_ckpt(args.pre_ckpt_path, args.config_path)
|
|
else:
|
|
convert_hf_ckpt(ckpt_dir=args.torch_ckpt_dir, output_name=args.mindspore_ckpt_path)
|