HaskellのStateモナドを定義から理解する

本記事の目的

Stateモナドの挙動を定義から導出したい. それが本記事の目標です.

Stateモナドについて調べると, いくつかの記事がヒットします. そこではStateモナドの使い方として以下のようなサンプルコードが引き合いに出されることが多いでしょう.

import Control.Monad.State

pickHead :: State [Int] Int
pickHead = do
    x:xs <- get    -- 状態(整数のリスト)を取得
    put xs         -- 先頭以外の部分を新たな状態としてセット
    return x       -- 先頭を返す

sumHeads :: State [Int] Int
sumHeads = do
    x <- pickHead  -- 先頭を取得(このとき状態が更新される)
    y <- pickHead  -- 先頭(初期状態では2番目の要素)を取得
    return $ x + y -- 足して返す

main = do
    print $ runState sumHeads [1, 1, 1, 1, 1]
    print $ runState sumHeads [1, 2, 3, 4, 5]

実行結果は以下の通りです.

(2,[1,1,1])
(3,[3,4,5])

コメントが書かれているため, どのように処理が進んでいくのかを追うことはできます. しかし釈然としません. なぜStateモナドはまるで状態を保持しているように振る舞うのかがわからないのです.

本記事ではそれを解決したいと思います.

解説

Stateモナドの定義

コードを定義から繙いていくわけですから, 定義の把握が必要ですね.

newtype State s a = State { runState :: (s -> (a,s)) } 
 
instance Monad (State s) where 
    return a        = State $ \s -> (a,s)
    (State x) >>= f = State $ \s -> let (v,s') = x s in runState (f v) s' 
class MonadState m s | m -> s where 
    get :: m s
    put :: s -> m ()

instance MonadState (State s) s where 
    get   = State $ \s -> (s,s) 
    put s = State $ \_ -> ((),s) 

ひとつの補題

サンプルコードを理解する上で助けになる補題をひとつ挙げておきます. 補題といっても定義をそのまんま適用したものにすぎません.

まずは補題に必要な登場人物の紹介です.

m :: State s a
f :: a -> State s b
x :: s

上記の設定のもとにおいて, 以下の2式は等価です.

runState (m >>= f) x
runState (f a) x'

ただし以下の言い換えをしています.

(a, x') = runState m x

記号aが重複していますが, そのほうが対応がわかりやすいだろうとの判断です.

以上が補題の主張です. >>=の定義を追っていけば確かめられるでしょう. この補題により, >>=をひとつずつほぐしていけるようになりました.

コードを読み解いていく

最初に挙げたサンプルコードを解読していきます.

まずは部分的に

まずは以下の式が何をどう返すのかをみていきましょう.

runState pickHead [1,2,3]

pickHeadの定義をそのまま適用します.

runState (
    x:xs <- get
    put xs
    return x
    ) [1,2,3]

ただし上の式は文法的に正しくありません. <-等の記号はdoブロックの中でしか使えないからですね. 見辛くはなりますが, 正しく書き換えましょう.

runState (
    get >>= \(x:xs) -> 
    put xs >> 
    return x
    ) [1,2,3]

こんな風にラムダ式の中でもパターンマッチが使えるんですね. さてこれで, 上で述べた補題が使える形になりました. では一つ目の>>=を解消しましょう.

getの定義から

runState get [1,2,3] == ([1,2,3],[1,2,3])

ですから, もとの式は次の式に等価です.

runState (f [1,2,3]) [1,2,3]

ただしf = \(x:xs) -> put xs >> return xという置き換えをしています.

さて,

f [1,2,3] == put [2,3] >> return 1

ですから, 結局もとの式は次の式に等価です.

runState (put [2,3] >> return 1) [1,2,3]

>>>>=の特殊なバージョンですから, この式にも補題を適用することができますね.

putの定義から,

runState (put [2,3]) [1,2,3] == ((), [2,3])

ですから, 結局もとの式は次の式に等価です.

runState (return 1) [2,3]

これを計算することにより, 以下が得られました.

runState pickHead [1,2,3] == (1, [2,3])

そして全体へ

今得られた等式を利用して, コード全体を読み解いていきましょう.

評価したいのは以下の式です.

runState sumHeads [1,2,3]

これが(3, [3])に等しいことを定義から導出していきます.

上の式に, sumHeadsの定義を文法が正しくなるように適用します.

runState (
    pickHead >>= \x ->
    pickHead >>= \y ->
    return $ x + y
    ) [1,2,3]

先と同様に補題を適用して>>=を解消していきます.

上で得られた結果から

runState pickHead [1,2,3] == (1, [2,3])

ですから, もとの式は次の式に等価です.

runState (
    pickHead >>= \y ->
    return $ 1 + y
    ) [2,3]

同様にして以下のように変形できます.

runState (
    return $ 1 + 2
    ) [3]

これを計算することにより, 以下の式が得られました.

runState sumHeads [1,2,3] == (3,[3])

まとめ

なんとか処理を追いかけることができました. 変化させた状態を次々にリレーしていく様子が明らかになったのではないでしょうか. このようにしてStateモナド状態を保持しているように振る舞うことができるわけですね.