IOK, donc je veux pouvoir choisir des valeurs dans une distribution normale qui ne se situent que entre 0 et 1. Dans certains cas, je veux pouvoir simplement renvoyer une distribution complètement aléatoire, et dans d'autres cas, je veux retourner des valeurs qui prendre la forme d'un gaussien.
En ce moment j'utilise la fonction suivante:
def blockedgauss(mu,sigma):
while True:
numb = random.gauss(mu,sigma)
if (numb > 0 and numb < 1):
break
return numb
Il prend une valeur dans une distribution normale, puis la rejette si elle se situe en dehors de la plage de 0 à 1, mais je pense qu'il doit y avoir une meilleure façon de procéder.
Il semble que vous souhaitiez une distribution normale tronquée . En utilisant scipy, vous pouvez utiliser scipy.stats.truncnorm
pour générer des variables aléatoires à partir d'une telle distribution:
import matplotlib.pyplot as plt
import scipy.stats as stats
lower, upper = 3.5, 6
mu, sigma = 5, 0.7
X = stats.truncnorm(
(lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
N = stats.norm(loc=mu, scale=sigma)
fig, ax = plt.subplots(2, sharex=True)
ax[0].hist(X.rvs(10000), normed=True)
ax[1].hist(N.rvs(10000), normed=True)
plt.show()
La figure du haut montre la distribution normale tronquée, la figure du bas montre la distribution normale avec la même moyenne mu
et l'écart-type sigma
.
Je suis tombé sur ce post en cherchant un moyen de retourner une série de valeurs échantillonnées à partir d'une distribution normale tronquée entre zéro et 1 (c'est-à-dire les probabilités). Pour aider quelqu'un d'autre qui a le même problème, je voulais juste noter que scipy.stats.truncnorm a la capacité intégrée ".rvs".
Donc, si vous vouliez 100 000 échantillons avec une moyenne de 0,5 et un écart type de 0,1:
import scipy.stats
lower = 0
upper = 1
mu = 0.5
sigma = 0.1
N = 100000
samples = scipy.stats.truncnorm.rvs(
(lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma,size=N)
Cela donne un comportement très similaire à numpy.random.normal, mais dans les limites souhaitées. L'utilisation de la fonction intégrée sera beaucoup plus rapide que le bouclage pour recueillir des échantillons, en particulier pour les grandes valeurs de N.
Au cas où quelqu'un voudrait une solution utilisant uniquement numpy, voici une implémentation simple utilisant une fonction normale et un clip (l'approche de MacGyver):
import numpy as np
def truncated_normal(mean, stddev, minval, maxval):
return np.clip(np.random.normal(mean, stddev), minval, maxval)
EDIT: ne l'utilisez PAS !! c'est ainsi que vous ne devriez pas le faire !! par exemple,a = truncated_normal(np.zeros(10000), 1, -10, 10)
peut sembler fonctionner, maisb = truncated_normal(np.zeros(10000), 100, -1, 1)
ne dessinera certainement pas une normale tronquée , comme vous pouvez le voir dans l'histogramme suivant:
Désolé pour cela, j'espère que personne ne s'est blessé! Je suppose que la leçon est, n'essayez pas d'imiter MacGyver au codage ...
Andres
J'ai créé un exemple de script comme suit. Il montre comment utiliser les API pour implémenter les fonctions que nous voulions, telles que générer des échantillons avec des paramètres connus, comment calculer CDF, PDF, etc. J'attache également une image pour le montrer.
#load libraries
import scipy.stats as stats
#lower, upper, mu, and sigma are four parameters
lower, upper = 0.5, 1
mu, sigma = 0.6, 0.1
#instantiate an object X using the above four parameters,
X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
#generate 1000 sample data
samples = X.rvs(1000)
#compute the PDF of the sample data
pdf_probs = stats.truncnorm.pdf(samples, (lower-mu)/sigma, (upper-mu)/sigma, mu, sigma)
#compute the CDF of the sample data
cdf_probs = stas.truncnorm.cdf(samples, (lower-mu)/sigma, (upper-mu)/sigma, mu, sigma)
#make a histogram for the samples
plt.hist(samples, bins= 50,normed=True,alpha=0.3,label='histogram');
#plot the PDF curves
plt.plot(samples[samples.argsort()],pdf_probs[samples.argsort()],linewidth=2.3,label='PDF curve')
#plot CDF curve
plt.plot(samples[samples.argsort()],cdf_probs[samples.argsort()],linewidth=2.3,label='CDF curve')
#legend
plt.legend(loc='best')