Unsloth是什么?一小时让AI变聪明的神器!

2025年3月22日

来源:Practical Exercise: GRPO with Unsloth

想让AI变聪明,还不烧显存?Unsloth就是答案!它是个加速大语言模型(LLM)微调的工具,能让训练更快、更省资源,连免费Colab的T4 GPU都能玩转。

今天咱们用Unsloth搭配GRPO(组相对策略优化),把Google的Gemma 3 1B模型调教得会推理,小学数学题都能有模有样地解开。

先看成果:AI会“想”再答

调完后,AI不再乱扔答案,而是先推理再出结果。比如问“计算π”,它会输出:

<reasoning>
π是个数学常数,没法精确算,但可以用近似值,通常是3.14或3.14159。
</reasoning>
<answer>
3.14
</answer>

这逻辑多带感!下面是全流程,代码一个不落。

怎么搞定?一步步拆解

1. 装依赖

Unsloth加速微调,vLLM提速推理,先装上:

pip install unsloth vllm
pip install --upgrade pillow

Colab免费T4 GPU就能跑,家用电脑显存够也行。

2. 加载模型

用Unsloth的FastLanguageModel加载Gemma 3 1B,加4-bit量化和LoRA优化:

from unsloth import FastLanguageModel
import torch

max_seq_length = 1024  # 调大支持长推理
lora_rank = 32  # 越大越聪明,但也慢

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="google/gemma-3-1b-it",
    max_seq_length=max_seq_length,
    load_in_4bit=True,  # 16bit用False
    fast_inference=True,  # 开vLLM加速
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.6  # 显存不够调低
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,  # 建议8、16、32、64、128
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # 显存不够删QKVO
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # 支持长上下文
    random_state=3407
)

3. 数据准备

用GSM8K数据集(小学数学题),要求AI按格式答题:

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """<reasoning>{reasoning}</reasoning><answer>{answer}</answer>"""

数据处理代码:

import re
from datasets import load_dataset, Dataset

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1].split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text: return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]
    data = data.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]}
            ],
            "answer": extract_hash_answer(x["answer"])
        }
    )
    return data

dataset = get_gsm8k_questions()

4. 定义奖励函数

GRPO靠奖励函数引导AI,定义几个关键的:

def correctness_reward_func(prompts, completions, answer, kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    q = prompts[0][-1]["content"]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print("-" * 20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1: count += 0.125
    if text.count("\n</reasoning>\n") == 1: count += 0.125
    if text.count("\n<answer>\n") == 1: count += 0.125; count -= len(text.split("\n</answer>\n")[-1]) * 0.001
    if text.count("\n</answer>") == 1: count += 0.125; count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

5. 训练

用GRPOTrainer跑250步:

from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256

training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,  # 调到4更稳
    num_generations=6,  # 显存不够调小
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    max_steps=250,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="none",
    output_dir="outputs"
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[xmlcount_reward_func, soft_format_reward_func, strict_format_reward_func, int_reward_func, correctness_reward_func],
    args=training_args,
    train_dataset=dataset
)

耐心点,150-200步后效果才明显。

6. 测试

训练完测一下:

from vllm import SamplingParams

text = tokenizer.apply_chat_template(
    [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "Calculate pi."}],
    tokenize=False,
    add_generation_prompt=True
)

sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024)
output = model.fast_generate(text, sampling_params=sampling_params, lora_request=model.load_lora("grpo_saved_lora"))[0].outputs[0].text
print(output)

7. 保存

先存LoRA权重:

model.save_lora("grpo_saved_lora")

再存16-bit完整模型:

model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")

推到Hugging Face Hub:

model.push_to_hub_merged("your-username/model-name", tokenizer, save_method="merged_16bit", token="your-token")

转GGUF给llama.cpp用:

model.push_to_hub_gguf("your-username/model-name", tokenizer, quantization_method=["q4_k_m", "q8_0", "q5_k_m"], token="your-token")

为啥爽?

Unsloth让微调省时省力,GRPO用奖励函数教AI又准又规范。代码全在这,自己跑一遍,看着AI从乱答变有逻辑,那成就感直接爆棚!想再挖深点?去Unsloth文档瞅瞅吧!