{- $Id: Parser.hs,v 1.1.2.33 2009/02/20 18:00:03 orlov Exp $ -}
{- vim: set syntax=haskell expandtab tabstop=4: -}

module Parser {-(
 eval, compile, uncompile
)-} where

import Prelude
import System.IO.Unsafe (unsafePerformIO)
import List (elemIndices, mapAccumL)
import Monad (ap, foldM, liftM, when, zipWithM_)
import UU_Offside
import UU_Parsing hiding (Exp, Symbol, Parser, parse)
import UU_Scanner
import StateParser
import qualified XSG
import qualified URA

---------------------------------------------------------------------------------------------------
eval :: FileName -> String -> [Branch]
eval file str = map uncompile (eval' file str)

compile :: FileName -> XSG.Prog
compile file = snd $ compileFile file

uncompile :: XSG.Exps -> Branch
uncompile es = zipWith ($) (cycle [ActAssign, ActCompare]) $ snd $ mapAccumL (mapAccumL printExp) [] ess where 
    (cd, es') = XSG.mkCond es
    ess = (concatMap (\(e1 XSG.:=: e2) -> [[e1], [e2]]) cd)++[es']
    printExp vs (XSG.VAR v) =
        case elemIndices v (vs++[v]) of
            [n]    -> (vs++[v], ExpNew vt (XSG.myShow vt++XSG.myShow n, noPos)) where vt = uncompileVar v
            [n, _] -> (vs     , ExpId (XSG.myShow vt++XSG.myShow n, noPos)) where vt = uncompileVar v  
    printExp vs (XSG.C (cN, _) []) = (vs , ExpId (cN, noPos))
    printExp vs (XSG.C (cN, _) es) = (vs', ExpConcat $ (ExpId (cN, noPos)):es') where 
        (vs', es') = mapAccumL printExp vs es

uncompileVar (XSG.X _) = X
uncompileVar (XSG.Y _) = Y
uncompileVar (XSG.W _) = W    
uncompileVar (XSG.L _) = W    

--uncompileBranch :: (XSG.Exps, XSG.Branch) -> Branch
uncompileBranch :: (XSG.Vars, XSG.Branch) -> Branch
uncompileBranch (args,(ts, cd, es)) = (ActCompare qs0) : zipWith ($) (cycle [ActAssign, ActCompare]) (qs4++qs3)
 where ess1 = map (\(_ XSG.:= (_, es)) -> es) ts
       ls1  = map length ess1
       es2  = concatMap (\(e1 XSG.:=: e2) -> [e1, e2]) cd
       l2   = length es2
       esA  = (concat ess1)++es2++es
       (cdF, esF) = XSG.mkCond esA
       (esF1, ess1F) = (mapAccumL $ flip (\t -> (\(x,y)->(y,x)) . splitAt t)) esF ls1
       (es2F, esF2) = splitAt l2 esF1
       cdF2 = cdF++(map (\[x,y] -> (x XSG.:=: y)) $ snd $ (mapAccumL $ flip (\t -> (\(x,y)->(y,x)) . splitAt t)) es2F (take (l2 `div` 2) $ repeat 2))
       essG = (concatMap (\(e1 XSG.:=: e2) -> [[e1], [e2]]) cdF2)++[esF2]

       vs0 = []
       (vs1, qs0) = mapAccumL printExp vs0 $ map XSG.v2e args
       (vs2, qs1) = mapAccumL (\vs (_ XSG.:= (fn, _), es) -> printFun vs (fn,es)) vs1 (zip ts ess1F)
       (vs3, qs2) = mapAccumL (mapAccumL printVar) vs2 (map (\(vs XSG.:= _) -> vs) ts)
       (vs4, qs3) = mapAccumL (mapAccumL printExp) vs3 essG
       qs4 = concat $ zipWith (\x y -> [x, [y]]) qs2 qs1

       printExp vs (XSG.VAR v) = printVar vs v
       printExp vs (XSG.C (cN, _) []) = (vs , ExpId (cN, noPos))
       printExp vs (XSG.C (cN, _) es) = (vs', ExpConcat $ (ExpId (cN, noPos)):es')
        where (vs', es') = mapAccumL printExp vs es
       printVar vs v = case elemIndices v (vs++[v]) of
                         [n]    -> (vs++[v], ExpNew (uncompileVar v) ('x':XSG.myShow n, noPos))
                         [n, _] -> (vs     , ExpId ('x':XSG.myShow n, noPos) )
       printFun vs (fN, es) = (vs', ExpConcat $ (ExpId (fN, noPos)):es')
        where (vs', es') = mapAccumL printExp vs es
       

---------------------------------------------------------------------------------------------------
eval' :: FileName -> String -> [XSG.Exps]
eval' file str = URA.int (main:prog) [] XSG.initIdx
 where (tbl, prog) = compileFile file
       br = parse' pBranchA str
       main = XSG.FUNC "$EVAL$" [] [compileBranch tbl XSG.initIdx [] br]

compileFile :: FileName -> (SymbolTable, XSG.Prog)
compileFile = compileModule . parse . unsafePerformIO . readFile

parse :: Parsible a => String -> a
parse = parse' parser

parse' p input = unsafePerformIO io where
    io = parseIO (unOP p) (Input (scanXSG (convert input), NoContext))
    unOP (OP p) = p

scanXSG input = scan keywordstxt keywordsops specchars opchars "online.xsg" initPos $ convert input where^I
    keywordstxt = []
    keywordsops = []
    specchars   = "{;}()!?@=,:"
    opchars     = ""

---------------------------------------- Parsing datatypes ----------------------------------------
type FileName = String

type Id = String

type CtorType = Int

type FuncType = (Int, Int)

data Decl = CtorDecl [(Id, Pos)] CtorType
          | FuncDecl [(Id, Pos)] FuncType 

data VarType = X | Y | W

data Exp = ExpId (Id, Pos)
         | ExpNew VarType (Id, Pos) -- @x !y ?w
         | ExpConcat Exps

type Exps = [Exp]

data Act = ActAssign Exps
         | ActCompare Exps

type Branch = [Act]

data Func = Func (Id, Pos) [Branch]

type Prog = [Either Decl Func]

sep c x y = x++c++y

instance XSG.MyShow VarType where
    myShow X = "X"
    myShow Y = "Y"
    myShow W = "W"

instance XSG.MyShow Pos where
    myShow (Pos x y) = show (x, y)

instance XSG.MyShow Decl where
    myShow (CtorDecl ids ar      ) = 
        XSG.intersperseShow "," (map fst ids)++[symType]++" "++(XSG.myShow ar)++"\n"
    myShow (FuncDecl ids (ar, co)) = 
        XSG.intersperseShow "," (map fst ids)++[symType]++" "++(XSG.myShow ar)++[symAssign]++" "++(XSG.myShow co)++"\n"

instance XSG.MyShow Exp where
    myShow (ExpId (id, _) ) = id
    myShow (ExpConcat es  ) = "("++(XSG.myShow es)++")"
    myShow (ExpNew tp (id, _)) = [symNew tp]++id where
        symNew X = symNewX
        symNew Y = symNewY
        symNew W = symNewW

instance XSG.MyShow Exps where
    myShow [] = "()"
    myShow es = foldr1 (sep " ") (map XSG.myShow es)

instance XSG.MyShow Act where
    myShow (ActAssign es ) = [symAssign]++" "++(XSG.myShow es)
    myShow (ActCompare es) = [symCompare]++(XSG.myShow es)

instance XSG.MyShow Branch where
    myShow br = concatMap XSG.myShow br

instance XSG.MyShow [Branch] where
    myShow brs = concatMap (\x -> (XSG.myShow x)++"\n") brs

instance XSG.MyShow Func where
    myShow (Func (id, _) [] ) = id++" "++[symLast]++"\n"
    myShow (Func (id, _) brs) = id++"\n  "++(XSG.intersperseShow "\n  " brs)++[symLast]++"\n"

instance XSG.MyShow Prog where
    myShow prog = concatMap (either XSG.myShow XSG.myShow) prog

getExpsPos = getExpPos . head where 
    getExpPos (ExpId  (_, pos)) = pos
    getExpPos (ExpNew _ (_, pos)) = pos
    getExpPos (ExpConcat (e:_)) = getExpPos e

getBranchPos = getActPos . last where 
    getActPos (ActAssign e ) = getExpsPos e
    getActPos (ActCompare e) = getExpsPos e

--------------------------------------------- Syntax ----------------------------------------------
(symCompare, symAssign, symNewX, symNewY, symNewW, symNext, symBegin, symEnd, symType, symLast) = 
    ('=', ',', '@', '!', '?', ';', '{' , '}', ':', ' ')

--------------------------------------------- Parsers ---------------------------------------------
class Parsible a where
 parser :: OffsideParser [] Pair Token a

type Parser = OffsideParser [] Pair Token

pId :: Parser (Id,Pos)
pId = getIdPos <$> (pSym (Tok TkVarid "" "?id?" noPos "") <|> pSym (Tok TkConid "" "?id?" noPos ""))
 where getIdPos (Tok _ _ id pos _) = (id, pos)

pCompare, pAssign, pNext, pBegin, pEnd, pType :: Parser String
pCompare = pSpec symCompare
pAssign = pSpec symAssign
pNext = pOnside (pSpec symNext)
pBegin = pSpec symBegin
pEnd = pSpec symEnd
pType = pSpec symType

pNew :: Parser VarType 
pNew = X <$ pSpec symNewX <|> Y <$ pSpec symNewY <|> W <$ pSpec symNewW  

pExp :: Parser Exp
pExp = (ExpId <$> pId) <|>
       (ExpNew <$> pNew <*> pId) <|>
       (ExpConcat <$ pOParen <*> pExps <* pCParen)
instance Parsible Exp where parser = pExp

pExps :: Parser Exps
pExps = pList1_ng pExp
instance Parsible Exps where parser = pExps

pAct :: Parser Act
pAct = (ActAssign <$ pAssign <*> pExps) <|>
       (ActCompare <$ pCompare <*> pExps)
instance Parsible Act where parser = pAct

pBranchA, pBranchC :: Parser Branch
pBranch' :: (Exps->Act) -> Parser Branch
pBranch' dflt = (:) <$> (dflt <$> pExps <|> pAct) <*> pList_ng pAct
pBranchA = pBranch' ActAssign
pBranchC = pBranch' ActCompare
instance Parsible Branch where parser = pBranchC

pFunc :: Parser Func
pFunc = Func <$> pId <*> pBlock1 pBegin pNext pEnd pBranchC
instance Parsible Func where parser = pFunc

pInt :: Parser Int
pInt = read <$> pInteger10

pCtorDecl,pFuncDecl :: Parser Decl
pCtorDecl = CtorDecl <$> pListSep (pSpec ',') pId <* pType <*> pCtorType where
    pCtorType = (makeCtorType <$> pList_ng pId) <|> pInt
    makeCtorType = length
pFuncDecl = FuncDecl <$> pListSep (pSpec ',') pId <* pType <*> pFuncType where
    pFuncType = (makeFuncType <$> pList pId <* pAssign <*> pList_ng pId) <|>
                ((,) <$> pInt <* pAssign <*> pInt)
    makeFuncType as rs = (length as, length rs)
instance Parsible Decl where parser = pFuncDecl <|> pCtorDecl

instance Parsible Prog where parser = pBlock1 pBegin pNext pEnd (Left <$> parser <|> Right <$> parser)

-------------------------------------------- COMPILER ---------------------------------------------
data Symbol = Constructor Id CtorType 
            | Function Id FuncType 
            | Variable Id XSG.Var
            | Unknown

data SymbolTable = SymbolTable (Id -> Symbol) (Id -> Symbol -> SymbolTable)

emptySymbolTable = tbl
 where tbl = SymbolTable (const Unknown) (new tbl)
       new (SymbolTable find add) id sym = tbl'
        where tbl' = SymbolTable find' (new tbl')
              find' id' | id == id' = sym
                        | otherwise = find id'

addSymbol :: Symbol -> SymbolTable -> SymbolTable
addSymbol sym@(Constructor id _  ) (SymbolTable _ add) = add id sym
addSymbol sym@(Function id (_, _)) (SymbolTable _ add) = add id sym
addSymbol sym@(Variable id _     ) (SymbolTable _ add) = add id sym

getSymbol :: SymbolTable -> (Id, Pos) -> Symbol
getSymbol (SymbolTable find _) (id, pos) = check (find id)
 where check Unknown = error ("symbol '"++(XSG.myShow id)++"' at position "++(XSG.myShow pos)++" is undefined")
       check s       = s

data MONAD a = MONAD (SymbolTable -> XSG.Index -> (SymbolTable, a))

instance Prelude.Monad MONAD where
 return x        = MONAD (\tbl idx -> (tbl, x))
 (MONAD x) >>= f = MONAD (\tbl idx -> let idx1:idx2:_ = XSG.newIdxs idx
                                          (tbl', a) = x tbl idx1
                                          MONAD y = f a
                                      in y tbl' idx2)

compileSymbol :: Pos -> (Symbol, XSG.Index) -> (XSG.Terms, XSG.Exps) -> (XSG.Terms, XSG.Exps)
compileSymbol pos (Constructor id ar   , idx) (ts, es) = check `seq` (ts, (XSG.C (id, idx) es1):es2)
 where (es1, es2) = splitAt ar es
       check = when (ar /= length es1) $ error ("constructor arity mismatch at position "++(XSG.myShow pos)++" ("++(XSG.myShow $ ar)++"/="++(XSG.myShow $ length es1)++")") :: [()]
compileSymbol pos (Function id (ar, co), idx) (ts, es) = check `seq` ((vs XSG.:= (id, es1)):ts, (XSG.v2e vs)++es2)
 where (es1, es2) = splitAt ar es
       vs    = XSG.fresh co XSG.L idx
       check = when (ar /= length es1) $ error ("function arity mismatch at position "++(XSG.myShow pos)++" ("++(XSG.myShow $ ar)++"/="++(XSG.myShow $ length es1)++")") :: [()]
compileSymbol pos (Variable _ xsgVar   , idx) (ts, es) = (ts, (XSG.v2e xsgVar):es)

compileVar X = XSG.X
compileVar Y = XSG.Y
compileVar W = XSG.W

compileExp :: Exp -> MONAD ((XSG.Terms, XSG.Exps) -> (XSG.Terms, XSG.Exps))
compileExp (ExpId idp    ) = liftM (compileSymbol (snd idp)) (MONAD (\tbl idx -> (tbl, (getSymbol tbl idp, idx))))
compileExp (ExpNew typ idp) = do s <- liftM (Variable (fst idp) . compileVar typ) (MONAD (,))
                                 liftM (compileSymbol (snd idp)) (MONAD (\tbl idx -> (addSymbol s tbl, (s, idx)))) 
compileExp (ExpConcat esC) = liftM (\(ts, es) (ts', es') -> (ts++ts', es++es')) (compileExps esC)

compileExps :: Exps -> MONAD (XSG.Terms, XSG.Exps)
compileExps = foldr (ap . compileExp) (return ([], []))

compileAct :: XSG.State -> Act -> MONAD XSG.State
compileAct (ts, cd, es) (ActAssign esA ) = do (ts', es') <- compileExps esA
                                              return (ts++ts', cd, es')
compileAct (ts, cd, es) (ActCompare esC) = do (ts', es') <- compileExps esC
                                              when (length es /= length es') $ error ("clash coarity mismatch at position "++(XSG.myShow $ getExpsPos esC)++" ("++(XSG.myShow $ length es)++"/="++(XSG.myShow $ length es')++")")
                                              return (ts++ts', cd++(zipWith (XSG.:=:) es es'), es')

compileBranch :: SymbolTable -> XSG.Index -> XSG.Exps -> Branch -> XSG.State
compileBranch tbl idx es br = snd $ f tbl idx
 where MONAD f = foldM compileAct ([], [], es) br

compileFunc :: SymbolTable -> Func -> XSG.Func
compileFunc tbl (Func idp brs) = check `seq` XSG.FUNC id vs (st:sts)
 where Function id (ar, co) = getSymbol tbl idp
       vs  = XSG.fresh ar XSG.X idx1
       sts = map (compileBranch tbl idx2 (XSG.v2e vs)) brs
       st  = ([],[],XSG.v2e $ XSG.fresh co XSG.Y idx3)
       check = zipWithM_ (\(_, _, es) br -> when (length es /= co) $ error ("function coarity mismatch at position "++(XSG.myShow $ getBranchPos br)++" ("++(XSG.myShow co)++"/="++(XSG.myShow $ length es)++")")) sts brs :: [()]
       idx1:idx2:idx3:_ = XSG.newIdxs XSG.initIdx

compileModule :: Prog -> (SymbolTable, XSG.Prog)
compileModule prog = (tbl, map (compileFunc tbl) funcs)
 where (decls, funcs) = uneither prog
       tbl = foldr addSymbol emptySymbolTable $ concatMap compDecl decls
       compDecl (FuncDecl ids tp) = map (flip Function tp . fst) ids 
       compDecl (CtorDecl ids tp) = map (flip Constructor tp . fst) ids
       uneither = foldr (either (\x (xs, ys) -> (x:xs, ys)) (\y (xs, ys) -> (xs, y:ys))) ([], [])

-------------------------------------- UU_Parsing internals ---------------------------------------
convert = map f
 where f '\CR' = '\n'
       f x     = x

instance Offside Token where
 getPos t = let p = pos t
            in (column p, line p)

{--------------------------------------------- TRASH ----------------------------------------------

getArity (Function _ (x,_)) = x
getArity (Constructor _ x) = x
getArity (Variable _ _) = 0

getCoarity (Function _ (_,x)) = x
getCoarity (Constructor _ _) = 1
getCoarity (Variable _ _) = 1


getSymbol (SymbolTable find _) x = find x


solve :: SymbolTable -> Exp -> (SymbolTable, Int, Int) -- find arity/coarity/new symbols
--solve tbl (ExpNew (id,_)) = (tbl !+! Variable id (XSG.P "_"), 0, 1) --TODO?
solve tbl (ExpId (id,pos)) = (tbl, getArity s, getCoarity s) where
    s = mustGetSymbol tbl (id,pos)
solve tbl (ExpConcat es) = (tbl', 0, c0) where
    (tbl', _, c0) = checkNoArity (foldl process (tbl, 0, 0) es)
    checkNoArity x@(_, 0, _) = x
    checkNoArity (_, a, _) = error ("expression " ++ showWithPos' es ++ " is not ground (arity " ++ XSG.myShow a ++ ")")
    process (tbl, a1, c1) e2 = (tbl', a0, c0) where
        (tbl', a2, c2) = solve tbl e2
        (a0, c0) = (a1 -. c2 + a2 , c1 + c2 -. a1)
        x -. y = if x>y then x-y else 0

coarity tbl exp = c where
    (_, _, c) = solve tbl exp

class Cast a b where
    cast :: a -> b

---------------------------------------------------------------------------------------------------
-- dont use offside 2d syntax right now
pOnsideSemi :: (InputState state Token, OutputState out) => OffsideParser state out Token String
pOnsideSemi = pOnside pSemi

---------------------------------------------------------------------------------------------------
compile (Func (id, pos) brs) = XSG.FUNC id (genVars 0 arity) body where
    tbl = table
    arity = getArity (getSymbol tbl id)
    genVars k n = map xVar [k+1..k+n]
    body = map compileBranch brs
    compileBranch [ActAssign e] = [] XSG.:=> [XSG.RES res] where
        res = map XSG.VAR (genVars arity (coarity tbl e))

patch :: XSG.Exp -> XSG.Exp
patch = flip XSG.renum []

testFun = "test , nil; ,cons !x !y."

main = coarity table (parseXSG "cons (concat !x !y) x (cons y nil)" ::Exp)
main' = (parseXSG prog ::Prog)
prog = concat [
    "concat nil !ys , ys;",
    "       (cons !x !xs) !ys , cons x (concat xs ys).",
    "split !xs , nil xs;",
    "      (cons !x !xs) , (!ys !zs) = split xs , (cons x ys) zs.",
    "split' concat !xs !ys, xs ys."]

test str = map patch exps :: XSG.Exps where
    (_,_,stack) = generateCalls table 0 (parseXSG str :: Exp)
    exps = map cast (reverse stack)

--compile :: SymbolTable -> Func -> XSG.Func
xVar n = XSG.X [n] -- FIXME!! replace with XSG.P
--lVar n = XSG.P ("l" ++ XSG.myShow n) -- TODO?

data StackItem = StackItem Symbol [StackItem]
instance Cast StackItem XSG.Exp where
    cast (StackItem (Variable _ bind) [] ) = XSG.VAR bind
    cast (StackItem (Constructor id _) args) = XSG.C (id,[]) (map cast args)
type Stack = [StackItem]

type Index = Int

generateCalls :: SymbolTable -> Index -> Exp -> (SymbolTable, Index, Stack)
generateCalls tbl idx (ExpConcat es) = (tbl', idx', res) where
    res = checkGround res'
    (tbl', idx', res') = foldl process (tbl, idx, []) es
    checkGround (res1:_) | isNotGround res1 = error ("expression " ++ showWithPos' es ++ " is not ground")
    checkGround res = res

    process (tbl, idx, stack) (ExpId (id,pos)) = (tbl, idx, eval (symbol `push` stack)) where
        symbol = mustGetSymbol tbl (id,pos)
    process (tbl, idx, stack) (ExpNew (id,pos)) = (tbl !+! var, idx+1, eval (var `push` stack )) where
        var = Variable id (xVar idx)
    process (tbl, idx, stack) (ExpConcat es) = (tbl', idx', foldr eval' stack stack') where
        eval' x xs = eval (x:xs)
        (tbl', idx', stack') = generateCalls tbl idx (ExpConcat es)
    push symbol stack = StackItem symbol [] : stack
    eval stack@(top:_) | isNotGround top = stack
    eval (top1:top2:stack) | isNotGround top2 = eval (apply top2 top1 : stack) where
        apply (StackItem ctor@(Constructor _ _) args) arg = StackItem ctor (args++[arg])
    eval stack = stack -- dont need?
    isGround (StackItem (Variable _ _ ) _ ) = True
    isGround (StackItem (Constructor _ arity) args) = arity == length args
    isNotGround = not.isGround

--------------------------------------------------------------------------------------------------}