ノイズ耐性のある二分探索
観測にノイズが乗っても対処できる二分探索が意外と簡単に書けることが分かったのでメモ。
問題
(n-1)個の整数からなる列がある。そのうち左からi個は(-1)であり、それ以外は1である。
n=6, i=3の例 -1 -1 -1 1 1
クエリを繰り返すことでiを求めたい。各クエリは整数kであり、答として左からk番目の整数の値が得られる。ただし、この答にはノイズが加算される。ノイズは標準偏差σ(既知とする)の正規分布に従う。
解法
iがどの値を取るかの確率分布を持っておいて、クエリの答が得られるたびにベイズの定理に従って更新する。最初はn通りの一様分布。クエリは、得られる情報量の期待値を最大化するように選ぶ。これには、(-1)と1の境界よりも右にあるか左にあるかが半々に近い位置を選べば良い。
iの確率分布が十分に偏ったら終了。
実験
以下では、iが特定の値を取る確率が95%を越えた時点で探索を終了するようにしている。
ノイズが小さい場合は普通の二分探索のように動作する。
# n=8, σ=0.3, i=3 % ./noisy-bsearch 8 0.3 --answer 3 ? 3 -> 1.031 ? 1 -> -0.579 ? 2 -> -0.859 Found: 3 posterior=0.9999999944164625
ノイズが大きくなると、解を確定するのに同じクエリを何度も発行する必要が出てくる。
# n=8, σ=1, i=3 % ./noisy-bsearch 8 1 --answer 3 ? 3 -> 1.214 ? 1 -> -0.943 ? 2 -> 0.184 ? 2 -> 0.901 ? 1 -> -0.704 ? 2 -> -2.098 ? 3 -> 1.097 ? 2 -> -1.715 Found: 3 posterior=0.9579616340601826
たまに間違える。
% ./noisy-bsearch 8 1 --answer 3 ? 3 -> 1.813 ? 1 -> 1.609 ? 0 -> -1.378 ? 1 -> 0.376 ? 0 -> -1.202 Found: 1 posterior=0.9565230889174586
nを大きくすると、適当に当たりを付けながら試行錯誤する様子が見える。
# n=100, σ=2, i=77 % ./noisy-bsearch 100 2 --answer 77 ? 49 -> -1.182 ? 60 -> -1.373 ? 70 -> -0.116 ? 71 -> 0.242 ? 69 -> -0.160 ? 70 -> -4.324 ? 83 -> -0.343 ? 84 -> -0.377 ? 85 -> 2.448 ? 79 -> 4.538 ? 73 -> -0.290 ? 74 -> -0.890 ? 75 -> -1.684 ? 77 -> 2.514 ? 75 -> -1.450 ? 76 -> -1.166 ? 76 -> -4.089 ? 77 -> 0.621 ? 77 -> -2.006 ? 78 -> 2.714 ? 77 -> -0.352 ? 77 -> 1.402 ? 77 -> -0.629 ? 77 -> -2.606 ? 77 -> -0.347 ? 78 -> 1.579 ? 77 -> 7.286 ? 76 -> 2.773 ? 76 -> 0.334 ? 76 -> -1.448 ? 76 -> 0.915 ? 76 -> 1.719 ? 76 -> -2.039 ? 76 -> -3.672 ? 76 -> -0.300 ? 77 -> 1.227 ? 76 -> -3.140 Found: 77 posterior=0.9527731713033516
まとめ
推定すべき確率分布が離散的だと、頑張って近似しなくても教科書通りのベイズの定理が使えて楽しいことが分かった。
コード
{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TemplateHaskell #-} import Control.Applicative import Control.Monad import Control.Lens import Data.Function import Data.List import Data.Maybe import qualified Data.Vector.Unboxed as U import Statistics.Distribution import Statistics.Distribution.Normal import System.Environment import System.Console.GetOpt import System.IO import qualified System.Random.MWC as MWC import Text.Printf import Text.Read data Interaction = Human | Simulated !Int data Mode = Help | Search data Conf = Conf { _observationFile :: !(Maybe FilePath) , _interaction :: !Interaction , _verbose :: !Bool , _mode :: !Mode } makeLenses ''Conf options :: [OptDescr (Conf -> Conf)] options = [ Option "f" ["observation-file"] (ReqArg (\f -> observationFile .~ Just f) "FILE") "Apply observations from FILE before starting search" , Option "a" ["answer"] (ReqArg (\a -> interaction .~ Simulated (read' a)) "N") "Automatically respond to queries" , Option "v" ["verbose"] (NoArg (verbose .~ True)) "Display posterior distribution at each step" , Option "h" ["help"] (NoArg (mode .~ Help)) "Show this help" ] defaultConf :: Conf defaultConf = Conf { _observationFile = Nothing , _interaction = Human , _verbose = False , _mode = Search } read' :: (Read a) => String -> a read' s = fromMaybe (error $ "cannot parse " ++ s) $ readMaybe s main :: IO () main = do args <- getArgs either fail id $ do let !(fns, nonOptions, errs) = getOpt Permute options args case errs of [] -> return () _ -> Left $ unlines errs let conf = foldr ($) defaultConf fns case conf^.mode of Help -> return $ putStr $ usageInfo "noisy-bsearch N SIGMA" options Search -> do (n, sigma) <- case nonOptions of [read' -> n, read' -> sigma] -> return (n, sigma) _ -> Left "expecting 2 arguments" return $ doSearch conf n sigma doSearch :: Conf -> Int -> Double -> IO () doSearch conf n sigma = do obss <- case conf^.observationFile of Nothing -> return [] Just path -> map parse . lines <$> readFile path getAns <- getAnswer sigma $ conf^.interaction search conf sigma getAns $ foldl' (update sigma) (uniform n) obss where parse (words -> [a, b]) = (read a, read b) parse str = error $ "parse errro: " ++ str getAnswer :: Double -> Interaction -> IO (Int -> IO Double) getAnswer _ Human = return askInteractive getAnswer sigma (Simulated ansIdx) = do gen <- MWC.createSystemRandom return $ \k -> do r <- genContVar (normalDistr (if ansIdx <= k then 1 else -1) sigma) gen _ <- printf "? %d -> %.3f\n" k r return r askInteractive :: Int -> IO Double askInteractive k = do putStr $ show k ++ "? " hFlush stdout ans <- getLine case readMaybe ans of Just r -> return r Nothing -> do putStrLn "Try again" askInteractive k -- | 実際の探索ループ search :: Conf -> Double -> (Int -> IO Double) -> Dist -> IO () search conf sigma ask = loop where loop dist | Just k <- U.findIndex (>0.95) dist = printf "Found: %d posterior=%f\n" k (dist U.! k) | otherwise = do when (conf ^. verbose) $ print dist let k = medianIndex dist obs <- ask k loop $ update sigma dist (k, obs) -- | いくつかの選択肢から一つを選ぶ確率分布 type Dist = U.Vector Double -- | 一様分布 uniform :: Int -> Dist uniform n = U.replicate n (1 / fromIntegral n) -- | 観測データと事前分布から事後分布を得る update :: Double -> Dist -> (Int, Double) -> Dist update sigma dist (k, obs) = normalize $ U.imap upd dist where upd i x | i <= k = x * likR | otherwise = x * likL -- kが正解の右側にある場合の尤度 likR = exp (- ((obs - 1) / sigma) ^ (2::Int) / 2) / (sigma * pi * sqrt 2) -- kが正解の左側にある場合の尤度 likL = exp (- ((obs + 1) / sigma) ^ (2::Int) / 2) / (sigma * pi * sqrt 2) -- | P(answer <= k) が0.5に最も近いkを返す medianIndex :: Dist -> Int medianIndex = U.minIndexBy (compare `on` (abs . subtract 0.5)) . U.scanl1 (+) -- | 合計が1になるように定数倍する normalize :: U.Vector Double -> Dist normalize xs = U.map (/U.sum xs) xs