選自oxen.ai
作者:Greg Schoeninger
編譯:陳陳、澤南
RTX 3080 移動版能訓練哪種大模型?本文為那些 GPU 資源有限時使用 GRPO 訓練的開發者提供了寶貴的指導。
自 DeepSeek-R1 發布以來,群組相對策略優化(GRPO)因其有效性和易于訓練而成為大型語言模型強化學習的熱門話題。R1 論文展示了如何使用 GRPO 從遵循 LLM(DeepSeek-v3)的基本指令轉變為推理模型(DeepSeek-R1)。
GRPO 是一種在線學習算法(online learning algorithm),它通過使用訓練過程中由訓練模型自身生成的數據來進行迭代改進。GRPO 的目標是最大化生成補全(completions)的優勢函數(advantage),同時確保模型保持在參考策略(reference policy)附近。
本文的目的是幫你節省一些時間,讓你根據硬件預算選擇合適的模型大小。在開始微調時,你必須做出的重要決定是選擇模型大小,以及你是執行完全微調還是參數高效微調(PEFT)。
文章作者來自 AI 公司 Oxen.ai 的 CEO Greg Schoeninger。
原文鏈接:https://www.oxen.ai/blog/grpo-vram-requirements-for-the-gpu-poor
作者表示,他發現 trl 庫中已經有一個易于使用的 GRPO 實現,便立刻開始了訓練,使用的硬件是配備了 16GB 顯存的 Nvidia GeForce RTX 3080 的小型筆記本電腦。正如大家可能遇到的問題,作者發現示例代碼中的參數設置導致了一個巨大的顯存不足(OOM,out of memory )錯誤。
- torch
- OutOfMemoryError
- CUDA
- out
- of memory
- Tried
- to allocate
- 1.90
- GiB
- GPU
- 0
- has a total capacity of
- GiB
- of which
- 1.28
- GiB
- is
- free
- Including
- non
- PyTorch
- memory
- this
- process has
- GiB
- memory
- in
- use
- Of
- the allocated memory
- GiB
- is
- allocated
- by
- PyTorch
- and
- 2.41
- GiB
- is
- reserved
- by
- PyTorch
- but unallocated
- If
- reserved but unallocated memory
- is
- large
- try
- setting PYTORCH_CUDA_ALLOC_CONF
- expandable_segments
- True
- to avoid fragmentation
- See
- documentation
- for
- Memory
- Management
- //pytorch.org/docs/stable/notes/cuda.html#environment-variables)
實際使用情況
作者表示,他們進行了一系列實驗,以確定訓練各種大小的模型所需的顯存(VRAM)要求。參數數量從 5 億到 140 億不等,他們比較了權重的完全微調與參數高效微調(使用 LoRA),所有訓練運行都在英偉達 H100 上完成,因此這里的 OOM 意味著 >80GB 的 VRAM。
在表格中,你可以找到 GSM8K 數據集上訓練的前 100 步中的峰值內存使用情況。用于實驗的模型是:
所有實驗均使用 Shadeform 的 GPU 市場完成,因此每次實驗只需要花費幾美元 H100。
實驗結果表明,內存需求隨著模型大小和訓練方式的不同而顯著變化。例如,全參數微調比 PEFT 需要更多的內存。
為什么 GRPO 對內存需求較高
這要從 GRPO 的原理說起,這是它的流程圖。
GRPO 對內存需求較高的原因在于,其內部涉及多個模型,并且在訓練數據中每個查詢會產生多個輸出。上圖中的策略模型、參考模型和獎勵模型各自都是一個需要進行推理的 LLM。(盡管從技術上講,獎勵模型可能不需要參數化,可以只是一個 Python 函數或正則表達式,但不影響 GRPO 對內存的高需求。)
為什么 8-Bit 優化和梯度檢查點有助于減少內存占用?
通常來講,訓練一個大型語言模型需要在內存中存儲三種主要類型的信息:模型參數、模型學習所需的梯度、優化器的跟蹤數據。
對上述內容我們可以這樣理解:如果模型的參數占用了 X 的空間,那么梯度也會占用大約相同的空間。然后,像 AdamW 這樣的優化器需要更多的空間,因為它們就像一個記錄員,跟蹤最近的更新歷史,以便更好地決定未來的優化。
為了減輕這種內存負擔,通常采用兩種技術:
- 首先,可以使用像 AdamW 這樣的 8-bit 優化器版本,它們能更高效地存儲跟蹤數據,同時仍保持良好的性能 —— 類似于壓縮照片可以節省空間,同時保留大部分圖像質量;
- 其次,使用梯度檢查點技術,這就像在訓練過程中拍攝快照,而不是記錄所有內容。雖然這會使訓練速度減慢約 20-30%,但它顯著減少了內存使用。
結合這些技術,即使對 GPU 資源有限的人來說,也能夠訓練更大的模型。
代碼示例
像 trl 這樣的庫已經開始支持 GRPO,使得微調由 transformers 構成的 LLM 變得非常簡單。代碼也非常簡潔,只需將訓練器替換為 GRPOTrainer 并定義一些獎勵即可。GRPO 的最小代碼量大約只有 99 行,如果你使用的是像 meta-llama/Llama-3.2-1B-Instruct 這樣的小型模型和像 openai/GSM8K 這樣的數據集,可以非常快速地啟動。
trl 項目地址:https://github.com/huggingface/trl?ref=ghost.oxen.ai
- import
- torch
- from
- datasets
- import
- load_dataset
- Dataset
- from
- transformers
- import
- AutoTokenizer
- AutoModelForCausalLM
- from
- trl
- import
- GRPOConfig
- GRPOTrainer
- import
- re
- SYSTEM_PROMPT
- Respond in the following format:
- def
- extract_hash_answer
- text
- str
- str
- None
- if
- "####"
- not
- in
- text
- return
- None
- return
- text
- split
- "####"
- 1
- strip
- def
- get_gsm8k_questions
- split
- "train"
- Dataset
- data
- load_dataset
- 'openai/gsm8k'
- 'main'
- split
- data
- data
- map
- lambda
- 'prompt'
- 'role'
- 'system'
- 'content'
- SYSTEM_PROMPT
- },
- 'role'
- 'user'
- 'content'
- 'question'
- ],
- 'answer'
- extract_hash_answer
- 'answer'
- return
- data
- def
- extract_xml_answer
- text
- str
- str
- answer
- text
- split
- 1
- answer
- answer
- split
- ""
- 0
- return
- answer
- strip
- def
- format_reward_func
- completions
- kwargs
- list
- float
- """Reward function that checks if the completion has a specific format."""
- pattern
- r
- "^\n\n$"
- \n.*?\n
- \n.*?\n
- responses
- completion
- 0
- "content"
- for
- completion
- in
- completions
- matches
- re
- match
- pattern
- r
- for
- r
- in
- responses
- return
- 0.5
- if
- match
- else
- 0.0
- for
- match
- in
- matches
- def
- accuracy_reward_func
- prompts
- completions
- answer
- kwargs
- list
- float
- """Reward function that extracts the answer from the xml tags and compares it to the correct answer."""
- responses
- completion
- 0
- 'content'
- for
- completion
- in
- completions
- extracted_responses
- extract_xml_answer
- r
- for
- r
- in
- responses
- return
- 2.0
- if
- r
- a
- else
- 0.0
- for
- r
- a
- in
- zip
- extracted_responses
- answer
- def
- main
- dataset
- get_gsm8k_questions
- model_name
- "meta-llama/Llama-3.2-1B-Instruct"
- model
- AutoModelForCausalLM
- from_pretrained
- model_name
- torch_dtype
- torch
- bfloat16
- attn_implementation
- "flash_attention_2"
- device_map
- None
- to
- "cuda"
- tokenizer
- AutoTokenizer
- from_pretrained
- model_name
- tokenizer
- pad_token
- tokenizer
- eos_token
- training_args
- GRPOConfig
- output_dir
- "output"
- learning_rate
- 5e-6
- adam_beta1
- 0.9
- adam_beta2
- 0.99
- weight_decay
- 0.1
- warmup_ratio
- 0.1
- lr_scheduler_type
- 'cosine'
- logging_steps
- 1
- bf16
- True
- per_device_train_batch_size
- 1
- gradient_accumulation_steps
- 4
- num_generations
- 4
- max_prompt_length
- 256
- max_completion_length
- 786
- num_train_epochs
- 1
- save_steps
- 100
- save_total_limit
- 1
- max_grad_norm
- 0.1
- log_on_each_node
- False
- trainer
- GRPOTrainer
- model
- model
- processing_class
- tokenizer
- reward_funcs
- format_reward_func
- accuracy_reward_func
- ],
- args
- training_args
- train_dataset
- dataset
- trainer
- train
- if
- __name__
- "__main__"
- main
Num Generations 有什么用
Num Generations 是一個超參數,它決定了我們將在訓練數據中對每個查詢采樣多少個補全。然而,這會顯著增加 VRAM 的消耗。
目前有一個開放的 GitHub 問題,可能會幫助解決內存瓶頸問題,可以參考如下鏈接
地址:https://github.com/huggingface/trl/issues/2709?ref=ghost.oxen.ai
對于 num_completions=8,16,64 (DeepSeekMath 論文使用的 64),作者表示,不用再次計算上述所有值,而是使用了 1B 參數模型進行了測試,以顯示內存增長。不過,作者還是建議大家在內存瓶頸得到修復之前使用 num_generations=4,也能獲得不錯的性能。
影響 VRAM 的一些因素
要對所有影響顯存(VRAM)使用的因素進行全面的超參數驗證,需要進行大量的實驗。簡單起見,這里只指出了需要注意的設置,以及實驗中使用的具體數值。
- batch_size=1,由于 GRPO 為每個查詢生成多個響應,batch size 會迅速失控。
- gradient_accumulation_steps=4,優化器是另一個占用大量 VRAM 的地方。此參數決定了我們將存儲的梯度以幫助優化器進行其「爬山」過程。
- num_completions=4,DeepSeekMath 論文中使用了 64。這完全超出了有些人的計算預算。
- max_prompt_length=256,如果你想訓練模型擁有更大上下文的推理能力,將不得不增加 VRAM。GSM8K 的提示相對較小,適合此測試。
- max_completion_length=786,同樣,由于計算注意力的內存有限,推理鏈在這里受到限制。上下文或生成的 token 越多,需要的內存就越大。
- LoRA target_modules=["q_proj", "k_proj", "o_proj", "up_proj", "down_proj"] 在這方面可以嘗試幾種不同的迭代。target_modules="all-linear" 是一種流行的方式,可以從你的 LoRA 中擠出最多的性能(就準確性而言)。
對 VRAM 使用的粗略估算
如果你正在使用 FP16 精度進行訓練,以下是一些簡單的估算方法,可以幫助你了解內存主要用在了哪些地方:
- 模型參數:每個參數占用 2 字節。
- 參考模型參數:每個參數占用 2 字節。
- 梯度:每個參數占用 2 字節。
- 優化器狀態:每個參數占用 8 字節。
- 8 位優化器:每個參數占用 4 字節。
- PEFT:有助于減少梯度的顯存占用。
最后是關于準確率的。作者完成了一個 10 億參數的 Llama 3.2 模型的完整訓練。在應用 GRPO 之前,該模型在保留測試集上達到了約 19% 的準確率,而在經過一個訓練周期后,模型的準確率飆升至約 40.5%。雖然這離 SOTA 水平還差得很遠,但這展示了 GRPO 的強大潛力。
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.