Использование np.newaxis для вычисления суммы квадратов разностей

В главе 2 "Справочника по науке о данных Python" Джейк Вандерплас, он вычисляет сумму квадратов разностей нескольких двумерных точек, используя следующий код:

rand = np.random.RandomState(42)
X = rand.rand(10,2)
dist_sq = np.sum(X[:,np.newaxis,:] - X[np.newaxis,:,:]) ** 2, axis=-1)

Два вопроса:

  1. Почему создается третья ось? Как лучше всего представить себе, что происходит?
  2. Есть ли более интуитивный способ выполнить этот расчет?

Всего 1 ответ


Почему создается третья ось? Как лучше всего визуализировать происходящее?

Уловка добавления новых размеров перед добавлением / вычитанием является относительно распространенной уловкой для генерации всех пар с помощью широковещательной передачи (None здесь то же, что и np.newaxis):

>>> a = np.arange(10)
>>> a[:,None]
array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7],
       [8],
       [9]])

>>> a[None,:]
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

>>> a[:,None] + 100*a[None,:]
array([[  0, 100, 200, 300, 400, 500, 600, 700, 800, 900],
       [  1, 101, 201, 301, 401, 501, 601, 701, 801, 901],
       [  2, 102, 202, 302, 402, 502, 602, 702, 802, 902],
       [  3, 103, 203, 303, 403, 503, 603, 703, 803, 903],
       [  4, 104, 204, 304, 404, 504, 604, 704, 804, 904],
       [  5, 105, 205, 305, 405, 505, 605, 705, 805, 905],
       [  6, 106, 206, 306, 406, 506, 606, 706, 806, 906],
       [  7, 107, 207, 307, 407, 507, 607, 707, 807, 907],
       [  8, 108, 208, 308, 408, 508, 608, 708, 808, 908],
       [  9, 109, 209, 309, 409, 509, 609, 709, 809, 909]])

Ваш пример делает то же самое, только с 2-векторами вместо скаляров на самом внутреннем уровне:

>>> X[:,np.newaxis,:].shape
(10, 1, 2)

>>> X[np.newaxis,:,:].shape
(1, 10, 2)

>>> (X[:,np.newaxis,:] - X[np.newaxis,:,:]).shape
(10, 10, 2)

Таким образом, мы находим, что «магическое вычитание» - это просто все комбинации координат X, вычитаемые друг из друга.

Есть ли более интуитивный способ выполнить этот расчет?

Да, используйте scipy.spatial.distance.pdist для попарных расстояний. Чтобы получить результат, эквивалентный вашему примеру:

from scipy.spatial.distance import pdist, squareform
dist_sq = squareform(pdist(X))**2

Есть идеи?

10000