奥多码,点击查看详情 97CDN云盾,点击查看详情

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

       
广告2021-06-03到期2021-07-03广告2021-06-03到期2021-07-03
       
广告2021-06-03到期2021-07-03广告2021-06-03到期2021-07-03
随着 AI 模型的参数量越来越大,对算力的需求也水涨船高。
比如最近,Llama-3.1 登上了最强开源大模型的宝座,但超大杯 405B 版本的内存就高达 900 多 GB,这对算力构成了更加苛刻的挑战。
如何降低算力的使用成本和使用门槛,已经成为许多公司寻求突破的关键。Felafax 就是其中的一家创业公司,致力于简化 AI 训练集群的搭建流程。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
Nikhil Sonti 和 Nikhin Sonti 创立了 Felafy E d L } Gax,他们的口号是在构建开源 AI 平台,为下一代 AI 硬件服务,将机器学习的训练成本降低 30%。
与英伟达相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性价比,按每美元计算,其性能表现更为出色。
最近,Felafax 的联合创始人 Nikhil Sonti 发布( * Y p d f了一篇博客,详细分享了如何通过 8 张 AMD MI300X GPU 和 JAX 微调 LLaMA 3.1 405B\ 6 ^ 2 0 H k 模型的方法,所有代码现已开源。

微调大模型,amd mi300x就够了!跟着这篇博客微调llama 3.1 405b,效果媲美h100

Github 链接:https:\ P 5 S//github.com/felafax/felafax
本站对博客内容进行了不改变原意的编译、整理,以下是博客内容:
JAX 尤其适合非英伟达硬件
J+ c H d 4 _ D e BAX 是一个强大的机器学习库,结合了类似 NumPy 的 API、自动微6 W l 0 _分功能以及 Google 的 XLA 编q / – U L P译器。它在模型并行化方面提供了优秀的 API,因此非常适合像 LLaMA 3.1 405B 这样的6 \ z ; X C K ,超大模型训练。
在使用 AMDC J e 硬件时,JAX 有几个明显的优势:
  • 多硬件并行支持:JAX 采用 XLA(加速线性代数)编译器,将计算编译为硬件无关的中间表示(HLO),这意味着同样的 JA, # N ! SX 代码无需修改便可高效运行在不同硬件后端,t e d v Xo l p括 AMD} ( ` $ x Y 9 GPU。
  • 独立于底层硬件:XLA 编译器的优化策略是通用的,不针对某个特定的硬件平台。这使得任何支持 XLA 的硬件设备(如 CPU、GPU、TPU)都能受益于这些优化,获/ k k o U F得更好的性能表现。
  • 极高的适应性:从 NVIDIA 转移到 AMD(或其他硬件)时,JAX 只需做极少的代码改动。而相较之下,PyTorch 与英伟达的 CUDA 生态系统紧密耦合,迁移过程相对复杂。
因此,JAX 成为了我们在非英伟达硬件上的最佳选择。
拉取 Docker 镜像:
docker pull rocm/jax:latest
登录后复制
启动 Docker 容器:
# Pull the Docker Image:docker pull rocm/jax:latest # Start the Docker Container:docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest# Verify the Installation: python3 -c 'import jax; print(jax.devices())'
登录后复制\ Y L d q
验证安装
python3 -c 'import jax; print (jax.devices ())'
登录后复制
训练使用了一个配备了 8 张 AMD MI300x GPU 的 AMD 节点。每张 MI300x 拥有 192GB 的 HBM3 内存,性能表现与最新的英伟达 H100 GPU 相比非常出色。g K P ( 1 j
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
与英伟达 H100 的比较,来源:TensorWave
训练 LLaMA 405B:性能与可扩展性
使用 JAX,可以成功地在 AMDA , A o . T GPU 上训练 LZ ; H b gLaMA 405Bo c z s * } W . = 模型。我们使用 LoRw O _ 1A 微调,将所有模型权重和 LoRA 参数都设为 bf( g \ Lloat16,LoRA rank 设为 8,LoRA alpha 设为 16:
  • ( u ; o +型大小:LLaMA 模型S D Y的权重占用了约 800GB 的显存。
  • LoRA 权重 + 优化器状态:大约占用了 400GB 的显存。
  • 显存总使用量:占总显存的 77%,约 1200GB。
  • 限制:由于 405B 模型的规模过大,batch 大小和序列长度的空间有限,使用的 batcm 8 ^ h 9 Wh size 为 16,序列长度为 64。
  • JIT 编译:由于空间限制,无法运行 JIT 编译版本;它可能需要比急切模式/ = r = o ~ 6 v ,稍多的空间。
  • 训练速度:使用 JAX 急切模式,约为 35 toke, z B 1 \ 1 8 Zns / 秒。
  • 内存效率:稳定在约 70% 左右。
  • 扩展性:在 8 张 GPU 上,使用 JAX 的扩展性接近线性。
由于硬件和显存的限制,我们无法运行 JIT 编D \ B L n O o C译版本的 405B 模型,整个训练过程是在 JAX 的急切模式下执行的,因此还有很大的进步空间。– I e K f
下图中显示了在一次微调训练步骤中,8 张 GPU 的显存利用率和 rocm-smi 输出:
GPU 利用率:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
显存利用率:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
rocm-smi 输出:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

训练设置
将 LLaMA 3.1 从 PyT{ q b X ] ) Z B }orch 移植到 JAX
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
此前,Nikhil Sonti 分享过如何将 LLaMA 3.1 从 PyTorchj P q [ [ e G B = 移植到 JAX。他指出,目前 90% 的大型语言模型R . C i q(LLM)都运行在 NVIDIA GPU 上,但实际上还有一些同样强大且: t D @ Z Y性价比更高Y F G ! $的替代方案。例如,在 Google TPU 上训练和部署 Llama 3.1 的I . h J成本比 NVIDIA GPU 低约 30%。
然而,支持非 NVIDIA 硬件的开发工具较为匮乏。Sonti 最初尝试使用 PyTorch XLA 在 TPU 上训练 Llama 3.1,但过程并不顺利。XY ^ X % i ( ) .LA 与 PyTorch 的集成不够完善,缺少一些关键的库(如 bitsandbytes 无法正常运行),同时还遇到了一些难以解决的 HuggingFace 错误。
为此,他决定调整Z $ 6 7 O % @策略,将 Llama 3.1 从 PyTorch 移植到 JAX,成功解决了这些问题。Sonti| | D – G V r s 还录制了详细的教程视l ( * ! i Q频,并开源了所有代码:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100

  • 方法演示:https://dub.sh/felafax-demo
  • 代码仓库:https://github.com/felafax/felafax
加载模型,并把1 x T ; H L : . s模型参数分片
处理像 LLaMA 405B 这样的超大模型,需要在多个设备之间高效地进行参数分片。以下是如何通过 JAX 实现这一点的。
在 JAX 中进行参数分片
为了将巨大的 LH ; r r W P * QLaMA 405B 模型高效地分布到 8 张 AMD GPU 上,需要使用 JAX 的设备网格(device mesh)功能。
部署代码:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/traiK C A a / ;ner_enD O c 2 S R }gine/jax_utils.py#L69
JAX 的设备网格可以帮助我们把可用的设备组织成一个网格,让_ 9 * Y我们可以指定如何把模型的参数和计算分配到不同的 GPU 上。
在本文的设置中,需要创建一个形状为(1, 8, 1)的网格,并将轴分别命名为数据并行(dp)、全分片数T ] ( s 5 I o N j据并行(fsdp)和模型并行(mp)。然后,为模型的每个张量定义特定的分片规则,指定这些维度如何沿着这些网格轴进行分片。
DEVICES = jax.devices () DEVICE_COUNT = len (DEVICES) DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
登录后复制
可视化分片
可以使~ Z $ 2 ~ . ( / v用以下代码来可视化分片结果,从而方便地验证分片规则是否按预期应用。
jax.debug.visualize_array_sharding
登录后复制
分片规则
模型不同组件的分片规则如下所示:
  • 参数如何分片:
参数要在 8 个 GPU9 P m 7 { ` 之间分配y { h x E。例如,LM head(lm_head/kernel)张量有两个轴,按照 PS (“fsdp”8 k ) 8 4 – I D 7, “f X b . : z M v Pmp”) 进行分片。在本例中是 8 和 1,因此可以看到该张量在第一个轴上沿着 8 个 GPU 被拆分。
  • Non-Replicj 9 & , . O Mated 参数:
没有任何分片规范的参数会在所有设备上进行复制。例如,层归一化(attention_norm/kernel 和 ffn_norm/kernel)没有设置分片规范,是 PS (None)。
应用分片函数
在加载模型时,使用以下分片函数逐步对模型权重进行分片:
def make_shard_and_gather_fns (partition_specs):def make_shard_fn (partition_spec):out_sharding = NamedSharding (mesh, partition_spec)def shard_fn (tensor):return jax.device_put (tensor, out_sharding).block_until_ready ()return shard_fnshard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)return shard_fns# Create shard functions based on partitioning rulesshard_fns = make_shard_and_gather_fns (partitioning_rules)
登录后复制
这使得我们能够将每个参\ . c & r =数放置在指定的设备上,并按照设定的分片进行处理。
分片训练 Batch
最初,训练 Batch 是正常创建的,但在输入模型之前,需要按照下面的代码在 GPU 上进行分片:
train_batch = jax.device_put ( train_batch,NamedSharding (self.mesh, PS ("dp", "fsdp")))
登录后@ ! E n I y复制
在这里,我m X F Y &们指定训练 Batch 应该在 “dp” 和 “fsdp” 轴上进行分片,在本例中分别对应于被分成 1 和 8 份,如果把结果可视化出来,如下所示:
分片前:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
在调用 jax.device_put 之后:
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
加入 LoRA
LoRA 通过将权重更新分解为低秩矩阵,减少了可训练参数的数量,这对于微调大型模型特别有效。以下是在 AMD GPU 上微调 Llama 3.1-405 的 LoRA 的要点:
  • 将 LoRA 参数(lora_a 和 lora_b)与主模型参数分开。
  • 使用 jax.lax.stopu ? G s A_gradient (kernel) 来防止对主x & 2 I模型权重的更新。
  • 使用 lax.dot_general 进行快速、精确控制的矩阵运算。
  • LoRA 输出在添加到主输出之前会被缩放为 (self.lora_alpha/self.lora_rank)。
LoRADense 层
在此设定一个自定义的 LoRADense 层,该层集成了 LoRA 参数:
class LoRADense (nn.Module):features: intlora_rank: int = 8lora_alpha: float = 16.0@nn.compactdef __call__(self, inputs: Any) -> Any:# Original kernel parameter (frozen)kernel = self.param ('kernel', ...)y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)# LoRA parameters (trainable)lora_a = self.variable ('lora_params', 'lora_a', ..., ...)lora_b = self.variable ('lora_params', 'lora_b', ..., ...)# Compute LoRA outputlora_output = lax.dot_general (inputs, lora_a.value, ...)lora_output = lax.dot_general (lora_output, lora_b.value, ...)# Combine original output with LoRA modificationsy += (self.lora_alpha/self.lora_rank) * lora_outputreturn y.astype (self.dtype)
登录后复制
分片 LoRA 参数
为了高效地在设备之间分配 LoRA 参数,我们也通过 JAX 设定了分A l B片规则,这确保了 LoRA 参数与主模型参数的分片一致,优化了内存使用和计算效率。
LoRA A matrices (lora_a)
登录后复制
LoRA A 矩阵(lora_a)
  • 分片规则:PS (“fsdp”, “mJ / @p”)
  • 可视化结果:如下图所示,lora_a 参数被分片为 (8, 1),这意味着第一个轴在 8 个设备上进行分片(”fsdm ] 3 i # Q i Lp” 轴),而第二个轴未进行分片。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
LoRA B 矩阵(lora_b)
  • 分片规w 2 ) , u P k [ O则:PS (“mp”, “fsdp”)
  • 可视化结A 7 f S n / 6 &果:如下图所示,lora_b 参数被分片为 (1, 8),这意味着第二个轴在 8 个设备上进行分U ! 1 Z l * l片(fsdp 轴),而第一个轴未进行分片。
微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100
这种分片策略优化了参数的分配,减少了通信开销,并在训练过程中z | a @ 2 Y ` 1 7增强了并行性。它确保每个设备仅持有一部分 LoRA 参数,使得大模型如 LLaMA 405B 的高效扩展成为可能。
仅更新 LoRA 参数
为了优化训练,在微调 LLaMA 405B 模型,只计算 LoRA 参数的梯度,保持主模型参数j A & m n P B Z Z不变。这个方法减少了内存使用,并加速了训练,因为只更新较少的参数n \ ~ [ M D z W。可以移步w z x W k P V h S GitHubS j $ 仓库,查看实现细节。
在训练过程中,每一步都涉及将一批输入数据通} 0 L N u y r l过模型进行处} J 7 n理。由于只有 LoRA 参数是可训练的,因此模型的预测x g ] = M ] ) –和计算的损失仅依赖于这些参数,然后对 LoRA 参数进行反向传播。只更新这些参\ D t数简化了训练过程,使得在多个 GPU 上高效微调像 LLaMA 405B 这样的大型Q Y [模型成为可能。
更多研究细节,请参考原博客。

以上就是微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B,效果媲美H100的详细内容!

本文由【好易之】整理自网络!
原创文章,作者:【好易之】如转载请注明出处:https://www.zhengjiaxi.com/zxwd/itzx/117965.html
如有侵权,请邮件联系 aoduoye@qq.com 删除。
本站发布的文章及附件仅限用于学习和研究目的;不得将上述内容用于商业或非法用途,否则后果请用户自负。
本站信息来自网络,版权争议与本站无关。您必须在下载后的24个小时之内,从您的电脑中彻底删除上述内容。
如果您喜欢该程序,请支持正版软件,购买注册,得到更好的正版服务。如有侵权请邮件与我们联系处理。
(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
好易之的头像好易之站长
上一篇 2024-12-25 13:24
下一篇 2024-12-25 13:24

相关推荐

发表回复

登录后才能评论

联系我们

400-800-8888

在线咨询:点击这里给我发消息

 

工作时间:周一至周五,9:30-18:30,节假日休息

关注公众号
请查看头部文章来源地址!本站所有内容均为互联网收集整理和网友上传。仅限于学习研究,切勿用于商业用途。否则由此引发的法律纠纷及连带责任本站概不承担。