Как отобразить легенды в точечной диаграмме, чтобы различать классы

#python #matplotlib #scikit-learn #scatter-plot #iris-dataset

Вопрос:

Я работаю над набором данных радужной оболочки глаза от sklearn. Как вы, возможно, знаете, набор данных iris состоит из 3 классов [«сетоза», «разноцветный», «виргиния»]. Я сделал точечную диаграмму для этого набора данных. Подробности заключаются в следующем

 from sklearn.datasets import load_iris
iris=load_iris()
Y_train=iris.target
X_train=iris.data
class_labels=iris.target_names
plt.scatter(X_train[:,0], X_train[:,1], c=Y_train)
plt.xlabel('attr1')
plt.ylabel('attr2')
plt.show()
 

Участок Сакктера:

У меня есть точечная диаграмма, на которой вы можете видеть желтые, зеленые и фиолетовые точки. Я хочу знать, какая цветовая точка принадлежит к какому классу («setosa», «versicolor», «virginica»). Я хотел бы отобразить легенды, чтобы знать, какой цвет представляет какой класс

Ответ №1:

В этом случае вы можете создать пользовательскую легенду, прокручивая метки и используя ту же цветовую карту и норму, что и для точечной диаграммы. По умолчанию 'viridis' используется цветовая карта и норма, которая сопоставляет минимальное значение цвета с нулем, а максимальное-с единицей.

 import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

iris = load_iris()
Y_train = iris.target
X_train = iris.data
class_labels = iris.target_names
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(Y_train.min(), Y_train.max())
plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap='viridis', norm=norm)
handles = [plt.Line2D([0, 0], [0, 0], color=cmap(norm(i)), marker='o', linestyle='', label=label)
           for i, label in enumerate(class_labels)]
plt.legend(handles=handles, title='Species')
plt.show()
 

точечная диаграмма с легендой

Вы также можете использовать seaborn, хотя в настоящее время установка меток легенд не является простой.

 import seaborn as sns

sns.set()
ax = sns.scatterplot(x=X_train[:, 0], y=X_train[:, 1], hue=Y_train, palette='viridis')
ax.legend(ax.legend_.legendHandles, class_labels, title='Species')