web-dev-qa-db-fra.com

En utilisant Seaborn, comment puis-je tracer une ligne de mon choix sur mon nuage de points?

Je veux pouvoir tracer une ligne de mes spécifications à travers une parcelle générée dans seaborn. L'intrigue que j'ai choisie était JointGrid, mais n'importe quel nuage de points fera l'affaire. Je soupçonne que Seaborn ne facilite peut-être pas cela?

Voici le code traçant les données (cadres de données de l'ensemble de données Iris de longueur et largeur de pétale):

import seaborn as sns
iris = sns.load_dataset("iris")    
grid = sns.JointGrid(iris.petal_length, iris.petal_width, space=0, size=6, ratio=50)
    grid.plot_joint(plt.scatter, color="g")

enter image description here

Si vous prenez ce graphique à partir du jeu de données iris, comment puis-je tracer une ligne de mon choix à travers? Par exemple, une ligne de pente négative peut séparer les grappes et une pente positive peut les traverser.

10
user391339

Il semble que vous ayez importé matplotlib.pyplot as plt pour obtenir plt.scatter dans votre code. Vous pouvez simplement utiliser les fonctions matplotlib pour tracer la ligne:

import seaborn as sns
import matplotlib.pyplot as plt

iris = sns.load_dataset("iris")    
grid = sns.JointGrid(iris.petal_length, iris.petal_width, space=0, size=6, ratio=50)
grid.plot_joint(plt.scatter, color="g")
plt.plot([0, 4], [1.5, 0], linewidth=2)

enter image description here

16
chthonicdaemon

En créant un JointGrid dans seaborn, vous avez créé trois axes, le principal ax_joint, et les deux axes marginaux.

Pour tracer autre chose sur les axes de joint, nous pouvons accéder à la grille de joint en utilisant grid.ax_joint, puis créez des objets de tracé comme vous le feriez avec n'importe quel autre objet matplotlibAxes.

Par exemple:

import seaborn as sns
import matplotlib.pyplot as plt

iris = sns.load_dataset("iris")    
grid = sns.JointGrid(iris.petal_length, iris.petal_width, space=0, size=6, ratio=50)

# Create your scatter plot
grid.plot_joint(plt.scatter, color="g")

# Create your line plot.
grid.ax_joint.plot([0,4], [1.5,0], 'b-', linewidth = 2)

En passant, vous pouvez également accéder aux axes marginaux d'un JointGrid de la même manière:

grid.ax_marg_x.plot(...)
grid.ax_marg_y.plot(...)
5
tmdavison