Unsloth是什么?一小时让AI变聪明的神器!
2025年3月22日
来源:Practical Exercise: GRPO with Unsloth
想让AI变聪明,还不烧显存?Unsloth就是答案!它是个加速大语言模型(LLM)微调的工具,能让训练更快、更省资源,连免费Colab的T4 GPU都能玩转。
今天咱们用Unsloth搭配GRPO(组相对策略优化),把Google的Gemma 3 1B模型调教得会推理,小学数学题都能有模有样地解开。
先看成果:AI会“想”再答
调完后,AI不再乱扔答案,而是先推理再出结果。比如问“计算π”,它会输出:
<reasoning>
π是个数学常数,没法精确算,但可以用近似值,通常是3.14或3.14159。
</reasoning>
<answer>
3.14
</answer>
这逻辑多带感!下面是全流程,代码一个不落。
怎么搞定?一步步拆解
1. 装依赖
Unsloth加速微调,vLLM提速推理,先装上:
pip install unsloth vllm
pip install --upgrade pillow
Colab免费T4 GPU就能跑,家用电脑显存够也行。
2. 加载模型
用Unsloth的FastLanguageModel
加载Gemma 3 1B,加4-bit量化和LoRA优化:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # 调大支持长推理
lora_rank = 32 # 越大越聪明,但也慢
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="google/gemma-3-1b-it",
max_seq_length=max_seq_length,
load_in_4bit=True, # 16bit用False
fast_inference=True, # 开vLLM加速
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6 # 显存不够调低
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # 建议8、16、32、64、128
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], # 显存不够删QKVO
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # 支持长上下文
random_state=3407
)
3. 数据准备
用GSM8K数据集(小学数学题),要求AI按格式答题:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """<reasoning>{reasoning}</reasoning><answer>{answer}</answer>"""
数据处理代码:
import re
from datasets import load_dataset, Dataset
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1].split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text: return None
return text.split("####")[1].strip()
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split]
data = data.map(
lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question"]}
],
"answer": extract_hash_answer(x["answer"])
}
)
return data
dataset = get_gsm8k_questions()
4. 定义奖励函数
GRPO靠奖励函数引导AI,定义几个关键的:
def correctness_reward_func(prompts, completions, answer, kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
q = prompts[0][-1]["content"]
extracted_responses = [extract_xml_answer(r) for r in responses]
print("-" * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, kwargs) -> list[float]:
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1: count += 0.125
if text.count("\n</reasoning>\n") == 1: count += 0.125
if text.count("\n<answer>\n") == 1: count += 0.125; count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1: count += 0.125; count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
5. 训练
用GRPOTrainer跑250步:
from trl import GRPOConfig, GRPOTrainer
max_prompt_length = 256
training_args = GRPOConfig(
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=1, # 调到4更稳
num_generations=6, # 显存不够调小
max_prompt_length=max_prompt_length,
max_completion_length=max_seq_length - max_prompt_length,
max_steps=250,
save_steps=250,
max_grad_norm=0.1,
report_to="none",
output_dir="outputs"
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func],
args=training_args,
train_dataset=dataset
)
耐心点,150-200步后效果才明显。
6. 测试
训练完测一下:
from vllm import SamplingParams
text = tokenizer.apply_chat_template(
[{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "Calculate pi."}],
tokenize=False,
add_generation_prompt=True
)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024)
output = model.fast_generate(text, sampling_params=sampling_params, lora_request=model.load_lora("grpo_saved_lora"))[0].outputs[0].text
print(output)
7. 保存
先存LoRA权重:
model.save_lora("grpo_saved_lora")
再存16-bit完整模型:
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
推到Hugging Face Hub:
model.push_to_hub_merged("your-username/model-name", tokenizer, save_method="merged_16bit", token="your-token")
转GGUF给llama.cpp用:
model.push_to_hub_gguf("your-username/model-name", tokenizer, quantization_method=["q4_k_m", "q8_0", "q5_k_m"], token="your-token")
为啥爽?
Unsloth让微调省时省力,GRPO用奖励函数教AI又准又规范。代码全在这,自己跑一遍,看着AI从乱答变有逻辑,那成就感直接爆棚!想再挖深点?去Unsloth文档瞅瞅吧!