Многомерная диагональная матрица, где диагональные элементы - это векторы строк или столбцов

У меня есть трехмерная матрица х с формой (2,4,6)

x = np.arange(2*4*3).reshape(2,4,3)
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [21, 22, 23]]])

Я хочу создать две диагональные матрицы ei и ej с формами (2,4,4,1,3) и (2,4,4,3,1), чтобы диагональные элементы ei и ej были векторами строк и столбцов элементы в х.

0-й элемент ei: ei [0]

0-й элемент ej: ej [0]

Например

ei = array([[
        [[[ 0.,  1.,  2.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],
         [[ 3.,  4.,  5.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]]],

        ...

        [[[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[ 0.,  0.,  0.]],
         [[21., 22., 23.]]]]])

Мой наивный путь приведен ниже

ei = np.zeros([x.shape[0],x.shape[1],x.shape[1],1,x.shape[2]])
ej = np.zeros([x.shape[0],x.shape[1],x.shape[1],x.shape[2],1])

for j in range(x.shape[0]):
  for i in range(x.shape[1]):
    ei[j,i,i] = x[j,i]
    ej[j,i,i] = np.transpose([x[j,i]])

Есть ли альтернативный способ сделать то же самое?

Всего 2 ответа


Замените циклы np.arange индексированием np.arange :

In [260]: arr = np.arange(12).reshape(3,4)                                                                           
In [261]: res = np.zeros((3,3,4),int)                                                                                
In [262]: res[np.arange(3), np.arange(3),:] = arr                                                                    
In [263]: res                                                                                                        
Out[263]: 
array([[[ 0,  1,  2,  3],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 8,  9, 10, 11]]])

Мы можем определить n-мерную нулевую матрицу и заполнить ее диагональ, используя расширенную индексацию:

def to_nd_diagonal(x):
    *i, j, k = x.shape
    a = np.zeros((*i,j,j,k))
    I = np.arange(j)
    a[...,I,I,:] = x
    out=a[...,None]
    return out, out.swapaxes(-1,-2)

Где мы получаем:

ej, ei = to_nd_diagonal(x)

ei.shape
# (2, 4, 4, 1, 3)
ej.shape
# (2, 4, 4, 3, 1)

print(ei)

array([[[[[ 0.,  1.,  2.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 3.,  4.,  5.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 6.,  7.,  8.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 9., 10., 11.]]]],



       [[[[12., 13., 14.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[15., 16., 17.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[18., 19., 20.]],

         [[ 0.,  0.,  0.]]],


        [[[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[ 0.,  0.,  0.]],

         [[21., 22., 23.]]]]])

print(ej)

array([[[[[ 0.],
          [ 1.],
          [ 2.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 3.],
          [ 4.],
          [ 5.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 6.],
          [ 7.],
          [ 8.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 9.],
          [10.],
          [11.]]]],



       [[[[12.],
          [13.],
          [14.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[15.],
          [16.],
          [17.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[18.],
          [19.],
          [20.]],

         [[ 0.],
          [ 0.],
          [ 0.]]],


        [[[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[ 0.],
          [ 0.],
          [ 0.]],

         [[21.],
          [22.],
          [23.]]]]])

Есть идеи?

10000