当前位置: 首页 > news >正文

PyTorch 张量(Tensors)全面指南:从基础到实战

文章目录

    • 什么是张量?
    • 张量初始化方法
      • 1. 直接从数据创建
      • 2. 从 NumPy 数组转换
      • 3. 基于现有张量创建
      • 4. 使用随机值或常量
    • 张量属性
    • 张量操作
      • 设备转移
      • 索引和切片
      • 连接张量
      • 算术运算
      • 单元素张量转换
    • 原地操作(In-place Operations)
    • PyTorch 与 NumPy 互操作
      • 张量转 NumPy 数组
      • NumPy 数组转张量
    • 张量操作总结表
    • 最佳实践与注意事项

PyTorch Tensor Visualization

什么是张量?

张量(Tensors)是 PyTorch 中的核心数据结构,类似于数组和矩阵,但具有更强大的功能。在深度学习中,我们使用张量来表示:

  • 模型的输入和输出数据
  • 模型的参数(权重和偏置)
  • 中间计算过程中的数据

张量与 NumPy 的 ndarrays 类似,但有两大关键优势:

  1. GPU 加速:可在 GPU 或其他硬件加速器上运行
  2. 自动微分:支持自动求导,这对深度学习至关重要
import torch
import numpy as np

张量初始化方法

1. 直接从数据创建

data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)

2. 从 NumPy 数组转换

np_array = np.array(data)
x_np = torch.from_numpy(np_array)

3. 基于现有张量创建

x_ones = torch.ones_like(x_data)  # 保留原始张量属性
x_rand = torch.rand_like(x_data, dtype=torch.float)  # 覆盖数据类型print(f"Ones Tensor:\n{x_ones}")
print(f"Random Tensor:\n{x_rand}")

4. 使用随机值或常量

shape = (2, 3)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)print(f"Random Tensor:\n{rand_tensor}")
print(f"Ones Tensor:\n{ones_tensor}")
print(f"Zeros Tensor:\n{zeros_tensor}")

张量属性

每个张量都有三个关键属性:

tensor = torch.rand(3, 4)print(f"Shape: {tensor.shape}")    # 形状
print(f"Datatype: {tensor.dtype}") # 数据类型
print(f"Device: {tensor.device}")  # 存储设备 (CPU/GPU)

张量操作

设备转移

# 转移到GPU(如果可用)
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = tensor.to(device)
print(f"Device after transfer: {tensor.device}")

索引和切片

tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")tensor[:, 1] = 0  # 修改第二列
print(tensor)

连接张量

t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(f"Concatenated tensor:\n{t1}")

算术运算

# 矩阵乘法(三种等效方式)
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)
y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)# 逐元素乘法(三种等效方式)
z1 = tensor * tensor
z2 = tensor.mul(tensor)
z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

单元素张量转换

agg = tensor.sum()
agg_item = agg.item()  # 转换为Python标量
print(f"Sum value: {agg_item}, Type: {type(agg_item)}")

原地操作(In-place Operations)

原地操作直接修改张量内容,使用 _ 后缀表示:

print("Original tensor:")
print(tensor)tensor.add_(5)  # 原地加5
print("\nAfter in-place addition:")
print(tensor)

注意:虽然原地操作节省内存,但在自动微分中可能导致梯度计算问题,应谨慎使用。

PyTorch 与 NumPy 互操作

张量转 NumPy 数组

t = torch.ones(5)
n = t.numpy()
print(f"Tensor: {t}\nNumPy: {n}")# 修改张量会影响NumPy数组
t.add_(1)
print(f"\nAfter modification:\nTensor: {t}\nNumPy: {n}")

NumPy 数组转张量

n = np.ones(5)
t = torch.from_numpy(n)
print(f"NumPy: {n}\nTensor: {t}")# 修改NumPy数组会影响张量
np.add(n, 1, out=n)
print(f"\nAfter modification:\nNumPy: {n}\nTensor: {t}")

张量操作总结表

操作类型方法示例说明
创建torch.tensor(), torch.rand(), torch.zeros()多种初始化方式
属性.shape, .dtype, .device获取张量元数据
索引tensor[0], tensor[:, 1]类似NumPy的索引
运算torch.matmul(), tensor.sum()矩阵运算和归约
连接torch.cat(), torch.stack()合并多个张量
转换.numpy(), torch.from_numpy()与NumPy互转

最佳实践与注意事项

  1. 设备管理:明确张量所在的设备(CPU/GPU),避免不必要的设备间传输
  2. 数据类型:注意操作中的数据类型一致性,使用 .dtype 检查
  3. 内存共享:PyTorch 和 NumPy 数组共享内存,修改一个会影响另一个
  4. 自动微分:避免在需要梯度的计算图中使用原地操作
  5. 性能优化:对大规模数据使用 GPU 加速,对小规模操作可能 CPU 更高效
# 高效设备转移示例
if torch.cuda.is_available():tensor = tensor.to('cuda')# 保持数据类型一致
float_tensor = torch.rand(3, dtype=torch.float32)
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int32)
result = float_tensor + int_tensor.float()  # 显式转换

掌握张量操作是使用 PyTorch 进行深度学习的基础。通过本文介绍的各种方法,您可以高效地创建、操作和转换张量,为构建复杂模型奠定坚实基础!

官方文档:https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

http://www.lqws.cn/news/489493.html

相关文章:

  • WebSocket长连接在小程序中的实践:消息推送与断线重连机制设计
  • 全链接神经网络,CNN,RNN各自擅长解决什么问题
  • qt常用控件--02
  • uniapp+vue3做小程序,获取容器高度
  • 相机标定与3D重建技术通俗讲解
  • <tauri><threejs><rust><GUI>基于tauri和threejs,实现一个3D图形浏览程序
  • UE5 AnimMontage 的混合(Blend)模式
  • npm install时,遇到digital envelope routines::unsupported
  • BlazorWebView微软跨平台浏览器控件,UI组件
  • .NET多线程任务实现的几种方法及线程等待全面分析
  • Redis Stream 消息队列详解及 PHP 实现
  • 认识鸿蒙之了解应用结构
  • 关于华为Pura70Pro+升级鸿蒙NEXT和回退
  • 【Oracle篇】Windows平台单进程多线程架构设计与实现(比对Linux多进程架构)
  • 【Linux篇章】线程同步与互斥2:打破多线程并发困境,开启高效程序运行新境界
  • Gartner《Generative AI Use - Case Comparison for Legal Departments》
  • 【机器学习1】线性回归与逻辑回归
  • AI大模型之机器学习理论及实践:监督学习-机器学习的核心基石
  • 跟着AI学习C#之项目实践Day3
  • 【Linux网络编程】序列化与反序列化
  • 1个翠绿联网状态指示灯,闪烁未连接,常亮连接正常;软件如何实现
  • 浅析std::atomic<T>::compare_exchange_weak和std::atomic<T>::compare_exchange_strong
  • 【C++】C++中的虚函数和多态的定义与使用
  • AI 领航设计模式学习:飞算 JavaAI 解锁单例模式实践新路径
  • PROFIBUS DP转ETHERNET/IP在热电项目中的创新应用
  • WinUI3入门9:自制SplitPanel
  • Java基础(三):逻辑运算符详解
  • 提高WordPress网站加载速度和用户体验
  • C# SolidWorks二次开发-实战2,解决SolidWorks2024转step文件名乱码问题
  • 【25】木材表面缺陷数据集(有v5/v8模型)/YOLO木材表面缺陷检测