# Probabilistic.hs

Plain text version of Probabilistic.hs

```module PGF.Probabilistic
( Probabilities(..)
, mkProbabilities                 -- :: PGF -&gt; M.Map CId Double -&gt; Probabilities
, defaultProbabilities            -- :: PGF -&gt; Probabilities
, getProbabilities
, setProbabilities
, showProbabilities               -- :: Probabilities -&gt; String
, readProbabilitiesFromFile       -- :: FilePath -&gt; PGF -&gt; IO Probabilities

, probTree
, rankTreesByProbs
, mkProbDefs
) where

import PGF.CId
import PGF.Data
import PGF.Macros

import qualified Data.Map as Map
import Data.List (sortBy,partition,nub,mapAccumL)
import Data.Maybe (fromMaybe) --, fromJust

-- | An abstract data structure which represents
-- the probabilities for the different functions in a grammar.
data Probabilities = Probs {
funProbs :: Map.Map CId Double,
catProbs :: Map.Map CId (Double, [(Double, CId)])
}

-- | Renders the probability structure as string
showProbabilities :: Probabilities -> String
showProbabilities = unlines . concatMap prProb . Map.toList . catProbs where
prProb (c,(p,fns)) = pr (p,c) : map pr fns
pr (p,f) = showCId f ++ "\t" ++ show p

-- | Reads the probabilities from a file.
-- This should be a text file where on every line
-- there is a function name followed by a real number.
-- The number represents the probability mass allocated for that function.
-- The function name and the probability should be separated by a whitespace.
readProbabilitiesFromFile :: FilePath -> PGF -> IO Probabilities
let ps0 = Map.fromList [(mkCId f,read p) | f:p:_ <- map words (lines s)]
return \$ mkProbabilities pgf ps0

-- | Builds probability tables. The second argument is a map
-- which contains the know probabilities. If some function is
-- not in the map then it gets assigned some probability based
-- on the even distribution of the unallocated probability mass
-- for the result category.
mkProbabilities :: PGF -> Map.Map CId Double -> Probabilities
mkProbabilities pgf probs =
let funs1 = Map.fromList [(f,p) | (_,(_,fns)) <- Map.toList cats1, (p,f) <- fns]
cats1 = Map.mapWithKey (\c (_,fns,_) ->
let p'   = fromMaybe 0 (Map.lookup c probs)
fns' = sortBy cmpProb (fill fns)
in (p', fns'))
(cats (abstract pgf))
in Probs funs1 cats1
where
cmpProb (p1,_) (p2,_) = compare p2 p1

fill fs = pad [(Map.lookup f probs,f) | (_,f) <- fs]
where
pad :: [(Maybe Double,a)] -> [(Double,a)]
pad pfs = [(fromMaybe deflt mb_p,f) | (mb_p,f) <- pfs]
where
deflt = case length [f | (Nothing,f) <- pfs] of
0 -> 0
n -> max 0 ((1 - sum [d | (Just d,f) <- pfs]) / fromIntegral n)

-- | Returns the default even distibution.
defaultProbabilities :: PGF -> Probabilities
defaultProbabilities pgf = mkProbabilities pgf Map.empty

getProbabilities :: PGF -> Probabilities
getProbabilities pgf = Probs {
funProbs = Map.map (\(_,_,_,p) -> p      ) (funs (abstract pgf)),
catProbs = Map.map (\(_,fns,p) -> (p,fns)) (cats (abstract pgf))
}

setProbabilities :: Probabilities -> PGF -> PGF
setProbabilities probs pgf = pgf {
abstract = (abstract pgf) {
funs = mapUnionWith (\(ty,a,df,_) p       -> (ty,a,df,  p)) (funs (abstract pgf)) (funProbs probs),
cats = mapUnionWith (\(hypos,_,_) (p,fns) -> (hypos,fns,p)) (cats (abstract pgf)) (catProbs probs)
}}
where
mapUnionWith f map1 map2 =
Map.mapWithKey (\k v -> maybe v (f v) (Map.lookup k map2)) map1

-- | compute the probability of a given tree
probTree :: PGF -> Expr -> Double
probTree pgf t = case t of
EApp f e -> probTree pgf f * probTree pgf e
EFun f   -> case Map.lookup f (funs (abstract pgf)) of
Just (_,_,_,p) -> p
Nothing        -> 1
_ -> 1

-- | rank from highest to lowest probability
rankTreesByProbs :: PGF -> [Expr] -> [(Expr,Double)]
rankTreesByProbs pgf ts = sortBy (\ (_,p) (_,q) -> compare q p)
[(t, probTree pgf t) | t <- ts]

mkProbDefs :: PGF -> ([[CId]],[(CId,Type,[Equation])])
mkProbDefs pgf =
let cs = [(c,hyps,fns) | (c,(hyps0,fs,_)) <- Map.toList (cats (abstract pgf)),
not (elem c [cidString,cidInt,cidFloat]),
let hyps = zipWith (\(bt,_,ty) n -> (bt,mkCId ('v':show n),ty))
hyps0
[1..]
fns  = [(f,ty) | (_,f) <- fs,
let Just (ty,_,_,_) = Map.lookup f (funs (abstract pgf))]
]
((_,css),eqss) = mapAccumL (\(ngen,css) (c,hyps,fns) ->
let st0      = (1,Map.empty)
((_,eqs_map),cs) = computeConstrs pgf st0 [(fn,[],es) | (fn,(DTyp _ _ es)) <- fns]
(ngen', eqs) = mapAccumL (mkEquation eqs_map hyps) ngen fns
ceqs     = [(id,DTyp [] cidFloat [],reverse eqs) | (id,eqs) <- Map.toList eqs_map, not (null eqs)]
in ((ngen',cs:css),(p_f c, mkType c hyps, eqs):ceqs)) (1,[]) cs
in (reverse (concat css),concat eqss)
where
mkEImplArg bt e
| bt == Explicit = e
| otherwise      = EImplArg e

mkPImplArg bt p
| bt == Explicit = p
| otherwise      = PImplArg p

mkType c hyps =
DTyp (hyps++[mkHypo (DTyp [] c es)]) cidFloat []
where
is = reverse [0..length hyps-1]
es = [mkEImplArg bt (EVar i) | (i,(bt,_,_)) <- zip is hyps]

sig = (funs (abstract pgf), \_ -> Nothing)

mkEquation ceqs hyps ngen (fn,ty@(DTyp args _ es)) =
let fs1         = case Map.lookup (p_f fn) ceqs of
Nothing              -> [mkApp (k_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
Just eqs | null eqs  -> []
| otherwise -> [mkApp (p_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
(ngen',fs2) = mapAccumL mkFactor2 ngen vs2
fs3         = map mkFactor3 vs3
eq = Equ (map mkTildeP xes++[PApp fn (zipWith mkArgP [1..] args)])
(mkMult (fs1++fs2++fs3))
in (ngen',eq)
where
xes = map (normalForm sig k env) es

mkTildeP e =
case e of
EImplArg e -> PImplArg (PTilde e)
e          ->           PTilde e

mkArgP n (bt,_,_) = mkPImplArg bt (PVar (mkCId ('v':show n)))

mkMult []  = ELit (LFlt 1)
mkMult [e] = e
mkMult es  = mkApp (mkCId "mult") es

mkFactor2 ngen (src,dst) =
let vs = [EVar (k-i-1) | (i,ty) <- src]
in (ngen+1,mkApp (p_i ngen) vs)

mkFactor3 (i,DTyp _ c es) =
let v = EVar (k-i-1)
in mkApp (p_f c) (map (normalForm sig k env) es++[v])

(k,env,vs1,vs2,vs3) = mkDeps ty

mkDeps (DTyp args _ es) =
let (k,env,dep1) = updateArgs 0 [] [] args
dep2         = foldl (update k env) dep1 es
(vs2,vs3)    = closure k dep2 [] []
vs1          = concat [src | (src,dst) <- dep2, elem k dst]
in (k,map (\k -> VGen k []) env,vs1,reverse vs2,vs3)
where
updateArgs k env dep []                              = (k,env,dep)
updateArgs k env dep ((_,x,ty@(DTyp _ _ es)) : args) =
let dep1 = foldl (update k env) dep es ++ [([(k,ty)],[])]
env1 | x == wildCId =     env
| otherwise    = k : env
in updateArgs (k+1) env1 dep1 args

update k env dep e =
case e of
EApp e1 e2 -> update k env (update k env dep e1) e2
EFun _     -> dep
EVar i     -> let (dep1,(src,dst):dep2) = splitAt (env !! i) dep
in dep1++(src,k:dst):dep2

closure k []               vs2 vs3 = (vs2,vs3)
closure k ((src,dst):deps) vs2 vs3
| null dst   = closure k deps vs2 (vs3++src)
| otherwise  =
let (deps1,deps2) = partition (\(src',dst') -> not (null [v1 | v1 <- dst, v2 <- dst', v1 == v2])) deps
deps3 = (src,dst):deps1
src2  = concatMap fst deps3
dst2  = [v | v <- concatMap snd deps3
, lookup v src2 == Nothing]
dep2  = (src2,dst2)
dst'  = nub dst
in if null deps1
then if dst' == [k]
then closure k deps2 vs2 vs3
else closure k deps2 ((src,dst') : vs2) vs3
else closure k (dep2 : deps2) vs2 vs3
{-
mkNewSig src =
DTyp (mkArgs 0 0 [] src) cidFloat []
where
mkArgs k l env []                      = []
mkArgs k l env ((i,DTyp _ c es) : src)
| i == k    = let ty = DTyp [] c (map (normalForm sig k env) es)
in (Explicit,wildCId,ty) : mkArgs (k+1) (l+1) (VGen l [] : env) src
| otherwise = mkArgs (k+1) l (VMeta 0 env [] : env) src
-}
type CState = (Int,Map.Map CId [Equation])

computeConstrs :: PGF -> CState -> [(CId,[Patt],[Expr])] -> (CState,[[CId]])
computeConstrs pgf (ngen,eqs_map) fns@((id,pts,[]):rest)
| null rest =
let eqs_map' =
Map.insertWith (++) (p_f id)
(if null pts
then []
else [Equ pts (ELit (LFlt 1.0))])
eqs_map
in ((ngen,eqs_map'),[])
| otherwise =
let (st,ks) = mapAccumL mk_k (ngen,eqs_map) fns

mk_k (ngen,eqs_map) (id,pts,[])
| null pts  = ((ngen,eqs_map),k_f id)
| otherwise = let eqs_map' =
Map.insertWith (++)
(p_f id)
[Equ pts (EFun (k_i ngen))]
eqs_map
in ((ngen+1,eqs_map'),k_i ngen)

in (st,[ks])
computeConstrs pgf st fns =
let (st',res) = mapAccumL (\st (p,fns) -> computeConstrs pgf st fns)
st
(computeConstr fns)
in (st',concat res)
where
computeConstr fns = merge (split fns (Map.empty,[]))

merge (cns,vrs) =
[(p,fns++[(id,ps++[p],es) | (id,ps,es) <- vrs])
| (p,fns) <- concatMap addArgs (Map.toList cns)]
++
if null vrs
then []
else [(PWild,[(id,ps++[PWild],es) | (id,ps,es) <- vrs])]
where
where
Just (DTyp args _ _es,_,_,_) = Map.lookup cn (funs (abstract pgf))

addArg 0 cn ps fns = [(PApp cn (reverse ps),fns)]
addArg n cn ps fns = concat [addArg (n-1) cn (arg:ps) fns' | (arg,fns') <- computeConstr fns]

split []                   (cns,vrs) = (cns,vrs)
split ((id, ps, e:es):fns) (cns,vrs) = split fns (extract e [])
where
extract (EFun cn)     args = (Map.insertWith (++) cn [(id,ps,args++es)] cns, vrs)
extract (EVar i)      args = (cns, (id,ps,es):vrs)
extract (EApp e1 e2)  args = extract e1 (e2:args)
extract (ETyped e ty) args = extract e args
extract (EImplArg e)  args = extract e args

p_f c = mkCId ("p_"++showCId c)
p_i i = mkCId ("p_"++show i)
k_f f = mkCId ("k_"++showCId f)
k_i i = mkCId ("k_"++show i)
```