Переместить ось в тензорном потоке

У меня есть два тензора. Основной тензор выглядит следующим образом:

array([[[ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217]],

       [[ 450,  607,  493,  662],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 950, 1277, 1028, 1335],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]]], dtype=int32)

Я хочу переместить этот тензор в соответствии со следующим тензором:

array([0, 2, 5], dtype=int32)

Вышеупомянутый тензор содержит ось, к которой мы хотим, чтобы текущая ось двигалась.

Конечный тензор должен выглядеть так:

array([[[ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217],
        [ 298, 1217,  298, 1217]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 450,  607,  493,  662],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[ 950, 1277, 1028, 1335],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]],

       [[   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0],
        [   0,    0,    0,    0]]], dtype=int32)

Всего 1 ответ


Для этого вы можете использовать функцию разброса tf.scatter_nd .

Определите ваш input тензор:

input = tf.constant([[[ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217]],

   [[ 450,  607,  493,  662],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 950, 1277, 1028, 1335],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]]])

Поскольку нас интересуют только первые 3 элемента по нулевому измерению, давайте разделим его на новый тензор:

sliced_input = tf.slice(input, [0, 0, 0], [3, -1, -1])

Определите ваши целевые indices :

indices = tf.constant([[0], [2], [5]])

Определите shapes вашего целевого output , здесь так же, как ваша input форма:

shape = tf.shape(input)

Теперь используйте функцию разброса, чтобы получить ваш output :

output = tf.scatter_nd(indices, sliced_input, shape)

output :

array([[[ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217],
    [ 298, 1217,  298, 1217]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 450,  607,  493,  662],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[ 950, 1277, 1028, 1335],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]],

   [[   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0],
    [   0,    0,    0,    0]]], dtype=int32)

Есть идеи?

10000