我有一个由单个矩阵的下对角线元素组成的数组。我可以按照如何在 NumPy 中将三角形矩阵转换为正方形中的方法将其转换为完整矩阵?. 对于单个矩阵,示例如下所示:
# Create the lower diagonal elements of a 6x6 matrix. ld = np.arange(21) # Create full 6x6 matrix x = np.zeros((6,6)) # Stuff lower triangular values into it x[np.tril_indices(6)] = ld # Populate upper triangular elements x = x + x.T # Fix diagonals (they got doubled) diag_idx = [0, 2, 5, 9, 14, 20] np.fill_diagonal(x, ld[diag_idx]) print(x)
我们得到了预期的完整矩阵
[[ 0. 1. 3. 6. 10. 15.] [ 1. 2. 4. 7. 11. 16.] [ 3. 4. 5. 8. 12. 17.] [ 6. 7. 8. 9. 13. 18.] [10. 11. 12. 13. 14. 19.] [15. 16. 17. 18. 19. 20.]]
现在我想将其扩展到具有 N 组下对角线元素的数组,并想要返回 N 个完整矩阵的数组。前者有形状(N, 21),后者有形状(N, 6, 6)。我将单个矩阵示例扩展为包含 2 个矩阵的示例
# Two sets of lower diagonal elements ld = np.arange(2*21).reshape(2, 21) # Two sets of full 6x6 matrices x = np.zeros((ld.shape[0], 6, 6)) # Find the lower triangular indices of each row and stuff them with the # values from the corresponding row in the lower diagonal array x[:, np.tril_indices(6)] = ld[:] # Populate upper triangular elements x[:] = x[:] + x[:].T # Fix diagonals (they got doubled) diag_idx = [0, 2, 5, 9, 14, 20] np.fill_diagonal(x[:], ld[:][diag_idx])
但我在线上出现形状不匹配x[:, np.tril_indices(6)] = ld[:]
x[:, np.tril_indices(6)] = ld[:]
ValueError: shape mismatch: value array of shape (2,21) could not be broadcast to indexing result of shape (2,2,21,6)
我可以对 N 组较低的对角线值进行正常的 Python 循环,但我试图通过 Numpy 来完成这一切。关于我的索引出错的地方有什么建议吗?
X 中的期望值为:
[[[ 0. 1. 3. 6. 10. 15.] [ 1. 2. 4. 7. 11. 16.] [ 3. 4. 5. 8. 12. 17.] [ 6. 7. 8. 9. 13. 18.] [10. 11. 12. 13. 14. 19.] [15. 16. 17. 18. 19. 20.]], [[21., 22., 24., 27., 31., 36.], [22., 23., 25., 28., 32., 37.], [24., 25., 26., 29., 33., 38.], [27., 28., 29., 30., 34., 39.], [31., 32., 33., 34., 35., 40.], [36., 37., 38., 39., 40., 41.]]]
你可以这样做,适用于任何N.
N
import numpy as np N = 3 ld = np.arange(N*21).reshape(N, 21) x = np.zeros((ld.shape[0], 6, 6)) tril_ind = np.tril_indices(6) x[:, tril_ind[0], tril_ind[1]] = ld x += np.transpose(x, (0, 2, 1)) diag_ind = np.diag_indices(6) x[:, diag_ind[0], diag_ind[1]] /= 2 print(x)
这打印
[[[ 0. 1. 3. 6. 10. 15.] [ 1. 2. 4. 7. 11. 16.] [ 3. 4. 5. 8. 12. 17.] [ 6. 7. 8. 9. 13. 18.] [10. 11. 12. 13. 14. 19.] [15. 16. 17. 18. 19. 20.]] [[21. 22. 24. 27. 31. 36.] [22. 23. 25. 28. 32. 37.] [24. 25. 26. 29. 33. 38.] [27. 28. 29. 30. 34. 39.] [31. 32. 33. 34. 35. 40.] [36. 37. 38. 39. 40. 41.]] [[42. 43. 45. 48. 52. 57.] [43. 44. 46. 49. 53. 58.] [45. 46. 47. 50. 54. 59.] [48. 49. 50. 51. 55. 60.] [52. 53. 54. 55. 56. 61.] [57. 58. 59. 60. 61. 62.]]]