小能豆

使用转置进行 Numpy 数组乘法

py

我有以下 numpy 数组

import numpy as np
A = np.array([[0.5, 0.5]])

现在我想计算 A^t*A,为此我想到了以下内容

np.dot(A.T,A)

我想要得到的是一个数组形式

A_new = np.array([[0.0025, 0.0025], [0.0025,0.0025]])

但我实际上得到的只是一个数字

A_new = 0.005

我该如何进行这种数组乘法?2x1 形状乘以 1x2 形状不应该得到 2x2 形状吗?


阅读 19

收藏
2024-12-29

共1个答案

小能豆

你遇到的结果并不是一个数字,而是形状为 (1, 1) 的二维数组。让我们详细分析这个问题并提供解决方案。

原因分析

  1. 你的数组 A 的形状是 (1, 2)(1 行 2 列)。
    py A = np.array([[0.5, 0.5]]) print(A.shape) # (1, 2)

  2. 当你使用 A.T,你得到的是形状为 (2, 1) 的转置矩阵。

  3. np.dot(A.T, A) 对应的是矩阵乘法 (2, 1)(1, 2),结果是一个形状为 (2, 2) 的矩阵,如你预期:
    py result = np.dot(A.T, A) print(result) # 输出: # [[0.25 0.25] # [0.25 0.25]]

  4. 你的代码中得到单个值的原因可能是输入了错误的 A,比如形状为 (2,) 的 1D 数组。这会让 np.dot(A.T, A) 变成向量点积,返回一个标量。

检查问题

确保你的 A 是二维数组 (1, 2),而不是一维数组。可以使用以下方式强制二维形状:

A = np.array([[0.5, 0.5]])  # 明确为二维数组

或者:

A = np.array([0.5, 0.5]).reshape(1, -1)  # 使用 reshape 强制二维

解决方案

计算 A^T * A 的正确代码如下:

import numpy as np

A = np.array([[0.5, 0.5]])  # 确保是二维数组
A_new = np.dot(A.T, A)
print(A_new)
# 输出:
# [[0.25 0.25]
#  [0.25 0.25]]

其他有用的操作

如果你使用的是较新的 NumPy 版本,可以直接使用 @ 符号进行矩阵乘法:

A_new = A.T @ A
print(A_new)
# 输出:
# [[0.25 0.25]
#  [0.25 0.25]]

结论

  • 确保你的输入 A 是二维数组 (1, 2)
  • 使用 np.dot(A.T, A)A.T @ A 来进行矩阵乘法。
  • 如果输入是错误的(如一维数组),结果会和你的预期不同。
2024-12-29