Tous les conseils sur la façon de résoudre efficacement la fonction suivante dans Haskell, pour les grands nombres (n > 108)
f(n) = max(n, f(n/2) + f(n/3) + f(n/4))
J'ai vu des exemples de mémorisation dans Haskell pour résoudre les nombres de fibonacci, qui impliquaient de calculer (paresseusement) tous les nombres de fibonacci jusqu'au n requis. Mais dans ce cas, pour un n donné, il suffit de calculer très peu de résultats intermédiaires.
Merci
Nous pouvons le faire très efficacement en créant une structure que nous pouvons indexer en temps sub-linéaire.
Mais d'abord,
{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)
Définissons f
, mais faisons en sorte qu'il utilise la "récursion ouverte" plutôt que de s'appeler directement.
f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
mf (n `div` 3) +
mf (n `div` 4)
Vous pouvez obtenir un f
non remémoré en utilisant fix f
Cela vous permettra de tester que f
fait ce que vous voulez dire pour les petites valeurs de f
en appelant, par exemple: fix f 123 = 144
Nous pourrions mémoriser cela en définissant:
f_list :: [Int]
f_list = map (f faster_f) [0..]
faster_f :: Int -> Int
faster_f n = f_list !! n
Cela fonctionne passablement bien et remplace ce qui allait prendre O (n ^ 3) temps avec quelque chose qui mémorise les résultats intermédiaires.
Mais il faut toujours du temps linéaire pour simplement indexer pour trouver la réponse mémorisée pour mf
. Cela signifie que des résultats comme:
*Main Data.List> faster_f 123801
248604
sont tolérables, mais le résultat n'évolue pas beaucoup mieux que cela. On peut faire mieux!
Définissons d'abord un arbre infini:
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
Et puis nous allons définir un moyen de l'indexer, afin que nous puissions trouver un nœud avec l'index n
dans O (log n) temps à la place:
index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
(q,0) -> index l q
(q,1) -> index r q
... et nous pouvons trouver un arbre plein de nombres naturels pour être pratique afin que nous n'ayons pas à jouer avec ces indices:
nats :: Tree Int
nats = go 0 1
where
go !n !s = Tree (go l s') n (go r s')
where
l = n + s
r = l + s
s' = s * 2
Puisque nous pouvons indexer, vous pouvez simplement convertir un arbre en liste:
toList :: Tree a -> [a]
toList as = map (index as) [0..]
Vous pouvez vérifier le travail jusqu'à présent en vérifiant que toList nats
vous donne [0..]
Maintenant,
f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats
fastest_f :: Int -> Int
fastest_f = index f_tree
fonctionne comme avec la liste ci-dessus, mais au lieu de prendre du temps linéaire pour trouver chaque nœud, vous pouvez le poursuivre en temps logarithmique.
Le résultat est considérablement plus rapide:
*Main> fastest_f 12380192300
67652175206
*Main> fastest_f 12793129379123
120695231674999
En fait, c'est tellement plus rapide que vous pouvez parcourir et remplacer Int
par Integer
ci-dessus et obtenir des réponses ridiculement grandes presque instantanément
*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489
*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
la réponse d'Edward est un joyau si merveilleux que je l'ai dupliqué et fourni des implémentations de combinateurs memoList
et memoTree
qui mémorisent une fonction sous une forme récursive ouverte.
{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)
f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
mf (div n 3) +
mf (div n 4)
-- Memoizing using a list
-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
where memoList_f = (memo !!) . fromInteger
memo = map (f memoList_f) [0..]
faster_f :: Integer -> Integer
faster_f = memoList f
-- Memoizing using a tree
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
(q,0) -> index l q
(q,1) -> index r q
nats :: Tree Integer
nats = go 0 1
where
go !n !s = Tree (go l s') n (go r s')
where
l = n + s
r = l + s
s' = s * 2
toList :: Tree a -> [a]
toList as = map (index as) [0..]
-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
where memoTree_f = index memo
memo = fmap (f memoTree_f) nats
fastest_f :: Integer -> Integer
fastest_f = memoTree f
Pas le moyen le plus efficace, mais mémorise:
f = 0 : [ g n | n <- [1..] ]
where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)
lors de la demande f !! 144
, il est vérifié que f !! 143
existe, mais sa valeur exacte n'est pas calculée. Il est toujours défini comme un résultat inconnu d'un calcul. Les seules valeurs exactes calculées sont celles nécessaires.
Donc, initialement, en ce qui concerne le montant calculé, le programme ne sait rien.
f = ....
Lorsque nous faisons la demande f !! 12
, il commence à faire une correspondance de modèle:
f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Maintenant, il commence à calculer
f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3
Cela fait récursivement une autre demande sur f, donc nous calculons
f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0
Maintenant, nous pouvons remonter un peu
f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1
Ce qui signifie que le programme sait maintenant:
f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Continuer à ruisseler:
f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3
Ce qui signifie que le programme sait maintenant:
f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Maintenant, nous continuons notre calcul de f!!6
:
f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6
Ce qui signifie que le programme sait maintenant:
f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...
Maintenant, nous continuons notre calcul de f!!12
:
f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13
Ce qui signifie que le programme sait maintenant:
f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...
Le calcul se fait donc assez paresseusement. Le programme sait qu'une certaine valeur pour f !! 8
existe, qu'il est égal à g 8
, mais il n'a aucune idée de ce que g 8
est.
Comme indiqué dans la réponse d'Edward Kmett, pour accélérer les choses, vous devez mettre en cache des calculs coûteux et pouvoir y accéder rapidement.
Pour garder la fonction non monadique, la solution de construction d'un arbre paresseux infini, avec une manière appropriée de l'indexer (comme indiqué dans les articles précédents) remplit cet objectif. Si vous abandonnez la nature non monadique de la fonction, vous pouvez utiliser les conteneurs associatifs standard disponibles dans Haskell en combinaison avec des monades "d'état" (comme State ou ST).
Alors que le principal inconvénient est que vous obtenez une fonction non monadique, vous n'avez plus besoin d'indexer la structure vous-même, et vous pouvez simplement utiliser des implémentations standard de conteneurs associatifs.
Pour ce faire, vous devez d'abord réécrire votre fonction pour accepter tout type de monade:
fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _ 0 = return 0
fm recf n = do
recs <- mapM recf $ div n <$> [2, 3, 4]
return $ max n (sum recs)
Pour vos tests, vous pouvez toujours définir une fonction qui ne fait aucune mémorisation en utilisant Data.Function.fix, bien qu'elle soit un peu plus verbeuse:
noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm
Vous pouvez ensuite utiliser State monad en combinaison avec Data.Map pour accélérer les choses:
import qualified Data.Map.Strict as MS
withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
where
recF i = do
v <- MS.lookup i <$> get
case v of
Just v' -> return v'
Nothing -> do
v' <- fm recF i
modify $ MS.insert i v'
return v'
Avec des modifications mineures, vous pouvez adapter le code pour qu'il fonctionne avec Data.HashMap à la place:
import qualified Data.HashMap.Strict as HMS
withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
where
recF i = do
v <- HMS.lookup i <$> get
case v of
Just v' -> return v'
Nothing -> do
v' <- fm recF i
modify $ HMS.insert i v'
return v'
Au lieu de structures de données persistantes, vous pouvez également essayer des structures de données mutables (comme le Data.HashTable) en combinaison avec la monade ST:
import qualified Data.HashTable.ST.Linear as MHM
withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
do ht <- MHM.new
recF ht n
where
recF ht i = do
k <- MHM.lookup ht i
case k of
Just k' -> return k'
Nothing -> do
k' <- fm (recF ht) i
MHM.insert ht i k'
return k'
Comparé à l'implémentation sans aucune mémorisation, n'importe laquelle de ces implémentations vous permet, pour des entrées énormes, d'obtenir des résultats en micro-secondes au lieu d'avoir à attendre plusieurs secondes.
En utilisant Criterion comme référence, j'ai pu observer que l'implémentation avec Data.HashMap fonctionnait en fait légèrement mieux (environ 20%) que celle de Data.Map et Data.HashTable pour lesquelles les timings étaient très similaires.
J'ai trouvé les résultats du benchmark un peu surprenants. Mon sentiment initial était que le HashTable surpasserait l'implémentation de HashMap car il est mutable. Il peut y avoir un défaut de performance caché dans cette dernière implémentation.
Ceci est un addendum à l'excellente réponse d'Edward Kmett.
Lorsque j'ai essayé son code, les définitions de nats
et index
semblaient assez mystérieuses, alors j'écris une version alternative que j'ai trouvé plus facile à comprendre.
Je définis index
et nats
en termes de index'
et nats'
.
index' t n
est défini sur la plage [1..]
. (Rappeler que index t
est défini sur la plage [0..]
.) Il fonctionne dans l'arborescence en traitant n
comme une chaîne de bits et en parcourant les bits à l'envers. Si le bit est 1
, il prend la branche de droite. Si le bit est 0
, il prend la branche de gauche. Il s'arrête lorsqu'il atteint le dernier bit (qui doit être un 1
).
index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
(n', 0) -> index' l n'
(n', 1) -> index' r n'
Tout comme nats
est défini pour index
de sorte que index nats n == n
est toujours vrai, nats'
est défini pour index'
.
nats' = Tree l 1 r
where
l = fmap (\n -> n*2) nats'
r = fmap (\n -> n*2 + 1) nats'
nats' = Tree l 1 r
Maintenant, nats
et index
sont simplement nats'
et index'
mais avec des valeurs décalées de 1:
index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Quelques années plus tard, j'ai regardé cela et j'ai réalisé qu'il y avait un moyen simple de le mémoriser en temps linéaire en utilisant zipWith
et une fonction d'aide:
dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs
dilate
a la propriété pratique que dilate n xs !! i == xs !! div i n
.
Donc, en supposant qu'on nous donne f (0), cela simplifie le calcul pour
fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
where (.+.) = zipWith (+)
infixl 6 .+.
(#/) = flip dilate
infixl 7 #/
Ressemblant beaucoup à notre description originale du problème et donnant une solution linéaire (sum $ take n fs
prendra O (n)).
Encore un addendum à la réponse d'Edward Kmett: un exemple autonome:
data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)
memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
where nats = go 0 1
go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
where s' = 2*s
index (NatTrie l v r) i
| i < 0 = f (index_to_arg i)
| i == 0 = v
| otherwise = case (i-1) `divMod` 2 of
(i',0) -> index l i'
(i',1) -> index r i'
memoNat = memo1 id id
Utilisez-le comme suit pour mémoriser une fonction avec un seul argument entier (par exemple fibonacci):
fib = memoNat f
where f 0 = 0
f 1 = 1
f n = fib (n-1) + fib (n-2)
Seules les valeurs des arguments non négatifs seront mises en cache.
Pour mettre également en cache les valeurs des arguments négatifs, utilisez memoInt
, défini comme suit:
memoInt = memo1 arg_to_index index_to_arg
where arg_to_index n
| n < 0 = -2*n
| otherwise = 2*n + 1
index_to_arg i = case i `divMod` 2 of
(n,0) -> -n
(n,1) -> n
Pour mettre en cache les valeurs des fonctions avec deux arguments entiers, utilisez memoIntInt
, défini comme suit:
memoIntInt f = memoInt (\n -> memoInt (f n))
Une solution sans indexation, et non basée sur celle d'Edward KMETT.
Je factorise les sous-arbres communs vers un parent commun (f(n/4)
est partagé entre f(n/2)
et f(n/4)
, et f(n/6)
est partagé entre f(2)
et f(3)
). En les enregistrant comme une seule variable dans le parent, le calcul du sous-arbre est effectué une fois.
data Tree a =
Node {datum :: a, child2 :: Tree a, child3 :: Tree a}
f :: Int -> Int
f n = datum root
where root = f' n Nothing Nothing
-- Pass in the arg
-- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
where
d = if n < 12 then n
else max n (d2 + d3 + d4)
[n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
[d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
c2 = case m2 of -- Check for a passed-in subtree before recursing.
Just c2' -> c2'
Nothing -> f' n2 Nothing (Just c6)
c3 = case m3 of
Just c3' -> c3'
Nothing -> f' n3 (Just c6) Nothing
c4 = child2 c2
c6 = f' n6 Nothing Nothing
main =
print (f 123801)
-- Should print 248604.
Le code ne s'étend pas facilement à une fonction de mémorisation générale (au moins, je ne saurais pas comment le faire), et vous devez vraiment réfléchir à la façon dont les sous-problèmes se chevauchent, mais la stratégie devrait fonctionner pour les paramètres généraux multiples non entiers. (Je l'ai pensé pour deux paramètres de chaîne.)
Le mémo est supprimé après chaque calcul. (Encore une fois, je pensais à deux paramètres de chaîne.)
Je ne sais pas si c'est plus efficace que les autres réponses. Chaque recherche n'est techniquement qu'une ou deux étapes ("Regardez votre enfant ou l'enfant de votre enfant"), mais il peut y avoir beaucoup de mémoire supplémentaire.
Edit: Cette solution n'est pas encore correcte. Le partage est incomplet.
Edit: Il devrait partager correctement les sous-enfants maintenant, mais j'ai réalisé que ce problème a beaucoup de partage non trivial: n/2/2/2
et n/3/3
pourrait être le même. Le problème ne convient pas à ma stratégie.