Flash Attention 2 提升巨大吗?不见得

该文章测试了分别使用xformers与flash attention 2作为后端注意力机制时,vllm的性能差距。
勘误
后续测试中发现30系列即使使用 VLLM_ATTENTION_BACKEND=XFORMERS 环境变量启动 vllm 推理api 服务,并且在启动设置中显示为 XFORMERS,貌似实际使用的 attention 后端也经过了某种加速(很大概率就是 FLASH_ATTN2),所以重新做了一个测试,实际发现 flash attention 2 确实有不错的提升,详细结果请见新文章。
下面是本文原内容,未作修改:
前言
写这篇文章的起因是,一直没看到一个对于使用xformers与flash attention 2的推理框架之间性能差距任何定量的分析,你能查到的资料无一不是说fa2真的好棒棒,显存占用减少xx%,attention计算加快xx%。没有一个测试是真正把使用这两个attention计算的推理引擎性能(例如,vllm)真正做一个公平的比较。我的主力设备是2080ti(环保主义者~~,绝对不是因为穷~~),他在很久前(一年多)就已经在fa2的支持的计划列表了,但是随着rtx50系列的发售,估计以后都不可能支持了。
我特别想知道我因为使用这样老旧的设备到底损失了多少潜在的提升,于是简单测一下这玩意在真实的llm推理环境下,到底有没有效果。
提前声明,测试结果只是在我机器上的真实结果,可能有局限性,如果你有不同看法或见解欢迎联系我。
测试环境
硬件环境
# nvidia-smi
Thu Feb 6 15:37:54 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.65 Driver Version: 566.07 CUDA Version: 12.7 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3070 ... On | 00000000:01:00.0 On | N/A |
| N/A 49C P8 14W / 95W | 672MiB / 8192MiB | 27% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
软件环境
# pip show vllm
Name: vllm
Version: 0.7.1
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
Home-page: https://github.com/vllm-project/vllm
Author: vLLM Team
Author-email:
License: Apache 2.0
Location: /root/miniconda3/envs/vllm/lib/python3.11/site-packages
Requires: aiohttp, blake3, cloudpickle, compressed-tensors, depyf, einops, fastapi, filelock, gguf, importlib_metadata, lark, lm-format-enforcer, mistral_common, msgspec, numpy, nvidia-ml-py, openai, outlines, partial-json-parser, pillow, prometheus-fastapi-instrumentator, prometheus_client, protobuf, psutil, py-cpuinfo, pydantic, pyyaml, pyzmq, ray, requests, sentencepiece, tiktoken, tokenizers, torch, torchaudio, torchvision, tqdm, transformers, typing_extensions, uvicorn, xformers, xgrammar
Required-by:
测试工具
https://github.com/Yoosu-L/llmapibenchmark
测试结果
XFORMERS
export VLLM_ATTENTION_BACKEND=XFORMERS
python -m vllm.entrypoints.openai.api_server --model=/root/llm/models/Qwen/Qwen2.5-7B-Instruct-AWQ --served-model=Qwen2.5-7B-Instruct-AWQ --dtype=float16 --tensor-parallel-size=1 --trust-remote-code --host=0.0.0.0 --port=8008 --gpu-memory-utilization=0.9 --max-model-len=5000
short input
Input Tokens: 45
Output Tokens: 512
Test Model: Qwen2.5-7B-Instruct-AWQ
Latency: 2.20 ms
Concurrency | Generation Throughput (tokens/s) | Prompt Throughput (tokens/s) | Min TTFT (s) | Max TTFT (s) |
---|---|---|---|---|
1 | 58.49 | 846.81 | 0.05 | 0.05 |
2 | 114.09 | 989.94 | 0.08 | 0.09 |
4 | 222.62 | 1193.99 | 0.11 | 0.15 |
8 | 414.35 | 1479.76 | 0.11 | 0.24 |
16 | 752.26 | 1543.29 | 0.13 | 0.47 |
32 | 653.94 | 1625.07 | 0.14 | 0.89 |
long input
Input Tokens: 2771
Output Tokens: 512
Test Model: Qwen2.5-7B-Instruct-AWQ
Latency: 3.60 ms
Concurrency | Generation Throughput (tokens/s) | Prompt Throughput (tokens/s) | Min TTFT (s) | Max TTFT (s) |
---|---|---|---|---|
1 | 45.64 | 1767.91 | 1.62 | 1.62 |
2 | 77.52 | 1743.44 | 1.67 | 3.28 |
4 | 71.76 | 1763.34 | 1.69 | 6.48 |
FLASH_ATTN
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
python -m vllm.entrypoints.openai.api_server --model=/root/llm/models/Qwen/Qwen2.5-7B-Instruct-AWQ --served-model=Qwen2.5-7B-Instruct-AWQ --dtype=float16 --tensor-parallel-size=1 --trust-remote-code --host=0.0.0.0 --port=8008 --gpu-memory-utilization=0.9 --max-model-len=5000
short input
Input Tokens: 45
Output Tokens: 512
Test Model: Qwen2.5-7B-Instruct-AWQ
Latency: 3.00 ms
Concurrency | Generation Throughput (tokens/s) | Prompt Throughput (tokens/s) | Min TTFT (s) | Max TTFT (s) |
---|---|---|---|---|
1 | 60.04 | 648.85 | 0.07 | 0.07 |
2 | 118.09 | 804.13 | 0.09 | 0.11 |
4 | 229.75 | 1030.40 | 0.13 | 0.17 |
8 | 431.84 | 1384.16 | 0.13 | 0.26 |
16 | 730.86 | 1538.19 | 0.13 | 0.47 |
32 | 692.52 | 1609.80 | 0.14 | 0.89 |
long input
Input Tokens: 2796
Output Tokens: 512
Test Model: Qwen2.5-7B-Instruct-AWQ
Latency: 3.20 ms
Concurrency | Generation Throughput (tokens/s) | Prompt Throughput (tokens/s) | Min TTFT (s) | Max TTFT (s) |
---|---|---|---|---|
1 | 44.18 | 1679.05 | 1.69 | 1.69 |
2 | 78.83 | 1654.61 | 1.74 | 3.41 |
4 | 75.70 | 1636.30 | 1.74 | 6.91 |
结论
看到这个结果其实没有一点意外,因为根据硬件规格、模型参数量以及别人使用H100 rtx4090放出来的推理速度的数据来看,就隐隐猜到是这样。
不管是从prompt处理速度还是生成速度来看,提升不能说是一点没有,只能说是微乎其微,就连显存占用也是一样,在batch size>16后都是几乎占满gpu内存。我们不能下定论说fa2真的没用,也许是我哪些参数搞错了,也许只是在我这张30系列显卡上没有效果(但他确实是支持fa2的)。
这篇测试对于大多数人是没有意义的,你只要选择更好的显卡就够了,新的特性总会有人为你支持,较真在意这些新特性是否有用的人总是少数。不过我更希望是我哪里弄错了,毕竟老显卡总有需要退役的时候,谁不希望新的更好呢。