1. FastChat 框架概述
随着 ChatGPT 的兴起,出现了各类用于训练、推理大模型的框架, FastChat [1]就是其中之一。 FastChat 是一个开源的用于训练,评估以及部署大模型的框架,本文聚焦在模型的推理部分,从源码的角度介绍其基本的工作原理。
利用 FastChat 框架部署一个完整的模型服务主要分为三个部分,分别为: Controller , Server 以及多个 Worker 。这三者之间的关系如官方给出的下图所示:
这三个部分之间的关系为:
- Server 部分用于接收请求,并将不同的请求分发到对应的 Worker 上,对应了上图中蓝色的链路 Data plane;
- 对于 Server ,如何知道有哪些 Worker (即对应的 IP 地址),这个时候就需要 Controller 部分, Controller 部分不但需要与 Worker 通信,还需要与 Server 通信:
- 与 Worker 部分的通信,是将 Worker 的信息注册到 Controller 部分的对应数据结构中, Controller 还需要记录每一个 Worker 的健康状况;
- 与 Server 部分的通信,则是让 Server 查找到对应请求的 Worker 信息,并将请求转发给具体的 Worker。
- Worker 是参与真实的模型计算的模块。
在参考文献[1]中提供了多种的部署方式,类似上述的三个部分的部署代码如下所示:
# Controller
python3 -m fastchat.serve.controller
# Worker
# worker 0
CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000
# worker 1
CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001
# Web UI
python3 -m fastchat.serve.gradio_web_server
2. 三个重要的部分
在 FastChat 框架中,对于模型的推理,主要是由三个部分组成,分别为: Server , Worker 和 Controller 。这三个部分一起就组成了完整的大模型服务框架。 Server 部分负责接收用户的请求,并将请求转发给对应的 Worker,在计算完成后,将 Worker 返回的结果返回给用户; Worker 部分是整个模型计算的核心部分; Controller 部分用于保存和更新模型的信息。
2.1. Server
Server 部分有不同的 Server 方式,在 FastChat 中提供了两种,分别为:
- Web GUI ,其对应的启动脚本为:
gradio_web_server.py
。 - Web API ,其对应的启动脚本为:
open_api_server.py
。
第一种启动后会有一个利用 gradio 生成的页面。我们以第一种为例。 Gradio web server 的启动方式为:
python3 -m fastchat.serve.gradio_web_server
在 Server 模块中,需要处理的工作包括:1. 构建 UI ;2. 与 Controller 通信,用于取得当前的模型以及对应模型的 Worker 地址;3. 处理请求 Worker 以及处理 Worker 的返回,其过程如下图所示:
其中,与 Controller 通信包括两个方面,一个是在构建 UI 之前需要取到当前已有的模型列表,供用户在页面上选择对应的模型,对应的函数为 get_model_list()
,具体代码如下:
def get_model_list(
controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm
):
# Controller 地址不为空时取 controller 模块中的对应接口
if controller_url:
# 1. 先刷新
ret = requests.post(controller_url + "/refresh_all_workers")
assert ret.status_code == 200
# 2. 取list
ret = requests.post(controller_url + "/list_models")
models = ret.json()["models"]
else:
models = []
# Add API providers
if register_openai_compatible_models:
global openai_compatible_models_info
openai_compatible_models_info = json.load(
open(register_openai_compatible_models)
)
models += list(openai_compatible_models_info.keys())
if add_chatgpt:
models += ["gpt-3.5-turbo", "gpt-4"]
if add_claude:
models += ["claude-2", "claude-instant-1"]
if add_palm:
models += ["palm-2"]
# 去重
models = list(set(models))
priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
第二个是在请求的过程中,根据模型名称获取模型对应的 Worker 的地址,对应的代码在 bot_response()
函数内,其具体代码如下所示:
# Query worker address
# 请求获取模型对应的 Worker 地址
ret = requests.post(
controller_url + "/get_worker_address", json={"model": model_name}
)
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
由上可知, Server 与 Controller 通信,涉及到三个接口,主要关系如下图所示:
Server 与 Worker 的交互主要是在获取到模型对应的地址后,将请求发到对应的接口上,对应的代码在 model_worker_stream_iter()
函数内,其具体代码如下所示:
# Stream output
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=gen_params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
可见,Server 部分请求的是 Worker 中的 worker_generate_stream
接口。
2.2. Worker
Worker 模块的主要功能是加载模型以及模型的推理,其启动的命令如下:
# worker 0
CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.3 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000
# worker 1
CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001
在 model_worker.py
文件中,Worker 模块的主要函数为create_model_worker()
,该函数即完成了模型的加载以及模型的推理。对于模型加载和推理主要的代码在 BaseModelWorker
和 ModelWorker
两个类中,其中, ModelWorker
继承自 BaseModelWorker
,在 ModelWorker
类的实例化过程中,实现了两件事,第一是加载模型,第二是初始化心跳。首先,我们看下模型的加载,模型加载是在 ModelWorker
类的 __init__
函数中,其具体代码如下:
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
self.model, self.tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
)
self.device = device
if self.tokenizer.pad_token == None:
self.tokenizer.pad_token = self.tokenizer.eos_token
Worker 中除了加载模型之外,还涉及到处理请求,Server 模块发过来的请求会对应到 worker_generate_stream
接口上,其具体的代码如下所示:
@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
# 获取请求参数
params = await request.json()
await acquire_worker_semaphore()
# 取得 Worker 生成响应的具体的函数
generator = worker.generate_stream_gate(params)
background_tasks = create_background_tasks()
return StreamingResponse(generator, background=background_tasks)
最终调用的是 inference.py
中的 generate_stream()
函数,其具体函数如下:
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
除了提供了 worker_generate_stream
接口外,还提了以下的几个接口:
- worker_generate
- worker_get_embeddings
- worker_get_status
- count_token
- worker_get_conv_template
- model_details
除了上述与具体任务相关的处理外, Worker 还需要处理与 Controller 之间的关系,包括注册,发心跳等,主要涉及到如下的几个函数:
# 初始化
def init_heart_beat(self):
# 注册到 Controller
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker,
args=(self,),
daemon=True,
)
self.heart_beat_thread.start()
# 注册到 Controller
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
# 注册信息
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
}
r = requests.post(url, json=data)
assert r.status_code == 200
# 发送心跳信息
def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {self.model_names}. "
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
f"call_ct: {self.call_ct}. "
f"worker_id: {self.worker_id}. "
)
url = self.controller_addr + "/receive_heart_beat"
# 每隔 5 秒发送一次心跳信息
while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except (requests.exceptions.RequestException, KeyError) as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
# 不存在该 Worker 重新注册
if not exist:
self.register_to_controller()
2.3. Controller
由上述可知,Controller 起到了 Server 和 Worker 之间的桥梁的作用,Controller 的启动命令如下:
python3 -m fastchat.serve.controller
Controller 的作用一方面是存储模型的相关信息,另外一方面是对外提供一些接口。首先看一下 Controller 类的构造,其具体代码如下所示:
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
# 使用 dict 存储模型的信息
self.worker_info = {}
# dispatch_method的默认值是shortest_queue
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
# 线程
self.heart_beat_thread = threading.Thread(
target=heart_beat_controller, args=(self,)
)
self.heart_beat_thread.start()
# 以下是其他一些具体函数
在Controller中,提供了以下的接口:
- register_worker
- refresh_all_workers
- list_models
- get_worker_address
- worker_generate_stream
- worker_get_status
- test_connection
其中,在 Server 部分通过 refresh_all_workers 接口刷新所有 Worker ,通过 list_models 接口获取当前所有的 Worker ,通过 get_worker_address 接口根据模型名称获取对应的模型信息,这三个接口的具体代码如下所示:
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
由上代码可见,这三个接口分别调用了 Controller 类中的 refresh_all_workers()
, list_models()
和 get_worker_address()
三个函数。类似地, Worker 部分通过 register_worker 接口将模型的信息注册到 Controller ,通过 receive_heart_beat 接口更新模型的信息。
至此,利用 FastChat 搭建模型服务的主要的三个部分的大致部分已经介绍完,也只是大致介绍其工作原理,对于模型的计算, Prompt 的构造等其他涉及到算法的具体部分则没有涉及。
3. 其他
在上述的框架的搭建过程中,使用到了诸如 threading 库, FastAPI 和 Uvicorn 库,这些库对于搭建高效的服务框架有着重要的作用。
3.1. threading
threading 库是 Python 的线程模型,利用 threading 库可以轻松实现多线程任务。在 init_heart_beat()
函数内,使用到了threading 库,如下代码:
def init_heart_beat(self):
# 注册到 Controller
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker,
args=(self,),
daemon=True,
)
self.heart_beat_thread.start()
通过 threading.Thread
初始化一个线程对象,并调用 start()
函数启动线程,线程内执行的函数体便是 heart_beat_worker
,其具体内容如下代码所示:
def heart_beat_worker(obj):
while True:
# 每隔一段时间,发送心跳信息
time.sleep(WORKER_HEART_BEAT_INTERVAL)
obj.send_heart_beat()
3.2. FastAPI 和 Uvicorn
FastAPI[2] 是什么?官方给出的解释是: FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建基于 Python 的 API 。它是一个开源项目,基于 Starlette 和 Pydantic 库构建而成,提供了强大的功能和高效的性能。而 Uvicorn 是一个快速的 ASGI 服务器。简单来说, FastAPI 是一个构建 API 的 Python 框架,它使用了 Python 的异步编程特性并基于 ASGI(Asynchronous Server Gateway Interface)规范。而 Uvicorn 则是一个高性能的 ASGI 服务器,它可以运行支持 ASGI 的 Web 应用程序,如 FastAPI 。以如下的一个简单的例子介绍如何使用:
from fastapi import FastAPI
import uvicorn
app = FastAPI()
@app.get("/msg")
async def msg():
return {"msg": "Hello World"}
@app.get("/msg/{name}")
async def msg(name: str):
return {"msg": f"Hello World, {name}"}
if __name__ == "__main__":
uvicorn.run(app, host='127.0.0.1', port=9999, log_level="info")
运行上述代码,将在端口 9999 上启动服务,通过 Postman 或者 ApiPost 可测试上述的两个接口。