首页
社区
课程
招聘
[原创]大模型微调分享
发表于: 2026-5-17 21:16 1419

[原创]大模型微调分享

2026-5-17 21:16
1419

前言

由于在针对单片机和freertos类系统做逆向分析时,IDA中完全看不到任何符号,所以我想到了使用现有的qwen3-coder模型做微调以适应垂类任务

本文章全部基于我自己上传的unsloth项目制作,固如有不同架构,请另寻不同方法

开始

准备工作

本地需要安装docker,但是PGX上已经集成好了所有的环境,所以第一步应该是使用 apt 等命令更新所有可以更新的包,由于PGX机器的系统是NVIDIA定制的,所以源用的也是NVIDIA的软件源

代码部分

完整的微调代码很长,我放到最后了,这里只讲重要的部分

基座模型选择

选择基座模型时,我选择的是 unsloth/Qwen3-Coder-30B-A3B-Instruct 

# =========================
# 1) 选择Qwen3 Coder基座模型
#    说明:你可以改成你要微调的Qwen3 Coder具体型号
#    常见:Qwen/Qwen3-Coder-7B / 14B / 32B 等(以HF实际存在为准)
# =========================
base_model = "unsloth/Qwen3-Coder-30B-A3B-Instruct"   # ← 按需修改
output_dir = "./qwen3-coder-decompilebench-lora"

主要是因为在我做微调的时候,unsloth在huggingface上还没有上传qwen3其他的coder模型,我的目标是针对IDA的伪代码进行符号表的还原,这类任务基本来说就是对代码进行审计,所以使用coder模型的效果应当是最好的

基座模型之间的区别简单来说:base < instruction < thinking

base模型只是做了预训练

instruction在base的基础上做了指令对齐,学会了结构化输出

thinking在instruction的基础上加入了思考模式,使得模型可以做更复杂的推理

数据选择

对于所有AI训练,数据都是最重要的,这里我只使用了一个开源的数据集并随机截取了其中15k条数据

# =========================
# 4) 加载数据集:LLM4Binary/decompile-bench
#    说明:该数据集的字段名可能会随版本变化
#    这里先拉取并查看列名,然后做自适配映射
# =========================
ds = load_dataset("LLM4Binary/decompile-bench")

print(ds)
print("Train columns:", ds["train"].column_names if "train" in ds else None)
# 有些数据集只有一个split,比如 "test" 或 "validation"
# 你可以按实际情况选择split
split_name = "train" if "train" in ds else list(ds.keys())[0]
raw_train = ds[split_name]

# =========================
# 4.1) 关键优化:先抽样/截取到 10k~15k,再做 map/filter
#      避免对全量大数据做 format_sft/tokenize/filter,节省时间和CPU
# =========================
TARGET_N = 15000  # 你希望保留的样本数量(建议 10000~15000 之间)
SHUFFLE = True    # True:先打乱再取子集(更推荐,样本分布更随机)
SEED = 42         # shuffle 随机种子,保证可复现

n_total = len(raw_train)
n_take = min(TARGET_N, n_total)

if SHUFFLE:
    # 先洗牌再取前 n_take 条
    # 优点:相比直接取前N条更不容易出现“数据按来源/难度排序导致的偏差”
    raw_train_small = raw_train.shuffle(seed=SEED).select(range(n_take))
else:
    # 直接取前 n_take 条(速度最快,但可能存在顺序偏差)
    raw_train_small = raw_train.select(range(n_take))

print(f"Subset samples: {len(raw_train_small)}/{n_total}")

由于我们使用Lora做微调,所以数据不用太多,10k~15k就很完美,具体多少可以在训练时参考loss下降的速度

训练参数

# =========================
# 7) 训练参数
#    说明:GB10显存/统一内存很大,但仍建议从保守配置开始
# =========================
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,  # 等效batch=8
    learning_rate=2e-4,
    warmup_ratio=0.03,
    num_train_epochs=1,             # 先跑通流程;再加大到 2-3
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    fp16=torch.cuda.is_available() and (not torch.cuda.is_bf16_supported()),
    optim="paged_adamw_8bit",       # bitsandbytes 8bit优化器,省显存
    weight_decay=0.0,
    report_to="none",
    dataloader_num_workers=16,      # 按CPU核心数调大,例如 8/16/32
    dataloader_pin_memory=True,     # GPU拷贝(通常有帮助)
)

这里有几个参数非常重要

per_device_train_batch_size=8 这个参数代表了每次训练时喂进去的数据量,由于使用的PGX有128G统一内存,所以这里可以设置高一些

gradient_accumulation_steps=4 这个参数代表了梯度下降的速度,这里设置的效果会在后续训练时通过loss观察到,建议4或者8

learning_rate=2e-4 这个参数代表学习率,步长小会导致训练速度变慢,步长大会导致收敛结果不稳定

dataloader_num_workers=16 这个参数是用来读取工作进程数的,由于之前我把 per_device_train_batch_size 参数设置成了1,所以导致学习非常慢,且GPU跑不满(但是CPU占用很高),总之这里的值是根据你本机的CPU核数来的,一般是2的倍数

max_seq_length = 8192 这个参数是写在了加载模型的位置,这个参数代表了训练时的上下文长度,由于我们的数据是汇编和C对应的代码,可能会很长,所以这里上下文长度可以给大一点;这里的上下文长度不会影响最终模型的可用上下文长度

训练时

训练的代码非常简单,这里就不贴了

需要注意的是一般对于微调任务来说,时间不会超过三天,事实上一天半到两天就是一个正常范围,如果运行了一晚上,发现可预见的训练时长会超过三天,那么应当停止训练并检查上面提到的:数据是否过多;per_device_train_batch_size 参数是否过低;以及检查docker是否挂载了GPU

训练后生成GGUF文件

训练完成后我们会在指定的output文件夹下看到一个safetensors文件,如果想要使用Ollama来启动微调完成后的大模型,这个还不能用,需要使用 llama.cpp 工具做转换才行,整体的流程是先将safetensors文件转换成merge_hf文件,此时微调的增量参数会和基座模型合并成一个完整的模型,然后使用 llama.cpp 提供的 convert_hf_to_gguf.py 脚本将merge文件转成gguf文件,在转成gguf文件之前需要指定一个模型的量化类型,根据测试,就算我使用bf16,最终的对话返回也很快

Ollama如何调用

Ollama是没办法直接调用gguf文件的,主要是Ollama需要一个调用规定,写明如何与指定的大模型进行对话,这个规定文件称为 modelfile ,由于Ollama官方对于这个 modelfile 的教程很少,所以我们可以参考一下官网提供的模型的 modelfile 是怎么写的,我们的基座模型用的是 Qwen3-Coder-30B-A3B-Instruct ,那么我们就可以下载一个官方的qwen3-coder模型,然后使用 ollama show --modelfile 命令进行查看

# Modelfile generated by "ollama show"
# To build a new Modelfile based on this, replace FROM with:
# FROM qwen3-coder:30b-a3b-fp16

FROM /root/.ollama/models/blobs/sha256-2f3c93d7adf85fcfeb6620d80058b22c51d5a8b21ce18f1c58bd3004c0a63f45
TEMPLATE {{ .Prompt }}
RENDERER qwen3-coder
PARSER qwen3-coder
PARAMETER top_k 20
PARAMETER top_p 0.8
PARAMETER repeat_penalty 1.05
PARAMETER stop <|im_start|>
PARAMETER stop <|im_end|>
PARAMETER stop <|endoftext|>
PARAMETER temperature 0.7
LICENSE """                                 Apache License
                           Version 2.0, January 2004
                        38eK9s2c8@1M7q4)9K6b7g2)9J5c8W2)9J5c8Y4N6%4N6#2)9J5k6h3q4H3j5h3y4Z5k6g2)9J5k6h3!0J5k6#2)9J5c8X3I4A6j5$3g2F1M7$3g2K6i4K6u0r3

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

这里我只节选了最重要的一部分(后面全都是证书内容),这里需要着重解释一下对话模板

对话模板

对话模板(chat template)是告诉Ollama应该如何调用和拼接对话的,如果不设置,虽然大模型也能启动,但是会完全无法使用

TEMPLATE {{ .Prompt }}
RENDERER qwen3-coder
PARSER qwen3-coder

这三个参数是同一个作用,告诉Ollama对话模板使用源码中的模板

PARAMETER stop <|im_start|>
PARAMETER stop <|im_end|>
PARAMETER stop <|endoftext|>

这三个参数称之为停止符,用来告诉Ollama当模型返回到什么内容时应当截断和停止

我们可以假设我们给大模型发送一个“你好”,那么实际上发送给大模型的内容应该是这样的

<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks.<|im_end|>\n<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n

而大模型的返回会接着assistant继续,由于qwen3的对话模板直接写死在了Ollama源码中(开源社区对这种行为也是深恶痛绝),所以我们并不能看到和修改对话模板中的内容,这里我将 GPT-OSS:120b 的对话模板拷贝出来

TEMPLATE """<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: {{ currentDate }}
{{- if and .IsThinkSet .Think (ne .ThinkLevel "") }}

Reasoning: {{ .ThinkLevel }}
{{- else if or (not .IsThinkSet) (and .IsThinkSet .Think) }}

Reasoning: medium
{{- end }}

{{- $hasNonBuiltinTools := false }}
{{- if .Tools -}}
{{- $hasBrowserSearch := false }}
{{- $hasBrowserOpen := false }}
{{- $hasBrowserFind := false }}
{{- $hasPython := false }}
  {{- range .Tools }}
    {{- if eq .Function.Name "browser.search" -}}{{- $hasBrowserSearch = true -}}
    {{- else if eq .Function.Name "browser.open" -}}{{- $hasBrowserOpen = true -}}
    {{- else if eq .Function.Name "browser.find" -}}{{- $hasBrowserFind = true -}}
    {{- else if eq .Function.Name "python" -}}{{- $hasPython = true -}}
    {{- else }}{{ $hasNonBuiltinTools = true -}}
    {{- end }}
  {{- end }}
{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind $hasPython }}

# Tools
{{- if or $hasBrowserSearch $hasBrowserOpen $hasBrowserFind }}

## browser

// Tool for browsing.
// The `cursor` appears in brackets before each browsing display: `[{cursor}]`.
// Cite information from the tool using the following format:
// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`.
// Do not quote more than 10 words directly from the tool output.
// sources=web (default: web)
namespace browser {
{{- if $hasBrowserSearch }}

// Searches for information related to `query` and displays `topn` results.
type search = (_: {
query: string,
topn?: number, // default: 10
source?: string,
}) => any;
{{- end }}
{{- if $hasBrowserOpen }}

// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines.
// Valid link ids are displayed with the formatting: `【{id}†.*】`.
// If `cursor` is not provided, the most recent page is implied.
// If `id` is a string, it is treated as a fully qualified URL associated with `source`.
// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available.
// Use this function without `id` to scroll to a new location of an opened page.
type open = (_: {
id?: number | string, // default: -1
cursor?: number, // default: -1
loc?: number, // default: -1
num_lines?: number, // default: -1
view_source?: boolean, // default: false
source?: string,
}) => any;
{{- end }}
{{- if $hasBrowserFind }}

// Finds exact matches of `pattern` in the current page, or the page given by `cursor`.
type find = (_: {
pattern: string,
cursor?: number, // default: -1
}) => any;
{{- end }}

} // namespace browser
{{- end }}{{/* end if has browser tools */}}
{{- if $hasPython }}

## python

Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).

When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster.
{{- end }}{{/* end if hasPython */}}
{{- end }}{{/* end if has any built-in tools */}}
{{- end }}{{/* end if .Tools */}}

# Valid channels: analysis, commentary, final. Channel must be included for every message.{{ if $hasNonBuiltinTools }}
Calls to these tools must go to the commentary channel: 'functions'.
{{- end -}}<|end|>{{/* end of system */ -}}
{{- if or $hasNonBuiltinTools .System -}}
<|start|>developer<|message|>{{- if $hasNonBuiltinTools }}# Tools

## functions

namespace functions {
{{- range .Tools }}
{{- if not (or (eq .Function.Name "browser.search") (eq .Function.Name "browser.open") (eq .Function.Name "browser.find") (eq .Function.Name "python")) }}
{{if .Function.Description }}
// {{ .Function.Description }}
{{- end }}
{{- if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0) }}
type {{ .Function.Name }} = (_: {
{{- range $name, $prop := .Function.Parameters.Properties }}
{{- if $prop.Description }}
  // {{ $prop.Description }}
{{- end }}
  {{ $name }}: {{ $prop | toTypeScriptType }},
{{- end }}
}) => any;
{{- else }}
type {{ .Function.Name }} = () => any;
{{- end }}
{{- end }}{{/* end if not browser tool */}}
{{- end }}{{/* end of range .Tools */}}

} // namespace functions
{{- end }}{{/* end if hasNonBuiltinTools */}}
{{- if .System}}

# Instructions

{{ .System }}
{{- end -}}
<|end|>
{{- end -}}
{{- /* Find the index of the last user message */ -}}
{{- $lastUserIdx := -1 }}
{{- $prefillingContent := false }}
{{- $prefillingThinkingOnly := false }}
{{- range $i, $msg := .Messages }}
  {{- $last := eq (len (slice $.Messages $i)) 1 -}}
  {{- if eq $msg.Role "user" }}
    {{- $lastUserIdx = $i }}
  {{- end -}}
  {{- if and $last (eq $msg.Role "assistant") (gt (len $msg.Content) 0) }}
    {{- $prefillingContent = true }}
  {{- else if and $last (eq $msg.Role "assistant") (gt (len $msg.Thinking) 0) }}
    {{- $prefillingThinkingOnly = true }}
  {{- end }}
{{- end -}}
{{- /* Now render messages */ -}}
{{- range $i, $msg := .Messages }}
  {{- $last := eq (len (slice $.Messages $i)) 1 -}}
  {{- if (ne $msg.Role "system") -}}
    {{- if eq $msg.Role "tool" -}}
      {{- if or (eq $msg.ToolName "python") (eq $msg.ToolName "browser.search") (eq $msg.ToolName "browser.open") (eq $msg.ToolName "browser.find") -}}
        <|start|>{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
      {{- else -}}
        <|start|>functions.{{ $msg.ToolName }} to=assistant<|message|>{{ $msg.Content }}<|end|>
      {{- end -}}
    {{- else if eq $msg.Role "assistant" -}}
      {{- if and $msg.Thinking (gt $i $lastUserIdx) -}}{{- /* Show thinking only after last user message */ -}}
      <|start|>assistant<|channel|>analysis<|message|>{{ $msg.Thinking }}{{- if not $prefillingThinkingOnly -}}<|end|>{{- end -}}
      {{- end -}}
      {{- if gt (len $msg.Content) 0 -}}
        <|start|>assistant<|channel|>final<|message|>{{ $msg.Content }}{{- if not $prefillingContent -}}<|end|>{{- end -}}
      {{- end -}}
      {{- if gt (len $msg.ToolCalls) 0 -}}
        {{- range $j, $toolCall := $msg.ToolCalls -}}
          {{- $isBuiltin := or (eq $toolCall.Function.Name "python") (eq $toolCall.Function.Name "browser.search") (eq $toolCall.Function.Name "browser.open") (eq $toolCall.Function.Name "browser.find") -}}
          <|start|>assistant<|channel|>{{ if $isBuiltin }}analysis{{ else }}commentary{{ end }} to={{ if not $isBuiltin}}functions.{{end}}{{ $toolCall.Function.Name }} <|constrain|>json<|message|>{{ $toolCall.Function.Arguments }}<|call|>
        {{- end -}}
      {{- end -}}
    {{- else if eq $msg.Role "user" -}}
      <|start|>{{ $msg.Role }}<|message|>{{ $msg.Content }}<|end|>
    {{- end }}
  {{- else }}
  {{- end }}
{{- end -}}
{{- if not (or $prefillingContent $prefillingThinkingOnly) -}}
<|start|>assistant
{{- end -}}"""

修改对话模板本身就已经可以单独写一长篇文章(并且这也是一个单独的工种),这里我们只需要关注在对话模板中,这个大模型支持哪些功能,以及各个功能在使用过程中(如:mcp、function)需要注意的格式问题;事实上在微调后你发现模型并不复合你的预期,这并不一定是模型微调的不好或者数据不好,也有可能是对话模板的问题,这里我给出GPT-OSS系列模型的对话模板链接以供参考(6dbK9s2c8@1M7s2y4Q4x3@1q4Q4x3V1k6Q4x3V1k6Z5N6h3N6Y4K9h3&6Y4k6X3q4U0k6g2)9J5k6h3y4G2i4K6u0r3L8%4m8W2L8X3q4A6i4K6u0r3k6%4m8@1i4K6u0V1L8%4y4K6i4K6u0V1x3e0t1H3j5W2)9J5c8X3u0D9L8$3u0Q4x3V1k6E0j5h3W2F1i4K6u0r3j5$3S2S2N6q4)9#2k6Y4c8W2L8i4m8D9j5i4c8W2i4K6u0W2K9X3W2F1K9X3q4Q4c8f1k6Q4b7V1y4Q4z5o6V1`.

写在最后

上面的微调针对的是逆向还原符号表的任务,由于微调的时间跨度非常大,在我构思时coder类模型还没多少,等我完成微调并测试时,coder类模型就已经百花齐放了,个人用开源数据做的微调是无论如何也达不到专业团队用付费数据训练的大模型的效果的,所以我的微调不出意外的效果比新发布的大模型差,但是学习和踩坑过程还是很值得记录的

在你对现有大模型在垂类任务效果不好,想要微调之前请思考如下内容:

  1. 数据是否是永恒不变的,如果数据内容会改变,那么RAG更好
  2. 你的调用方式是否没有和对话模板对齐,这点非常重要!!!
  3. 是否有其他类型的模型也能满足你的需求,不要在一个树上吊死
  4. 是否可以开发工具或修改对话模板来满足垂类任务的需求,能不动模型本体就别动
  5. 你的数据是否足够好?普通的数据是微调不出满足要求的模型的

完整的微调代码

import os
import torch
import subprocess
import time
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForLanguageModeling, TrainerCallback
import os, subprocess, textwrap, json, shutil, time

print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0))
print("Mem free/total:", torch.cuda.mem_get_info())
os.environ["TOKENIZERS_PARALLELISM"] = "true"          # 让HF fast tokenizer多线程

# =========================
# 1) 选择Qwen3 Coder基座模型
#    说明:你可以改成你要微调的Qwen3 Coder具体型号
#    常见:Qwen/Qwen3-Coder-7B / 14B / 32B 等(以HF实际存在为准)
# =========================
base_model = "unsloth/Qwen3-Coder-30B-A3B-Instruct"   # ← 按需修改
output_dir = "./qwen3-coder-decompilebench-lora"

# =========================
# 2) 加载模型(Unsloth加速 + 4bit量化)
#    说明:GB10 128GB统一内存很大,一般可用4bit + LoRA
# =========================
max_seq_length = 8192   # 反编译数据可能较长;显存够可调大(如 3072/4096)
dtype = None            # 让Unsloth自动选bf16/fp16
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name       = base_model,
    max_seq_length   = max_seq_length,
    dtype            = dtype,
    load_in_4bit     = load_in_4bit,
)

# Qwen系通常需要右侧padding以配合训练(尤其是Flash Attention/packing)
tokenizer.padding_side = "right"

# =========================
# 3) 配置LoRA
#    说明:rank可按显存/效果权衡;target_modules给一个通用集合
# =========================
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=False,  # “unsloth”省显存; False取消
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
)

# =========================
# 4) 加载数据集:LLM4Binary/decompile-bench
#    说明:该数据集的字段名可能会随版本变化
#    这里先拉取并查看列名,然后做自适配映射
# =========================
ds = load_dataset("LLM4Binary/decompile-bench")

print(ds)
print("Train columns:", ds["train"].column_names if "train" in ds else None)
# 有些数据集只有一个split,比如 "test" 或 "validation"
# 你可以按实际情况选择split
split_name = "train" if "train" in ds else list(ds.keys())[0]
raw_train = ds[split_name]

# =========================
# 4.1) 关键优化:先抽样/截取到 10k~15k,再做 map/filter
#      避免对全量大数据做 format_sft/tokenize/filter,节省时间和CPU
# =========================
TARGET_N = 15000  # 你希望保留的样本数量(建议 10000~15000 之间)
SHUFFLE = True    # True:先打乱再取子集(更推荐,样本分布更随机)
SEED = 42         # shuffle 随机种子,保证可复现

n_total = len(raw_train)
n_take = min(TARGET_N, n_total)

if SHUFFLE:
    # 先洗牌再取前 n_take 条
    # 优点:相比直接取前N条更不容易出现“数据按来源/难度排序导致的偏差”
    raw_train_small = raw_train.shuffle(seed=SEED).select(range(n_take))
else:
    # 直接取前 n_take 条(速度最快,但可能存在顺序偏差)
    raw_train_small = raw_train.select(range(n_take))

print(f"Subset samples: {len(raw_train_small)}/{n_total}")

# =========================
# 5) 探测输入/输出字段名
#    数据集字段名可能随版本变化,这里用候选字段列表自动匹配
# =========================
CAND_IN  = ["input", "prompt", "asm", "assembly", "disasm", "disassembly", "source_asm", "decompiled_input"]
CAND_OUT = ["output", "completion", "answer", "c", "code", "decompile", "decompiled", "target"]

def pick_first_existing(example, candidates):
    """
    从 candidates 里按顺序挑第一个满足:
    1) 字段存在
    2) 非 None
    3) 去掉空白后非空字符串
    的字段名
    """
    for k in candidates:
        if k in example and example[k] is not None and str(example[k]).strip() != "":
            return k
    return None

# 用子集第一条数据探测字段
first = raw_train_small[0]
in_key  = pick_first_existing(first, CAND_IN)
out_key = pick_first_existing(first, CAND_OUT)

print("Detected in_key :", in_key)
print("Detected out_key:", out_key)

# 如果自动探测失败,建议 print(first) 查看真实字段名,再手动指定
if in_key is None or out_key is None:
    print("First example keys:", list(first.keys()))
    # print("First example:", first)  # 需要的话可以打开这一行看看完整样本
    raise ValueError("无法自动探测 input/output 字段,请根据样本手动设置 in_key/out_key。")

# =========================
# 构造 SFT 训练文本(ChatML/聊天格式)
# =========================
SYSTEM_PROMPT = "You are an expert reverse engineer. Decompile the given assembly/disassembly into readable, correct C code."

def format_sft(example):
    """
    把一条样本拼成一条聊天格式的训练文本:
    system: 角色设定
    user:   输入(汇编/反汇编/伪代码等)
    assistant: 输出(目标C代码)
    """
    user_content = f"""Decompile the following code to C.
[INPUT]
{example[in_key]}
"""

    assistant_content = f"""{example[out_key]}"""

    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_content},
    ]

    # tokenize=False:这里只生成纯文本,真正tokenize一般由 trainer/数据整理器做
    # add_generation_prompt=False:训练时不需要额外的 generation prompt
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    return {"text": text}

# =========================
# map:把子集映射成只有 "text" 一列的训练集
# 注意:remove_columns 会删掉原始列,减少内存占用
# =========================
train_ds = raw_train_small.map(
    format_sft,
    remove_columns=raw_train_small.column_names,
    # num_proc 不建议无脑开满,否则可能导致机器卡顿/频繁上下文切换
    num_proc=min(os.cpu_count(), 8),
    desc="Formatting SFT text",
)

# =========================
# 可选:过滤过长样本,避免OOM或严重截断
# 这里仍然只对“小子集”做过滤,成本可控
# =========================
def length_filter(example):
    """
    粗略长度过滤:
    - truncation=False:不截断,拿到真实长度
    - add_special_tokens=False:避免额外特殊token影响估算
    """
    ids = tokenizer(
        example["text"],
        truncation=False,
        add_special_tokens=False
    ).input_ids

    # 允许比 max_seq_length 稍微长一点点(比如 1.2倍),避免过滤过狠
    return len(ids) <= (max_seq_length * 1.2)

train_ds = train_ds.filter(
    length_filter,
    num_proc=min(os.cpu_count(), 8),
    desc="Filtering by length",
)

print(train_ds)
print(train_ds[0]["text"][:500])

# =========================
# 6) 数据整理:packing可显著提升吞吐(尤其是短样本多时)
# =========================
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# =========================
# 7) 训练参数
#    说明:GB10显存/统一内存很大,但仍建议从保守配置开始
# =========================
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,  # 等效batch=8
    learning_rate=2e-4,
    warmup_ratio=0.03,
    num_train_epochs=1,             # 先跑通流程;再加大到 2-3
    lr_scheduler_type="cosine",
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    fp16=torch.cuda.is_available() and (not torch.cuda.is_bf16_supported()),
    optim="paged_adamw_8bit",       # bitsandbytes 8bit优化器,省显存
    weight_decay=0.0,
    report_to="none",
    dataloader_num_workers=16,      # 按CPU核心数调大,例如 8/16/32
    dataloader_pin_memory=True,     # 加速CPU->GPU拷贝(通常有帮助)
)

class NvidiaSMICallback(TrainerCallback):
    """
    周期性调用 nvidia-smi 打印GPU利用率/显存等信息,便于在训练日志中实时观察
    """
    def __init__(self, every_n_steps=10, gpu_index=0):
        self.every_n_steps = every_n_steps
        self.gpu_index = gpu_index
        self._last_time = 0.0

    def _query(self):
        # 用nvidia-smi查询关键指标(单位:MiB、% 、W)
        cmd = [
            "nvidia-smi",
            f"--id={self.gpu_index}",
            "--query-gpu=timestamp,name,utilization.gpu,utilization.memory,temperature.gpu,power.draw,memory.used,memory.total",
            "--format=csv,noheader,nounits",
        ]
        out = subprocess.check_output(cmd, text=True).strip()
        return out

    def on_log(self, args, state, control, logs=None, **kwargs):
        # Trainer每次log时会调用这里;我们按step间隔打印GPU信息
        if state.global_step == 0:
            return
        if (state.global_step % self.every_n_steps) != 0:
            return

        try:
            info = self._query()
            # English output:
            print(f"[GPU MONITOR] step={state.global_step} | {info}")
        except Exception as e:
            print(f"[GPU MONITOR] step={state.global_step} | nvidia-smi query failed: {e}")

class TorchCudaMemCallback(TrainerCallback):
    """
    补充输出PyTorch视角的显存统计(allocated/reserved),便于定位碎片/缓存
    """
    def __init__(self, every_n_steps=10):
        self.every_n_steps = every_n_steps

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not torch.cuda.is_available():
            return
        if state.global_step == 0:
            return
        if (state.global_step % self.every_n_steps) != 0:
            return

        alloc = torch.cuda.memory_allocated() / (1024**3)
        reserv = torch.cuda.memory_reserved() / (1024**3)
        max_alloc = torch.cuda.max_memory_allocated() / (1024**3)
        max_reserv = torch.cuda.max_memory_reserved() / (1024**3)

        # English output:
        print(
            f"[TORCH CUDA MEM] step={state.global_step} | "
            f"allocated={alloc:.2f} GB, reserved={reserv:.2f} GB, "
            f"max_allocated={max_alloc:.2f} GB, max_reserved={max_reserv:.2f} GB"
        )
        
# =========================
# 8) 使用TRL的SFTTrainer进行监督微调
#    注意:transformers 4.57 + trl + unsloth 组合下,SFTTrainer是常用方式
# =========================
gpu_callbacks = [
    NvidiaSMICallback(every_n_steps=10, gpu_index=0),
    TorchCudaMemCallback(every_n_steps=10),
]

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_ds,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    packing=False,                   # 把多个样本拼到同一序列提升效率
    args=training_args,
    data_collator=data_collator,
    callbacks=gpu_callbacks,
)

# =========================
# 9) 开始训练
# ========================= 
train_result = trainer.train()
print(train_result)

# =========================
# 10) 保存LoRA权重与tokenizer
# =========================
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print("Saved to:", output_dir)

# 你已有的 output_dir,例如: "./qwen3-coder-decompilebench-lora"
LORA_DIR = output_dir

# 额外输出目录(合并后的 HF 模型)
MERGED_DIR = os.path.join(LORA_DIR, "merged_hf")

# GGUF 输出目录
GGUF_DIR = os.path.join(LORA_DIR, "gguf")
os.makedirs(GGUF_DIR, exist_ok=True)

# llama.cpp 路径(你需要自己改成实际位置)
LLAMA_CPP_DIR = os.environ.get("LLAMA_CPP_DIR", "../llama.cpp/")

# 你的 Ollama 模型名(最终 `ollama run` 用这个)
OLLAMA_MODEL_NAME = "qwen3-coder-decompilebench"

# 量化类型:常用 Q4_K_M / Q5_K_M / Q8_0
# Q4_K_M 体积小、速度快、质量还不错;Q8_0 质量好但大
GGUF_QUANT = "Q8_0"


def run(cmd, cwd=None):
    print(">>", " ".join(cmd))
    subprocess.run(cmd, cwd=cwd, check=True)
    
# 合并 LoRA 到基座,并保存为标准 HF 格式(fp16/bf16)
# 注意:你训练时 load_in_4bit=True,合并导出时最好转成 16-bit 权重再导出
# Unsloth 通常提供 save_pretrained_merged(不同版本函数名可能略有差异)
print("Merging LoRA into base model and saving HF merged model to:", MERGED_DIR)

try:
    # Unsloth 常见用法:FastLanguageModel.save_pretrained_merged
    # 保存为 16-bit(更适合后续转 GGUF)
    FastLanguageModel.save_pretrained_merged(
        model,
        tokenizer,
        MERGED_DIR,
        save_method="merged_16bit",   # 若你的 Unsloth 版本不支持该字段,见下面 except 的兜底
    )
except Exception as e:
    print("Unsloth merged export failed, fallback to manual merge. Error:", e)
    # 兜底:尝试使用 transformers/peft 的 merge(需要你是 PeftModel)
    from peft import PeftModel
    base = FastLanguageModel.from_pretrained(
        model_name=base_model,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=False,  # 合并时不要 4bit
    )[0]
    peft_model = PeftModel.from_pretrained(base, LORA_DIR)
    merged = peft_model.merge_and_unload()
    merged.save_pretrained(MERGED_DIR, safe_serialization=True)
    tokenizer.save_pretrained(MERGED_DIR)

[招生]科锐逆向工程师培训(2026年7月3日实地,远程教学同时开班, 第56期)!

收藏
免费 5
打赏
分享
最新回复 (0)
游客
登录 | 注册 方可回帖
返回