Torch:分割张量

我想将我的数据集(10,000个50x50的RGB图像)分成两个数据集,类似于:

X = torch.rand(10000,3,50,50)
inds =torch.randperm(X:size(1))[{{1, nTrain}}]:long()
X_selected=X:index(1, inds)
X_remaining=X:delete(1,inds)

无论我搜索什么,我都只会得到Torch的GitHub文档,我该如何做?

点赞
用户3754413
用户3754413

你可以试试这个方法

X = torch.rand(10000, 3, 50, 50)
inds = torch.randperm(X:size(1)):long()
train_inds = inds:narrow(1, 1, nTrain)
valid_inds = inds:narrow(1, nTrain + 1, X:size(1) - nTrain)
X_train = X:index(1, train_inds)
X_valid = X:index(1, valid_inds)
2017-04-17 15:39:51