admin管理员组

文章数量:1611403

成功解决
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

报错内容

程序在这一步报错
objv_all = torch.cat((objv_all, objv))

于是分别打印出了objv_all, objv
结果:

原来是输入一个是数一个是列表

原因分析

直接上代码

① 两个变量都是tensor([])的形式

import torch
objv_all = torch.tensor([0.1])
objv = torch.tensor([0.2])
print(objv_all)
print(objv)
objv_all = torch.cat((objv_all, objv))
print(objv_all)

运行结果:
tensor([0.1000])
tensor([0.2000])
tensor([0.1000, 0.2000])
结果分析:
如果两个变量都是tensor([])的形式,则不报错。

② 两个变量都是tensor([])的形式

import torch
objv_all = torch.tensor(0.1)
objv = torch.tensor(0.2)
print(objv_all)
print(objv)
objv_all = torch.cat((objv_all, objv))
print(objv_all)

运行结果

此时报错

③两个变量是tensor()和tensor([])的形式

import torch
objv_all = torch.tensor(0.1)
objv = torch.tensor([0.2])
print(objv_all)
print(objv)
objv_all = torch.cat((objv_all, objv))
print(objv_all)

运行结果:

此时也报错

解决方法

由上分析可得,如果两个变量都是**tensor([])**的形式,则不报错。
有一方不是这个形式,则objv_all = torch.cat((objv_all, objv))则会报错。

将tensor(0.1)转化为tensor([0.1])的代码如下:

【使用np.expand_dims扩充维度】

import numpy as np
objv_all = torch.tensor(0.1).cuda()
print(objv_all)
objv_all = objv_all.cpu()
objv_all = np.expand_dims(objv_all, 0)
print(objv_all)
objv_all = torch.tensor(objv_all).cuda()
print(objv_all)

运行结果:

将③代码修改成正确代码

import torch
import numpy as np
objv_all = torch.tensor(0.1)
objv = torch.tensor([0.2])
objv_all = np.expand_dims(objv_all, 0)
print(objv_all)
objv_all = torch.tensor(objv_all)
print(objv_all)
print(objv)
objv_all = torch.cat((objv_all, objv))
print(objv_all)

运行结果:

修改成功!不报错了!

本文标签: DimensionalRuntimeErrorTensorconcatenatedPosition