TensorFlow中的双线性张量积

问题描述 投票:16回答:2

我正在努力重新实施this paper,关键操作是双线性张量产品。我几乎不知道这意味着什么,但是纸张有一个很好的小图形,我明白了。

enter image description here

关键操作是e_1 * W * e_2,我想知道如何在tensorflow中实现它,因为其余部分应该很容易。

基本上,给定3D张量W,将其切片为矩阵,对于第j个切片(矩阵),将其在每一侧乘以e_1和e_2,得到标量,这是结果向量中的第j个条目(输出此操作)。

所以我想要执行e_1的乘积,d维向量,W,d x d x k张量,以及e_2,另一个d维向量。这个产品能否像现在一样在TensorFlow中简明扼要地表达,还是我必须以某种方式定义自己的操作?

早期编辑

为什么这些张量的乘法不起作用,是否有某种方法可以更明确地定义它以使其有效?

>>> import tensorflow as tf
>>> tf.InteractiveSession()
>>> a = tf.ones([3, 3, 3])
>>> a.eval()
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]], dtype=float32)
>>> b = tf.ones([3, 1, 1])
>>> b.eval()
array([[[ 1.]],

       [[ 1.]],

       [[ 1.]]], dtype=float32)
>>> 

错误消息是

ValueError: Shapes TensorShape([Dimension(3), Dimension(3), Dimension(3)]) and TensorShape([Dimension(None), Dimension(None)]) must have the same rank

目前

事实证明,乘以两个3D张量不能与tf.matmul一起使用,所以tf.batch_matmul可以。 tf.batch_matmul也会做3D张量和矩阵。然后我尝试了3D和矢量:

ValueError: Dimensions Dimension(3) and Dimension(1) are not compatible
tensorflow
2个回答
12
投票

您可以通过简单的重塑来完成此操作。对于两个矩阵乘法中的第一个,你有k * d,长度为d的矢量与dot product。

这应该是接近的:

temp = tf.matmul(E1,tf.reshape(Wddk,[d,d*k]))
result = tf.matmul(E2,tf.reshape(temp,[d,k]))

0
投票

您可以在W和e2之间执行秩3张量和矢量乘法,从而产生2D阵列,然后将结果乘以e1。以下函数利用张量积和张量收缩来定义该乘积(例如W * e3)

import sympy as sp

def tensor3_vector_product(T, v):
    """Implements a product of a rank 3 tensor (3D array) with a 
       vector using tensor product and tensor contraction.

    Parameters
    ----------
    T: sp.Array of dimensions n x m x k

    v: sp.Array of dimensions k x 1

    Returns
    -------
    A: sp.Array of dimensions n x m

    """
    assert(T.rank() == 3)
    # reshape v to ensure a 1D vector so that contraction do 
    # not contain x 1 dimension
    v.reshape(v.shape[0], )
    p = sp.tensorproduct(T, v)
    return sp.tensorcontraction(p, (2, 3))

您可以使用ref中提供的示例验证此乘法。上面的函数收缩了第二和第三轴,在你的情况下我认为你应该收缩(1,2)因为W被定义为d x d x k而不是k x d x d,就像我的情况一样。

© www.soinside.com 2019 - 2024. All rights reserved.