大模型推理

有分布式推理需求执行第2部分,没有的话执行第1、3部分。

1.合并权重

1.1先用filter_ckpt_param.py过滤以下权重参数,把优化器参数这些都过滤掉,只保留模型参数

import os
from glob import glob
import mindspore as ms

ignore_keys = ['accu_grads',
               'scale_sense',
               'global_step',
               'adam',
               'current_iterator_step',
               'last_overflow_iterator_step',
               'epoch_num',
               'step_num',
               'loss_scale']

def only_save_model_param(ckpt_path, save_path):
    checkpoint = ms.load_checkpoint(ckpt_path)
    new_param_list = []
    for name, param in checkpoint.items():
        ignore = False
        for key in ignore_keys:
            if key in name:
                ignore = True
                break
        if not ignore:
            new_param_list.append({"name": name, "data": param})
    ms.save_checkpoint(new_param_list, save_path)
    print(f"process {ckpt_path} finished!")
    
if __name__ == '__main__':
    
    ckpt_path_or_dir = '/home/huawei/output/output'
    assert os.path.exists(ckpt_path_or_dir), f'{ckpt_path_or_dir} not exists!' 
    if os.path.isfile(ckpt_path_or_dir):
        ckpt_paths = [ckpt_path_or_dir]
    elif os.path.isdir(ckpt_path_or_dir):
        ckpt_paths = glob(os.path.join(ckpt_path_or_dir, 'rank*/*.ckpt'))
    
    save_root = "filter_out"
    for ckpt_path in ckpt_paths:
        replace_part = ckpt_path.split('/rank')[0]
        save_path = ckpt_path.replace(replace_part, save_root)
        save_dir = os.path.dirname(save_path)
        os.makedirs(save_dir, exist_ok=True)
        only_save_model_param(ckpt_path, save_path)

1.2然后用transform_ckpt.py合并权重,dst_ckpt_strategy不填

import os
import argparse
import mindspore as ms

def get_strategy(startegy_path, rank_id=None):
    """Merge strategy if strategy path is dir

    Args:
        startegy_path (str): The path of stategy.
        rank_id (int): The rank id of device.

    Returns:
        None or strategy path
    """
    if not startegy_path:
        return None

    assert os.path.exists(startegy_path), f'{startegy_path} not found!'

    if os.path.isfile(startegy_path):
        return startegy_path

    if os.path.isdir(startegy_path):
        if rank_id:
            merge_path = os.path.join(startegy_path, f'merged_ckpt_strategy_{rank_id}.ckpt')
        else:
            merge_path = os.path.join(startegy_path, f'merged_ckpt_strategy.ckpt')

        if os.path.exists(merge_path):
            os.remove(merge_path)

        ms.merge_pipeline_strategys(startegy_path, merge_path)
        return merge_path

    return None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--src_ckpt_strategy',
                        default="",
                        help='path of src ckpt strategy')
    parser.add_argument('--dst_ckpt_strategy',
                        default="",
                        help='path of dst ckpt strategy')
    parser.add_argument('--src_ckpt_dir',
                        default="",
                        type=str,
                        help='path of src ckpt')
    parser.add_argument('--dst_ckpt_dir',
                        default="",
                        type=str,
                        help='path where to save dst ckpt')
    parser.add_argument('--prefix',
                        default='checkpoint_',
                        type=str,
                        help='prefix of transformed checkpoint')
    args = parser.parse_args()

    src_ckpt_strategy = get_strategy(args.src_ckpt_strategy)
    dst_ckpt_strategy = get_strategy(args.dst_ckpt_strategy)
    src_ckpt_dir = args.src_ckpt_dir
    dst_ckpt_dir = args.dst_ckpt_dir
    prefix = args.prefix

    assert os.path.exists(args.src_ckpt_dir), f'{args.src_ckpt_dir} not found!'

    print(f"src_ckpt_strategy: {src_ckpt_strategy}")
    print(f"dst_ckpt_strategy: {dst_ckpt_strategy}")
    print(f"src_ckpt_dir: {src_ckpt_dir}")
    print(f"dst_ckpt_dir: {dst_ckpt_dir}")
    print(f"prefix: {prefix}")

    print("......Start transform......")
    ms.transform_checkpoints(src_ckpt_dir, dst_ckpt_dir, prefix, src_ckpt_strategy, dst_ckpt_strategy)
    print("......Transform succeed!......")

2.推理策略转换

代码

2.1生成目标策略文件

先生成策略文件 修改conv_stragety.py,运行run_conv.sh

2.2按照策略进行权重转换

python mindformers/tools/transform_ckpt.py --src_ckpt_strategy  (空) --dst_ckpt_strategy 刚生成的策略文件目录 --src_ckpt_dir SRC_CKPT_DIR --dst_ckpt_dir DST_CKPT_DIR
# 参数说明
# src_ckpt_strategy:待转权重的分布式策略文件路径。
  若为None,表示待转权重为完整权重;
  若为切分策略文件,表示原始的权重对应的策略文件;
  若为文件夹,表示需要合并文件夹内策略文件(仅在流水并行生成的策略文件时需要),合并后的策略文件保存在`SRC_CKPT_STRATEGY/merged_ckpt_strategy.ckpt`路径下;
# dst_ckpt_strategy:目标权重的分布式策略文件路径。即step1中生成的分布式策略文件路径。
  若为None,表示将待转权重合并为完整权重;
  若为切分策略文件,表示目标卡数对应的策略文件
  若为文件夹,表示需要合并文件夹内策略文件(仅在流水并行生成的策略文件时需要),合并后的策略文件保存在`DST_CKPT_STRATEGY/merged_ckpt_strategy.ckpt`路径下;
# src_ckpt_dir: 待转权重路径,须按照`SRC_CKPT_DIR/rank_{i}/checkpoint_{i}.ckpt`存放,比如单一权重存放格式为`SRC_CKPT_DIR/rank_0/checkpoint_0.ckpt`。
# dst_ckpt_dir:目标权重保存路径,为自定义空文件夹路径,转换后模型以`DST_CKPT_DIR/rank_{i}/xxx.ckpt`存放。
具体参考:https://gitee.com/mindspore/mindformers/blob/dev/docs/README.md#%E5%88%86%E5%B8%83%E5%BC%8F%E6%8E%A8%E7%90%86

2.3按照策略进行推理

修改chat.py,config同conv_stragety.py保持一致,运行run_chat.sh

3.基于generate的推理

以下是示例1,根据实际情况,只需修改tokenizer、config等
import os
import time
import numpy as np
import mindspore as ms
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindformers import pipeline
from mindformers import LlamaForCausalLM, LlamaConfig, AutoTokenizer, LlamaTokenizer
from mindformers import init_context
from mindformers.modules import TransformerOpParallelConfig
from mindformers.trainer.utils import get_last_checkpoint
from mindformers.tools import logger

SEQ_LENGTH = 256
DISTRIBUTED_CKPT_PATH = os.getenv("DISTRIBUTED_CKPT_PATH", "")


# set context
context_config = {"device_target": "Ascend", "mode": 0,  "max_device_memory": "31GB"}
parallel_context_config = {"parallel_mode": 1, "gradients_mean": False, "full_batch": True}
rank_id, device_num = init_context(use_parallel=True, context_config=context_config, parallel_config=parallel_context_config)
set_algo_parameters(elementwise_op_strategy_follow=True, fully_use_devices=True)
_set_multi_subgraphs()


# config blooom 7.1b
config = LlamaConfig(
    run_mode='predict',
    use_parallel= True,
    embedding_init_type="float32",
    checkpoint_name_or_path="",
    seq_length=SEQ_LENGTH,
    hidden_size=8192,
    num_layers=80,
    num_heads=64,
    n_kv_heads=8,
    vocab_size=50000,
    multiple_of=256,
    pad_token_id=50000,
    max_decode_length=1024,
    ffn_dim_multiplier=1.3,
    hidden_dropout_rate=0.0,
    attention_dropout_rate=0.0,
    top_k=3, top_p=1, do_sample=True,
    use_past = True,
    parallel_config=TransformerOpParallelConfig(
        data_parallel=1,
        model_parallel=8,
        pipeline_stage=1,
        vocab_emb_dp=True
        )
    )

def chat():
    # init bloom
    tokenizer = LlamaTokenizer("/home/ma-user/work/checkpoint_download/llama/tokenizer.model")
    llama = LlamaForCausalLM(config)
    llama.set_train(False)
    print(llama.config)
    print("*********************************************")
    print(llama.lm_head.weight.shape)
    print("*********************************************")
    print(llama.config.parallel_config.vocab_emb_dp)
    if DISTRIBUTED_CKPT_PATH:
        # find the sharded ckpt path for this rank
        ckpt_path = os.path.join(DISTRIBUTED_CKPT_PATH, "rank_{}".format(rank_id))
        ckpt_path = get_last_checkpoint(ckpt_path)
        logger.info("ckpt path: %s", str(ckpt_path))

        # shard bloom and load sharded ckpt
        #m = Model(llama)
        #m.infer_predict_layout(ms.Tensor(np.ones(shape=(1, SEQ_LENGTH)), ms.int32))
        infer_data=(ms.Tensor(np.ones(shape=(1, SEQ_LENGTH)), ms.int32),)
        llama.set_auto_parallel()
        llama.compile(*infer_data)
        print(llama.lm_head.weight.shape)
        print("*******************************************")
        checkpoint_dict = load_checkpoint(ckpt_path)
        not_load_network_params = load_param_into_net(llama, checkpoint_dict)
        logger.info("Network parameters are not loaded: %s", str(not_load_network_params))

    question_list = [
        "This is my motivation letter to apply the master course of Global Business in the college:",
        "大型网站建设最关心的问题就是网站速度",
        "糖渍板栗是一种以新鲜板栗",
        "呼伦贝尔是全世界最大的市",
        "我是练习时长两年半的个人练习生"
        ]


    for question in question_list:
        t1=time.time()
        inputs = tokenizer.encode(question)
        inputs = np.array([inputs]).astype(np.int32) # add batch dim
        outputs = llama.generate(inputs, max_length=None, do_sample=False, eos_token_id=2)
        outputs = outputs[0] # remove batch dim
        print(tokenizer.decode(outputs))
        print("chat time :",time.time()-t1)


if __name__ == "__main__":
    chat()
以下是示例2,根据实际情况,只需修改tokenizer、config等
https://gitee.com/mindspore/mindformers/blob/dev/run_chat_web.py