#matplotlib #plot #legend #scatter-plot
#matplotlib #график #легенда #точечный график
Вопрос:
Вот код
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
import numpy as np
fig, subs = plt.subplots(4,3) #setting the shape of the figure in one line as opposed to creating 12 variables
iris = load_iris() ##code as per the example
data = np.array(iris['data'])
targets = np.array(iris['target'])
cd = {0:'r',1:'b',2:"g"}
cols = np.array([cd[target] for target in targets])
# Row 1
subs[0][0].scatter(data[:,0], data[:,1], c=cols)
subs[0][1].scatter(data[:,0], data[:,2], c=cols)
subs[0][2].scatter(data[:,0], data[:,3], c=cols)
# Row 2
subs[1][0].scatter(data[:,1], data[:,0], c=cols)
subs[1][1].scatter(data[:,1], data[:,2], c=cols)
subs[1][2].scatter(data[:,1], data[:,3], c=cols)
# Row 3
subs[2][0].scatter(data[:,2], data[:,0], c=cols)
subs[2][1].scatter(data[:,2], data[:,1], c=cols)
subs[2][2].scatter(data[:,2], data[:,3], c=cols)
#Row 4
subs[3][0].scatter(data[:,3], data[:,0], c=cols)
subs[3][1].scatter(data[:,3], data[:,1], c=cols)
subs[3][2].scatter(data[:,3], data[:,2], c=cols)
plt.show()
Мне было бы интересно добавить легенду, указывающую, что красные точки представляют 'setosa'
, зеленые точки 'versicolor'
и синие точки 'virginica'
. Эти легенды будут внизу и в центре рисунка выше. Как я могу это сделать?
Я думаю, что мне нужно поиграть fig.legend
, но я совсем не уверен, как это сделать.
Ответ №1:
Вы можете перебирать цели в одном из подзаголовков и выводить легенду за пределы этого графика. Вот что я получил с вашим кодом:
Вот код:
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
import numpy as np
fig, subs = plt.subplots(4,3, constrained_layout=True) #setting the shape of the figure in one line as opposed to creating 12 variables
iris = load_iris() ##code as per the example
data = np.array(iris['data'])
target_names = iris['target_names']
targets = np.array(iris['target'])
cd = {0:'r',1:'b',2:"g"}
cols = np.array([cd[target] for target in targets])
# Row 1
subs[0][0].scatter(data[:,0], data[:,1], c=cols)
subs[0][1].scatter(data[:,0], data[:,2], c=cols)
subs[0][2].scatter(data[:,0], data[:,3], c=cols)
# Row 2
subs[1][0].scatter(data[:,1], data[:,0], c=cols)
subs[1][1].scatter(data[:,1], data[:,2], c=cols)
subs[1][2].scatter(data[:,1], data[:,3], c=cols)
# Row 3
subs[2][0].scatter(data[:,2], data[:,0], c=cols)
subs[2][1].scatter(data[:,2], data[:,1], c=cols)
subs[2][2].scatter(data[:,2], data[:,3], c=cols)
# Row 4
subs[3][0].scatter(data[:,3], data[:,0], c=cols)
# loop for central subplot at last row
for t, name in zip(np.unique(targets), target_names):
subs[3][1].scatter(data[targets==t,3], data[targets==t,1], c=cd[t], label=name)
subs[3][1].legend(bbox_to_anchor=(2, -.2), ncol=len(target_names)) # you can play with bbox_to_anchor for legend position
subs[3][2].scatter(data[:,3], data[:,2], c=cols)
plt.savefig('legend')
РЕДАКТИРОВАТЬ: я также нашел этот пост в документации matplotlib, где вы можете напрямую извлекать элементы рассеяния из точечной диаграммы (без использования for
цикла). Я пробовал использовать набор данных IRIS, но не смог заставить его работать.
Комментарии:
1. Вы создали вертикальную легенду. Можете ли вы разместить эту легенду горизонтально под фигурой?
2. Да, настройка
bbox_to_anchor=(2, -.2)
иncol=len(target_names)
сделает свое дело (если вы попросите указать 3 столбца, метки будут расширяться по горизонтали). Я отредактирую ответ.3. Обратите внимание, что оси необходимо уменьшить, чтобы предотвратить слишком сильное перекрытие легенды с другими осями. Вероятно, это можно решить, добавив a
constrained_layout=True
вplt.subplots
функцию.4. Это идеально! : D
Ответ №2:
Добавьте label='versicor'
etc только к одному из ваших подзаголовков:
# Row 1
subs[0][0].scatter(data[:,0], data[:,1], c=cols, label='virginica')
subs[0][1].scatter(data[:,0], data[:,2], c=cols, label='setosa')
subs[0][2].scatter(data[:,0], data[:,3], c=cols, label='versicolor')
Затем вы можете вызвать fig.legend(loc='lower center', ncol=3)
before plt.show()
.
Я установил ncol=3
, чтобы сделать его коротким и широким.
Смотрите мой пример ниже:
import numpy as np
import matplotlib.pyplot as plt
a = range(1,11)
b = np.random.randn(10).cumsum()
c = np.random.randn(10).cumsum()
d = np.random.randn(10).cumsum()
fig, (ax1, ax2) = plt.subplots(2, figsize=(9,5))
ax1.plot(a,b, label = 'one')
ax1.plot(a,c, label = 'two')
ax1.plot(a,d, label = 'two')
ax2.plot(a,c)
fig.legend(loc='lower center', ncol=3)
plt.show()
Комментарии:
1. Это не ответило на мой вопрос. Можете ли вы заставить его работать с примером, который я вам предоставил?
2. Я показал вам, что нужно добавить в ваш код, и привел пример, не уверен, что это не ответило на вопрос. Вы добавили
label={species}
, как я показал? (У меня нет вашего набора данных).3. Да, у вас есть набор данных. Вам просто нужно использовать строку
from sklearn.datasets import load_iris
для ее импорта. На данный момент этот ответ не подходит. Я не могу с этим работать4. У меня не установлен sklearn, и я не хочу его устанавливать.