GHCで動くcall/ccの実装

IOモナド上で動作するcall/ccを実装できることが分かったので書いておく。ただし実用に耐えるものではない。

これを使うとたとえばこういうコードを書くことができる。

import Control.Applicative
import Control.Monad
import Data.IORef
import System.IO.Unsafe

import Continuation (callCC, withContinuationsDo)
test :: IO ()
test = do
  r <- callCC $ \cc -> return $ Right $ cc . Left
  case r of
    Left val -> do
      putStrLn $ "left: " ++ show val
      return ()
    Right cc -> do
      putStrLn "right"
      cc "foo"
      putStrLn "this should not be printed"
-- 実行結果:
--  right
--  left: "foo"

非決定計算もできる

ambNext :: IORef (IO a)
ambNext = unsafePerformIO $ newIORef $ fail "Dead end"
{-# NOINLINE ambNext #-}

amb :: [a] -> IO a
amb vs = foldr amb2' (join $ readIORef ambNext) vs
  where
    amb2' a b = join $ amb2 (return a) b

amb2 :: a -> a -> IO a
amb2 a b = do
  next <- readIORef ambNext
  callCC $ \cc -> do
    writeIORef ambNext $ do
      writeIORef ambNext next
      cc b
    return a

8クイーン問題を解く例。

eightQueens :: IO ()
eightQueens = loop 8 allPoints >>= print
  where
    allPoints = (,) <$> [0..7] <*> [0..7]

    loop :: Int -> [(Int, Int)] -> IO [(Int, Int)]
    loop 0 _ = return []
    loop k avail = do
      pos <- amb avail
      (pos:) <$> loop (k-1) (narrow pos avail)

    narrow pos avail = filter (not . reachable pos) avail
    reachable (x0, y0) (x1, y1)
      = x0 == x1
      || y0 == y1
      || x0 + y0 == x1 + y1
      || x0 - y0 == x1 - y1

mainは次のようになる。

main :: IO ()
main = withContinuationsDo $ do
  test
  eightQueens

種明かし

実装は以下。

{-# OPTIONS_GHC -O #-}
  -- 最適化をオフにするとちゃんと動かない

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE UnliftedFFITypes #-}

module Continuation
  ( withContinuationsDo
  , getCC
  , callCC
  ) where

import Control.Concurrent
import Control.Exception
import Control.Monad
import Data.IORef
import Data.Typeable
import Foreign.Marshal.Utils
import Foreign.Storable
import System.IO.Unsafe

import GHC.Exts
import GHC.IO(IO(..), unIO)

-- | (Right 継続)を返す。返ってきた継続に値vを渡して呼ぶと、
-- getCCがもう一度返り、(Left v)を返す。
getCC :: IO (Either a (a -> IO b))
getCC = do
  ref <- newIORef $ error "getCC: bug"
  t <- myThreadId
  -- 自分自身に非同期例外を投げることで、基底ループに制御を移す。
  throwTo t $ GetCC $ writeIORef ref
  -- 再開。基底ループによりrefに適切な値が書き込まれている。
  readIORef ref

-- | Scheme風call/cc
callCC :: ((a -> IO b) -> IO a) -> IO a
callCC f = either return f =<< getCC

-- | メインルーチンから基底ループへのメッセージ。現在の継続を要求する。
data GetCCRequest = forall a b. GetCC (Either a (a -> IO b) -> IO ())
  deriving (Typeable)

instance Show GetCCRequest where
  show _ = "<getcc>"

instance Exception GetCCRequest

-- | 基底ループが内部的に使う例外。
data ResumeRequest = Resume () -- ^ 評価継続すべきサンク
  deriving (Typeable, Show)

instance Exception ResumeRequest

-- | @main = withContinuationsDo $ ...@ のように使う
withContinuationsDo :: IO () -> IO ()
withContinuationsDo x = do
  -- サンクの評価が非同期例外によって中断された場合、そのサンクの値を
  -- 再度要求することで評価を継続することができるのを利用する。
  let thunk = unsafePerformIO x
  loop thunk
  where
    -- 基底ループ本体。
    loop thunk = do
      debug "evaluating"
      -- サンクが例外を投げるまで評価する。
      r <- try $ try $ evaluate thunk
      -- GHC 7.6.3では、thunkの評価が例外によって書き換えられた場合、
      -- thunkはAP_STACKオブジェクトを指すINDオブジェクトになっている。
      case r of
        -- 現在の継続を要求してきた。
        Left (GetCC ccSink) -> do
          debug "caught gcc"
          -- まずRightを返す。
          ccSink $ Right $ \val -> do
            debug "entering cc"
            -- 継続が呼ばれた。こんどはLeftを返す。
            ccSink $ Left val
            debug "value added"
            -- ここはメインルーチンの内部なので、
            -- 基底ループに制御を移し、thunkの評価を再開する。
            throwIO $ Resume thunk
            -- ここにはこない。
          -- 戻り値を設定した上でthunkの評価を再開する。ただし、thunkを
          -- 上書きするのを防ぐため、コピーをとってそちらを評価する。
          -- こうすることで、一つの継続を複数回呼び出すことができる。
          loop =<< dupAP_STACK thunk
        -- thunk'の評価の再開を要求してきた。
        Right (Left (Resume thunk')) ->
          -- 上と同じ。thunkには二度と戻らないので捨ててよい。
          loop =<< dupAP_STACK thunk'
        -- 実行終了。
        Right (Right ()) -> return ()
{-# NOINLINE withContinuationsDo #-}

-- | @dupAP_STACK x@はxが指すサンクを複製してそれを返す。
-- xはAP_STACKか、それを指すINDオブジェクトでなければならない。
dupAP_STACK :: a -> IO a
dupAP_STACK thing = IO $ \s00 -> let
  -- 1ワードのバイト数
  !(I# wsz) = sizeOf (undefined::Int)
  -- コピー先のメモリを確保する。これには、まず
  -- MutableByteArray#を作り、それをByteArray#に変換してから
  -- アドレスを得る。手抜きのためサイズは固定
  !bufsize = 100# *# wsz
  !(# s10, mbarr #) = newByteArray# bufsize s00
  !(# s20, barr #) = unsafeFreezeByteArray# mbarr s10
  !dest = byteArrayContents# barr
  -- ここ以降、このletブロックの終了までにGCが走ると
  -- ポインタが失効してひどいことになるので、
  -- メモリ確保を一切しないようにする。
  !s21 = ndebug "critical block"# nullAddr# s20
  -- 与えられたオブジェクトのアドレス。
  !(# s30, obj #) = thunkToAddr thing s21
  !s31 = ndebug "obj"# obj s30
  -- AP_STACKサンクのアドレスを得る。
  !(# s40, thunk #) = derefIndirections obj s31
  -- AP_STACKオブジェクトのsizeフィールドを読んで、
  -- payloadのワード数を得る。
  !(# s50, size #) = readIntOffAddr# thunk 2# s40
  -- サンクのバイト数。これは以下からなる
  --   header (2ワード)
  --   size (1ワード)
  --   fun (1ワード)
  --   payload (sizeワード)
  !bytes = wsz *# (size +# 4#)
  !s51 = ndebug "bytes"# (int2Addr# bytes) s50
  -- サンクをコピーする。ここで被せている構築子は最適化で
  -- 取り除かれることを期待している
  !(# s60, _ #) =
    if bufsize <# bytes
      then error "buffer too short"
      else unIO
    (copyBytes (Ptr dest :: Ptr Int) (Ptr thunk) (I# bytes)) s51
  -- コピーをサンクとして返す
  !(# s70, r #) = addrToThunk dest s60
  -- ここ以降、GCが発生しても問題ない
  !s71 = ndebug "end critical block"# nullAddr# s70
  in (# s71, r #)

-- | 間接参照を再帰的に辿り、AP_STACKオブジェクトを返す。
-- この関数はメモリを確保しない。
derefIndirections
  :: Addr# -> State# RealWorld -> (# State# RealWorld, Addr# #)
derefIndirections obj s00 = let
  -- objのinfoポインタを読む
  !(# s10, info #) = readAddrOffAddr# obj 0# s00
  !s11 = ndebug "info"# info s10
  -- infoテーブルのtypeフィールドを読む。infoポインタはinfoテーブルの
  -- 終端を指すので、ここの添字は負になる。
  !(# s20, typ #) = readIntOffAddr# info -1# s11
  !s21 = ndebug "type"# (int2Addr# typ) s20
  -- includes/rts/storage/ClosureTypes.h:
  -- #define AP_STACK 27
  in if typ ==# 27#
    then (# s21, obj #)
    else let
      -- objはAP_STACKオブジェクトではなかったので、おそらく間接参照。
      -- indirecteeフィールドを読む。
      !(# s30, next #) = readAddrOffAddr# obj 1# s21
      !s31 = ndebug "indirectee"# next s30
      in derefIndirections next s31

-- | デバッグ文字列を表示する
debug :: String -> IO ()
debug msg = when debugEnabled $ putStrLn msg

-- | メモリを確保せずにデバッグ文字列を表示する
ndebug :: Addr# -> Addr# -> State# RealWorld -> State# RealWorld
ndebug loc addr s = if debugEnabled
  then let (# s1, () #) = unIO (c_printf_pp "%s: %p\n"# loc addr) s
    in s1
  else s

foreign import ccall unsafe "printf"
  c_printf_pp :: Addr# -> Addr# -> Addr# -> IO ()

debugEnabled :: Bool
debugEnabled = False

-- | サンクやINDのアドレスを得る。GCの前後でアドレスが変わり得る
-- という点でこれは外界依存の操作なので、その旨を型で表した上で
-- NOINLINEにしておく。
-- INDやサンクのタグビットは0なのでビット演算の必要なし。
thunkToAddr :: a -> State# RealWorld -> (# State# RealWorld, Addr# #)
thunkToAddr v s = (# s, unsafeCoerce# v #)
{-# NOINLINE thunkToAddr #-}

-- | サンクオブジェクトのアドレスをHaskellの値に変換する。
-- 'thunkToAddr' と同じ理由でNOINLINE。
addrToThunk :: Addr# -> State# RealWorld -> (# State# RealWorld, a #)
addrToThunk addr s = (# s, unsafeCoerce# addr #)
{-# NOINLINE addrToThunk #-}

見てのとおりGHCの実装詳細にかなり依存しているので、Linux amd64 ghc-7.6.3の組み合わせ以外では動かないかもしれない。加えて、例外やスレッドと一緒に使うと変なことになると思われるので実用性はない。