PGF.Probabilistic

Plain source file: src/runtime/haskell/PGF/Probabilistic.hs (2015-03-03)

PGF.Probabilistic is imported by: ...
module PGF.Probabilistic
         ( Probabilities(..)
         , mkProbabilities                 -- :: PGF -> M.Map CId Double -> Probabilities
         , defaultProbabilities            -- :: PGF -> Probabilities
         , getProbabilities
         , setProbabilities
         , showProbabilities               -- :: Probabilities -> String
         , readProbabilitiesFromFile       -- :: FilePath -> PGF -> IO Probabilities

         , probTree
         , rankTreesByProbs
         , mkProbDefs
         ) where

import PGF.CId
import PGF.Data
import PGF.Macros

import qualified Data.Map as Map
import Data.List (sortBy,partition,nub,mapAccumL)
import Data.Maybe (fromMaybe) --, fromJust

-- | An abstract data structure which represents
-- the probabilities for the different functions in a grammar.
data Probabilities = Probs {
  funProbs :: Map.Map CId Double,
  catProbs :: Map.Map CId (Double, [(Double, CId)])
  }

-- | Renders the probability structure as string
showProbabilities :: Probabilities -> String
showProbabilities = unlines . concatMap prProb . Map.toList . catProbs where
  prProb (c,(p,fns)) = pr (p,c) : map pr fns
  pr (p,f) = showCId f ++ "\t" ++ show p

-- | Reads the probabilities from a file.
-- This should be a text file where on every line
-- there is a function name followed by a real number.
-- The number represents the probability mass allocated for that function.
-- The function name and the probability should be separated by a whitespace.
readProbabilitiesFromFile :: FilePath -> PGF -> IO Probabilities
readProbabilitiesFromFile file pgf = do
  s <- readFile file
  let ps0 = Map.fromList [(mkCId f,read p) | f:p:_ <- map words (lines s)]
  return $ mkProbabilities pgf ps0

-- | Builds probability tables. The second argument is a map
-- which contains the know probabilities. If some function is
-- not in the map then it gets assigned some probability based
-- on the even distribution of the unallocated probability mass
-- for the result category.
mkProbabilities :: PGF -> Map.Map CId Double -> Probabilities
mkProbabilities pgf probs =
  let funs1 = Map.fromList [(f,p) | (_,(_,fns)) <- Map.toList cats1, (p,f) <- fns]
      cats1 = Map.mapWithKey (\c (_,fns,_) -> 
                                 let p'   = fromMaybe 0 (Map.lookup c probs)
                                     fns' = sortBy cmpProb (fill fns)
                                 in (p', fns'))
                             (cats (abstract pgf))
  in Probs funs1 cats1
  where
    cmpProb (p1,_) (p2,_) = compare p2 p1

    fill fs = pad [(Map.lookup f probs,f) | (_,f) <- fs]
      where
        pad :: [(Maybe Double,a)] -> [(Double,a)]
        pad pfs = [(fromMaybe deflt mb_p,f) | (mb_p,f) <- pfs]
          where
            deflt = case length [f | (Nothing,f) <- pfs] of
                      0 -> 0
                      n -> max 0 ((1 - sum [d | (Just d,f) <- pfs]) / fromIntegral n)

-- | Returns the default even distibution.
defaultProbabilities :: PGF -> Probabilities
defaultProbabilities pgf = mkProbabilities pgf Map.empty

getProbabilities :: PGF -> Probabilities
getProbabilities pgf = Probs {
  funProbs = Map.map (\(_,_,_,p) -> p      ) (funs (abstract pgf)),
  catProbs = Map.map (\(_,fns,p) -> (p,fns)) (cats (abstract pgf))
  }

setProbabilities :: Probabilities -> PGF -> PGF
setProbabilities probs pgf = pgf {
  abstract = (abstract pgf) {
    funs = mapUnionWith (\(ty,a,df,_) p       -> (ty,a,df,  p)) (funs (abstract pgf)) (funProbs probs),
    cats = mapUnionWith (\(hypos,_,_) (p,fns) -> (hypos,fns,p)) (cats (abstract pgf)) (catProbs probs)
  }}
  where
    mapUnionWith f map1 map2 = 
      Map.mapWithKey (\k v -> maybe v (f v) (Map.lookup k map2)) map1

-- | compute the probability of a given tree
probTree :: PGF -> Expr -> Double
probTree pgf t = case t of
  EApp f e -> probTree pgf f * probTree pgf e
  EFun f   -> case Map.lookup f (funs (abstract pgf)) of
                Just (_,_,_,p) -> p
                Nothing        -> 1
  _ -> 1

-- | rank from highest to lowest probability
rankTreesByProbs :: PGF -> [Expr] -> [(Expr,Double)]
rankTreesByProbs pgf ts = sortBy (\ (_,p) (_,q) -> compare q p) 
  [(t, probTree pgf t) | t <- ts]


mkProbDefs :: PGF -> ([[CId]],[(CId,Type,[Equation])])
mkProbDefs pgf =
  let cs = [(c,hyps,fns) | (c,(hyps0,fs,_)) <- Map.toList (cats (abstract pgf)),
                           not (elem c [cidString,cidInt,cidFloat]),
	                       let hyps = zipWith (\(bt,_,ty) n -> (bt,mkCId ('v':show n),ty))
	                                          hyps0
	                                          [1..]
	                           fns  = [(f,ty) | (_,f) <- fs, 
	                                            let Just (ty,_,_,_) = Map.lookup f (funs (abstract pgf))]
           ]
      ((_,css),eqss) = mapAccumL (\(ngen,css) (c,hyps,fns) -> 
              let st0      = (1,Map.empty)
                  ((_,eqs_map),cs) = computeConstrs pgf st0 [(fn,[],es) | (fn,(DTyp _ _ es)) <- fns]
                  (ngen', eqs) = mapAccumL (mkEquation eqs_map hyps) ngen fns
                  ceqs     = [(id,DTyp [] cidFloat [],reverse eqs) | (id,eqs) <- Map.toList eqs_map, not (null eqs)]
              in ((ngen',cs:css),(p_f c, mkType c hyps, eqs):ceqs)) (1,[]) cs
  in (reverse (concat css),concat eqss)
  where
    mkEImplArg bt e
      | bt == Explicit = e
      | otherwise      = EImplArg e
      
    mkPImplArg bt p
      | bt == Explicit = p
      | otherwise      = PImplArg p

    mkType c hyps =
      DTyp (hyps++[mkHypo (DTyp [] c es)]) cidFloat []
      where
        is = reverse [0..length hyps-1]
        es = [mkEImplArg bt (EVar i) | (i,(bt,_,_)) <- zip is hyps]

    sig = (funs (abstract pgf), \_ -> Nothing)
    
    mkEquation ceqs hyps ngen (fn,ty@(DTyp args _ es)) =
      let fs1         = case Map.lookup (p_f fn) ceqs of
                          Nothing              -> [mkApp (k_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
                          Just eqs | null eqs  -> []
                                   | otherwise -> [mkApp (p_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
          (ngen',fs2) = mapAccumL mkFactor2 ngen vs2
          fs3         = map mkFactor3 vs3
          eq = Equ (map mkTildeP xes++[PApp fn (zipWith mkArgP [1..] args)])
                   (mkMult (fs1++fs2++fs3))
      in (ngen',eq)
      where
        xes = map (normalForm sig k env) es

        mkTildeP e =
          case e of
            EImplArg e -> PImplArg (PTilde e)
            e          ->           PTilde e

        mkArgP n (bt,_,_) = mkPImplArg bt (PVar (mkCId ('v':show n)))

        mkMult []  = ELit (LFlt 1)
        mkMult [e] = e
        mkMult es  = mkApp (mkCId "mult") es

        mkFactor2 ngen (src,dst) =
          let vs = [EVar (k-i-1) | (i,ty) <- src]
          in (ngen+1,mkApp (p_i ngen) vs)

        mkFactor3 (i,DTyp _ c es) =
          let v = EVar (k-i-1)
          in mkApp (p_f c) (map (normalForm sig k env) es++[v])

        (k,env,vs1,vs2,vs3) = mkDeps ty

        mkDeps (DTyp args _ es) =
          let (k,env,dep1) = updateArgs 0 [] [] args
              dep2         = foldl (update k env) dep1 es
              (vs2,vs3)    = closure k dep2 [] []
              vs1          = concat [src | (src,dst) <- dep2, elem k dst]
          in (k,map (\k -> VGen k []) env,vs1,reverse vs2,vs3)
          where
            updateArgs k env dep []                              = (k,env,dep)
            updateArgs k env dep ((_,x,ty@(DTyp _ _ es)) : args) =
              let dep1 = foldl (update k env) dep es ++ [([(k,ty)],[])]
                  env1 | x == wildCId =     env
                       | otherwise    = k : env
              in updateArgs (k+1) env1 dep1 args

            update k env dep e =
              case e of
                EApp e1 e2 -> update k env (update k env dep e1) e2
                EFun _     -> dep
                EVar i     -> let (dep1,(src,dst):dep2) = splitAt (env !! i) dep
                              in dep1++(src,k:dst):dep2

            closure k []               vs2 vs3 = (vs2,vs3)
            closure k ((src,dst):deps) vs2 vs3
              | null dst   = closure k deps vs2 (vs3++src)
              | otherwise  =
                  let (deps1,deps2) = partition (\(src',dst') -> not (null [v1 | v1 <- dst, v2 <- dst', v1 == v2])) deps
                      deps3 = (src,dst):deps1
                      src2  = concatMap fst deps3
                      dst2  = [v | v <- concatMap snd deps3
                                 , lookup v src2 == Nothing]
                      dep2  = (src2,dst2)
                      dst'  = nub dst
                  in if null deps1
                       then if dst' == [k]
                              then closure k deps2 vs2 vs3
                              else closure k deps2 ((src,dst') : vs2) vs3
                       else closure k (dep2 : deps2) vs2 vs3

        mkNewSig src =
          DTyp (mkArgs 0 0 [] src) cidFloat []
          where
            mkArgs k l env []                      = []
            mkArgs k l env ((i,DTyp _ c es) : src)
               | i == k    = let ty = DTyp [] c (map (normalForm sig k env) es)
                             in (Explicit,wildCId,ty) : mkArgs (k+1) (l+1) (VGen l [] : env) src
               | otherwise = mkArgs (k+1) l (VMeta 0 env [] : env) src

type CState = (Int,Map.Map CId [Equation])

computeConstrs :: PGF -> CState -> [(CId,[Patt],[Expr])] -> (CState,[[CId]])
computeConstrs pgf (ngen,eqs_map) fns@((id,pts,[]):rest)
  | null rest =
     let eqs_map' = 
           Map.insertWith (++) (p_f id)
                               (if null pts
                                  then []
                                  else [Equ pts (ELit (LFlt 1.0))])
                               eqs_map
     in ((ngen,eqs_map'),[])
  | otherwise =
     let (st,ks) = mapAccumL mk_k (ngen,eqs_map) fns

         mk_k (ngen,eqs_map) (id,pts,[])
           | null pts  = ((ngen,eqs_map),k_f id)
           | otherwise = let eqs_map' = 
                               Map.insertWith (++) 
                                              (p_f id) 
                                              [Equ pts (EFun (k_i ngen))]
                                              eqs_map
                         in ((ngen+1,eqs_map'),k_i ngen)

     in (st,[ks])
computeConstrs pgf st fns =
  let (st',res) = mapAccumL (\st (p,fns) -> computeConstrs pgf st fns)
                            st
                            (computeConstr fns)
  in (st',concat res)
  where
    computeConstr fns = merge (split fns (Map.empty,[]))

    merge (cns,vrs) =
      [(p,fns++[(id,ps++[p],es) | (id,ps,es) <- vrs])
                                | (p,fns) <- concatMap addArgs (Map.toList cns)]
      ++
      if null vrs 
        then []
        else [(PWild,[(id,ps++[PWild],es) | (id,ps,es) <- vrs])]
      where
        addArgs (cn,fns) = addArg (length args) cn [] fns
          where
            Just (ty@(DTyp args _ es),_,_,_) = Map.lookup cn (funs (abstract pgf))

        addArg 0 cn ps fns = [(PApp cn (reverse ps),fns)]
        addArg n cn ps fns = concat [addArg (n-1) cn (arg:ps) fns' | (arg,fns') <- computeConstr fns]

    split []                   (cns,vrs) = (cns,vrs)
    split ((id, ps, e:es):fns) (cns,vrs) = split fns (extract e [])
      where
        extract (EFun cn)     args = (Map.insertWith (++) cn [(id,ps,args++es)] cns, vrs)
        extract (EVar i)      args = (cns, (id,ps,es):vrs)
        extract (EApp e1 e2)  args = extract e1 (e2:args)
        extract (ETyped e ty) args = extract e args
        extract (EImplArg e)  args = extract e args

p_f c = mkCId ("p_"++showCId c)
p_i i = mkCId ("p_"++show i)
k_f f = mkCId ("k_"++showCId f)
k_i i = mkCId ("k_"++show i)

Index

(HTML for this module was generated on 2015-03-03. About the conversion tool.)