Boostで抽象構文木を作ってllvm言語に変換してみた
Boost.Variant、Boost.Spirit.Qi、Boost.Spirit.Phoenixを使って簡単な中置記法の計算するための抽象構文木(AST)を作って、そのASTをllvm言語に変換するというのを作ってみました。コード書き殴り。 BoostでASTが割とすっきりと書け、そこからllvm言語に変換することも割と簡単に出来ます。
コードと出力結果はGistのほうに置いておきます。
雑ですが説明。
ASTの表現
Boost.VariantでASTを表します。
namespace ast { struct add; struct sub; struct mul; struct div; template <class Op> struct binary_op; using expr = boost::variant< int, boost::recursive_wrapper< binary_op< add > >, boost::recursive_wrapper< binary_op< sub > >, boost::recursive_wrapper< binary_op< mul > >, boost::recursive_wrapper< binary_op< div > > >; template <class Op> struct binary_op { expr lhs; expr rhs; binary_op(expr const& lhs_, expr const& rhs_) : lhs( lhs_ ), rhs( rhs_ ) { } }; } // namespace ast
値または式を格納するためにboost::variantを使ってast::exprを定義しています。 ここでは値はintまたは二項演算を表すast::binary_opで四則演算のいずれかを格納できるようにしています。 2 * 3 + 1のような式の場合、以下の図のような木になります。この図のbinary_opやintは全てast::exprで表すことができます。
ast::binary_opについて、演算子の左、右でそれぞれast::exprのlhs、rhsで表現しています。 しかし、ast::exprを定義するためにast::binary_opが必要で、ast::binary_opを定義するためにast::exprが必要という循環になってしまってます。なので、ast::binary_opをast::exprの前方で宣言しておいてboost::recursive_wrapperを使って解決しています。
パーサ
Boost.Spirit.Qi、Boost.Spirit.Phoenixを使ってパーサを作ります。
namespace parser { namespace qi = boost::spirit::qi; template <class Iterator> struct arith_grammar : qi::grammar< Iterator, ast::expr(), qi::ascii::space_type > { template <class T> using rule_t = qi::rule< Iterator, T, qi::ascii::space_type >; rule_t< ast::expr() > expr; rule_t< ast::expr() > term; rule_t< ast::expr() > factor; arith_grammar() : arith_grammar::base_type( expr ) { namespace phx = boost::phoenix; expr %= term[qi::_val = qi::_1] >> *( ( '+' >> term[qi::_val = phx::construct< ast::binary_op< ast::add > >( qi::_val, qi::_1 )] ) | ( '-' >> term[qi::_val = phx::construct< ast::binary_op< ast::sub > >( qi::_val, qi::_1 )] ) ); term %= factor[qi::_val = qi::_1] >> *( ( '*' >> factor[qi::_val = phx::construct< ast::binary_op< ast::mul > >( qi::_val, qi::_1 )] ) | ( '/' >> factor[qi::_val = phx::construct< ast::binary_op< ast::div > >( qi::_val, qi::_1 )] ) ); factor %= qi::int_ | ( '(' >> expr >> ')' )[qi::_val = qi::_1]; } }; } // namespace parser
セマンティックアクションの中でast::binary_opを構築するときに、直接ast::binary_opと書くとセマンティックアション前に構築されてしまってエラーになるのでboost::phoenix::constructを使って構築を遅延しています。
llvm言語に変換
ASTが出来たらllvm言語に変換します。 llvm言語周りはまだわかってないので過程だけ説明します。
218行目で、まずLLVM側の準備、main関数の作成、printfを使えるようにしています。
auto& context = llvm::getGlobalContext(); std::unique_ptr< llvm::Module > module( new llvm::Module( "arith", context ) ); llvm::IRBuilder<> builder( context ); make_main_func( context, module, builder ); auto* printf_func = make_printf( module, builder ); auto* format = builder.CreateGlobalStringPtr( "%d\n" );
ASTをasemmblyを使ってllvm言語に変換します。
class asemmbly : public boost::static_visitor< llvm::Value* > { llvm::IRBuilder<>& builder_; public: asemmbly(llvm::IRBuilder<>& builder) : boost::static_visitor< llvm::Value* >(), builder_( builder ) { } llvm::Value* operator()(int value) { return llvm::ConstantInt::get( builder_.getInt32Ty(), value ); } template <class Op> llvm::Value* operator()(ast::binary_op< Op > const& op) { llvm::Value* lhs = boost::apply_visitor( *this, op.lhs ); llvm::Value* rhs = boost::apply_visitor( *this, op.rhs ); return apply_op( op, lhs, rhs ); } private: llvm::Value* apply_op(ast::binary_op< ast::add > const&, llvm::Value* lhs, llvm::Value* rhs) { return builder_.CreateAdd( lhs, rhs ); } llvm::Value* apply_op(ast::binary_op< ast::sub > const&, llvm::Value* lhs, llvm::Value* rhs) { return builder_.CreateSub( lhs, rhs ); } llvm::Value* apply_op(ast::binary_op< ast::mul > const&, llvm::Value* lhs, llvm::Value* rhs) { return builder_.CreateMul( lhs, rhs ); } llvm::Value* apply_op(ast::binary_op< ast::div > const&, llvm::Value* lhs, llvm::Value* rhs) { return builder_.CreateSDiv( lhs, rhs ); } };
返り値をllvm::Value*に合わせて再帰的にllvm言語に変換していきます。
227行目、std::vector< llvm::Value* > argsはprintfに渡す引数をとなっており、前述のprintf_funcと合わせてbuilder.CreateCallするとllvm言語で表されたprintfの呼び出しコードが生成されます。
asemmbly asm_obj( builder ); for( auto const& i : asts ) { std::vector< llvm::Value* > args = { format, boost::apply_visitor( asm_obj, i ) }; builder.CreateCall( printf_func, llvm::ArrayRef< llvm::Value* >( args ) ); }
235行目でmake_ret_mainを呼び出してmain関数の戻り値を書き込みます。
237行目でstd::cout << print_llvm_lang( module );
として生成したllvm言語のコードを標準出力しています。
出力
入力を次のようにすると、
2 * ( 1 + 2 ) 2 * 1 + 2
下記のように出力されます。
; ModuleID = 'arith' @0 = private unnamed_addr constant [4 x i8] c"%d\0A\00" define i32 @main() { entry: %0 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @0, i32 0, i32 0), i32 6) %1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @0, i32 0, i32 0), i32 4) ret i32 0 } declare i32 @printf(i8*, ...)
これをとりあえず動かすにはlliを使うといいでしょう。
参考文献
Boost -- Spirit 2.5.2
LLVM Tutorial: Table of Contents
LLVM API Documentation
redboltzの日記 -- qiを使って構文ツリーを構築
stackoverflow -- Retrieving AST from boost::spirit parser
letsboost::spirit
letsboost::variant