PyTorch 求导实例
在深度学习中,自动微分是一个非常重要的概念。它允许我们自动计算函数的导数,从而简化了梯度下降等优化算法的实现。PyTorch 提供了强大的自动微分功能,使得我们可以轻松地计算复杂函数的导数。本文将带你通过实例学习如何在 PyTorch 中进行求导。
什么是自动微分?
自动微分(Automatic Differentiation,简称 AD)是一种计算函数导数的方法。它通过将函数分解为一系列基本操作(如加法、乘法、指数等),并利用链式法则逐步计算导数。PyTorch 中的 autograd
模块实现了自动微分功能。
基本概念
在 PyTorch 中,Tensor
是核心数据结构。当我们创建一个 Tensor
并设置 requires_grad=True
时,PyTorch 会跟踪所有对该 Tensor
的操作,以便后续计算梯度。
示例:简单函数的求导
让我们从一个简单的例子开始,计算函数 y = x^2
在 x = 2
处的导数。
python
import torch
# 创建一个 Tensor 并设置 requires_grad=True 以跟踪操作
x = torch.tensor(2.0, requires_grad=True)
# 定义函数 y = x^2
y = x ** 2
# 计算 y 对 x 的导数
y.backward()
# 打印导数
print(x.grad) # 输出: tensor(4.)
在这个例子中,我们首先创建了一个标量 x
,并设置了 requires_grad=True
。然后我们定义了函数 y = x^2
,并通过调用 y.backward()
来计算 y
对 x
的导数。最后,我们通过 x.grad
获取导数值。
解释
x = torch.tensor(2.0, requires_grad=True)
:创建一个标量x
,并启用梯度跟踪。y = x ** 2
:定义函数y = x^2
。y.backward()
:计算y
对x
的导数。x.grad
:获取导数值。
实际应用:线性回归
让我们通过一个更实际的例子来理解自动微分的应用。假设我们有一个简单的线性回归模型 y = wx + b
,其中 w
和 b
是需要学习的参数。我们的目标是通过梯度下降法来优化这些参数。
步骤 1:定义模型和损失函数
python
import torch
# 定义模型参数
w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
# 定义输入数据和目标值
x = torch.tensor([1.0, 2.0, 3.0])
y_true = torch.tensor([2.0, 4.0, 6.0])
# 定义模型
def linear_model(x):
return w * x + b
# 定义损失函数(均方误差)
def loss_fn(y_pred, y_true):
return ((y_pred - y_true) ** 2).mean()
步骤 2:计算损失并更新参数
python
# 计算预测值
y_pred = linear_model(x)
# 计算损失
loss = loss_fn(y_pred, y_true)
# 计算梯度
loss.backward()
# 打印梯度
print(w.grad) # 输出: tensor(-6.)
print(b.grad) # 输出: tensor(-2.)
# 更新参数(假设学习率为 0.1)
learning_rate = 0.1
with torch.no_grad():
w -= learning_rate * w.grad
b -= learning_rate * b.grad
# 清零梯度
w.grad.zero_()
b.grad.zero_()
解释
w
和b
是需要学习的参数,我们设置了requires_grad=True
以跟踪它们的梯度。linear_model(x)
定义了线性模型y = wx + b
。loss_fn(y_pred, y_true)
计算了预测值与真实值之间的均方误差。loss.backward()
计算了损失函数对w
和b
的梯度。- 通过
w.grad
和b.grad
获取梯度值,并使用梯度下降法更新参数。 - 最后,我们使用
zero_()
方法清零梯度,以便进行下一次迭代。
总结
通过本文的实例,我们学习了如何在 PyTorch 中使用自动微分功能进行求导。我们从简单的函数求导开始,逐步深入到线性回归模型的实际应用。自动微分是深度学习中不可或缺的工具,掌握它将帮助你更好地理解和实现各种机器学习算法。
附加资源
练习
- 尝试修改线性回归模型中的学习率,观察对训练过程的影响。
- 实现一个简单的神经网络,并使用自动微分功能进行训练。