GHCで動くcall/ccの実装

IOモナド上で動作するcall/ccを実装できることが分かったので書いておく。ただし実用に耐えるものではない。

これを使うとたとえばこういうコードを書くことができる。

import Control.Applicative
import Control.Monad
import Data.IORef
import System.IO.Unsafe

import Continuation (callCC, withContinuationsDo)
test :: IO ()
test = do
  r <- callCC $ \cc -> return $ Right $ cc . Left
  case r of
    Left val -> do
      putStrLn $ "left: " ++ show val
      return ()
    Right cc -> do
      putStrLn "right"
      cc "foo"
      putStrLn "this should not be printed"
-- 実行結果:
--  right
--  left: "foo"

非決定計算もできる

ambNext :: IORef (IO a)
ambNext = unsafePerformIO $ newIORef $ fail "Dead end"
{-# NOINLINE ambNext #-}

amb :: [a] -> IO a
amb vs = foldr amb2' (join $ readIORef ambNext) vs
  where
    amb2' a b = join $ amb2 (return a) b

amb2 :: a -> a -> IO a
amb2 a b = do
  next <- readIORef ambNext
  callCC $ \cc -> do
    writeIORef ambNext $ do
      writeIORef ambNext next
      cc b
    return a

8クイーン問題を解く例。

eightQueens :: IO ()
eightQueens = loop 8 allPoints >>= print
  where
    allPoints = (,) <$> [0..7] <*> [0..7]

    loop :: Int -> [(Int, Int)] -> IO [(Int, Int)]
    loop 0 _ = return []
    loop k avail = do
      pos <- amb avail
      (pos:) <$> loop (k-1) (narrow pos avail)

    narrow pos avail = filter (not . reachable pos) avail
    reachable (x0, y0) (x1, y1)
      = x0 == x1
      || y0 == y1
      || x0 + y0 == x1 + y1
      || x0 - y0 == x1 - y1

mainは次のようになる。

main :: IO ()
main = withContinuationsDo $ do
  test
  eightQueens

種明かし

実装は以下。

{-# OPTIONS_GHC -O #-}
  -- 最適化をオフにするとちゃんと動かない

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE UnliftedFFITypes #-}

module Continuation
  ( withContinuationsDo
  , getCC
  , callCC
  ) where

import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.IORef
import Data.Typeable
import Foreign.Marshal.Utils
import Foreign.Storable
import System.IO.Unsafe

import GHC.Exts
import GHC.IO(IO(..), unIO)

-- | (Right 継続)を返す。返ってきた継続に値vを渡して呼ぶと、
-- getCCがもう一度返り、(Left v)を返す。
getCC :: IO (Either a (a -> IO b))
getCC = do
  ref <- newIORef $ error "getCC: bug"
  t <- myThreadId
  -- 自分自身に非同期例外を投げることで、基底ループに制御を移す。
  throwTo t $ GetCC $ writeIORef ref
  -- 再開。基底ループによりrefに適切な値が書き込まれている。
  readIORef ref

-- | Scheme風call/cc
callCC :: ((a -> IO b) -> IO a) -> IO a
callCC f = either return f =<< getCC

-- | メインルーチンから基底ループへのメッセージ。現在の継続を要求する。
data GetCCRequest = forall a b. GetCC (Either a (a -> IO b) -> IO ())
  deriving (Typeable)

instance Show GetCCRequest where
  show _ = "<getcc>"

instance Exception GetCCRequest

-- | 基底ループが内部的に使う例外。
data ResumeRequest = Resume () -- ^ 評価継続すべきサンク
  deriving (Typeable, Show)

instance Exception ResumeRequest

-- | @main = withContinuationsDo $ ...@ のように使う
withContinuationsDo :: IO () -> IO ()
withContinuationsDo x = do
  -- サンクの評価が非同期例外によって中断された場合、そのサンクの値を
  -- 再度要求することで評価を継続することができるのを利用する。
  let thunk = unsafePerformIO x
  loop thunk
  where
    -- 基底ループ本体。
    loop thunk = do
      debug "evaluating"
      -- サンクが例外を投げるまで評価する。
      r <- try $ try $ evaluate thunk
      -- GHC 7.6.3では、thunkの評価が例外によって書き換えられた場合、
      -- thunkはAP_STACKオブジェクトを指すINDオブジェクトになっている。
      case r of
        -- 現在の継続を要求してきた。
        Left (GetCC ccSink) -> do
          debug "caught gcc"
          -- まずRightを返す。
          ccSink $ Right $ \val -> do
            debug "entering cc"
            -- 継続が呼ばれた。こんどはLeftを返す。
            ccSink $ Left val
            debug "value added"
            -- ここはメインルーチンの内部なので、
            -- 基底ループに制御を移し、thunkの評価を再開する。
            throwIO $ Resume thunk
            -- ここにはこない。
          -- 戻り値を設定した上でthunkの評価を再開する。ただし、thunkを
          -- 上書きするのを防ぐため、コピーをとってそちらを評価する。
          -- こうすることで、一つの継続を複数回呼び出すことができる。
          loop =<< dupAP_STACK thunk
        -- thunk'の評価の再開を要求してきた。
        Right (Left (Resume thunk')) ->
          -- 上と同じ。thunkには二度と戻らないので捨ててよい。
          loop =<< dupAP_STACK thunk'
        -- 実行終了。
        Right (Right ()) -> return ()
{-# NOINLINE withContinuationsDo #-}

-- | @dupAP_STACK x@はxが指すサンクを複製してそれを返す。
-- xはAP_STACKか、それを指すINDオブジェクトでなければならない。
dupAP_STACK :: a -> IO a
dupAP_STACK thing = IO $ \s00 -> let
  -- 1ワードのバイト数
  !(I# wsz) = sizeOf (undefined::Int)
  -- コピー先のメモリを確保する。これには、まず
  -- MutableByteArray#を作り、それをByteArray#に変換してから
  -- アドレスを得る。手抜きのためサイズは固定
  !bufsize = 100# *# wsz
  !(# s10, mbarr #) = newByteArray# bufsize s00
  !(# s20, barr #) = unsafeFreezeByteArray# mbarr s10
  !dest = byteArrayContents# barr
  -- ここ以降、このletブロックの終了までにGCが走ると
  -- ポインタが失効してひどいことになるので、
  -- メモリ確保を一切しないようにする。
  !s21 = ndebug "critical block"# nullAddr# s20
  -- 与えられたオブジェクトのアドレス。
  !(# s30, obj #) = thunkToAddr thing s21
  !s31 = ndebug "obj"# obj s30
  -- AP_STACKサンクのアドレスを得る。
  !(# s40, thunk #) = derefIndirections obj s31
  -- AP_STACKオブジェクトのsizeフィールドを読んで、
  -- payloadのワード数を得る。
  !(# s50, size #) = readIntOffAddr# thunk 2# s40
  -- サンクのバイト数。これは以下からなる
  --   header (2ワード)
  --   size (1ワード)
  --   fun (1ワード)
  --   payload (sizeワード)
  !bytes = wsz *# (size +# 4#)
  !s51 = ndebug "bytes"# (int2Addr# bytes) s50
  -- サンクをコピーする。ここで被せている構築子は最適化で
  -- 取り除かれることを期待している
  !(# s60, _ #) =
    if bufsize <# bytes
      then error "buffer too short"
      else unIO
    (copyBytes (Ptr dest :: Ptr Int) (Ptr thunk) (I# bytes)) s51
  -- コピーをサンクとして返す
  !(# s70, r #) = addrToThunk dest s60
  -- ここ以降、GCが発生しても問題ない
  !s71 = ndebug "end critical block"# nullAddr# s70
  in (# s71, r #)

-- | 間接参照を再帰的に辿り、AP_STACKオブジェクトを返す。
-- この関数はメモリを確保しない。
derefIndirections
  :: Addr# -> State# RealWorld -> (# State# RealWorld, Addr# #)
derefIndirections obj s00 = let
  -- objのinfoポインタを読む
  !(# s10, info #) = readAddrOffAddr# obj 0# s00
  !s11 = ndebug "info"# info s10
  -- infoテーブルのtypeフィールドを読む。infoポインタはinfoテーブルの
  -- 終端を指すので、ここの添字は負になる。
  !(# s20, typ #) = readIntOffAddr# info -1# s11
  !s21 = ndebug "type"# (int2Addr# typ) s20
  -- includes/rts/storage/ClosureTypes.h:
  -- #define AP_STACK 27
  in if typ ==# 27#
    then (# s21, obj #)
    else let
      -- objはAP_STACKオブジェクトではなかったので、おそらく間接参照。
      -- indirecteeフィールドを読む。
      !(# s30, next #) = readAddrOffAddr# obj 1# s21
      !s31 = ndebug "indirectee"# next s30
      in derefIndirections next s31

-- | デバッグ文字列を表示する
debug :: String -> IO ()
debug msg = when debugEnabled $ putStrLn msg

-- | メモリを確保せずにデバッグ文字列を表示する
ndebug :: Addr# -> Addr# -> State# RealWorld -> State# RealWorld
ndebug loc addr s = if debugEnabled
  then let (# s1, () #) = unIO (c_printf_pp "%s: %p\n"# loc addr) s
    in s1
  else s

foreign import ccall unsafe "printf"
  c_printf_pp :: Addr# -> Addr# -> Addr# -> IO ()

debugEnabled :: Bool
debugEnabled = False

-- | サンクやINDのアドレスを得る。GCの前後でアドレスが変わり得る
-- という点でこれは外界依存の操作なので、その旨を型で表した上で
-- NOINLINEにしておく。
-- INDやサンクのタグビットは0なのでビット演算の必要なし。
thunkToAddr :: a -> State# RealWorld -> (# State# RealWorld, Addr# #)
thunkToAddr v s = (# s, unsafeCoerce# v #)
{-# NOINLINE thunkToAddr #-}

-- | サンクオブジェクトのアドレスをHaskellの値に変換する。
-- 'thunkToAddr' と同じ理由でNOINLINE。
addrToThunk :: Addr# -> State# RealWorld -> (# State# RealWorld, a #)
addrToThunk addr s = (# s, unsafeCoerce# addr #)
{-# NOINLINE addrToThunk #-}

見てのとおりGHCの実装詳細にかなり依存しているので、Linux amd64 ghc-7.6.3の組み合わせ以外では動かないかもしれない。加えて、例外やスレッドと一緒に使うと変なことになると思われるので実用性はない。

RULESによるコンパイル時プログラミング

これはHaskell Advent Calendar 2013の(3+π)日目の記事です。

(3 + pi)や(quot 7 8)のような単純な定数式は、ghc -Oが行なう定数畳み込みによってコンパイル時に計算される。uncurry (*) (3, max 5 2)のようなやや複雑な式も、インライン展開してから定数畳み込みをすることでやはりコンパイル時に整数リテラルにまで簡約される。

これは一見万能だが、再帰的な関数が一つでもあると何もできなくなる。GHC再帰関数をインライン化しないからだ。(sum [1])ですら実行時のループにコンパイルされる*1

どうしてもコンパイル時に計算してほしい関数がある場合はどうしたら良いか。一つの方法はTemplate Haskell(ja)を使うことだが、特別な構文を使わなければいけないこと、-fwarn-unused-binds(ja)をはじめとして警告がいくつか効かなくなることなど副作用がある。この記事は、もう一つの方法、書き換え規則(ja)を使ってこの種のコンパイル時計算を実装する術を紹介する。

例として、二つの整数の最大公約数を計算する関数を考える。Haskellで普通に書くと以下のようになる。

-- | 二つの整数の最大公約数を求める
-- Preludeのgcdと区別するためにgcd'という名前にしておく
gcd' :: Int -> Int -> Int
gcd' x 0 = x
gcd' x y = pgcd (abs x) (abs y)
{-# INLINE gcd' #-}

-- | 非負の整数と正の整数の最大公約数を求める
pgcd :: Int -> Int -> Int
pgcd x y
  | q == 0 = y
  | otherwise = pgcd y q
  where
    q = mod x y

関数gcd'は非再帰的なので、INLINEプラグマ(ja)を付けておけば間違いなくインライン展開される。問題は再帰的に定義されたpgcdだ。これをRULESプラグマで何とかすることを考えよう。すぐに思い付くのは、以下のような規則を使ってコンパイル時に強引に再帰を起こすことだ*2

{-# RULES
"pgcd" forall x y. pgcd x y =
    let q = rem x y
    in if q == 0
        then y
        else pgcd y q
  #-}

これは要するに、pgcdの定義をそのままRULEとして書き直したものだ。残念なことにこれはうまく行かない。試してみると、GHCがパニックを起こす。

ghc: panic! (the 'impossible' happened)
  (GHC version 7.6.3 for x86_64-unknown-linux):
        Simplifier ticks exhausted
    When trying RuleFired Class op ==
    To increase the limit, use -fsimpl-tick-factor=N (default 100)
    If you need to do this, let GHC HQ know, and what factor you needed
    To see detailed counts use -ddump-simpl-stats
    Total ticks: 7320

このエラーはGHCの単純化器(simplifier)が無限ループを起こしたときに発生する*3。問題は、pgcd規則が適合して、左辺から右辺への書き換えが起こった後、すぐに右辺のpgcdの呼び出しがまた規則に適合することだ。これはqが0かどうかにかかわりなく起こるので、再帰が止まらない。

この問題に見覚えがある人も居るはずだ。正格な言語でifを関数として実装した場合に、両方の分岐が評価されてしまう問題に似ているのだ。こちらの問題の良く知られた解決策は、評価されて欲しくない部分をラムダで包むことだが、ここではそれだけでは足りない。GHCの単純化器は、ラムダの中であっても簡約できそうなところは何でも簡約するからだ。それでも基本的な考え方は流用できて、要するにifの分岐部分を、「評価が進まない形」にして渡してやれば良い。次の関数を定義しておく。

data Key = Key

delay :: Key -> a -> a
delay _ x = x
{-# NOINLINE delay #-}

{-# RULES "delay/Key" forall x. delay Key x = x #-}

delayは第二引数をそのまま返す関数だが、その事実はNOINLINEプラグマによって隠蔽されるので単純化器が知ることができない*4。例外は、delayの第一引数が具体的にKeyだと分かっている場合であり、このときは規則"delay/Key"が発動して実質的にdelayがインライン展開される。これを使って、(pgcd y q :: Int)という式を次のように書き換える。

(\k -> delay k pgcd y q) :: Key -> Int

この形は、単純化器によって簡約され得る項を一つも含まないので、規則の右辺に持ってきても安全だ。一般には、一つの式の中に簡約を防ぎたいところが複数箇所あっても良い。たとえば(f 3 + g 4)は(\k -> delay k f 3 + delay k g 4)のように変形できる。

このようにして作った「評価の止まった」式を再び解凍して簡約を再開させる関数を用意する。

force :: (Key -> a) -> a
force x = x Key
{-# INLINE force #-}

これによってx内部のdelayの第一引数がKeyに確定するので、"delay/Key"規則によってx内のdelayが除去される。

式の評価を止めることができるようになったので、これを使ってifを実装する。

if_ :: Bool -> (Key -> a) -> (Key -> a) -> a
if_ c x y = if GHC.Exts.lazy c then force x else force y
{-# NOINLINE if_ #-}

{-# RULES
"if_/True" forall x y. if_ True x y = force x
"if_/False" forall x y. if_ False x y = force y
  #-}

"if_/True"と"if_/False"の二つの規則によって、if_の第一引数がTrueかFalseに確定した場合、それに対応する分岐のみが簡約再開される。それ以外の場合、どちらの選択肢も簡約されない*5

あとは上の規則"pgcd"をif_を使った形に書き直せば良い。

{-# RULES
"pgcd" forall x y. pgcd x y =
    let q = rem x y
    in if_ (q == 0)
        (\_ -> y)
        (\k -> delay k pgcd y q)
  #-}

コード

module M where

import qualified GHC.Exts as GHC

-- | 二つの整数の最大公約数を求める
gcd' :: Int -> Int -> Int
gcd' x 0 = x
gcd' x y = pgcd (abs x) (abs y)
{-# INLINE gcd' #-}

-- | 非負の整数と正の整数の最大公約数を求める
pgcd :: Int -> Int -> Int
pgcd x y
  | q == 0 = y
  | otherwise = pgcd y q
  where
    q = mod x y
{-# NOINLINE pgcd #-}

data Key = Key

delay :: Key -> a -> a
delay _ x = x
{-# NOINLINE delay #-}

{-# RULES "delay/Key" forall x. delay Key x = x #-}

force :: (Key -> a) -> a
force x = x Key
{-# INLINE force #-}

{-# RULES
"pgcd" forall x y. pgcd x y =
    let q = rem x y
    in if_ (q == 0)
        (\_ -> y)
        (\k -> delay k pgcd y q)
  #-}

if_ :: Bool -> (Key -> a) -> (Key -> a) -> a
if_ c x y = if GHC.lazy c then force x else force y
{-# NOINLINE if_ #-}

{-# RULES
"if_/True" forall x y. if_ True x y = force x
"if_/False" forall x y. if_ False x y = force y
  #-}

test :: Int
test = gcd' 120 84

これをコンパイルしてみる。*6

% ghc rulegcd.hs -O -ddump-simpl -fforce-recomp | grep 'M.test' 
M.test :: GHC.Types.Int
M.test = GHC.Types.I# 12

やったね!

*1:将来sumが融合変換の対象になれば、この例はコンパイル時に評価されるようになるだろう

*2:modの代わりにコンパイル時に評価しやすいremを使っている。引数が非負なので結果は同じ

*3:他の状況でも発生するが

*4:実用する場合は、delayの呼び出しが残っても性能に悪影響がないようにNOINLINE[0]とするのが良いかもしれない。この場合後述のRULESにはすべて[~0]を付けること

*5:if_の定義の中でGHC.Exts.lazyを呼んでいるのはやっつけだが、これがないと(if_ (case e of { p -> True; _ -> False }) x y)が(case e of p -> if_ True x y; _ -> if_ False x y)に書き換えられてしまい、結果としてif_の第一引数が確定していないのに"if_/True"と"if_/False"が発動することになる。

*6:このままだと、pgcdの定義の右辺に出てくるpgcdの呼び出しまでもが規則によって書き換えられてしまうので、もうちょっと工夫の余地がある

圏論と数式の練習

変な議論してたら教えてください!

命題0

Aを集合、\mathcal{C}を圏、|-| : \mathcal{C} \to \mathbf{Set}を忘却関手、F : \mathbf{Set} \to \mathcal{C}を自由関手とすると、以下が成り立つ。

Nat(Hom(A, |-|), |-|) \simeq |F(A)|

証明

米田の補題から|F(A)| \simeq Nat(Hom(A, -), |F(-)|)なので、Nat(Hom(A, |-|), |-|) \simeq Nat(Hom(A, -), |F(-)|)を示せば良い。

F|-|の左随伴だから、単位\eta : Id \to |F(-)|と余単位\epsilon : F(|-|) \to Idが存在して以下が成り立つ。

|\epsilon| \circ \eta = id

\epsilon \circ F(\eta) = id

以下で具体的に同型射を構成する。写像fgを次のように定義する。

f : Nat(Hom(A, |-|), |-|) \to Nat(Hom(A, -), |F(-)|)
f(h)_X = h_{F(X)} \circ Hom(A, \eta_X)
g : Nat(Hom(A, -), |F(-)|) \to Nat(Hom(A, |-|), |-|)
g(r)_Y = |\epsilon_Y| \circ r_{|Y|}

fgが互いに逆であることを示せば良い。

(g \circ f)(h)_X
= |\epsilon_X| \circ h_{F(|X|)} \circ Hom(A, \eta_{|X|}) \quad (定義から)
= h_X \circ Hom(A, |\epsilon_X|) \circ Hom(A, \eta_{|X|}) \quad (hの自然性)
= h_X \circ Hom(A, |\epsilon_X| \circ \eta_{|X|}) (Hom(A,-)が関手なので)
= h_X \circ Hom(A, id) (単位と余単位の合成)
= h_X (Hom(A,-)が関手なので)

(f \circ g)(r)_Y
= |\epsilon_{F(Y)}| \circ r_{|F(Y)|} \circ Hom(A, \eta_Y) (定義から)
= |\epsilon_{F(Y)}| \circ |F(\eta_Y)| \circ r_Y (rの自然性)
= |\epsilon_{F(Y)} \circ F(\eta_Y)| \circ r_Y (|-|が関手なので)
= |id| \circ r_Y (単位と余単位の合成)
= r_Y (|-|が関手なので)

これで示された。

以下はインフォーマルな(というよりいい加減な)議論。

主張1

forall m. (Monoid m) => (A -> m) -> m

という型は、Vector Aとだいたい同型である。

forall m. (CommutativeIdempotentMonoid m) => (A -> m) -> m

という型は、Set Aとだいたい同型である。

説得

これらの型は多相型なので、パラメトリシティの制約を受ける。Theorems for free!によると、以下の命題が成り立つ

閉じた項tが型Tを持つならば、(t,t) ∈ R(T)である。

ただし、Rは型を値の集合間の関係に写す写像で、以下のように再帰的に定義される。(この説得に関係ありそうなところだけ)

(f0, f1) ∈ R(A -> B) iff
  ∀ (a0, a1) ∈ R(A). (f0 a0, f1 a1) ∈ R(B)
(t0, t1) ∈ R(∀a. F(a)) iff
  ∀ A : A0 <=> A1. (t0[a=A0], t1[a=A1]) ∈ R(F)A
TがInt,Charなどの基本型なら、
(x0, x1) ∈ R(T, T) iff x0 = x1

ただし、X <=> YはXの値の集合とYの値の集合の関係を表す。

これに従って上の型が課す制約を同値変形していく。

(t, t) ∈ R(forall m. (Monoid m) => (A -> m) -> m)
(t, t) ∈ R(forall m. (m -> m -> m) -> m -> (A -> m) -> m) -- Monoidを辞書渡しに
∀M : M0 <=> M1.
  (t, t) ∈ ((M -> M -> M) -> M -> (A -> M) -> M) -- forallを処理
∀M : M0 <=> M1.
  ∀(<>, <>') ∈ R(M -> M -> M).
    (t (<>), t (<>')) ∈ R(M -> (A -> M) -> M) -- ->を処理
∀M : M0 <=> M1.
  ∀(<>, <>') ∈ R(M -> M -> M).
    ∀(e, e') ∈ R(M).
      (t (<>) e, t (<>') e') ∈ R((A -> M) -> M)
∀M : M0 <=> M1.
  ∀(<>, <>') ∈ R(M -> M -> M).
    ∀(e, e') ∈ M.
      ∀(f, f') ∈ R(A -> M).
        (t (<>) e f, t (<>') e' f') ∈ M

さらに、

(<>, <>') ∈ R(M -> M -> M)
∀(a, a') ∈ M.
  (<> a, <>' a') ∈ R(M -> M)
∀(a, a') ∈ M.
  ∀(b, b') ∈ M.
    (<> a b, <>' a' b') ∈ M
(f, f') ∈ R(A -> M)
∀(x, x') ∈ R(A).
  (f x, f' x') ∈ M
∀x :: A.
  (f x, f' x) ∈ M -- Aを基本型として扱う

ここで、Mが(関係の特別な場合としての)関数の場合だけを考える*1。すなわち、この関数をuと書くことにして、

(a, a') ∈ M iff a' = u a

すると、さらに以下のように簡略化できる。

(<>, <>') ∈ R(M -> M -> M)
∀a,b. u (a <> b) = u a <> u b
(f, f') ∈ R(A -> M)
∀x :: A. u (f x) = f' x
u . f = f'

元の条件に代入する。

(t, t) ∈ R(forall m. (Monoid m) => (A -> m) -> m)
∀u :: M0 -> M1.
  ∀(<>, <>').
    (∀a,b. u (a <> b) = u a <> u b) =>
      ∀(e, e').
        (u e = e') =>
          ∀(f, f').
            (u . f = f') =>
              u (t (<>) e f) = t (<>') e' f'
∀u :: M0 -> M1.
  ∀(<>, <>').
    ∀(e, e').
      ∀(f, f').
        (uがモノイド準同型) =>
              u (t (<>) e f) = t (<>') e' (u . f)
∀u :: M0 -> M1.
  ∀f.
    (uがモノイド準同型) =>
      u (t f) = t (u . f) -- 辞書渡しを暗黙に

結局、満たされるべき条件は、fが任意の関数で、uが任意のモノイド準同型のとき、u (t f) = t (u . f)が成り立つことである。

これは要するに、tの要素がHom_{\mathbf{Mon}}(A,|-|)から|-|への自然変換だということである。命題0より、このような自然変換の集合は|F(A)|と同型である。\mathbf{Mon}におけるAの自由対象はAの有限列だから、これはVector Aに同型。

制約がMonoidでなくCommutativeIdempotentMonoidの場合も同様の議論。

以上が説得されるべきことであった。

感想と反省

  • 圏論の議論が面倒。全部Haskellの等式推論でやりたい。(実際fとgの構成はHaskellでやった)。たぶん圏論上手い人はもっとうまく証明できるはずなんだけど、どうしていいか分からない。
  • 米田の補題を使わずにやろうとしたらひどいことになった。なぜかは分からない。
  • パラメトリシティから自然性にもっていく議論があやしい。必要性しか示してない。
  • seqどころか⊥の存在すら無視している。モノイド則を満たさないMonoidインスタンスの存在も。
  • Lens' s a (定義は forall f. (Functor f) => (a -> f a) -> s -> f s)が(s -> (a, (a -> s)))と同型であることも似た議論で示せる?
  • 集合としての同型しか示してないけど、実際には命題0の左辺の自然変換の集合には(CがMonなら)モノイド構造が入るはず。本当はこれを主張したかったが、述べる方法すら分からなかった。もしかして同型が存在するだけでなく自然であることを示せばいい?

*1:こうするとなぜかうまくいくらしい

doの乱用

メモ。

Haskell Reportによれば、(do a)という式の意味は(a)と同じなので、aがモナドな型を持っている必要はない。これを利用して、括弧を減らすためだけにdoを使うことができる。

import Data.Complex
import Data.Monoid
import Data.Text.Lazy.Builder
import Data.Text.Lazy.Builder.RealFloat

-- | 普通
complexInPolar :: Complex Double -> Builder
complexInPolar x
  =  fromString "("
  <> realFloat (magnitude x)
  <> fromString ":"
  <> realFloat (phase x)
  <> fromString ")"

-- | doの乱用
complexInPolar' :: Complex Double -> Builder
complexInPolar' x
  =  fromString "("
  <> do realFloat $ magnitude x
  <> fromString ":"
  <> do realFloat $ phase x
  <> fromString ")"

多相関数のprintfデバッグをGHCiで

Haskellでのデバッグといえばprintfデバッグなのではないかと思う。printfデバッグは大抵の場合うまくいくが、多相的な関数を書いているときは不便なことがある。表示したい値を文字列にする手段がない場合だ。

import Debug.Trace
import Data.List

mysort :: (Ord a) => [a] -> [a]
mysort [] = []
mysort (x:xs) =
  trace ("inserting " ++ show x) $ -- エラー! xはshowできない
  insert x $ mysort xs

main = print $ mysort [1,3,2,0] -- デバッグ用入力

実際のデバッグ入力であるIntegerはshowできるのだが、mysortの型からはそれが演繹できないのでshowを呼ぶことができない。ここでxを表示するためにはmysortの型を変えて(Show a)制約を追加せねばならない。このように小さな関数ならそれも我慢できるが、呼び出し元にも再帰的に制約が必要になって何十行も変更することになるのは辛い。

ここでやりたいことは、実行時にxの型を調べて、それが表示可能であれば表示することだ。GHCでは実行時の値に型情報が付かないので一見不可能だが、GHCiにはこれを可能にする魔術が実装されている。

GHCiデバッガ

最近のGHCiにはデバッガが搭載されていて、プログラムの任意の位置にブレークポイントを仕掛けることができる。ブレークポイントの付いた式が評価され始めるとそのタイミングで実行が中断され、ローカル変数の値などを調べることができる。詳しくはマニュアル(和訳)を参照。

静的には型の決まっていないローカル変数の値を調べる際、GHCiは実行時のヒープオブジェクトを調べて可能な限り型情報を復元しようとし、成功すればshowを使って表示することができる。この能力をprintfデバッグに転用しよう。

実装

以下のMyDebug.hsを用意する。

module MyDebug where

myTrace :: a -> b -> b
myTrace message _body = const _body message

myTraceM :: (Monad m) => a -> m ()
myTraceM toShow = myTrace toShow $ return ()

myTraceにブレークポイントを仕掛け、ローカル変数messageの値をGHCiで覗くというのが意図だ。

Debug.Trace.traceの代わりにMyDebug.myTraceを使うようにコードを書き換える。

import MyDebug
import Data.List

mysort :: (Ord a) => [a] -> [a]
mysort [] = []
mysort (x:xs) =
  myTrace ("inserting", x) $
  insert x $ mysort xs

main = print $ mysort [1,3,2,0] -- デバッグ用入力

ブレークポイントを設定したり、messageの値を調べた後に実行を再開するのが面倒なので、GHCiスクリプトを書いて自動化する。

:{
let
  ghci_debug_impl expr = return $ unlines
    [ ":break MyDebug 4 25" -- myTraceの定義本体にブレークポイントを設定
    , ":set stop :debug_stop_handler" -- ブレークポイントに到達するたびに自動でdebug_stop_handlerが呼ばれるように
    , expr -- 与えられた式を評価
    , ":unset stop" -- :set stopを消す
    , ":delete *" -- ブレークポイントを消す
    ]
  debug_stop_handler_impl _ = return $ unlines
    [ ":force message" -- 「message」という名前のローカル変数を表示
    , ":continue" -- 実行再開
    ];
:}
:def debug ghci_debug_impl
:def debug_stop_handler debug_stop_handler_impl

-- :force時、可能ならshow関数を使う
:set -fprint-evld-with-show

これで、「:debug <式>」と入力すれば<式>を評価しつつmyTraceが呼ばれるたびに第一引数の値が表示されるようになった。

% ghci sort.hs
GHCi, version 7.4.2: http://www.haskell.org/ghc/  :? for help
Loading package ghc-prim ... linking ... done.
Loading package integer-gmp ... linking ... done.
Loading package base ... linking ... done.
[1 of 2] Compiling MyDebug          ( MyDebug.hs, interpreted )
[2 of 2] Compiling Main             ( sort.hs, interpreted )
Ok, modules loaded: MyDebug, Main.
*Main> :debug main
Breakpoint 0 activated at MyDebug.hs:4:25-43
Stopped at MyDebug.hs:4:25-43
_body :: b = _
_result :: b = _
message :: (a, Integer) = (_,1)
message = ("inserting",1)
Stopped at MyDebug.hs:4:25-43
_body :: b = _
_result :: b = _
message :: (a, Integer) = (_,3)
message = ("inserting",3)
Stopped at MyDebug.hs:4:25-43
_body :: b = _
_result :: b = _
message :: (a, Integer) = (_,2)
message = ("inserting",2)
Stopped at MyDebug.hs:4:25-43
_body :: b = _
_result :: b = _
message :: (a, Integer) = (_,0)
message = ("inserting",0)
[0,1,2,3]
*Main>

ブレークポイントに到達するたびに邪魔な表示が出るのでフィルタする(強引に)。以下のスクリプトをdebugfilter.shという名前で保存しておく。

#!/bin/sh
grep -v '^\(_body :: .*_$\)\|\(_result :: .*_$\)\|\(message :: .*\)\|\(Stopped at MyDebug\.hs:4:25-43$\)\|\(Breakpoint 0 activated at MyDebug\.hs:4:25-43$\)'

「ghci 」の代わりに「ghci | ./debugfilter.sh」のように起動すれば綺麗な出力が得られる。

% ghci sort.hs | ./debugfilter.sh
GHCi, version 7.4.2: http://www.haskell.org/ghc/  :? for help
Loading package ghc-prim ... linking ... done.
Loading package integer-gmp ... linking ... done.
Loading package base ... linking ... done.
[1 of 2] Compiling MyDebug          ( MyDebug.hs, interpreted )
[2 of 2] Compiling Main             ( sort.hs, interpreted )
Ok, modules loaded: MyDebug, Main.
*Main> :debug main
message = ("inserting",1)
message = ("inserting",3)
message = ("inserting",2)
message = ("inserting",0)
[0,1,2,3]
*Main> 

まとめ

もうちょっと邪悪じゃない方法はないものか。

後からフィールドを追加できるレコード

通常のレコード型は一旦定義したらフィールドを追加したりはできないが、Template Haskellを使えば似たようなことができるのに気付いたので書いてみた。

次のように使う。

{-# LANGUAGE TemplateHaskell, TypeFamilies, DeriveDataTypeable #-}
import OpenProduct

defineOpenProduct "Foo"
  -- レコード型Fooを定義する

defineOpField [t|Foo|] "FieldA" [t|Int|] [|0|]
  -- FooのフィールドとしてFieldAを定義する
  -- 型はIntで初期値は0

x :: Foo
x = opSetField FieldA 4 opEmpty
  -- x = Foo{ fieldA = 4 } みたいな感じ

defineOpField [t|Foo|] "FieldB" [t|Maybe String|] [|Nothing|]
  -- FooのフィールドとしてFieldBを定義する
  -- 型はMaybe Stringで初期値はNothing

y :: Foo
y = opSetField FieldB (Just "y") x
  -- y = x{ fieldB = Just "y" } みたいな感じ

main = do
  print $ opGetField FieldA x -- 4
  print $ opGetField FieldB x -- Nothing
  print $ opGetField FieldA y -- 4
  print $ opGetField FieldB y -- Just "y"

実装は以下。

{-# LANGUAGE TypeFamilies, GADTs, FlexibleContexts #-}
{-# OPTIONS_GHC -Wall #-}
module OpenProduct
  ( OpenProduct -- abstract!
  , OPField -- abstract!
  , opEmpty
  , opGetField
  , opSetField
  , (%<)
  , (%=)
  , defineOpenProduct
  , defineOpField
  ) where

import Control.Applicative
import qualified Data.Map as M
import Data.Maybe
import Data.Typeable
import GHC.Exts (Any)
import Language.Haskell.TH
import Unsafe.Coerce

-- | 開レコードのクラス
class OpenProduct a

-- 全ての開レコードはこの型のnewtypeになる
type OpRep = M.Map TypeRep Any

toOpRep :: (OpenProduct a) => a -> OpRep
toOpRep = unsafeCoerce

fromOpRep :: (OpenProduct a) => OpRep -> a
fromOpRep = unsafeCoerce

-- | 空の開レコード
opEmpty :: (OpenProduct a) => a
opEmpty = fromOpRep M.empty

-- | 開レコードのフィールドのクラス
class (OpenProduct (OPContaining f)) => OPField f where
  type OPFieldType f -- ^ フィールドfの型
  type OPContaining f -- ^ フィールドfを含む開レコードの型

  opfKey :: f -> TypeRep
  opfDefaultValue :: f -> OPFieldType f

-- | 開レコードのフィールドを読む
opGetField, (%<) :: (OPField f) => f -> OPContaining f -> OPFieldType f
opGetField fld rec = fromMaybe (opfDefaultValue fld) $
  unsafeCoerce $ M.lookup (opfKey fld) $ toOpRep rec
(%<) = opGetField

-- | 開レコードのフィールドを更新する
opSetField, (%=) :: (OPField f) => f -> OPFieldType f -> OPContaining f -> OPContaining f
opSetField fld val rec = fromOpRep $ M.insert (opfKey fld) (unsafeCoerce val) $ toOpRep rec
(%=) = opSetField

---- ここからマクロ ----

-- | 開レコードを定義する
defineOpenProduct :: String -> Q [Dec]
-- defineOpenProduct "Foo" =>
--   newtype Foo = Fooabc OpRep
--   instance OpenProduct Foo
defineOpenProduct nameS = do
  conName <- newName nameS
  return
    [ NewtypeD [] name [] (con conName) []
    , InstanceD [] (AppT (ConT ''OpenProduct) (ConT name)) []
    ]
  where
    name = mkName nameS
    con conName = NormalC conName [(NotStrict, ConT ''OpRep)]

-- | フィールドを定義する
defineOpField :: TypeQ -> String -> TypeQ -> ExpQ -> Q [Dec]
-- defineOpField [t|Foo|] "Fld" [t|Int|] [|4|] =>
--   data Fld = Fld
--     deriving (Typeable)
--   instance OPField Fld where
--     type OPFieldType Fld = Int
--     type OPContaining Fld = Foo
--     opfKey = typeOf
--     opfDefaultValue = \_ -> 4
defineOpField recType nameS fldType defexp = sequence
  [ dataD (pure []) name [] [normalC name []] [''Typeable]
  , instanceD (pure []) (appT (conT ''OPField) (conT name))
    [ tySynInstD ''OPFieldType [conT name] fldType
    , tySynInstD ''OPContaining [conT name] recType
    , valD (varP 'opfKey) (normalB $ varE 'typeOf) []
    , valD (varP 'opfDefaultValue) (normalB $ lamE [wildP] defexp) []
    ]
  ]
  where name = mkName nameS

使いみちはなんだろう。モナディックな関数のメモ化とか、packrat parsingとかに使えなくもないか

induceBackwardを高速化しようとした

Haskell vs F# - Life Goes Onが気になったのでやってみた。

手元にF#の実行環境がないので、元のコードを2倍高速化することを目標にしてみる。環境はLinux x64, GHC 7.0.4.

最初のコード。

import Data.Array.Unboxed

data Node = Node {
  df :: Double,
  branch :: [(Int, Double)]
  }

induceBackward :: Array Int Node -> UArray Int Double -> UArray Int Double
induceBackward nodes values = accumArray (+) 0 (bounds nodes)
  [(j, p * values ! k * df) | (j, Node df branch) <- assocs nodes, (k, p) <- branch]

iteration = 1000

main :: IO()
main = print (maximum [value i | i <- [1..iteration]])
  where
  value i = foldr induceBackward (lastValues i) testTree ! 0
  lastValues i = listArray (-100, 100) (repeat (fromIntegral i))
  testTree = [listArray (-i, i)
    [Node 1.0 [(j-1, 1.0/6.0), (j, 2.0/3.0), (j+1, 1.0/6.0)] | j <- [-i..i]]
    | i <- [0..99]]

実行時間。

% time ./backward
999.9999999999998
./backward  1.12s user 0.01s system 99% cpu 1.126 total

まず気が付いたのは、一番内側のループ(induceBackward内の「(k, p) <- branch」部分)でリストを辿っていること。これを配列上のfoldにしてしまいたい。そのためにはもう一段外のループ(assocs nodeの部分)を配列上のmapにしたい。しかしこれはArrayからUArrayへの変換なので効率的にやるのは面倒。そこでarrayをやめてvectorパッケージを使うことにする。

vectorパッケージを使うと、vがboxed vectorなら次のようにしてunboxed vectorに変換しつつmapできる。

-- import qualified Data.Vector as V
V.convert (V.map f v)

V.convertはboxed vectorをunboxed vectorに変換する関数だが、融合変換の魔法によってV.mapで作られてV.convertで消費される中間のboxed vectorが排除されるので、全体としてmap一回分のコストで済む。

vectorパッケージのvectorはarrayパッケージの配列とちがって添字のオフセットを指定できないので、それを合わせて型ArrとUArrを作ることにした。

import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U

data Node = Node {
  df :: Double,
  branch :: [(Int, Double)]
  }

type Arr a = (Int, V.Vector a)
type UArr a = (Int, U.Vector a)

(!) :: (U.Unbox a) => UArr a -> Int -> a
(offset, vec) ! k = vec U.! (k - offset)

induceBackward :: Arr Node -> UArr Double -> UArr Double
induceBackward (nodesOffset, nodes) values = (nodesOffset, newValues)
  where
    newValues = V.convert $ V.map f nodes
    f (Node df branch) = sum [p * values ! k * df | (k, p) <- branch]

iteration = 1000

main :: IO()
main = print (maximum [value i | i <- [1..iteration]])
  where
  value i = foldr induceBackward (lastValues i) testTree ! 0
  lastValues i = (-100, U.replicate 201 (fromIntegral i))
  testTree = [(-i, V.fromList [Node 1.0 [(j-1, 1.0/6.0), (j, 2.0/3.0), (j+1, 1.0/6.0)] | j <- [-i..i]])
    | i <- [0..99]]
./backward2  3.01s user 0.02s system 99% cpu 3.034 total

かなり遅くなったが、気にしないで次にいく。branchをvector型に変更。

data Node = Node { 
  df :: Double,
  branch :: U.Vector (Int, Double)
  } 
f (Node df branch) = U.sum $ U.map (\(k, p) -> p * values ! k * df) branch

この(U.sum $ U.map ...)もやはり融合変換で中間vectorのない形になっていることを期待している。

./backward3  1.15s user 0.01s system 99% cpu 1.163 total

実行時間は最初のコードとほぼ同じに戻った。

induceBackwardの引数valuesが内側のループで頻繁に使われているので正格にする。

induceBackward (nodesOffset, nodes) values@(!_, !_) = (nodesOffset, newValues)
./backward4  0.76s user 0.01s system 99% cpu 0.766 total

この時点でコアを読んでみるが明かな無駄が見つからなかった。Nodeのdfフィールドを正格にしてUNPACKプラグマを付けてみる。

./backward6  0.73s user 0.01s system 99% cpu 0.743 total

同様にbranchフィールドにもUNPACK指定したいところだが、なぜか型族をUNPACKすることはできない。

LLVMバックエンドが速いようなので試してみる。手元のLLVMが新しすぎる(3.0)ためGHC 7.0では対応していないのでGHC 7.4.1をインストールした。

./backward6-7.4.1  0.74s user 0.01s system 99% cpu 0.751 total
./backward6-7.4.1-llvm  0.71s user 0.01s system 99% cpu 0.724 total

このあたりが限界かと思ったが、branchの長さが3固定であることを利用してループをアンロールし、さらに二箇所ある(U.!)をU.unsafeIndexに置き換えることで劇的な高速化ができることに偶然気づいた。

./backward7-7.4.1-llvm  0.39s user 0.00s system 99% cpu 0.401 total

まとめ

2倍を越える高速化は達成したが釈然としない。なぜアンロールにこれほどの効果があるのか分からないし、unsafeIndexが必須なのも気分が悪い。あとはもっと詳しい人に任せたい。

あとF#速いですね。

最終的なコード。

{-# LANGUAGE BangPatterns #-}
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U

data Node = Node {
  df :: {-# UNPACK #-} !Double,
  branch :: U.Vector (Int, Double)
  }

type Arr a = (Int, V.Vector a)
type UArr a = (Int, U.Vector a)

(!) :: (U.Unbox a) => UArr a -> Int -> a
(!) (offset, vec) k = vec `U.unsafeIndex` (k - offset)

induceBackward :: Arr Node -> UArr Double -> UArr Double
induceBackward (nodesOffset, nodes) values@(!_, !_) = (nodesOffset, newValues)
  where
    newValues = V.convert $ V.map f nodes
    f (Node df branch) = fold3 0 $ \i s -> case branch `U.unsafeIndex` i of
      (k, p) -> p * values ! k * df + s

fold3 :: a -> (Int -> a -> a) -> a
fold3 x f = f 0 $ f 1 $ f 2 x

iteration = 1000

main :: IO()
main = print (maximum [value i | i <- [1..iteration]])
  where
  value i = foldr induceBackward (lastValues i) testTree ! 0
  lastValues i = (-100, U.replicate 201 (fromIntegral i))
  testTree = [(-i, V.fromList [Node 1.0 $ U.fromList [(j-1, 1.0/6.0), (j, 2.0/3.0), (j+1, 1.0/6.0)] | j <- [-i..i]])
    | i <- [0..99]]

おまけ

保守性とか安全性とか無視して速さを追求したらこうなった。

./backward-fast  0.20s user 0.00s system 99% cpu 0.205 total

ここまでしてやっと、適当に書いたC++コードと同等。

./a.out  0.21s user 0.00s system 99% cpu 0.214 total

コードは以下。

{-# LANGUAGE BangPatterns #-}

import Control.Monad
import Control.Monad.ST
import qualified Data.Primitive as P
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U

data Node = Node {
  df :: {-# UNPACK #-} !Double,
  branchIndex :: {-# UNPACK #-} !P.ByteArray{- Int -},
  branchCoefficient :: {-# UNPACK #-} !P.ByteArray{- Double -}
  }

type Arr a = (Int, V.Vector a)
type UArr a = (Int, U.Vector a)

(!) :: (U.Unbox a) => UArr a -> Int -> a
(!) (offset, vec) k = vec `U.unsafeIndex` (k - offset)

induceBackward :: Arr Node -> UArr Double -> UArr Double
induceBackward (nodesOffset, nodes) values@(!_, !_) = (nodesOffset, newValues)
  where
    newValues = V.convert $ V.map f nodes
    f (Node df branchIndex branchCoefficient) = fold3 0 $ \i s ->
      s +
      P.indexByteArray branchCoefficient i *
      values ! P.indexByteArray branchIndex i *
      df

fold3 :: a -> (Int -> a -> a) -> a
fold3 x f = f 0 $ f 1 $ f 2 x

iteration = 1000

main :: IO()
main = print (maximum [value i | i <- [1..iteration]])
  where
  value i = foldr induceBackward (lastValues i) testTree ! 0
  lastValues i = (-100, U.replicate 201 (fromIntegral i))
  testTree = [(-i, V.fromList [Node 1.0 (byteArrayFromList [j-1, j, j+1]) coefficients | j <- [-i..i]])
    | i <- [0..99]]
  coefficients = byteArrayFromList [1.0/6.0, 2.0/3.0, 1.0/6.0::Double]

byteArrayFromList :: (P.Prim a) => [a] -> P.ByteArray
byteArrayFromList xs = runST $ do
  mut <- P.newByteArray (length xs * P.sizeOf (head xs))
  forM_ (zip [0..] xs) $ \(i, v) -> P.writeByteArray mut i v
  P.unsafeFreezeByteArray mut
# include <vector>
# include <utility>
# include <memory>
# include <cstdio>

using namespace std;

struct node
{
  double df;
  vector<pair<int, double> > branch;
};

auto_ptr<vector<double> > induce_backward(const vector<node> &nodes, const vector<double> &values)
{
  const int n_nodes = nodes.size();
  auto_ptr<vector<double> > ret(new vector<double>());
  ret->reserve(n_nodes);
  const int n = values.size() / 2;
  for(int j = 0; j < n_nodes; j++)
  {
    double sum = 0;
    const node &node = nodes[j];
    for(int k = 0; k < 3; k++)
      sum += node.branch[k].second * values[n + node.branch[k].first] * node.df;
    ret->push_back(sum);
  }
  return ret;
}

const int iteration = 1000;

int main()
{
  vector<vector<node> > test_tree;
  for(int i = 0; i < 100; i++)
  {
    test_tree.push_back(vector<node>());
    vector<node> &nodes = test_tree.back();
    for(int j = -i; j <= i; j++)
    {
      nodes.push_back(node());
      node &node = nodes.back();
      node.df = 1.0;
      node.branch.push_back(make_pair(j-1, 1.0/6.0));
      node.branch.push_back(make_pair(j, 2.0/3.0));
      node.branch.push_back(make_pair(j+1, 1.0/6.0));
    }
  }
  double maxval = 0;
  for(int i = 1; i <= iteration; i++)
  {
    auto_ptr<vector<double> > values(new vector<double>(201, (double)i));
    for(int j = 99; j >= 0; j--)
      values = induce_backward(test_tree[j], *values);
    maxval = max(maxval, (*values)[0]);
  }
  printf("%f\n", maxval);
}