博客
关于我
PyTorch系列
阅读量:270 次
发布时间:2019-03-01

本文共 5635 字,大约阅读时间需要 18 分钟。

第1章 PyTorch中的张量

1.1 概念与介绍

1.1.1 概念

张量是标量、向量、矩阵的高维拓展,统一了多维数据的表示方式,是PyTorch中核心的数据结构。

1.1.2 Tensor与Variable

在PyTorch中,Tensor是核心的数据类型,而Variable则是用于封装Tensor并支持梯度计算的数据结构。Variable主要包含以下属性:

  • data:被包装的Tensor
  • grad:Tensor的梯度
  • grad_fn:创建Tensor的函数
  • requires_grad:指示是否需要梯度
  • is_leaf:指示是否是叶子节点(不可拆分的张量)

从PyTorch 0.4.0版本开始,Variable与Tensor合并,Tensor已经内置了Variable的功能。Tensor的主要属性包括:

  • dtype:数据类型
  • shape:张量的形状
  • device:所在设备(CPU/GPU)

1.2 创建张量

1.2.1 直接创建

PyTorch提供了多种方法来创建张量:

使用torch.tensor()

  • 功能:从数据创建Tensor
  • 参数
    • data:数据来源,可以是list、ndarray或Tensor
    • dtype:数据类型,默认与data一致
    • device:所在设备
    • requires_grad:是否需要梯度
    • pin_memory:是否存于锁页内存

例子

import torch
import numpy as np
# 创建Tensor并移动到GPU
arr = np.ones((3, 3))
t = torch.tensor(arr, device='cuda')
print(t) # 输出: tensor([[1., 1., 1.], ...], device='cuda:0', dtype=torch.float64)

使用torch.from_numpy()

  • 功能:从numpy创建Tensor
  • 注意:与numpy数组共享内存,修改一方会影响另一方

例子

arr = np.array([[1, 2, 3], [4, 5, 6]])
t = torch.from_numpy(arr)
# 修改arr时,t也会随之改变
arr[0, 0] = 0
print(arr) # numpy array: [[0, 2, 3], ...]
print(t) # tensor: tensor([[0, 2, 3], ...], dtype=torch.int32)

1.2.2 依据数值创建

PyTorch提供了多种方法依据数值创建张量:

torch.zeros()

  • 功能:创建全0张量
  • 参数
    • size:张量的形状
    • device:所在设备
    • requires_grad:是否需要梯度

例子

t = torch.zeros((3, 3))
print(t) # tensor([[0., 0., 0.], ...], dtype=torch.float64)

torch.full()和torch.full_like()

  • 功能:创建全0张量
  • 参数
    • size:张量的形状
    • fill_value:填充值

例子

t = torch.full((3, 3), 1)
print(t) # tensor([[1., 1., 1.], ..., dtype=torch.float64)

torch.arange()

  • 功能:创建等差的一维张量
  • 参数
    • start:数列起始值
    • end:数列结束值
    • steps:步长,默认为1

例子

t = torch.arange(2, 10, 2)
print(t) # tensor([2, 4, 6, 8])

torch.linspace()

  • 功能:创建均分的一维张量
  • 参数
    • start:数列起始值
    • end:数列结束值
    • steps:数列长度

例子

t = torch.linspace(2, 10, 5)
print(t) # tensor([2., 4., 6., 8., 10.])

torch.logspace()

  • 功能:创建对数均分的一维张量
  • 参数
    • start:数列起始值
    • end:数列结束值
    • steps:数列长度
    • base:对数函数的底,默认为10

例子

t = torch.logspace(0, 10, 100, base=10)
print(t) # tensor([1., 10., 100., ...])

torch.eye()

  • 功能:创建单位对角矩阵
  • 参数
    • n:矩阵的行数
    • m:矩阵的列数,默认与n相同

例子

t = torch.eye(2, 2)
print(t) # tensor([[1., 0.], [0., 1.]], dtype=torch.float64)

1.2.3 依据概率分布创建

PyTorch提供了多种方法依据概率分布创建张量:

torch.normal()

  • 功能:生成正态分布(高斯分布)
  • 参数
    • mean:均值
    • std:标准差

例子

mean = torch.arange(1, 5, dtype=torch.float)
std = torch.arange(1, 5, dtype=torch.float)
t_normal = torch.normal(mean, std)
print(mean, std, t_normal) # 输出: mean: tensor([1., 2., 3., 4.]), std: tensor([1., 2., 3., 4.]), t_normal: tensor([...])

torch.randn()和torch.randn_like()

  • 功能:在区间[0, 1)上生成均匀分布的随机数
  • 参数
    • size:张量的形状

例子

t = torch.rand((3, 3))
print(t) # tensor([[0.682..., 0.682..., ...], ...], dtype=torch.float64)

torch.randint()和torch.randint_like()

  • 功能:在区间[low, high)生成整数均匀分布
  • 参数
    • low:区间下界
    • high:区间上界
    • size:张量的形状

例子

t = torch.randint(0, 9, (3, 3))
print(t) # tensor([[3, 8, 8], ..., dtype=torch.long)

torch.bernoulli()

  • 功能:生成比努力分布(0-1分布)
  • 参数
    • input:概率值

例子

t = torch.bernoulli(torch.tensor([0.5, 0.3, 0.7]))
print(t) # tensor([1, 0, 1], dtype=torch.long)

torch.randperm()

  • 功能:生成从0到n-1的随机排列
  • 参数
    • n:排列的长度

例子

t = torch.randperm(8)
print(t) # tensor([0, 6, 4, 7, 2, 5, 1, 3])

1.3 张量的操作

拼接与切分

PyTorch提供了多种方法进行张量的拼接和切分:

torch.cat()
  • 功能:按维度拼接张量
  • 参数
    • tensors:张量序列
    • dim:拼接的维度
例子
t = torch.ones((2, 3))
t0 = torch.cat([t, t], dim=0)
t1 = torch.cat([t, t], dim=1)
print(t0, t1) # 输出: tensor([[1., 1., 1.], ...], shape: [4, 3]), tensor([[1., 1., 1., 1., 1., 1.], ...], shape: [2, 6])
torch.stack()
  • 功能:按维度拼接张量
  • 参数
    • tensors:张量序列
    • dim:拼接的维度
例子
t = torch.stack([t, t, t], dim=0)
print(t) # tensor([[[1., 1., 1.], [1., 1., 1.]], ...], shape: [3, 2, 3])
torch.chunk()
  • 功能:按维度切分张量
  • 参数
    • input:要切分的张量
    • chunks:切分的份数
    • dim:切分的维度
例子
t = torch.ones((2, 5))
list_of_tensors = torch.chunk(t, dim=1, chunks=2)
for idx, t in enumerate(list_of_tensors):
print(t.shape)
# 输出: shape: [2, 3], [2, 2], [2, 1]
torch.split()
  • 功能:按维度切分张量
  • 参数
    • tensor:切分的张量
    • split_size_or_sections:切分的份数或切分长度
    • dim:切分的维度
例子
t = torch.ones((2, 5))
list_of_tensors = torch.split(t, 2, dim=1)
for idx, t in enumerate(list_of_tensors):
print(t.shape)
# 输出: shape: [2, 2], [2, 2], [2, 1]

索引

PyTorch提供了多种方法进行张量的索引操作:

torch.index_select()
  • 功能:按维度索引张量
例子
t = torch.randint(0, 9, (3, 3))
idx = torch.tensor([0, 2], dtype=torch.long)
t_select = torch.index_select(t, dim=1, index=idx)
print(t, t_select) # 输出: tensor([[5, 8, 8], ...], tensor([[5, 8], ..., [1, 3]])
torch.masked_select()
  • 功能:按mask中的True索引张量
  • 参数
    • input:要索引的张量
    • mask:与input同形状的布尔张量
例子
t = torch.randint(0, 9, (3, 3))
mask = t.le(5)
t_select = torch.masked_select(t, mask)
print(t, t_select) # 输出: tensor([[5, 8, 8], ...], tensor([5, 8, 8]))

变换

PyTorch提供了多种方法进行张量的变换:

torch.reshape()
  • 功能:变换张量的形状
例子
t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1, 2, 2))
print(t, t_reshape) # 输出: tensor([...], shape: [8]), tensor([...], shape: [2, 4])
torch.transpose()
  • 功能:交换张量的维度
例子
t = torch.rand((2, 3, 4))
t_transpose = torch.transpose(t, 1, 2)
print(t, t_transpose) # 输出: tensor([...], shape: [2, 3, 4]), tensor([...], shape: [2, 4, 3])
torch.t()
  • 功能:将二维张量转置
例子
t = torch.rand((1, 2))
t_t = torch.t(t)
print(t, t_t) # 输出: tensor([...], shape: [2, 1]), tensor([...], shape: [1, 2])
torch.squeeze()
  • 功能:压缩长度为1的维度
例子
t = torch.rand((1, 2, 3, 1))
t_sq = torch.squeeze(t)
t_sq1 = torch.squeeze(t, dim=0)
t_sq2 = torch.squeeze(t, dim=1)
print(t.shape, t_sq.shape, t_sq1.shape, t_sq2.shape) # 输出: (1, 2, 3, 1), (2, 3), (2, 3, 1), (1, 2, 3, 1)
torch.unsqueeze()
  • 功能:扩展张量的维度
例子
t = torch.rand((2, 3))
t_unsqueezed = torch.unsqueeze(t, dim=1)
print(t.shape, t_unsqueezed.shape) # 输出: (2, 3), (2, 1, 3)

数学运算

PyTorch提供了丰富的数学运算功能,包括加减乘除、对数、指数、幂函数、三角函数等。以下是常用方法示例:

加减乘除
  • torch.add():逐元素加法
  • torch.sub():逐元素减法
  • torch.mul():逐元素乘法
  • torch.div():逐元素除法
对数、指数、幂函数
  • torch.log():对数
  • torch.exp():指数
  • torch.pow():幂运算
三角函数
  • torch.acos():反余弦
  • torch.asin():反正弦
  • torch.atan():反正切

这些操作都可以按元素方式进行,适用于张量计算。

转载地址:http://iedo.baihongyu.com/

你可能感兴趣的文章
Nodejs异步回调的处理方法总结
查看>>
NodeJS报错 Fatal error: ENOSPC: System limit for number of file watchers reached, watch ‘...path...‘
查看>>
Nodejs教程09:实现一个带接口请求的简单服务器
查看>>
nodejs服务端实现post请求
查看>>
nodejs框架,原理,组件,核心,跟npm和vue的关系
查看>>
Nodejs模块、自定义模块、CommonJs的概念和使用
查看>>
nodejs生成多层目录和生成文件的通用方法
查看>>
nodejs端口被占用原因及解决方案
查看>>
Nodejs简介以及Windows上安装Nodejs
查看>>
nodejs系列之express
查看>>
nodejs系列之Koa2
查看>>
Nodejs连接mysql
查看>>
nodejs连接mysql
查看>>
NodeJs连接Oracle数据库
查看>>
nodejs配置express服务器,运行自动打开浏览器
查看>>
Nodemon 深入解析与使用
查看>>
node~ http缓存
查看>>
node不是内部命令时配置node环境变量
查看>>
node中fs模块之文件操作
查看>>
Node中同步与异步的方式读取文件
查看>>