Day.42
hook函数:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(42)
np.random.seed(42)
张量钩子:
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3
def tensor_hook(grad):
print(f"原始梯度: {grad}")
return grad / 2
hook_handle = y.register_hook(tensor_hook)
z.backward()
print(f"x的梯度: {x.grad}")
ho@浙大疏锦行ok_handle.remove()