ノイズ耐性のある二分探索

観測にノイズが乗っても対処できる二分探索が意外と簡単に書けることが分かったのでメモ。

問題

(n-1)個の整数からなる列がある。そのうち左からi個は(-1)であり、それ以外は1である。

n=6, i=3の例
-1 -1 -1 1 1

クエリを繰り返すことでiを求めたい。各クエリは整数kであり、答として左からk番目の整数の値が得られる。ただし、この答にはノイズが加算される。ノイズは標準偏差σ(既知とする)の正規分布に従う。

解法

iがどの値を取るかの確率分布を持っておいて、クエリの答が得られるたびにベイズの定理に従って更新する。最初はn通りの一様分布。クエリは、得られる情報量の期待値を最大化するように選ぶ。これには、(-1)と1の境界よりも右にあるか左にあるかが半々に近い位置を選べば良い。

iの確率分布が十分に偏ったら終了。

実験

以下では、iが特定の値を取る確率が95%を越えた時点で探索を終了するようにしている。

ノイズが小さい場合は普通の二分探索のように動作する。

# n=8, σ=0.3, i=3
% ./noisy-bsearch 8 0.3 --answer 3 
? 3 -> 1.031
? 1 -> -0.579
? 2 -> -0.859
Found: 3 posterior=0.9999999944164625

ノイズが大きくなると、解を確定するのに同じクエリを何度も発行する必要が出てくる。

# n=8, σ=1, i=3
% ./noisy-bsearch 8 1 --answer 3
? 3 -> 1.214
? 1 -> -0.943
? 2 -> 0.184
? 2 -> 0.901
? 1 -> -0.704
? 2 -> -2.098
? 3 -> 1.097
? 2 -> -1.715
Found: 3 posterior=0.9579616340601826

たまに間違える。

% ./noisy-bsearch 8 1 --answer 3
? 3 -> 1.813
? 1 -> 1.609
? 0 -> -1.378
? 1 -> 0.376
? 0 -> -1.202
Found: 1 posterior=0.9565230889174586

nを大きくすると、適当に当たりを付けながら試行錯誤する様子が見える。

# n=100, σ=2, i=77
% ./noisy-bsearch 100 2 --answer 77
? 49 -> -1.182
? 60 -> -1.373
? 70 -> -0.116
? 71 -> 0.242
? 69 -> -0.160
? 70 -> -4.324
? 83 -> -0.343
? 84 -> -0.377
? 85 -> 2.448
? 79 -> 4.538
? 73 -> -0.290
? 74 -> -0.890
? 75 -> -1.684
? 77 -> 2.514
? 75 -> -1.450
? 76 -> -1.166
? 76 -> -4.089
? 77 -> 0.621
? 77 -> -2.006
? 78 -> 2.714
? 77 -> -0.352
? 77 -> 1.402
? 77 -> -0.629
? 77 -> -2.606
? 77 -> -0.347
? 78 -> 1.579
? 77 -> 7.286
? 76 -> 2.773
? 76 -> 0.334
? 76 -> -1.448
? 76 -> 0.915
? 76 -> 1.719
? 76 -> -2.039
? 76 -> -3.672
? 76 -> -0.300
? 77 -> 1.227
? 76 -> -3.140
Found: 77 posterior=0.9527731713033516

まとめ

推定すべき確率分布が離散的だと、頑張って近似しなくても教科書通りのベイズの定理が使えて楽しいことが分かった。

コード

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TemplateHaskell #-}

import Control.Applicative
import Control.Monad
import Control.Lens
import Data.Function
import Data.List
import Data.Maybe
import qualified Data.Vector.Unboxed as U
import Statistics.Distribution
import Statistics.Distribution.Normal
import System.Environment
import System.Console.GetOpt
import System.IO
import qualified System.Random.MWC as MWC
import Text.Printf
import Text.Read

data Interaction = Human | Simulated !Int
data Mode = Help | Search

data Conf = Conf
  { _observationFile :: !(Maybe FilePath)
  , _interaction :: !Interaction
  , _verbose :: !Bool
  , _mode :: !Mode
  }

makeLenses ''Conf

options :: [OptDescr (Conf -> Conf)]
options =
  [ Option "f" ["observation-file"]
      (ReqArg (\f -> observationFile .~ Just f) "FILE")
      "Apply observations from FILE before starting search"
  , Option "a" ["answer"]
      (ReqArg (\a -> interaction .~ Simulated (read' a)) "N")
      "Automatically respond to queries"
  , Option "v" ["verbose"]
      (NoArg (verbose .~ True))
      "Display posterior distribution at each step"
  , Option "h" ["help"] (NoArg (mode .~ Help)) "Show this help"
  ]

defaultConf :: Conf
defaultConf = Conf
  { _observationFile = Nothing
  , _interaction = Human
  , _verbose = False
  , _mode = Search
  }

read' :: (Read a) => String -> a
read' s = fromMaybe (error $ "cannot parse " ++ s) $ readMaybe s

main :: IO ()
main = do
  args <- getArgs
  either fail id $ do
    let !(fns, nonOptions, errs) = getOpt Permute options args
    case errs of
      [] -> return ()
      _ -> Left $ unlines errs
    let conf = foldr ($) defaultConf fns
    case conf^.mode of
      Help -> return $
        putStr $ usageInfo "noisy-bsearch N SIGMA" options
      Search -> do
        (n, sigma) <- case nonOptions of
          [read' -> n, read' -> sigma] -> return (n, sigma)
          _ -> Left "expecting 2 arguments"
        return $ doSearch conf n sigma 

doSearch :: Conf -> Int -> Double -> IO ()
doSearch conf n sigma = do
  obss <- case conf^.observationFile of
    Nothing -> return []
    Just path -> map parse . lines <$> readFile path
  getAns <- getAnswer sigma $ conf^.interaction
  search conf sigma getAns
    $ foldl' (update sigma) (uniform n) obss
  where
    parse (words -> [a, b]) = (read a, read b)
    parse str = error $ "parse errro: "  ++ str

getAnswer :: Double -> Interaction -> IO (Int -> IO Double)
getAnswer _ Human = return askInteractive
getAnswer sigma (Simulated ansIdx) = do
  gen <- MWC.createSystemRandom
  return $ \k -> do
    r <- genContVar
      (normalDistr (if ansIdx <= k then 1 else -1) sigma) gen
    _ <- printf "? %d -> %.3f\n" k r
    return r

askInteractive :: Int -> IO Double
askInteractive k = do
  putStr $ show k ++ "? "
  hFlush stdout
  ans <- getLine
  case readMaybe ans of
    Just r -> return r
    Nothing -> do
      putStrLn "Try again"
      askInteractive k

-- | 実際の探索ループ
search :: Conf -> Double -> (Int -> IO Double) -> Dist -> IO ()
search conf sigma ask = loop
  where
    loop dist
      | Just k <- U.findIndex (>0.95) dist =
        printf "Found: %d posterior=%f\n" k (dist U.! k)
      | otherwise = do
        when (conf ^. verbose) $ print dist
        let k = medianIndex dist
        obs <- ask k
        loop $ update sigma dist (k, obs)

-- | いくつかの選択肢から一つを選ぶ確率分布
type Dist = U.Vector Double

-- | 一様分布
uniform :: Int -> Dist
uniform n = U.replicate n (1 / fromIntegral n)

-- | 観測データと事前分布から事後分布を得る
update :: Double -> Dist -> (Int, Double) -> Dist
update sigma dist (k, obs) = normalize $ U.imap upd dist
  where
    upd i x
      | i <= k = x * likR
      | otherwise = x * likL
    -- kが正解の右側にある場合の尤度
    likR = exp (- ((obs - 1) / sigma) ^ (2::Int) / 2) / (sigma * pi * sqrt 2)
    -- kが正解の左側にある場合の尤度
    likL = exp (- ((obs + 1) / sigma) ^ (2::Int) / 2) / (sigma * pi * sqrt 2)

-- | P(answer <= k) が0.5に最も近いkを返す
medianIndex :: Dist -> Int
medianIndex = U.minIndexBy (compare `on` (abs . subtract 0.5)) . U.scanl1 (+)

-- | 合計が1になるように定数倍する
normalize :: U.Vector Double -> Dist
normalize xs = U.map (/U.sum xs) xs