Torch - 像numpy repeat那样重复张量

我正在尝试用两种方法在torch中重复张量。例如将张量{1,2,3,4}两种方式各重复3次得到;

{1,2,3,4,1,2,3,4,1,2,3,4}
{1,1,1,2,2,2,3,3,3,4,4,4}

有一个内置的torch:repeatTensor函数将生成前面的一个(像numpy.tile()),但我找不到后面的一个功能(像numpy.repeat())。我相信我可以对第一个调用sort来得到第二个,但对于更大的数组来说可能计算代价比较高?

谢谢。

点赞
用户1257954
用户1257954
a = torch.Tensor([1,2,3,4])

想要得到 [1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.],我们需要将张量在第1维度上重复三次:

a.repeat(3)

想要得到 [1,1,1,2,2,2,3,3,3,4,4,4],我们需要在张量上添加一个维度并在第2维度上重复三次,以得到一个4 x 3的张量,然后再将其压平。

b = a.reshape(4,1).repeat(1,3).flatten()

或者

b = a.reshape(4,1).repeat(1,3).view(-1)
2016-02-14 03:15:42
用户4982729
用户4982729

以下是https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853中的引用 -

z = torch.FloatTensor([[1,2,3],[4,5,6],[7,8,9]])
1 2 3
4 5 6
7 8 9
z.transpose(0,1).repeat(1,3).view(-1, 3).transpose(0,1)
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
7 7 7 8 8 8 9 9 9

这将让你对它的工作原理有直观的感觉。

2018-09-10 13:38:47
用户1155034
用户1155034

以下是一个通用的函数,用于在张量中重复元素。

def repeat(tensor, dims):
    if len(dims) != len(tensor.shape):
        raise ValueError("第二个参数的长度必须等于第一个参数的维度数。")
    for index, dim in enumerate(dims):
        repetition_vector = [1]*(len(dims)+1)
        repetition_vector[index+1] = dim
        new_tensor_shape = list(tensor.shape)
        new_tensor_shape[index] *= dim
        tensor = tensor.unsqueeze(index+1).repeat(repetition_vector).reshape(new_tensor_shape)
    return tensor

如果你有

foo = tensor([[1, 2],
              [3, 4]])

通过调用 repeat(foo, [2,1]),你会得到

tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4]])

因此,您已经沿着维度0复制了每个元素,同时保留了维度1上的元素。

2019-04-23 12:34:27
用户6830722
用户6830722

你能试试像这样的东西吗:

import torch as pt

#1 与numpy中的tile函数相似

b = pt.arange(10)
print(b.repeat(3))

#2 与numpy中的tile函数相似

b = pt.tensor(1).repeat(10).reshape(2,-1)
print(b)

#3 与numpy中的repeat函数相似

t = pt.tensor([1,2,3])
t.repeat(2).reshape(2,-1).transpose(1,0).reshape(-1)
2019-05-16 05:44:37
用户6731799
用户6731799

尝试使用torch.repeat_interleave()方法:https://pytorch.org/docs/stable/torch.html#torch.repeat_interleave

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
2020-01-03 11:08:01
用户498892
用户498892

使用 einops:

from einops import repeat

repeat(x, 'i -> (repeat i)', repeat=3)
# 结果: {1,2,3,4,1,2,3,4,1,2,3,4}

repeat(x, 'i -> (i repeat)', repeat=3)
# 结果: {1,1,1,2,2,2,3,3,3,4,4,4}

这段代码对于任何框架(如numpy,torch,tf等)均适用。

2020-08-30 00:06:27