Дублирование кода с подзаголовками динамического размера: как это решить?

#python #matplotlib #subplot

#python #matplotlib #подзаголовок

Вопрос:

Я пытаюсь сохранить изображение, состоящее из нескольких подзаголовков: проблема в том, что изначально я не знаю, сколько подзаголовков мне нужно (это зависит от num_plots ), поэтому я хочу получить функцию, которая генерирует графики по-разному в разных случаях:

  1. Если num_plots равно 3 или меньше, создайте строку из ровно num_plots подзаголовков
  2. Если их больше 3, сгенерируйте необходимое количество строк (каждая размером 3)

Проблема в том, что мне не удается сделать это, не избегая дублирования кода. Вот код, который я написал:

 def foo(num_plots, ...):
    # ... obtain data to plot ('all_results')
    total_cols = 3 if num_plots > 2 else num_plots
    total_rows = math.ceil(num_plots / total_cols)
    fig, axs = plt.subplots(nrows=total_rows, ncols=total_cols, figsize=(7 * total_cols, 7 * total_rows),
                            constrained_layout=True)
    for i in range(num_plots):
        row = i // total_cols
        pos = i % total_cols

        # psi_name and psi_value depend on the parameters of the function
    
        if num_plots == 1:
            sns.kdeplot(all_results[:, i], color='magenta', ax=axs)
            axs.hist(all_results[:, i], bins=nbins)
            axs.axvline(x=psi_value, linestyle='--', color='red')
            axs.set_title(f"Estimation test for {psi_name} = {psi_value}")
        elif total_rows == 1:
            sns.kdeplot(all_results[:, i], color='magenta', ax=axs[pos])
            axs[pos].hist(all_results[:, i], bins=nbins)
            axs[pos].axvline(x=psi_value, linestyle='--', color='red')
            axs[pos].set_title(f"Estimation test for {psi_name} = {psi_value}")
        else:
            sns.kdeplot(all_results[:, i], color='magenta', ax=axs[row, pos])
            axs[row, pos].hist(all_results[:, i], bins=nbins)
            axs[row, pos].axvline(x=psi_value, linestyle='--', color='red')
            axs[row, pos].set_title(f"Estimation test for {psi_name} = {psi_value}")
    plt.savefig(token_hex(8))
  

Как вы можете видеть, проблема заключается в том, что в зависимости от количества необходимых графиков я должен использовать axs или axs[pos] или axs[row, pos] вместо этого. Обратите внимание, что простое использование axs[row, pos] выдает ошибку. Как я могу решить эту проблему дублирования кода?

Вот несколько примеров графиков, полученных в результате этой функции:

Если num_plots равно 1:

Если num_plots равно 1

Если num_plots равно 3:

Если num_plots равно 3

Если num_plots равно 4:

Если num_plots равно 4

Ответ №1:

Работает ли метод add_subplots() последовательного рисования для ваших целей? Однако он настроен на рисование нескольких графиков на одной диаграмме, поэтому один из них увеличивается на размер диаграммы. Вы получите предупреждение, но его можно проигнорировать.

 import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(12,9))

x = np.linspace(-3, 3, 20)
cnt = 5
cols = 3
rows = round(cnt / cols,0)

for i in range(cnt):
    ax = fig.add_subplot(rows,cols,i 1)
    ax.plot(x, x**i)
  

введите описание изображения здесь

Комментарии:

1. Спасибо, это то, что я искал! Еще один небольшой вопрос: результирующие графики кажутся мне слишком далекими, если я попытаюсь использовать constrained_layout=True их, они становятся слишком близкими (перекрывая метки осей). Можете ли вы предложить мне способ получить более тесное построение без преувеличения? Заранее спасибо!

2. constrained_layout=False Что произойдет, если я разблокирую его? Попробуйте это, если это улучшится. fig.subplots_adjust(wspace=0.2, hspace=0.2)

Ответ №2:

Если вы хотите, чтобы ограниченный макет работал, вы можете создать спецификацию сетки и добавлять к ней по частям:

 import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(12,9), constrained_layout=True)


x = np.linspace(-3, 3, 20)
cnt = 5
cols = 3
rows = round(cnt / cols,0)

gs = fig.add_gridspec(int(rows), int(cols))

for ind in range(cnt):
    i = int(np.floor(ind / cols))
    j = ind % cols

    ax = fig.add_subplot(gs[i, j])
    ax.plot(x, x**ind)

plt.show()
  

Обратите внимание, что будущий matplotlib 3.4, add_subplot как указано выше, также будет работать с constrained_layout , но он еще не выпущен: https://github.com/matplotlib/matplotlib/pull/17347