Torch: 如何通过行洗牌张量?
2021-8-20 10:24:35
收藏:0
阅读:130
评论:4
我目前正在使用 Torch 实现一些输入数据的随机洗牌(在这种情况下是按行或第一维)。但是我新手 Torch,因此在 figuring out 如何使用排列时遇到了一些问题。
以下被认为是洗牌数据的:
if argshuffle then
local perm = torch.randperm(sids:size(1)):long()
print("\n\n\nSize of X and y before")
print(X:view(-1, 1000, 128):size())
print(y:size())
print(sids:size())
print("\nPerm size is: ")
print(perm:size())
X = X:view(-1, 1000, 128)[{{perm},{},{}}]
y = y[{{perm},{}}]
print(sids[{{1}, {}}])
sids = sids[{{perm},{}}]
print(sids[{{1}, {}}])
print(X:size())
print(y:size())
print(sids:size())
os.exit(69)
end
这会输出
Size of X and y before
99
1000
128
[torch.LongStorage of size 3]
99
1
[torch.LongStorage of size 2]
99
1
[torch.LongStorage of size 2]
Perm size is:
99
[torch.LongStorage of size 1]
5
[torch.LongStorage of size 1x1]
5
[torch.LongStorage of size 1x1]
99
1000
128
[torch.LongStorage of size 3]
99
1
[torch.LongStorage of size 2]
99
1
[torch.LongStorage of size 2]
从这些值中,我可以推断出该函数未对数据进行洗牌。我该如何正确洗牌,lua/torch 的常见解决方案是什么?
点赞
用户7630458
我也遇到了类似的问题。文档中没有针对张量的洗牌功能(对于数据集加载器,有相应的洗牌功能)。我用 torch.randperm 找到了解决这个问题的方法。
>>> a=torch.rand(3,5)
>>> print(a)
tensor([[0.4896, 0.3708, 0.2183, 0.8157, 0.7861],
[0.0845, 0.7596, 0.5231, 0.4861, 0.9237],
[0.4496, 0.5980, 0.7473, 0.2005, 0.8990]])
>>> # 行洗牌
...
>>> a=a[torch.randperm(a.size()[0])]
>>> print(a)
tensor([[0.4496, 0.5980, 0.7473, 0.2005, 0.8990],
[0.0845, 0.7596, 0.5231, 0.4861, 0.9237],
[0.4896, 0.3708, 0.2183, 0.8157, 0.7861]])
>>> # 列洗牌
...
>>> a=a[:,torch.randperm(a.size()[1])]
>>> print(a)
tensor([[0.2005, 0.7473, 0.5980, 0.8990, 0.4496],
[0.4861, 0.5231, 0.7596, 0.9237, 0.0845],
[0.8157, 0.2183, 0.3708, 0.7861, 0.4896]])
希望它能解决这个问题!
2018-11-13 15:47:07
用户10531501
根据您的语法,我认为您是在使用带有lua的torch而不是PyTorch。 [torch.Tensor.index](https://github.com/torch/torch7/blob/master/doc/tensor.md)是您的函数,它的工作方式如下:
x = torch.rand(4, 4)
p = torch.randperm(4)
print(x)
print(p)
print(x:index(1,p:long())
2019-07-27 08:35:54
用户9067615
如果你的张量的形状是 CxNxF(通道数 by 行数 by 特征数),那么你可以通过以下方式沿着第二个维度进行洗牌:
dim = 1
idx = torch.randperm(t.shape[dim])
t_shuffled = t[:,idx]
2021-06-27 15:22:37
评论区的留言会收到邮件通知哦~
推荐文章
- Lua 虚拟机加密load(string.dump(function)) 后执行失败问题如何解决
- 我想创建一个 Nginx 规则,禁止访问
- 如何将两个不同的lua文件合成一个 东西有点长 大佬请耐心看完 我是小白研究几天了都没搞定
- 如何在roblox studio中1:1导入真实世界的地形?
- 求解,lua_resume的第二次调用继续执行协程问题。
- 【上海普陀区】内向猫网络招募【Skynet游戏框架Lua后端程序员】
- SF爱好求教:如何用lua实现游戏内调用数据库函数实现账号密码注册?
- Lua实现网站后台开发
- LUA错误显式返回,社区常见的规约是怎么样的
- lua5.3下载库失败
- 请问如何实现文本框内容和某个网页搜索框内容连接,并把网页输出来的结果反馈到另外一个文本框上
- lua lanes多线程使用
- 一个kv数据库
- openresty 有没有比较轻量的 docker 镜像
- 想问一下,有大佬用过luacurl吗
- 在Lua执行过程中使用Load函数出现问题
- 为什么 neovim 里没有显示一些特殊字符?
- Lua比较两个表的值(不考虑键的顺序)
- 有个lua简单的项目,外包,有意者加微信 liuheng600456详谈,最好在成都
- 如何在 Visual Studio 2022 中运行 Lua 代码?

一种简单的解决方法是使用置换矩阵(在线性代数中通常使用)。由于您似乎对三维情况感兴趣,我们必须先展开您的三维张量。因此,以下是我想出的一个示例代码(可直接使用)
data=torch.floor(torch.rand(5,3,2)*100):float() reordered_data=data:view(5,-1) perm=torch.randperm(5); perm_rep=torch.repeatTensor(perm,5,1):transpose(1,2) indexes=torch.range(1,5); indexes_rep=torch.repeatTensor(indexes,5,1) permutation_matrix=indexes_rep:eq(perm_rep):float() permuted=permutation_matrix*reordered_data print("perm") print(perm) print("before permutation") print(data) print("after permutation") print(permuted:view(5,3,2))从执行结果可以看出,它根据
perm中给定的行索引重新排序了张量data。