Tensorflow embedding_lookup на нескольких измерениях

Я хотел бы выбрать часть этого тензора.

A = tf.constant([[[1,1],[2,2],[3,3]], [[4,4],[5,5],[6,6]]])

Выход А будет

[[[1 1]
  [2 2]
  [3 3]]

 [[4 4]
  [5 5]
  [6 6]]]

Индекс, который я хочу выбрать из A, равен [1, 0]. Я имею в виду [2 2] первой части и [4 4] второй части этого тензора, поэтому мой ожидаемый результат

[2 2]
[4 4]

Как я могу сделать это с помощью функции embedding_lookup?

B = tf.nn.embedding_lookup(A, [1, 0])

Я уже пробовал это

но это не мое ожидание.

[[[4 4]
  [5 5]
  [6 6]]

 [[1 1]
  [2 2]
  [3 3]]]

Может кто-нибудь помочь мне и объяснить, как это сделать?

Всего 1 ответ


Попробуйте следующее,

A = tf.constant([[[1,1],[2,2],[3,3]], [[4,4],[5,5],[6,6]]])
B = [1,0]
inds = [(a,b) for a,b in zip(np.arange(len(B)), B)]

C = tf.gather_nd(params=A, indices=inds)