FastChat 框架中的服务解析

1. FastChat 框架概述

随着 ChatGPT 的兴起,出现了各类用于训练、推理大模型的框架, FastChat [1]就是其中之一。 FastChat 是一个开源的用于训练,评估以及部署大模型的框架,本文聚焦在模型的推理部分,从源码的角度介绍其基本的工作原理。

利用 FastChat 框架部署一个完整的模型服务主要分为三个部分,分别为: Controller , Server 以及多个 Worker 。这三者之间的关系如官方给出的下图所示:

输入图片说明

这三个部分之间的关系为:

  • Server 部分用于接收请求,并将不同的请求分发到对应的 Worker 上,对应了上图中蓝色的链路 Data plane;
  • 对于 Server ,如何知道有哪些 Worker (即对应的 IP 地址),这个时候就需要 Controller 部分, Controller 部分不但需要与 Worker 通信,还需要与 Server 通信:
    • 与 Worker 部分的通信,是将 Worker 的信息注册到 Controller 部分的对应数据结构中, Controller 还需要记录每一个 Worker 的健康状况;
    • 与 Server 部分的通信,则是让 Server 查找到对应请求的 Worker 信息,并将请求转发给具体的 Worker。
  • Worker 是参与真实的模型计算的模块。

在参考文献[1]中提供了多种的部署方式,类似上述的三个部分的部署代码如下所示:

# Controller
python3 -m fastchat.serve.controller
# Worker
# worker 0 
CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.5 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000 
# worker 1 
CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001 
# Web UI 
python3 -m fastchat.serve.gradio_web_server

2. 三个重要的部分

在 FastChat 框架中,对于模型的推理,主要是由三个部分组成,分别为: Server , Worker 和 Controller 。这三个部分一起就组成了完整的大模型服务框架。 Server 部分负责接收用户的请求,并将请求转发给对应的 Worker,在计算完成后,将 Worker 返回的结果返回给用户; Worker 部分是整个模型计算的核心部分; Controller 部分用于保存和更新模型的信息。

2.1. Server

Server 部分有不同的 Server 方式,在 FastChat 中提供了两种,分别为:

  • Web GUI ,其对应的启动脚本为: gradio_web_server.py
  • Web API ,其对应的启动脚本为: open_api_server.py

第一种启动后会有一个利用 gradio 生成的页面。我们以第一种为例。 Gradio web server 的启动方式为:

python3 -m fastchat.serve.gradio_web_server

在 Server 模块中,需要处理的工作包括:1. 构建 UI ;2. 与 Controller 通信,用于取得当前的模型以及对应模型的 Worker 地址;3. 处理请求 Worker 以及处理 Worker 的返回,其过程如下图所示:

输入图片说明

其中,与 Controller 通信包括两个方面,一个是在构建 UI 之前需要取到当前已有的模型列表,供用户在页面上选择对应的模型,对应的函数为 get_model_list() ,具体代码如下:

def get_model_list(
	controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm 
): 
	# Controller 地址不为空时取 controller 模块中的对应接口 
	if controller_url: 
		# 1. 先刷新 
		ret = requests.post(controller_url + "/refresh_all_workers") 
		assert ret.status_code == 200 
		# 2. 取list 
		ret = requests.post(controller_url + "/list_models") 
		models = ret.json()["models"] 
	else: 
		models = [] 
	
	# Add API providers 
	if register_openai_compatible_models: 
		global openai_compatible_models_info 
		openai_compatible_models_info = json.load( 
			open(register_openai_compatible_models) 
		) 
		models += list(openai_compatible_models_info.keys()) 
	
	if add_chatgpt: 
		models += ["gpt-3.5-turbo", "gpt-4"] 
	if add_claude: 
		models += ["claude-2", "claude-instant-1"] 
	if add_palm: 
		models += ["palm-2"] 
	# 去重 
	models = list(set(models)) 
	
	priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)} 
	models.sort(key=lambda x: priority.get(x, x)) 
	logger.info(f"Models: {models}") 
	return models

第二个是在请求的过程中,根据模型名称获取模型对应的 Worker 的地址,对应的代码在 bot_response() 函数内,其具体代码如下所示:

# Query worker address 
# 请求获取模型对应的 Worker 地址 
ret = requests.post( 
	controller_url + "/get_worker_address", json={"model": model_name} 
) 
worker_addr = ret.json()["address"] 
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")

由上可知, Server 与 Controller 通信,涉及到三个接口,主要关系如下图所示:

输入图片说明

Server 与 Worker 的交互主要是在获取到模型对应的地址后,将请求发到对应的接口上,对应的代码在 model_worker_stream_iter() 函数内,其具体代码如下所示:

# Stream output 
response = requests.post( 
	worker_addr + "/worker_generate_stream", 
	headers=headers, 
	json=gen_params, 
	stream=True, 
	timeout=WORKER_API_TIMEOUT, 
)

可见,Server 部分请求的是 Worker 中的 worker_generate_stream 接口。

2.2. Worker

Worker 模块的主要功能是加载模型以及模型的推理,其启动的命令如下:

# worker 0 
CUDA_VISIBLE_DEVICES=0 python3 -m fastchat.serve.model_worker --model-path lmsys/vicuna-7b-v1.3 --controller http://localhost:21001 --port 31000 --worker http://localhost:31000 
# worker 1 
CUDA_VISIBLE_DEVICES=1 python3 -m fastchat.serve.model_worker --model-path lmsys/fastchat-t5-3b-v1.0 --controller http://localhost:21001 --port 31001 --worker http://localhost:31001

model_worker.py 文件中,Worker 模块的主要函数为create_model_worker() ,该函数即完成了模型的加载以及模型的推理。对于模型加载和推理主要的代码在 BaseModelWorkerModelWorker 两个类中,其中, ModelWorker 继承自 BaseModelWorker ,在 ModelWorker 类的实例化过程中,实现了两件事,第一是加载模型,第二是初始化心跳。首先,我们看下模型的加载,模型加载是在 ModelWorker 类的 __init__ 函数中,其具体代码如下:

logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...") 
self.model, self.tokenizer = load_model( 
	model_path, 
	device=device, 
	num_gpus=num_gpus, 
	max_gpu_memory=max_gpu_memory, 
	dtype=dtype, 
	load_8bit=load_8bit, 
	cpu_offloading=cpu_offloading, 
	gptq_config=gptq_config, 
	awq_config=awq_config, 
) 
self.device = device 
if self.tokenizer.pad_token == None: 
	self.tokenizer.pad_token = self.tokenizer.eos_token

Worker 中除了加载模型之外,还涉及到处理请求,Server 模块发过来的请求会对应到 worker_generate_stream 接口上,其具体的代码如下所示:

@app.post("/worker_generate_stream") 
async def api_generate_stream(request: Request): 
	# 获取请求参数 
	params = await request.json() 
	await acquire_worker_semaphore() 
	# 取得 Worker 生成响应的具体的函数 
	generator = worker.generate_stream_gate(params) 
	background_tasks = create_background_tasks() 
	return StreamingResponse(generator, background=background_tasks)

最终调用的是 inference.py 中的 generate_stream() 函数,其具体函数如下:

@torch.inference_mode() 
def generate_stream( 
	model, 
	tokenizer, 
	params: Dict, 
	device: str, 
	context_len: int, 
	stream_interval: int = 2, 
	judge_sent_end: bool = False, 
):

除了提供了 worker_generate_stream 接口外,还提了以下的几个接口:

  • worker_generate
  • worker_get_embeddings
  • worker_get_status
  • count_token
  • worker_get_conv_template
  • model_details

除了上述与具体任务相关的处理外, Worker 还需要处理与 Controller 之间的关系,包括注册,发心跳等,主要涉及到如下的几个函数:

# 初始化 
def init_heart_beat(self): 
	# 注册到 Controller 
	self.register_to_controller() 
	self.heart_beat_thread = threading.Thread( 
		target=heart_beat_worker, 
		args=(self,), 
		daemon=True, 
	) 
	self.heart_beat_thread.start() 

# 注册到 Controller 
def register_to_controller(self): 
	logger.info("Register to controller") 

	url = self.controller_addr + "/register_worker" 
	# 注册信息 
	data = { 
		"worker_name": self.worker_addr, 
		"check_heart_beat": True, 
		"worker_status": self.get_status(), 
	} 
	r = requests.post(url, json=data) 
	assert r.status_code == 200 

# 发送心跳信息 
def send_heart_beat(self): 
	logger.info( 
		f"Send heart beat. Models: {self.model_names}. " 
		f"Semaphore: {pretty_print_semaphore(self.semaphore)}. " 
		f"call_ct: {self.call_ct}. " 
		f"worker_id: {self.worker_id}. " 
	) 

	url = self.controller_addr + "/receive_heart_beat" 

	# 每隔 5 秒发送一次心跳信息 
	while True: 
		try: 
			ret = requests.post( 
				url, 
				json={ 
					"worker_name": self.worker_addr, 
					"queue_length": self.get_queue_length(), 
				}, 
				timeout=5, 
			) 
			exist = ret.json()["exist"] 
			break 
		except (requests.exceptions.RequestException, KeyError) as e: 
			logger.error(f"heart beat error: {e}") 
			time.sleep(5) 
	
	# 不存在该 Worker 重新注册 
	if not exist: 
		self.register_to_controller()

2.3. Controller

由上述可知,Controller 起到了 Server 和 Worker 之间的桥梁的作用,Controller 的启动命令如下:

python3 -m fastchat.serve.controller

Controller 的作用一方面是存储模型的相关信息,另外一方面是对外提供一些接口。首先看一下 Controller 类的构造,其具体代码如下所示:

class Controller: 
	def __init__(self, dispatch_method: str): 
		# Dict[str -> WorkerInfo] 
		# 使用 dict 存储模型的信息 
		self.worker_info = {} 
		# dispatch_method的默认值是shortest_queue 
		self.dispatch_method = DispatchMethod.from_str(dispatch_method) 
		
		# 线程 
		self.heart_beat_thread = threading.Thread( 
			target=heart_beat_controller, args=(self,) 
		) 
		self.heart_beat_thread.start() 
	# 以下是其他一些具体函数

在Controller中,提供了以下的接口:

  • register_worker
  • refresh_all_workers
  • list_models
  • get_worker_address
  • worker_generate_stream
  • worker_get_status
  • test_connection

其中,在 Server 部分通过 refresh_all_workers 接口刷新所有 Worker ,通过 list_models 接口获取当前所有的 Worker ,通过 get_worker_address 接口根据模型名称获取对应的模型信息,这三个接口的具体代码如下所示:

@app.post("/refresh_all_workers") 
async def refresh_all_workers(): 
	models = controller.refresh_all_workers() 

@app.post("/list_models") 
async def list_models(): 
	models = controller.list_models() 
	return {"models": models} 

@app.post("/get_worker_address") 
async def get_worker_address(request: Request): 
	data = await request.json() 
	addr = controller.get_worker_address(data["model"]) 
	return {"address": addr}

由上代码可见,这三个接口分别调用了 Controller 类中的 refresh_all_workers()list_models()get_worker_address() 三个函数。类似地, Worker 部分通过 register_worker 接口将模型的信息注册到 Controller ,通过 receive_heart_beat 接口更新模型的信息。

至此,利用 FastChat 搭建模型服务的主要的三个部分的大致部分已经介绍完,也只是大致介绍其工作原理,对于模型的计算, Prompt 的构造等其他涉及到算法的具体部分则没有涉及。

3. 其他

在上述的框架的搭建过程中,使用到了诸如 threading 库, FastAPI 和 Uvicorn 库,这些库对于搭建高效的服务框架有着重要的作用。

3.1. threading

threading 库是 Python 的线程模型,利用 threading 库可以轻松实现多线程任务。在 init_heart_beat() 函数内,使用到了threading 库,如下代码:

def init_heart_beat(self): 
	# 注册到 Controller 
	self.register_to_controller() 
	self.heart_beat_thread = threading.Thread( 
		target=heart_beat_worker, 
		args=(self,), 
		daemon=True, 
	) 
	self.heart_beat_thread.start()

通过 threading.Thread 初始化一个线程对象,并调用 start() 函数启动线程,线程内执行的函数体便是 heart_beat_worker ,其具体内容如下代码所示:

def heart_beat_worker(obj): 
	while True: 
		# 每隔一段时间,发送心跳信息 
		time.sleep(WORKER_HEART_BEAT_INTERVAL) 
		obj.send_heart_beat()

3.2. FastAPI 和 Uvicorn

FastAPI[2] 是什么?官方给出的解释是: FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建基于 Python 的 API 。它是一个开源项目,基于 Starlette 和 Pydantic 库构建而成,提供了强大的功能和高效的性能。而 Uvicorn 是一个快速的 ASGI 服务器。简单来说, FastAPI 是一个构建 API 的 Python 框架,它使用了 Python 的异步编程特性并基于 ASGI(Asynchronous Server Gateway Interface)规范。而 Uvicorn 则是一个高性能的 ASGI 服务器,它可以运行支持 ASGI 的 Web 应用程序,如 FastAPI 。以如下的一个简单的例子介绍如何使用:

from fastapi import FastAPI 
import uvicorn 

app = FastAPI() 

@app.get("/msg") 
async def msg(): 
	return {"msg": "Hello World"} 

@app.get("/msg/{name}") 
async def msg(name: str): 
	return {"msg": f"Hello World, {name}"} 

if __name__ == "__main__": 
	uvicorn.run(app, host='127.0.0.1', port=9999, log_level="info")

运行上述代码,将在端口 9999 上启动服务,通过 Postman 或者 ApiPost 可测试上述的两个接口。

参考文献

[1] https://github.com/lm-sys/FastChat

[2] https://fastapi.tiangolo.com/