跳到主要内容

PyTorch 求导实例

在深度学习中,自动微分是一个非常重要的概念。它允许我们自动计算函数的导数,从而简化了梯度下降等优化算法的实现。PyTorch 提供了强大的自动微分功能,使得我们可以轻松地计算复杂函数的导数。本文将带你通过实例学习如何在 PyTorch 中进行求导。

什么是自动微分?

自动微分(Automatic Differentiation,简称 AD)是一种计算函数导数的方法。它通过将函数分解为一系列基本操作(如加法、乘法、指数等),并利用链式法则逐步计算导数。PyTorch 中的 autograd 模块实现了自动微分功能。

基本概念

在 PyTorch 中,Tensor 是核心数据结构。当我们创建一个 Tensor 并设置 requires_grad=True 时,PyTorch 会跟踪所有对该 Tensor 的操作,以便后续计算梯度。

示例:简单函数的求导

让我们从一个简单的例子开始,计算函数 y = x^2x = 2 处的导数。

python
import torch

# 创建一个 Tensor 并设置 requires_grad=True 以跟踪操作
x = torch.tensor(2.0, requires_grad=True)

# 定义函数 y = x^2
y = x ** 2

# 计算 y 对 x 的导数
y.backward()

# 打印导数
print(x.grad) # 输出: tensor(4.)

在这个例子中,我们首先创建了一个标量 x,并设置了 requires_grad=True。然后我们定义了函数 y = x^2,并通过调用 y.backward() 来计算 yx 的导数。最后,我们通过 x.grad 获取导数值。

解释

  • x = torch.tensor(2.0, requires_grad=True):创建一个标量 x,并启用梯度跟踪。
  • y = x ** 2:定义函数 y = x^2
  • y.backward():计算 yx 的导数。
  • x.grad:获取导数值。

实际应用:线性回归

让我们通过一个更实际的例子来理解自动微分的应用。假设我们有一个简单的线性回归模型 y = wx + b,其中 wb 是需要学习的参数。我们的目标是通过梯度下降法来优化这些参数。

步骤 1:定义模型和损失函数

python
import torch

# 定义模型参数
w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义输入数据和目标值
x = torch.tensor([1.0, 2.0, 3.0])
y_true = torch.tensor([2.0, 4.0, 6.0])

# 定义模型
def linear_model(x):
return w * x + b

# 定义损失函数(均方误差)
def loss_fn(y_pred, y_true):
return ((y_pred - y_true) ** 2).mean()

步骤 2:计算损失并更新参数

python
# 计算预测值
y_pred = linear_model(x)

# 计算损失
loss = loss_fn(y_pred, y_true)

# 计算梯度
loss.backward()

# 打印梯度
print(w.grad) # 输出: tensor(-6.)
print(b.grad) # 输出: tensor(-2.)

# 更新参数(假设学习率为 0.1)
learning_rate = 0.1
with torch.no_grad():
w -= learning_rate * w.grad
b -= learning_rate * b.grad

# 清零梯度
w.grad.zero_()
b.grad.zero_()

解释

  • wb 是需要学习的参数,我们设置了 requires_grad=True 以跟踪它们的梯度。
  • linear_model(x) 定义了线性模型 y = wx + b
  • loss_fn(y_pred, y_true) 计算了预测值与真实值之间的均方误差。
  • loss.backward() 计算了损失函数对 wb 的梯度。
  • 通过 w.gradb.grad 获取梯度值,并使用梯度下降法更新参数。
  • 最后,我们使用 zero_() 方法清零梯度,以便进行下一次迭代。

总结

通过本文的实例,我们学习了如何在 PyTorch 中使用自动微分功能进行求导。我们从简单的函数求导开始,逐步深入到线性回归模型的实际应用。自动微分是深度学习中不可或缺的工具,掌握它将帮助你更好地理解和实现各种机器学习算法。

附加资源

练习

  1. 尝试修改线性回归模型中的学习率,观察对训练过程的影响。
  2. 实现一个简单的神经网络,并使用自动微分功能进行训练。