核心目标不是学会很多 API,而是先把最小训练闭环真正串起来:
dataset -> dataloader -> model -> forward -> loss -> backward -> optimizer.step() -> eval -> save/load
最小训练闭环是什么
训练一个神经网络,本质上是在重复下面几步:
- 从
dataloader中取出一个 batch 的数据 - 把数据送进
model做前向计算,得到预测值 - 用
loss function比较预测值和真实标签,得到loss - 对
loss调用backward(),计算梯度 - 调用
optimizer.step(),根据梯度更新参数 - 调用
optimizer.zero_grad(),清空旧梯度
需要理解的概念
- batch:一次送进模型的一批样本
- step:处理一个 batch 并更新一次参数
- epoch:完整遍历整个训练集一次
- loss:模型当前预测得有多差
- gradient:loss 对参数的变化率
- learning rate (lr):每次更新参数走多大一步
为什么参数更新时用减号
因为梯度指向的是 loss 增大的方向。
训练的目标是让 loss 下降,所以参数更新要沿着 负梯度方向 走:
Dataset 和 DataLoader
Dataset 是什么
Dataset 定义的是:第 i 个样本是什么。
它通常负责:
- 读取样本
- 返回输入和标签
- 支持按索引取数据
DataLoader 是什么
DataLoader 定义的是:这些样本怎么一批一批取出来。
它通常负责:
- 按 batch 组织样本
- 是否打乱顺序(
shuffle=True/False) - 多进程加载(
num_workers)
两者关系
Dataset:单个样本级别DataLoader:batch 级别
对应代码
Model、forward 和 autograd
nn.Module 是什么
在 PyTorch 中,模型通常继承自 nn.Module。
一般会在:
__init__()中定义层forward()中定义数据如何流过这些层
forward() 是什么
forward() 描述的是:输入如何变成输出。
例如:
- 输入先经过线性层
- 再经过激活函数
- 最后输出 logits
autograd 是什么
PyTorch 在前向计算时会自动构建计算图。
调用 loss.backward() 后,会自动按链式法则计算梯度,并把结果存到参数的 .grad 里。
对应代码
最小 autograd 例子
需要记住的点
requires_grad=True表示要跟踪梯度loss.backward()会把梯度写到参数的.grad上- 梯度默认会累积
optimizer、train/eval、checkpoint
optimizer 是什么
optimizer 负责根据梯度更新参数。
例如:
这里:
model.parameters()表示要训练哪些参数lr表示学习率
train 和 eval 的区别
model.train():训练模式model.eval():评估/推理模式
它们本身不会自动更新参数,也不会自动计算准确率。 它们的作用是切换某些层的行为,例如:
- Dropout
- BatchNorm
为什么推理前要 eval()
因为如果模型里有 Dropout / BatchNorm,而你没有切到 eval():
- Dropout 还会随机丢弃神经元
- BatchNorm 还会用当前 batch 的统计量
这样同一个输入可能得到不稳定的输出。
checkpoint 是什么
checkpoint 就是把训练好的模型参数保存下来,之后再加载恢复。
推荐方式是保存 state_dict()。
最小 train / test loop
保存和加载模型
最重要的结论
训练闭环的核心是:
forward -> loss -> backward -> optimizer.step()Dataset 和 DataLoader 的区别是:
- Dataset 负责单个样本
- DataLoader 负责 batch
模型通常继承
nn.Module,并在forward()中定义前向计算。autograd 会自动求梯度,
loss.backward()后梯度会写到参数的.grad上。optimizer 根据梯度更新参数,学习率控制更新步长。
推理时要用
model.eval(),并通常配合torch.no_grad()。推荐保存
state_dict(),加载时先重新创建模型再load_state_dict()。
到这里,应该已经能做到:
- 看懂一个最小 PyTorch 训练脚本
- 理解 batch / step / epoch 的关系
- 理解 loss、gradient、learning rate 的基本作用
- 写出一个最小
nn.Module - 写出最小 train/test loop
- 保存并加载模型参数