再将一个包含多个NDArray的list列表直接使用nd.array将其转换为NDArray的时候报TypeError: source_array must be array like object具体代码如下:
from mxnet import nd
import numpy as np
a = nd.ones((3,3))
print(a)
b = nd.zeros((3,3))
print(b)
c = nd.array(np.array([a,b]))
print(c)
控制台输出内容如下:
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
<NDArray 3x3 @cpu(0)>
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
<NDArray 3x3 @cpu(0)>
Traceback (most recent call last):
File "/home/zw/anaconda3/envs/arcFace/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2501, in array
source_array = np.array(source_array, dtype=dtype)
ValueError: setting an array element with a sequence.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/zw/workspace/dtn-antispoofing/DTN.py", line 202, in <module>
c = nd.array([a,b])
File "/home/zw/anaconda3/envs/arcFace/lib/python3.6/site-packages/mxnet/ndarray/utils.py", line 146, in array
return _array(source_array, ctx=ctx, dtype=dtype)
File "/home/zw/anaconda3/envs/arcFace/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py", line 2503, in array
raise TypeError('source_array must be array like object')
TypeError: source_array must be array like object
解决方法
通过NDArray的asnumpy()方法,先将NDArray转为numpy的array,然后再使用nd.array()将list转为NDArray
完整代码如下:
from mxnet import nd
import numpy as np
a = nd.ones((3,3))
print(a)
b = nd.zeros((3,3))
print(b)
c = nd.array([a.asnumpy(),b.asnumpy()])
print(c)
如果连接的两个NDArray的shape是一致的,还可以使用concat方法
d = []
d.extend(a)
d.extend(b)
c = nd.concat(*d,dim=0)