J'essaie de créer un diagramme de dispersion simple dans pyplot en utilisant un objet DataFrame Pandas), mais je veux un moyen efficace de tracer deux variables, mais les symboles sont dictés par une troisième colonne (clé). essayé de différentes manières en utilisant df.groupby, mais sans succès. Un exemple de script df est présenté ci-dessous. Il permet de colorer les marqueurs en fonction de 'key1', mais je voudrais voir une légende avec les catégories 'key1'. Suis-je fermé? Merci.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
Vous pouvez utiliser scatter
pour cela, mais cela nécessite d’avoir des valeurs numériques pour votre key1
, et vous n'aurez pas de légende, comme vous l'avez remarqué.
Il vaut mieux utiliser plot
pour des catégories discrètes comme celle-ci. Par exemple:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)
# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))
groups = df.groupby('label')
# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()
plt.show()
Si vous souhaitez que les choses ressemblent au style par défaut pandas
, mettez simplement à jour rcParams
avec la feuille de style pandas) et utilisez son générateur de couleur. (I Je modifie aussi légèrement la légende):
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)
# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))
groups = df.groupby('label')
# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')
fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')
plt.show()
C’est simple à faire avec Seaborn (pip install seaborn
) En tant que doublure
sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5)
:
import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)
df = pd.DataFrame(
np.random.normal(10, 1, 30).reshape(10, 3),
index=pd.date_range('2010-01-01', freq='M', periods=10),
columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)
sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5)
Voici le dataframe pour référence:
Comme vous disposez de trois colonnes variables dans vos données, vous pouvez tracer toutes les dimensions par paires avec:
sns.pairplot(vars=["one","two","three"], data=df, hue="key1", size=5)
https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ est une autre option.
Avec plt.scatter
, Je ne peux penser qu’à un seul: utiliser un artiste proxy:
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)
Et le résultat est:
Vous pouvez utiliser df.plot.scatter et passer un tableau à l'argument c = définissant la couleur de chaque point:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()
Vous pouvez également essayer Altair ou ggpot qui sont centrés sur les visualisations déclaratives.
import numpy as np
import pandas as pd
np.random.seed(1974)
# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))
from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')
from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()
C'est plutôt hacky, mais vous pouvez utiliser one1
comme un Float64Index
tout faire en une fois:
df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)
Notez qu'à partir de 0.20.3, il est nécessaire de trier l'index , et la légende est n peu confus .
À partir de matplotlib 3.1, vous pouvez utiliser .legend_elements()
. Un exemple est présenté dans Création automatique de légende . L'avantage est qu'un seul appel scatter peut être utilisé.
Dans ce cas:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3),
index = pd.date_range('2010-01-01', freq = 'M', periods = 10),
columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()
Au cas où les clés ne seraient pas directement attribuées sous forme de chiffres, cela ressemblerait à
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3),
index = pd.date_range('2010-01-01', freq = 'M', periods = 10),
columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")
labels, index = np.unique(df["key1"], return_inverse=True)
fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()