web-dev-qa-db-fra.com

Mergesort Python

Je ne pouvais pas trouver de code de fusion Python 3.3 fonctionnel, alors j'en ai créé un moi-même. Y a-t-il un moyen d'accélérer les choses? Il trie 20000 numéros en environ 0.3-0.5 secondes

def msort(x):
    result = []
    if len(x) < 2:
        return x
    mid = int(len(x)/2)
    y = msort(x[:mid])
    z = msort(x[mid:])
    while (len(y) > 0) or (len(z) > 0):
        if len(y) > 0 and len(z) > 0:
            if y[0] > z[0]:
                result.append(z[0])
                z.pop(0)
            else:
                result.append(y[0])
                y.pop(0)
        Elif len(z) > 0:
            for i in z:
                result.append(i)
                z.pop(0)
        else:
            for i in y:
                result.append(i)
                y.pop(0)
    return result
33
Hans

Vous pouvez initialiser la liste complète des résultats dans l’appel de niveau supérieur de mergesort:

result = [0]*len(x)   # replace 0 with a suitable default element if necessary. 
                      # or just copy x (result = x[:])

Ensuite, pour les appels récursifs, vous pouvez utiliser une fonction d'assistance à laquelle vous passez non pas des sous-listes, mais des index dans x. Et les appels de niveau inférieur lisent leurs valeurs à partir de x et écrivent directement dans result.

De cette façon, vous éviterez tout cela poping et appending qui devrait améliorer les performances.

13
firefrorefiddle

La première amélioration consisterait à simplifier les trois cas de la boucle principale: Plutôt que d’itérer alors que certaines séquences contiennent des éléments, itérer tant que les deux ont des éléments. En sortant de la boucle, l'un d'eux sera vide, on ne sait pas lequel, mais on s'en fiche: on les ajoute à la fin du résultat.

def msort2(x):
    if len(x) < 2:
        return x
    result = []          # moved!
    mid = int(len(x) / 2)
    y = msort2(x[:mid])
    z = msort2(x[mid:])
    while (len(y) > 0) and (len(z) > 0):
        if y[0] > z[0]:
            result.append(z[0])
            z.pop(0)
        else:
            result.append(y[0])
            y.pop(0)
    result += y
    result += z
    return result

La deuxième optimisation consiste à éviter popping les éléments. Plutôt, avoir deux indices:

def msort3(x):
    if len(x) < 2:
        return x
    result = []
    mid = int(len(x) / 2)
    y = msort3(x[:mid])
    z = msort3(x[mid:])
    i = 0
    j = 0
    while i < len(y) and j < len(z):
        if y[i] > z[j]:
            result.append(z[j])
            j += 1
        else:
            result.append(y[i])
            i += 1
    result += y[i:]
    result += z[j:]
    return result

Une dernière amélioration consiste à utiliser un algorithme non récursif pour trier les séquences courtes. Dans ce cas, j'utilise la fonction sorted intégrée et je l'utilise lorsque la taille de l'entrée est inférieure à 20:

def msort4(x):
    if len(x) < 20:
        return sorted(x)
    result = []
    mid = int(len(x) / 2)
    y = msort4(x[:mid])
    z = msort4(x[mid:])
    i = 0
    j = 0
    while i < len(y) and j < len(z):
        if y[i] > z[j]:
            result.append(z[j])
            j += 1
        else:
            result.append(y[i])
            i += 1
    result += y[i:]
    result += z[j:]
    return result

Mes mesures pour trier une liste aléatoire de 100 000 entiers sont de 2,46 secondes pour la version d'origine, 2,33 pour msort2, 0,60 pour msort3 et 0,40 pour msort4. Pour référence, le tri de la liste avec sorted prend 0,03 seconde.

59
anumi

Code du cours MIT. (avec coopérateur générique)

import operator


def merge(left, right, compare):
    result = []
    i, j = 0, 0
    while i < len(left) and j < len(right):
        if compare(left[i], right[j]):
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
    while i < len(left):
        result.append(left[i])
        i += 1
    while j < len(right):
        result.append(right[j])
        j += 1
    return result


def mergeSort(L, compare=operator.lt):
    if len(L) < 2:
        return L[:]
    else:
        middle = int(len(L) / 2)
        left = mergeSort(L[:middle], compare)
        right = mergeSort(L[middle:], compare)
        return merge(left, right, compare)
21
David Yachnis
def merge_sort(x):

    if len(x) < 2:return x

    result,mid = [],int(len(x)/2)

    y = merge_sort(x[:mid])
    z = merge_sort(x[mid:])

    while (len(y) > 0) and (len(z) > 0):
            if y[0] > z[0]:result.append(z.pop(0))   
            else:result.append(y.pop(0))

    result.extend(y+z)
    return result
18
sarath joseph

Prends ma mise en œuvre

def merge_sort(sequence):
    """
    Sequence of numbers is taken as input, and is split into two halves, following which they are recursively sorted.
    """
    if len(sequence) < 2:
        return sequence

    mid = len(sequence) // 2     # note: 7//2 = 3, whereas 7/2 = 3.5

    left_sequence = merge_sort(sequence[:mid])
    right_sequence = merge_sort(sequence[mid:])

    return merge(left_sequence, right_sequence)

def merge(left, right):
    """
    Traverse both sorted sub-arrays (left and right), and populate the result array
    """
    result = []
    i = j = 0
    while i < len(left) and j < len(right):
        if left[i] < right[j]:
            result.append(left[i])
            i += 1
        else:
            result.append(right[j])
            j += 1
    result += left[i:]
    result += right[j:]

    return result


print merge_sort([5, 2, 6, 8, 5, 8, 1])
8
Boubakr

Comme déjà dit, l.pop(0) est une opération O(len(l)) et doit être évité, la fonction msort ci-dessus est O (n ** 2). Si l’efficacité importe, l’indexation est préférable, mais elle a aussi un coût. Le for x in l est plus rapide mais difficile à implémenter pour mergesort: iter peut être utilisé ici. Enfin, la vérification de i < len(l) est effectuée deux fois car testée de nouveau lors de l'accès à l'élément: le mécanisme d'exception (try except) est meilleur et donne une dernière amélioration de 30%.

def msort(l):
    if len(l)>1:
        t=len(l)//2
        it1=iter(msort(l[:t]));x1=next(it1)
        it2=iter(msort(l[t:]));x2=next(it2)
        l=[]
        try:
            while True:
                if x1<=x2: l.append(x1);x1=next(it1)
                else     : l.append(x2);x2=next(it2)
        except:
            if x1<=x2: l.append(x2);l.extend(it2)
            else:      l.append(x1);l.extend(it1)
    return l
7
B. M.

Des boucles comme celle-ci peuvent probablement être accélérées:

for i in z:
    result.append(i)
    z.pop(0)

Au lieu de cela, faites simplement ceci:

result.extend(z)

Notez qu'il n'est pas nécessaire de nettoyer le contenu de z car vous ne l'utiliserez pas de toute façon.

6
Tamás

Une version plus longue qui compte les inversions et adhère à l’interface sorted. Il est trivial de modifier cela pour en faire une méthode d'objet qui trie sur place.

import operator

class MergeSorted:

    def __init__(self):
        self.inversions = 0

    def __call__(self, l, key=None, reverse=False):

        self.inversions = 0

        if key is None:
            self.key = lambda x: x
        else:
            self.key = key

        if reverse:
            self.compare = operator.gt
        else:
            self.compare = operator.lt

        dest = list(l)
        working = [0] * len(l)
        self.inversions = self._merge_sort(dest, working, 0, len(dest))
        return dest

    def _merge_sort(self, dest, working, low, high):
        if low < high - 1:
            mid = (low + high) // 2
            x = self._merge_sort(dest, working, low, mid)
            y = self._merge_sort(dest, working, mid, high)
            z = self._merge(dest, working, low, mid, high)
            return (x + y + z)
        else:
            return 0

    def _merge(self, dest, working, low, mid, high):
        i = 0
        j = 0
        inversions = 0

        while (low + i < mid) and (mid + j < high):
            if self.compare(self.key(dest[low + i]), self.key(dest[mid + j])):
                working[low + i + j] = dest[low + i]
                i += 1
            else:
                working[low + i + j] = dest[mid + j]
                inversions += (mid - (low + i))
                j += 1

        while low + i < mid:
            working[low + i + j] = dest[low + i]
            i += 1

        while mid + j < high:
            working[low + i + j] = dest[mid + j]
            j += 1

        for k in range(low, high):
            dest[k] = working[k]

        return inversions


msorted = MergeSorted()

Les usages

>>> l = [5, 2, 3, 1, 4]
>>> s = msorted(l)
>>> s
[1, 2, 3, 4, 5]
>>> msorted.inversions
6

>>> l = ['e', 'b', 'c', 'a', 'd']
>>> d = {'a': 10,
...      'b': 4,
...      'c': 2,
...      'd': 5,
...      'e': 9}
>>> key = lambda x: d[x]
>>> s = msorted(l, key=key)
>>> s
['c', 'b', 'd', 'e', 'a']
>>> msorted.inversions
5

>>> l = [5, 2, 3, 1, 4]
>>> s = msorted(l, reverse=True)
>>> s
[5, 4, 3, 2, 1]
>>> msorted.inversions
4

>>> l = ['e', 'b', 'c', 'a', 'd']
>>> d = {'a': 10,
...      'b': 4,
...      'c': 2,
...      'd': 5,
...      'e': 9}
>>> key = lambda x: d[x]
>>> s = msorted(l, key=key, reverse=True)
>>> s
['a', 'e', 'd', 'b', 'c']
>>> msorted.inversions
5
5
MikeRand

Un peu tard dans la soirée, mais je me suis dit que je mettrais mon chapeau dans le ring, ma solution semblant aller plus vite que celle des OP (sur ma machine, en tout cas):

# [Python 3]
def merge_sort(arr):
    if len(arr) < 2:
        return arr
    half = len(arr) // 2
    left = merge_sort(arr[:half])
    right = merge_sort(arr[half:])
    out = []
    li = ri = 0  # index of next element from left, right halves
    while True:
        if li >= len(left):  # left half is exhausted
            out.extend(right[ri:])
            break
        if ri >= len(right): # right half is exhausted
            out.extend(left[li:])
            break
        if left[li] < right[ri]:
            out.append(left[li])
            li += 1
        else:
            out.append(right[ri])
            ri += 1
    return out

Cela n'a pas de pop()s lente, et une fois qu'un des demi-tableaux est épuisé, il étend immédiatement l'autre sur le tableau de sortie plutôt que de démarrer une nouvelle boucle. 

Je sais que cela dépend de la machine, mais pour 100 000 éléments aléatoires (au-dessus de merge_sort() par rapport à sorted() de Python):

merge sort: 1.03605 seconds
Python sort: 0.045 seconds
Ratio merge / Python sort: 23.0229
2
Charles

Voici le CLRS Implementation:

def merge(arr, p, q, r):
    n1 = q - p + 1
    n2 = r - q
    right, left = [], []
    for i in range(n1):
        left.append(arr[p + i])
    for j in range(n2):
        right.append(arr[q + j + 1])
    left.append(float('inf'))
    right.append(float('inf'))
    i = j = 0
    for k in range(p, r + 1):
        if left[i] <= right[j]:
            arr[k] = left[i]
            i += 1
        else:
            arr[k] = right[j]
            j += 1


def merge_sort(arr, p, r):
    if p < r:
        q = (p + r) // 2
        merge_sort(arr, p, q)
        merge_sort(arr, q + 1, r)
        merge(arr, p, q, r)


if __== '__main__':
    test = [5, 2, 4, 7, 1, 3, 2, 6]
    merge_sort(test, 0, len(test) - 1)
    print test

Résultat:

[1, 2, 2, 3, 4, 5, 6, 7]
2
Little Roys
def mergeSort(alist):
    print("Splitting ",alist)
    if len(alist)>1:
        mid = len(alist)//2
        lefthalf = alist[:mid]
        righthalf = alist[mid:]

        mergeSort(lefthalf)
        mergeSort(righthalf)

        i=0
        j=0
        k=0
        while i < len(lefthalf) and j < len(righthalf):
            if lefthalf[i] < righthalf[j]:
                alist[k]=lefthalf[i]
                i=i+1
            else:
                alist[k]=righthalf[j]
                j=j+1
            k=k+1

        while i < len(lefthalf):
            alist[k]=lefthalf[i]
            i=i+1
            k=k+1

        while j < len(righthalf):
            alist[k]=righthalf[j]
            j=j+1
            k=k+1
    print("Merging ",alist)

alist = [54,26,93,17,77,31,44,55,20]
mergeSort(alist)
print(alist)
2
Hoslack

voici une autre solution

class MergeSort(object):
    def _merge(self,left, right):
        nl = len(left)
        nr = len(right)
        result = [0]*(nl+nr)
        i=0
        j=0
        for k in range(len(result)):
            if nl>i and nr>j:
                if left[i] <= right[j]:
                    result[k]=left[i]
                    i+=1
                else:
                    result[k]=right[j]
                    j+=1
            Elif nl==i:
                result[k] = right[j]
                j+=1
            else: #nr>j:
                result[k] = left[i]
                i+=1
        return result

    def sort(self,arr):
        n = len(arr)
        if n<=1:
            return arr 
        left = self.sort(arr[:n/2])
        right = self.sort(arr[n/2:] )
        return self._merge(left, right)
def main():
    import random
    a= range(100000)
    random.shuffle(a)
    mr_clss = MergeSort()
    result = mr_clss.sort(a)
    #print result

if __== '__main__':
    main()

et voici le temps d'exécution pour la liste avec 100000 éléments:

real    0m1.073s
user    0m1.053s
sys         0m0.017s
2
Moj
def merge(l1, l2, out=[]):
    if l1==[]: return out+l2
    if l2==[]: return out+l1
    if l1[0]<l2[0]: return merge(l1[1:], l2, out+l1[0:1])
    return merge(l1, l2[1:], out+l2[0:1])
def merge_sort(l): return (lambda h: l if h<1 else merge(merge_sort(l[:h]), merge_sort(l[h:])))(len(l)/2)
print(merge_sort([1,4,6,3,2,5,78,4,2,1,4,6,8]))
2
zack

Beaucoup ont répondu correctement à cette question, ceci est juste une autre solution (bien que ma solution soit très similaire à Max Montana) mais j'ai quelques différences pour la mise en œuvre:

passons en revue l’idée générale ici avant d’arriver au code:

  • Divisez la liste en deux moitiés à peu près égales. 
  • Triez la moitié gauche.
  • Triez la moitié droite. 
  • Fusionner les deux moitiés triées dans une liste triée.

voici le code (testé avec python 3.7):

def merge(left,right):
    result=[] 
    i,j=0,0
    while i<len(left) and j<len(right):
        if left[i] < right[j]:
            result.append(left[i])
            i+=1
        else:
            result.append(right[j])
            j+=1
    result.extend(left[i:]) # since we want to add each element and not the object list
    result.extend(right[j:])
    return result

def merge_sort(data):
    if len(data)==1:
        return data
    middle=len(data)//2
    left_data=merge_sort(data[:middle])
    right_data=merge_sort(data[middle:])
    return merge(left_data,right_data)


data=[100,5,200,3,100,4,8,9] 
print(merge_sort(data))
2
grepit
def merge(x):
    if len(x) == 1:
        return x
    else:
        mid = int(len(x) / 2)
        l = merge(x[:mid])
        r = merge(x[mid:])
    i = j = 0
    result = []
    while i < len(l) and j < len(r):
        if l[i] < r[j]:
            result.append(l[i])
            i += 1
        else:
            result.append(r[j])
            j += 1
    result += l[i:]
    result += r[j:]
    return result
1
Max Montana

Essayez cette version récursive

def mergeList(l1,l2):
    l3=[]
    Tlen=len(l1)+len(l2)
    inf= float("inf")
    for i in range(Tlen):
        print   "l1= ",l1[0]," l2= ",l2[0]
        if l1[0]<=l2[0]:
            l3.append(l1[0])
            del l1[0]
            l1.append(inf)
        else:
            l3.append(l2[0])
            del l2[0]
            l2.append(inf)
    return l3

def main():
    l1=[2,10,7,6,8]
    print mergeSort(breaklist(l1))

def breaklist(rawlist):
    newlist=[]
    for atom in rawlist:
        print atom
        list_atom=[atom]
        newlist.append(list_atom)
    return newlist

def mergeSort(inputList):
    listlen=len(inputList)
    if listlen ==1:
        return inputList
    else:
        newlist=[]
        if listlen % 2==0:
            for i in range(listlen/2):
                newlist.append(mergeList(inputList[2*i],inputList[2*i+1]))
        else:
            for i in range((listlen+1)/2):
                if 2*i+1<listlen:
                    newlist.append(mergeList(inputList[2*i],inputList[2*i+1]))
                else:
                    newlist.append(inputList[2*i])
        return  mergeSort(newlist)

if __== '__main__':
    main()
1
mojians

Si vous changez votre code comme ça, ça va marcher.

def merge_sort(arr):
    if len(arr) < 2:
        return arr[:]
    middle_of_arr = len(arr) / 2
    left = arr[0:middle_of_arr]
    right = arr[middle_of_arr:]
    left_side = merge_sort(left)
    right_side = merge_sort(right)
    return merge(left_side, right_side)

def merge(left_side, right_side):
    result = []
    while len(left_side) > 0 or len(right_side) > 0:
        if len(left_side) > 0 and len(right_side) > 0:
            if left_side[0] <= right_side[0]:
                result.append(left_side.pop(0))
            else:
                result.append(right_side.pop(0))
        Elif len(left_side) > 0:
            result.append(left_side.pop(0))
        Elif len(right_side) > 0:
            result.append(right_side.pop(0))
    return result

arr = [6, 5, 4, 3, 2, 1]
# print merge_sort(arr)
# [1, 2, 3, 4, 5, 6]
1
    def merge(a,low,mid,high):
        l=a[low:mid+1]
        r=a[mid+1:high+1]
        #print(l,r)
        k=0;i=0;j=0;
        c=[0 for i in range(low,high+1)]
        while(i<len(l) and j<len(r)):
            if(l[i]<=r[j]):

                c[k]=(l[i])
                k+=1

                i+=1
            else:
                c[k]=(r[j])
                j+=1
                k+=1
        while(i<len(l)):
            c[k]=(l[i])
            k+=1
            i+=1

        while(j<len(r)):
            c[k]=(r[j])
            k+=1
            j+=1
        #print(c)  
        a[low:high+1]=c  

    def mergesort(a,low,high):
        if(high>low):
            mid=(low+high)//2


            mergesort(a,low,mid)
            mergesort(a,mid+1,high)
            merge(a,low,mid,high)

    a=[12,8,3,2,9,0]
    mergesort(a,0,len(a)-1)
    print(a)
1
kamran shaik

Le code suivant apparaît à la fin (assez efficace) et est trié sur place en dépit du retour.

def mergesort(lis):
    if len(lis) > 1:
        left, right = map(lambda l: list(reversed(mergesort(l))), (lis[::2], lis[1::2]))
        lis.clear()
        while left and right:
            lis.append(left.pop() if left[-1] < right[-1] else right.pop())
        lis.extend(left[::-1])
        lis.extend(right[::-1])
    return lis
0
pba