使用社交账号登录
核心目标不是学会很多 API,而是先把最小训练闭环真正串起来:
dataset -> dataloader -> model -> forward -> loss -> backward -> optimizer.step() -> eval -> save/load
训练一个神经网络,本质上是在重复下面几步:
dataloader 中取出一个 batch 的数据model 做前向计算,得到预测值loss function 比较预测值和真实标签,得到 lossloss 调用 backward(),计算梯度optimizer.step(),根据梯度更新参数optimizer.zero_grad(),清空旧梯度因为梯度指向的是 loss 增大的方向。
训练的目标是让 loss 下降,所以参数更新要沿着 负梯度方向 走:
Dataset 定义的是:第 i 个样本是什么。
它通常负责:
DataLoader 定义的是:这些样本怎么一批一批取出来。
它通常负责:
shuffle=True/False)num_workers)Dataset:单个样本级别DataLoader:batch 级别nn.Module 是什么在 PyTorch 中,模型通常继承自 nn.Module。
一般会在:
__init__() 中定义层forward() 中定义数据如何流过这些层forward() 是什么forward() 描述的是:输入如何变成输出。
例如:
PyTorch 在前向计算时会自动构建计算图。
调用 loss.backward() 后,会自动按链式法则计算梯度,并把结果存到参数的 .grad 里。
requires_grad=True 表示要跟踪梯度loss.backward() 会把梯度写到参数的 .grad 上optimizer 负责根据梯度更新参数。
例如:
这里:
model.parameters() 表示要训练哪些参数lr 表示学习率model.train():训练模式model.eval():评估/推理模式它们本身不会自动更新参数,也不会自动计算准确率。 它们的作用是切换某些层的行为,例如:
eval()因为如果模型里有 Dropout / BatchNorm,而你没有切到 eval():
这样同一个输入可能得到不稳定的输出。
checkpoint 就是把训练好的模型参数保存下来,之后再加载恢复。
推荐方式是保存 state_dict()。
训练闭环的核心是:
forward -> loss -> backward -> optimizer.step()
Dataset 和 DataLoader 的区别是:
模型通常继承 nn.Module,并在 forward() 中定义前向计算。
autograd 会自动求梯度,loss.backward() 后梯度会写到参数的 .grad 上。
optimizer 根据梯度更新参数,学习率控制更新步长。
推理时要用 model.eval(),并通常配合 torch.no_grad()。
推荐保存 state_dict(),加载时先重新创建模型再 load_state_dict()。
nn.Modulew = w - lr * w.gradfrom torch.utils.data import Dataset, DataLoader
import torch
class ToyDataset(Dataset):
def __init__(self):
self.X = torch.tensor([
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
])
self.y = torch.tensor([0, 1, 1, 0])
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
dataset = ToyDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for X, y in dataloader:
print("X shape:", X.shape)
print("y shape:", y.shape)
breakimport torch
from torch import nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 8)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(8, 2)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = SimpleNet()
X = torch.tensor([[1.0, 0.0]])
logits = model(X)
print("logits:", logits)
print("shape:", logits.shape)import torch
x = torch.tensor([[1.0, 2.0]])
y = torch.tensor([[1.0]])
w = torch.tensor([[0.5], [-1.0]], requires_grad=True)
b = torch.tensor([0.1], requires_grad=True)
z = x @ w + b
loss = (z - y).pow(2).mean()
print("loss:", loss.item())
loss.backward()
print("w.grad:", w.grad)
print("b.grad:", b.grad)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
class ToyDataset(Dataset):
def __init__(self):
self.X = torch.tensor([
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
])
self.y = torch.tensor([0, 1, 1, 0])
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 8)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(8, 2)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
def train_loop(dataloader, model, loss_fn, optimizer):
model.train()
for X, y in dataloader:
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def test_loop(dataloader, model, loss_fn):
model.eval()
total_loss = 0.0
total_correct = 0
total_samples = 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
total_loss += loss_fn(pred, y).item()
total_correct += (pred.argmax(dim=1) == y).sum().item()
total_samples += y.size(0)
avg_loss = total_loss / len(dataloader)
acc = total_correct / total_samples
return avg_loss, acc
dataset = ToyDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
model = SimpleNet()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(10):
train_loss = train_loop(dataloader, model, loss_fn, optimizer)
test_loss, test_acc = test_loop(dataloader, model, loss_fn)
print(f"epoch={epoch+1}, train_loss={train_loss:.4f}, test_loss={test_loss:.4f}, test_acc={test_acc:.2%}")import torch
# save
torch.save(model.state_dict(), "model_weights.pth")
# load
new_model = SimpleNet()
new_model.load_state_dict(torch.load("model_weights.pth", weights_only=True))
new_model.eval()