Обеспечьте согласованную ширину столбцов при использовании тепловой карты с subplot2grid

#python #matplotlib #seaborn #heatmap #subplot

Вопрос:

Я пытаюсь отформатировать свои подзаголовки, но по какой-то причине я не могу понять, почему позиция не остается плоской для всех из них. Прямо сейчас они выглядят так:

введите описание изображения здесь

Как вы можете видеть, у меня есть две проблемы: 1. я не знаю, как исключить текстовые метки (например, «даты») и 2. Мне нужно отформатировать подзаголовки так, чтобы они имели одну и ту же ось, чтобы они оставались выровненными. Мой код до сих пор:

 fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid((23,20), (0,0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((23,20), (19,0), colspan=19, rowspan=1)

sns.set(font_scale=0.95)

sns.heatmap(pivot, ax= ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues")
sns.heatmap((pd.DataFrame(pivot.sum(axis=0))).transpose(), ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues", xticklabels=False, yticklabels=False)

plt.show()
 

Мой фрейм данных таков:

 dates   2020Q1  2020Q2  2020Q3  2020Q4  2021Q1  2021Q2  2021Q3
inicio                                                        
2020Q1    56.0    45.0    15.0     7.0     4.0     4.0     3.0
2020Q2     NaN   418.0   277.0    86.0    46.0    33.0    28.0
2020Q3     NaN     NaN   619.0   398.0   167.0   122.0    93.0
2020Q4     NaN     NaN     NaN  1163.0   916.0   521.0   319.0
2021Q1     NaN     NaN     NaN     NaN   976.0   680.0   363.0
2021Q2     NaN     NaN     NaN     NaN     NaN   811.0   559.0
2021Q3     NaN     NaN     NaN     NaN     NaN     NaN  1879.0
 

Комментарии:

1. Вам нужны метки x и y (например 2020Q1 )?

2. Я просто не могу удалить этикетки с рисунка. Я имею в виду ярлыки «даты» и «иничио». Другая часть состоит в том, чтобы выровнять оба участка…

3. Самый простой способ сопоставить ширину подзаголовков-это установить square=False . метки можно удалить с ax1.set_ylabel('') помощью , а метки-щекотки можно удалить с помощью ax1.set_xticklabels([])

4. Может быть, также поверните ярлыки с помощью ax1.set_yticklabels(pivot.columns, rotation=0)

5. Именно то, что мне было нужно. Спасибо, чувак! Вы хотите опубликовать его, чтобы я мог отметить как разрешенный?

Ответ №1:

  • Изменение square=True на square=False in seaborn.heatmap приведет к тому, что все столбцы будут иметь одинаковую ширину.
  • Метки можно удалить, установив их в виде пустой строки: ax1.set(xlabel='', ylabel='')
  • Протестировано в python 3.8.11 , pandas 1.3.3 , matplotlib 3.4.3 , seaborn 0.11.2
 import panda as pd

# test dataframe
data = {'dates': ['2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q4', '2020Q4', '2020Q4', '2020Q4', '2021Q1', '2021Q1', '2021Q1', '2021Q2', '2021Q2', '2021Q3'],
        'inicio': ['2020Q1', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2021Q1', '2021Q2', '2021Q3', '2021Q2', '2021Q3', '2021Q3'],
        'values': [56.0, 45.0, 15.0, 7.0, 4.0, 4.0, 3.0, 418.0, 277.0, 86.0, 46.0, 33.0, 28.0, 619.0, 398.0, 167.0, 122.0, 93.0, 1163.0, 916.0, 521.0, 319.0, 976.0, 680.0, 363.0, 811.0, 559.0, 1879.0]}
df = pd.DataFrame(data)

# pivot the dataframe
pv = df.pivot(index='dates', columns='inicio', values='values')

# create figure and subplots
fig = plt.figure(figsize=(20, 10))
ax1 = plt.subplot2grid((20, 10), (0, 0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((20, 10), (19, 0), colspan=19, rowspan=1)

sns.set(font_scale=0.95)

# create heatmap with square=False instead of True
sns.heatmap(pv, ax=ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues")
sns.heatmap(pv.sum().to_frame().T, ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues", xticklabels=False, yticklabels=False)

ax1.set_yticklabels(pv.columns, rotation=0)  # rotate the yticklabels
ax1.set(xlabel='', ylabel='')  # remove x amp; y labels
ax2.set(xlabel='', ylabel='')  # remove x amp; y labels

plt.show()
 

введите описание изображения здесь