小能豆

matshow 的问题

py

我正在尝试使用 matplotlib matshow 一起显示矩阵和相关矢量数据。

vec_data = np.array([[ 0.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,  0.,  0.]])

mat_data = np.array([
       [ 0. ,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0. ],
       [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  1. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  0. ,  0.5,  0. ,  0.5,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0.1,  0.1,  0.1,  0.1,  0. ,  0.1,  0.1,  0.1,  0. ,  0.1],
       [ 0. ,  0. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  1. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
       [ 0. ,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0. ,  0.1,  0.1],
       [ 0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0.1,  0. ,  0. ],
       [ 0.1,  0.1,  0.1,  0.1,  0. ,  0.1,  0.1,  0.1,  0.1,  0. ]])

fig, axes = plt.subplots(2,1,figsize=(4,4),sharey=False,sharex=True,gridspec_kw = {'height_ratios':[25,1]})
axes[0].matshow(mat_data)
axes[1].matshow(vec_data)
axes[1].tick_params(direction='out', length=6, width=0)
axes[1].set_yticklabels([''])
axes[1].set_xlabel('vector')

生成的图像如下:

1.png

这里的问题是,当将这两个 matshow 图像放在一起时,第一个图像的 ylim 会混乱:它应该显示从 0 到 9 的值,但它只显示 0.5 到 8.5 的范围。如果我使用命令单独绘制图像

plt.matshow(mat_data)

我通过正确的 ylim 获得了所需的图像。

2.png

有人知道是什么原因导致了这个问题,我该如何解决?我尝试使用

axes[0].set_ylim([-0.5,9.5])

但它不起作用。

PS:我使用了关键字 gridspec_kw = {‘height_ratios’:[25,1]},以便矢量显示为矢量 - 否则它将显示为具有空白值的矩阵,如下所示。

3.png

plt.subplots 使用参数 sharex = True 来对齐向量和矩阵。如果没有该参数,则图形将如下所示

4.png

但请注意,ylim 的问题已经消失——因此该参数可能是导致此问题的主要原因。我想如果我能找到另一种不使用“sharex = True”来对齐两幅图像的方法,就可以解决这个问题。


阅读 16

收藏
2025-01-09

共1个答案

小能豆

子图的使用sharex=True会过度约束系统。因此,Matplotlib 将释放绘图限制,以便能够显示具有给定规格的绘图。

解决方案是使用sharex=False(默认)。然后高度比需要与图像的尺寸相匹配,即

fig, axes = plt.subplots(2,1,figsize=(4,4),sharey=False,sharex=False,
                    gridspec_kw = {'height_ratios':[mat_data.shape[0],vec_data.shape[0]]})

完整示例:

import numpy as np
import matplotlib.pyplot as plt

vec_data = np.array([[ 0.,  1.,  1.,  1.,  0.,  1.,  1.,  0.,  0.,  0.]])
mat_data = np.random.choice([0,.1,.5,1], size=(10,10))

fig, axes = plt.subplots(2,1,figsize=(4,4),sharey=False,sharex=False,
                    gridspec_kw = {'height_ratios':[mat_data.shape[0],vec_data.shape[0]]})
axes[0].matshow(mat_data)
axes[1].matshow(vec_data)
axes[1].tick_params(direction='out', length=6, width=0)
axes[1].set_yticklabels([''])
axes[1].set_xlabel('vector')

plt.show()

1.png

2025-01-09