{- $Id: XSG.hs,v 1.12.2.18.2.33 2009/02/20 18:00:02 orlov Exp $ -}
{- vim: set syntax=haskell expandtab tabstop=4: -}

module XSG (
 Index, Var(..), Vars, Exp(..), Exps, Clash(..), Cond, Term(..),
 Terms, Branch, Func(..), Prog, Bind(..), Subst, State, IOExps,

 MyShow(myShow), ConVE(e2v, v2e), Apply((/.)), Composition(o),
 SetOp((+.), (-.), (*.)),
 GetVars(getVars, getWVars, getLVars, getXVars, getYVars),
 GetTopVars(getTopVars, getTopWVars, getTopLVars, getTopXVars,
 getTopYVars),

 intersperseShow, shiftStr, initIdx, newIdxs, fresh, mkSubst,
 unmixSubst, mkCond, mgu, findFunc, funcName, funcArity,
 funcCoarity
) where

import Prelude
import List (elemIndex, intersect, union, (\\))

------------------------------ DATA ------------------------------
type Index = [Int]

type WName = Index

type LName = Index

type XName = Index

type YName = Index

data Var = W WName
         | L LName
         | X XName
         | Y YName
         deriving Eq

type Vars = [Var]

type CName = (String, Index)

data Exp = VAR Var
         | C CName Exps

type Exps = [Exp]

data Clash = Exp :=: Exp

type Cond = [Clash]

type FName = String

data Term = Vars := (FName, Exps)

type Terms = [Term]

type Branch = (Terms, Cond, Exps)

type Body = [Branch]

data Func = FUNC FName Vars Body

type Prog = [Func]

data Bind = Var :-> Exp

type Subst = [Bind]

type State = Branch

type IOExps = (Exps, Exps)

------------------------------ SHOW ------------------------------
intersperseStr :: String -> [String] -> String
intersperseStr s []     = ""
intersperseStr s (x:xs) = x++(concatMap (s++) xs)

intersperseShow :: MyShow a => String -> [a] -> String
intersperseShow s = intersperseStr s . map myShow

shiftStr :: String -> String -> String
shiftStr s ""  = ""
shiftStr s str = s++(shiftStr' s str)
 where shiftStr' s "\n"       = "\n"
       shiftStr' s ('\n':str) = '\n':(s++(shiftStr' s str))
       shiftStr' s (c   :str) = c:(shiftStr' s str)


class MyShow a where
 myShow :: a -> String

instance MyShow Int where
 myShow = show

instance MyShow String where
 myShow = show

instance MyShow Index where
 myShow idx = intersperseShow "-" (reverse idx)

instance MyShow Var where
 myShow (W vN) = "W."++(myShow vN)
 myShow (L vN) = "L."++(myShow vN)
 myShow (X vN) = "X."++(myShow vN)
 myShow (Y vN) = "Y."++(myShow vN)

instance MyShow Vars where
 myShow vs = "["++(intersperseShow ", " vs)++"]"

instance MyShow CName where
 myShow (s, idx) = s

instance MyShow Exp where
 myShow e = myShow' [] e
  where myShow' cNs (VAR v  ) = myShow v
        myShow' cNs (C cN []) = myShow cN
        myShow' cNs (C cN es) = (myShow cN)++
                  case cN `elemIndex` cNs of
                   Just i  -> "!"++(myShow i)
                   Nothing -> " ["++(intersperseStr ", " str)++"]"
                    where str = map (myShow' (cN:cNs)) es

instance MyShow Exps where
 myShow es = "["++(intersperseShow ", " es)++"]"

instance MyShow Clash where
 myShow (e1:=:e2) = (myShow e1)++" :=: "++(myShow e2)

instance MyShow Cond where
 myShow eqs = intersperseShow "  &&  " eqs

instance MyShow Term where
 myShow (vs:=(fN, es)) = (myShow vs)++" := "++fN++" "++(myShow es)

instance MyShow Terms where
 myShow ts = concatMap (\x -> (myShow x)++"\n") ts

instance MyShow Branch where
 myShow (ts, cd, es) = (myShow es)++"\n"++(myShow' cd)++
                     (shiftStr "  " $ myShow ts)
  where myShow' [] = ""
        myShow' cd = "  "++(myShow cd)++"\n"

instance MyShow Body where
 myShow body = concatMap myShow body

instance MyShow Func where
 myShow (FUNC fN vs body) = "FUNC "++(myShow fN)++" "++(myShow vs)++
                          "\n"++(shiftStr "  " $ myShow body)++"\n"

instance MyShow Prog where
 myShow prog = concatMap myShow prog

instance MyShow Bind where
 myShow (v:->e) = (myShow v)++" :-> "++(myShow e)

instance MyShow Subst where
 myShow s = "["++(intersperseShow ", " s)++"]"

----------------------------- INDEX ------------------------------
initIdx :: Index
initIdx = []

newIdxs :: Index -> [Index]
newIdxs idx = map (:idx) [0..]

fresh :: Int -> (Index -> Var) -> Index -> Vars
fresh n f = take n . map f . newIdxs

----------------------------- Renum ------------------------------
class Renum a where
 renum :: a -> Index -> a

instance Renum Var where
 renum (W vN) idx = W (vN++idx)
 renum (L vN) idx = L (vN++idx)
 renum (X vN) idx = X (vN++idx)
 renum (Y vN) idx = Y (vN++idx)

instance Renum Exp where
 renum (VAR v         ) idx = VAR (renum v idx)
 renum (C (cN, cN') es) idx = C (cN, cN'++idx) (renum es idx)

instance Renum Clash where
 renum (e1:=:e2) idx = (renum e1 idx):=:(renum e2 idx)

instance Renum Term where
 renum (vs:=(fn, es)) idx = (renum vs idx):=(fn, renum es idx)

instance Renum Func where
 renum (FUNC fN vs body) idx = FUNC fN (renum vs idx) body'
  where body' = renum body idx
        
instance Renum a => Renum [a] where
 renum xs idx= map (flip renum idx) xs

instance (Renum a, Renum b, Renum c) => Renum (a, b, c) where
 renum (x, y, z) idx = (renum x idx, renum y idx, renum z idx)

----------------------------- ConVE ------------------------------
class ConVE a b where
 e2v :: a -> b
 v2e :: b -> a

instance ConVE Exp Var where
 e2v (VAR v) = v
 v2e v       = VAR v

instance ConVE Clash Bind where
 e2v (v:=:e) = (e2v v):->e
 v2e (v:->e) = (v2e v):=:e

instance ConVE a b => ConVE [a] [b] where
 e2v xs = map e2v xs
 v2e xs = map v2e xs

----------------------------- APPLY ------------------------------
infixl 7 /.
class Apply a b where
 (/.) :: a -> b -> a

instance Apply Var Subst where
 v /. s = head ([e2v e | (w:->e) <- s, w==v]++[v])

instance Apply Exp Subst where
 (VAR v  ) /. s = head ([e | (w:->e) <- s, w==v]++[VAR v])
 (C cN es) /. s = C cN (es/.s)

instance Apply Clash Subst where
 (e1:=:e2) /. s = (e1/.s):=:(e2/.s)

instance Apply Term Subst where
 (vs:=(fn, es)) /.s = (vs/.s):=(fn, (es/.s))

instance Apply Func Subst  where
 (FUNC fN vs body) /.s = FUNC fN (vs/.s) (body/.s)

instance Apply Bind Subst where
 (v:->e) /. s = v:->(e/.s)

instance Apply a b => Apply [a] b where
 x /. y = map (/.y) x

instance (Apply a c, Apply b c) => Apply (a, b) c where
 (x, y) /. z = (x/.z, y/.z)

instance  (Apply a d,Apply b d,Apply c d) => Apply (a,b,c) d where
 (x, y, z) /. s = (x/.s, y/.s, z/.s)

-------------------------- COMPOSITION ---------------------------
infixl 9 `o`

class Composition a where
 o :: a -> a -> a

instance Composition Subst where
 s1 `o` s2 = (s1/.s2)++s2

----------------------------- SET OP -----------------------------
infixl 6 +.
infixl 7 *.

class SetOp a where
 (+.)  :: a -> a -> a
 (-.)  :: a -> a -> a
 (*.)  :: a -> a -> a

instance Eq a => SetOp [a] where
 xs1 +. xs2 = union xs1 xs2
 xs1 -. xs2 = xs1\\xs2
 xs1 *. xs2 = intersect xs1 xs2

---------------------------- GET VARS ----------------------------
class GetVars a where
 getVars :: a -> Vars
 getWVars :: a -> Vars
 getLVars :: a -> Vars
 getXVars :: a -> Vars
 getYVars :: a -> Vars
 getWVars x = [v | v@(W _) <- getVars x]
 getLVars x = [v | v@(L _) <- getVars x]
 getXVars x = [v | v@(X _) <- getVars x]
 getYVars x = [v | v@(Y _) <- getVars x]

instance GetVars Var where
 getVars v = [v]

instance GetVars Exp where
 getVars e = getVars' [] e
  where getVars' cNs (VAR v  ) = [v]
        getVars' cNs (C cN _ ) | cN `elem` cNs = []
        getVars' cNs (C cN es) = foldr (+.) [] vss
         where vss = map (getVars' (cN:cNs)) es

instance GetVars Clash where
 getVars (e1:=:e2) = getVars (e1, e2)

instance GetVars Bind where
 getVars (v:->e) = getVars (v, e)

instance GetVars Term where
 getVars (vs:=(_, es)) = getVars (vs, es)

instance GetVars Func where
 getVars (FUNC fN vs body) = getVars (vs, body)

instance GetVars a => GetVars [a] where
 getVars xs = foldr (+.) [] (map getVars xs)

instance (GetVars a, GetVars b) => GetVars (a, b) where
 getVars (x, y) = (getVars x)+.(getVars y)

instance (GetVars a,GetVars b,GetVars c) => GetVars (a,b,c) where
 getVars (x, y, z) = (getVars x)+.(getVars y)+.(getVars z)

class GetTopVars a where
 getTopVars :: a -> Vars
 getTopWVars :: a -> Vars
 getTopLVars :: a -> Vars
 getTopXVars :: a -> Vars
 getTopYVars :: a -> Vars
 getTopWVars x = [v | v@(W _) <- getTopVars x]
 getTopLVars x = [v | v@(L _) <- getTopVars x]
 getTopXVars x = [v | v@(X _) <- getTopVars x]
 getTopYVars x = [v | v@(Y _) <- getTopVars x]

instance GetTopVars Var where
 getTopVars v = [v]

instance GetTopVars Exp where
 getTopVars (VAR v) = [v]
 getTopVars _       = []

instance GetTopVars Clash where
 getTopVars (e1:=:e2) = getTopVars (e1, e2)

instance GetTopVars a => GetTopVars [a] where
 getTopVars xs = foldr (+.) [] (map getTopVars xs)

instance (GetTopVars a, GetTopVars b) => GetTopVars (a, b) where
 getTopVars (x, y) = (getTopVars x)+.(getTopVars y)

----------------------------- SUBST ------------------------------
mkBind :: Var -> Exp -> Subst
mkBind v e = mkSubst [v] [e]

mkSubst :: Vars -> Exps -> Subst
mkSubst vs es = zipWith (:->) vs (map sbst es)
 where sbst (VAR v  ) = head ([sbst e | (w, e)<-s, w==v]++[VAR v])
       sbst (C cN es) = C cN (map sbst es)
       s = zip vs es

unmixSubst :: Subst -> (Subst, Subst, Subst)
unmixSubst = foldr f ([], [], [])
 where f b@(v:->_) (sW, sL, sX) = case v of
                                   W _ -> (b:sW,   sL,   sX)
                                   L _ -> (  sW, b:sL,   sX)
                                   X _ -> (  sW,   sL, b:sX)

mkCond :: Exps -> (Cond, Exps)
mkCond es = (sg (init ess), last ess)
 where find cNs (VAR _  ) = []
       find cNs (C cN es) | cN `elem` cNs = [(cN, es)]
       find cNs (C cN es) = concatMap (find (cN:cNs)) es
       find' = concatMap $ concatMap $ find []
       replace cN v e@(VAR _ ) = e
       replace cN v (C cN' _ ) | cN==cN' = VAR v
       replace cN v (C cN' es) = C cN' (map (replace cN v) es)
       replace' cN v = map $ map $ replace cN v
       sf ess =
          case find' ess of
           []           -> ess
           ((cN, es):_) -> let v = X (snd cN)
                               (es':ess') = replace' cN v (es:ess)
                           in sf ([VAR v]:[C cN es']:ess')
       ess = sf [es]
       sg [] = []
       sg ([x1]:[x2]:xs) = (x1 :=: x2):(sg xs)

------------------------------ MGU -------------------------------
mgu :: Cond -> Maybe Subst
mgu = mgu' []
 where
  mgu' _  []       = Just []
  mgu' ps (eq:eqs) =
   case eq of
    VAR v1    :=: VAR v2  | v1==v2 -> mgu' ps eqs

    VAR (v@(Y _)) :=: e | v<=e -> Nothing -- cannot bind Y Var 
    e :=: VAR (v@(Y _)) | v<=e -> Nothing -- cannot bind Y Var 

    VAR v1    :=: e2      | v1<=e2 -> mgu'' (mkBind v1 e2)
    e1        :=: VAR v2           -> mgu'' (mkBind v2 e1)
    C cN1 _   :=: C cN2 _ | (fst cN1)/=(fst cN2) -> Nothing
    C cN1 es1 :=: C cN2 es2 -> mgu' ((cN1, cN2):ps) $
                                if ((snd cN1)==(snd cN2) ||
                                    (cN1, cN2) `elem` ps ||
                                    (cN2, cN1) `elem` ps)
                                then eqs
                                else eqs++(zipWith (:=:) es1 es2)
   where mgu'' s = fmap (s `o`) (mgu' ps (eqs/.s))
         (Y _) <= (VAR (Y _)) = True
         (Y _) <= (VAR (X _)) = True
         (Y _) <= (VAR _    ) = False
         (X _) <= (VAR (W _)) = False
         (X _) <= (VAR (L _)) = False
         (L _) <= (VAR (W _)) = False
         _     <= _           = True

------------------------------ FUNC ------------------------------
findFunc :: FName -> Prog -> Index -> Func
findFunc fN prog = renum $ head [f|f@(FUNC fN' _ _)<-prog,fN'==fN]

funcName :: Func -> FName
funcName (FUNC fN _ _) = fN

funcArity :: Func -> Int
funcArity (FUNC _ vs _) = length vs

funcCoarity :: Func -> Int
funcCoarity (FUNC _ _ ((_, _, es):_)) = length es