From 3d17813eb4acb73bf77d9175e315173446e2c68c Mon Sep 17 00:00:00 2001 From: Ethan Girouard Date: Wed, 11 Dec 2024 16:11:05 -0500 Subject: [PATCH] Implement early version of CodeGen --- src/Windows12/CodeGen.hs | 349 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 346 insertions(+), 3 deletions(-) diff --git a/src/Windows12/CodeGen.hs b/src/Windows12/CodeGen.hs index 498e10e..fd70c38 100644 --- a/src/Windows12/CodeGen.hs +++ b/src/Windows12/CodeGen.hs @@ -1,8 +1,351 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE MultiParamTypeClasses #-} + module Windows12.CodeGen where -import Data.Text (Text) +import Windows12.Ast (BinOp(..), UnOp(..), AssignOp(..), Type(..), + Bind(..), TLStruct(..), TLEnum(..)) import Windows12.TAst -import LLVM.AST (Module) +import LLVM.AST hiding (ArrayType, VoidType, Call, function) +import LLVM.AST.Type (i32, i1, i8, double, ptr, void) +import qualified LLVM.AST.Constant as C +import LLVM.IRBuilder hiding (double, IRBuilder, ModuleBuilder) +import LLVM.AST.Typed (typeOf) +import LLVM.Prelude (ShortByteString) + +import qualified LLVM.AST.IntegerPredicate as IP +import qualified LLVM.AST.FloatingPointPredicate as FP + +import Control.Monad.State hiding (void) + +import Data.Text (Text, unpack) +import Data.String.Conversions +import Data.String + +-- Global program context, used to keep track of operands +data Ctx = Ctx { operands :: [(Text, Operand)], + structs :: [TLStruct], + enums :: [TLEnum], + strings :: [(Text, Operand)] } + deriving (Eq, Show) + +type ModuleBuilder = ModuleBuilderT (State Ctx) +type IRBuilder = IRBuilderT ModuleBuilder + +-- Allow easy string conversion +instance ConvertibleStrings Text ShortByteString where + convertString = Data.String.fromString . Data.Text.unpack + +-- Put an operand into the context with a name +createOperand :: MonadState Ctx m => Text -> Operand -> m () +createOperand name op = do + ctx <- get + put $ ctx { operands = (name, op) : operands ctx } + +-- Take in a source file name, the AST, and return the LLVM IR module codegen :: Text -> TProgram -> Module -codegen filename (TProgram structs enums funcs) = undefined +codegen filename (TProgram structs enums funcs) = + flip evalState (Ctx [] [] [] []) + $ buildModuleT (cs filename) + $ do + printf <- externVarArgs (mkName "printf") [ptr i8] i32 + createOperand "printf" printf + mapM_ emitTypeDef structs + mapM_ codegenFunc funcs + +-- Given a struct name, search the context for the struct and return its fields +getStructFields :: MonadState Ctx m => Text -> m [Bind] +getStructFields name = do + ctx <- get + case filter (\(Struct n _) -> n == name) (structs ctx) of + [] -> error $ "Struct " ++ show name ++ " not found. Valid structs: " ++ show (map (\(Struct n _) -> n) (structs ctx)) + [Struct _ fields] -> return fields + _ -> error $ "Multiple structs with name " ++ show name + +-- Convert a Windows12 type to an LLVM type +convertType :: MonadState Ctx m => Windows12.Ast.Type -> m LLVM.AST.Type +convertType IntType = return i32 +convertType UIntType = return i32 +convertType FloatType = return double +convertType StrType = convertType (PtrType CharType) +convertType BoolType = return i1 +convertType CharType = return i8 +convertType (PtrType t) = ptr <$> convertType t +convertType (ArrayType t) = convertType (PtrType t) +convertType (StructType name) = do + fields <- getStructFields name + types <- mapM (convertType . bindType) fields + return $ StructureType True types -- True indicates packed +convertType (EnumType name) = return i32 +convertType VoidType = return void + +-- Get the size of a type in bytes +size :: MonadState Ctx m => Windows12.Ast.Type -> m Int +size IntType = return 4 +size UIntType = return 4 +size FloatType = return 8 +size StrType = size (PtrType CharType) +size BoolType = return 1 +size CharType = return 1 +size (PtrType _) = return 4 +size (ArrayType t) = size (PtrType t) +size (StructType name) = do + fields <- getStructFields name + sizes <- mapM (size . bindType) fields + return $ sum sizes +size (EnumType _) = return 8 +size VoidType = return 0 + +-- CodeGen for LValues +codegenLVal :: TLVal -> IRBuilder Operand +codegenLVal (t, (TId name)) = do + ctx <- get + case lookup name (operands ctx) of + Just op -> return op + Nothing -> error $ "Variable " ++ show name ++ " not found" + +-- TODO support members of members +codegenLVal ((StructType t), (LTMember ((_, TId sName)) field)) = do + ctx <- get + case lookup sName (operands ctx) of + Just struct -> do + fields <- getStructFields t + offset <- structFieldOffset (Struct sName fields) field + gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] + Nothing -> error $ "Struct " ++ show sName ++ " not found" + +codeGenLVal (t, (TDeref e)) = codegenExpr e + +codeGenLVal (t, _) = error "Unimplemented or invalid LValue" + +-- Given a struct and a field name, return the offset of the field in the struct. +-- In LLVM each field is actually size 1 +structFieldOffset :: MonadState Ctx m => TLStruct -> Text -> m Int +structFieldOffset (Struct name fields) field = do + return $ length $ takeWhile (\(Bind n _) -> n /= field) fields + +-- CodeGen for expressions +codegenExpr :: TExpr -> IRBuilder Operand +codegenExpr (t, (TVar name)) = flip load 0 =<< codegenLVal (t, (TId name)) +codegenExpr (t, (TIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) +codegenExpr (t, (TUIntLit i)) = return $ ConstantOperand (C.Int 32 (fromIntegral i)) +codegenExpr (t, (TFloatLit f)) = undefined -- TODO floats +codegenExpr (t, (TStrLit s)) = do + strs <- gets strings + case lookup s strs of + -- If the string is already in the context, return it + Just str -> return str + -- Otherwise, create a new global string and add it to the context + Nothing -> do + let str_name = mkName ("str." <> show (length strs)) + op <- globalStringPtr (cs s) str_name + modify $ \ctx -> ctx { strings = (s, (ConstantOperand op)) : strs } + return (ConstantOperand op) +codegenExpr (t, (TBoolLit b)) = return $ ConstantOperand (C.Int 1 (if b then 1 else 0)) +codegenExpr (t, (TCharLit c)) = return $ ConstantOperand (C.Int 8 (fromIntegral (fromEnum c))) + +codegenExpr (t, (TBinOp op lhs rhs)) = do + lhs' <- codegenExpr lhs + rhs' <- codegenExpr rhs + + -- TODO pointers, floating points + case op of + Windows12.Ast.Add -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> add lhs' rhs' + _ -> error "Invalid types for add" + Windows12.Ast.Sub -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> sub lhs' rhs' + _ -> error "Invalid types for sub" + Windows12.Ast.Mul -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> mul lhs' rhs' + _ -> error "Invalid types for mul" + Windows12.Ast.Div -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> sdiv lhs' rhs' + _ -> error "Invalid types for div" + Windows12.Ast.Mod -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> srem lhs' rhs' + _ -> error "Invalid types for mod" + Windows12.Ast.Eq -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.EQ lhs' rhs' + _ -> error "Invalid types for eq" + Windows12.Ast.Ne -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.NE lhs' rhs' + _ -> error "Invalid types for ne" + Windows12.Ast.Lt -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.SLT lhs' rhs' + _ -> error "Invalid types for lt" + Windows12.Ast.Gt -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.SGT lhs' rhs' + _ -> error "Invalid types for gt" + Windows12.Ast.Le -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.SLE lhs' rhs' + _ -> error "Invalid types for le" + Windows12.Ast.Ge -> case (typeOf lhs', typeOf rhs') of + (IntegerType 32, IntegerType 32) -> icmp IP.SGE lhs' rhs' + _ -> error "Invalid types for ge" + + other -> error $ "Operator " ++ show other ++ " not implemented" + +codegenExpr (t, (TUnOp op e)) = undefined -- TODO handle unary operators + +-- Function calls: look up the function in operands, then call it with the args +codegenExpr (t, (TCall f args)) = do + ctx <- get + f <- case lookup f (operands ctx) of + Just f -> return f + Nothing -> error $ "Function " ++ show f ++ " not found" + args <- mapM (fmap (, []) . codegenExpr) args + call f args + +codegenExpr (t, (TIndex arr idx)) = undefined -- TODO arrays + +-- Get the address of the struct field and load it +codegenExpr (t, (TMember ((StructType sName), (TVar sVarName)) m)) = do + ctx <- get + case lookup sVarName (operands ctx) of + Just struct -> do + fields <- getStructFields sName + offset <- structFieldOffset (Struct sVarName fields) m + addr <- gep struct [ConstantOperand (C.Int 32 0), ConstantOperand (C.Int 32 (fromIntegral offset))] + load addr 0 + Nothing -> error $ "Struct operand " ++ show sVarName ++ " not found" + +codegenExpr (_, (TCast t e)) = undefined -- TODO casts + +codegenExpr (_, (TSizeof t)) = ConstantOperand . C.Int 32 . fromIntegral <$> size t + +mkTerminator :: IRBuilder () -> IRBuilder () +mkTerminator instr = do + check <- hasTerminator + unless check instr + +-- Codegen for statements +codegenStmt :: TStmt -> IRBuilder () + +-- For expression statements, just evaluate the expression and discard the result +codegenStmt (TExprStmt e) = do + _expr <- codegenExpr e + return () + +codegenStmt (TReturn e) = ret =<< codegenExpr e + +-- Generate if statements, with a merge block at the end +codegenStmt (TIf cond t f) = mdo + cond' <- codegenExpr cond + condBr cond' then' else' + + then' <- block `named` "then" + codegenStmt (TBlock t) + mkTerminator $ br merge + + else' <- block `named` "else" + codegenStmt (case f of + Just f' -> TBlock f' + Nothing -> TBlock []) + mkTerminator $ br merge + + merge <- block `named` "merge" + return () + +-- Generate while loops, with a merge block at the end +codegenStmt (TWhile cond body) = mdo + br condBlock + condBlock <- block `named` "cond" + cond' <- codegenExpr cond + condBr cond' loop end + + loop <- block `named` "loop" + codegenStmt (TBlock body) + mkTerminator $ br condBlock + + end <- block `named` "end" + return () + +codegenStmt (TAssign BaseAssign l@(t, (TId name)) e) = do + op <- codegenExpr e + var <- codegenLVal l + store var 0 op + +codegenStmt (TAssign BaseAssign l@((StructType tName), (LTMember ((_, TId sName)) field)) e) = do + op <- codegenExpr e + struct <- codegenLVal l + store struct 0 op + +codegenStmt (TAssign AddAssign l@(t, (TId name)) e) = do + op <- codegenExpr e + var <- codegenLVal l + val <- load var 0 + store var 0 =<< add val op + +codegenStmt (TAssign SubAssign l@(t, (TId name)) e) = do + op <- codegenExpr e + var <- codegenLVal l + val <- load var 0 + store var 0 =<< sub val op + +-- A block is just a list of statements +codegenStmt (TBlock stmts) = mapM_ codegenStmt stmts + +-- Since the vars are already allocated by genBody, we just need to assign the value +codegenStmt (TDeclVar name t (Just e)) = codegenStmt (TAssign BaseAssign (t, (TId name)) e) + +-- Do nothing with variable declaration if no expression is given +-- This is because allocation is done already +codegenStmt (TDeclVar name _ Nothing) = return () + +codegenStmt s = error $ "Unimplemented or invalid statement " ++ show s + +-- Generate code for a function +-- First create the function, then allocate space for the arguments and locals +codegenFunc :: TTLFunc -> ModuleBuilder () +codegenFunc func@(TTLFunc name args retType body) = mdo + createOperand name f + (f, strs) <- do + params' <- mapM mkParam args + retType' <- convertType retType + f <- function (mkName (cs name)) params' retType' genBody + strs <- gets strings + return (f, strs) + modify $ \ctx -> ctx { strings = strs } + where + mkParam (Bind name t) = (,) <$> convertType t <*> pure (ParameterName (cs name)) + + genBody :: [Operand] -> IRBuilder () + genBody ops = do + forM_ (zip ops args) $ \(op, (Bind name t)) -> do + addr <- alloca (typeOf op) Nothing 0 + store addr 0 op + createOperand name addr + + forM_ (getLocals func) $ \(Bind name t) -> do + ltype <- convertType t + addr <- alloca ltype Nothing 0 + createOperand name addr + + codegenStmt (TBlock body) + +-- Given a function, get all the local variables +-- Used so allocation can be done before the function body +getLocals :: TTLFunc -> [Bind] +getLocals (TTLFunc _ args _ body) = blockGetLocals body + +blockGetLocals :: [TStmt] -> [Bind] +blockGetLocals = concatMap stmtGetLocals + +stmtGetLocals :: TStmt -> [Bind] +stmtGetLocals (TDeclVar n t _) = [Bind n t] +stmtGetLocals (TBlock stmts) = blockGetLocals stmts +stmtGetLocals (TIf _ t f) = blockGetLocals t ++ maybe [] blockGetLocals f +stmtGetLocals (TWhile _ body) = blockGetLocals body +stmtGetLocals _ = [] + +-- Create structs +emitTypeDef :: TLStruct -> ModuleBuilder LLVM.AST.Type +emitTypeDef (Struct name fields) = do + modify $ \ctx -> ctx { structs = Struct name fields : structs ctx } + sType <- convertType (StructType name) + typedef (mkName (cs ("struct." <> name))) (Just sType)