From 9a0f6d226496fa7a3bf4a308919d84cf2bfebc0b Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Sat, 17 Aug 2024 12:33:45 +0100 Subject: [PATCH] Initial implementation of lambda compilation --- src/compiler.cpp | 152 ++++++++++++++++++++++++++++++++++++++++++++++- src/compiler.hpp | 3 + src/vli.cpp | 2 +- 3 files changed, 153 insertions(+), 4 deletions(-) diff --git a/src/compiler.cpp b/src/compiler.cpp index e38a51c..4f458c0 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -5,19 +5,22 @@ struct Context { Context() {} - Context(Value&& env, Array&& constants, Dict&& constants_dict) + Context(Value&& env, Array&& constants, Dict&& constants_dict, + Dict&& variables_dict) : env(std::move(env)), constants(std::move(constants)), constants_dict(std::move(constants_dict)), + variables_dict(std::move(variables_dict)), maxreg(0) {} static Result create() { auto env = TRY(Nil::create()); auto constants = TRY(Array::create()); auto constants_dict = TRY(Dict::create()); + auto variables_dict = TRY(Dict::create()); return Context(std::move(env), std::move(constants), - std::move(constants_dict)); + std::move(constants_dict), std::move(variables_dict)); } uint64_t alloc_reg() { @@ -42,9 +45,37 @@ struct Context { return i; } + Result add_var(const Value& sym) { + auto idx = variables_dict.get(sym); + if (idx.has_value()) { + if (!idx.value().is()) return ERROR(TypeMismatch); + + return idx.value().to()->value(); + } + + int64_t i = maxreg; + + variables_dict = TRY(variables_dict.insert(sym, TRY(Value::create(i)))); + + maxreg++; + return i; + } + + Result get_var(const Value& sym) { + auto idx = variables_dict.get(sym); + if (idx.has_value()) { + if (!idx.value().is()) return ERROR(TypeMismatch); + + return idx.value().to()->value(); + } + + return ERROR(KeyError); + } + Value env; Array constants; Dict constants_dict; + Dict variables_dict; uint64_t maxreg; }; @@ -83,10 +114,11 @@ Result Compiler::compile_expr(Context& context, Value& expr) { return compile_int64(context, *expr.to()); case Tag::Bool: return compile_bool(context, *expr.to()); + case Tag::Symbol: + return compile_symbol(context, *expr.to()); case Tag::Nil: case Tag::Float: case Tag::String: - case Tag::Symbol: case Tag::Syntax: case Tag::Array: case Tag::ByteArray: @@ -229,6 +261,100 @@ Result Compiler::compile_if(Context& context, Symbol& op, return std::move(ex); } +Result Compiler::compile_lambda(Context& context, Symbol& op, + Pair& expr) { + Context ctx = TRY(Context::create()); + + auto first = TRY(expr.rest()); + + if (first.is() || !first.is()) { + return ERROR(CompilationError); + } + + Pair& first_pair = *first.to(); + + auto param = TRY(first_pair.first()); + if (param.is()) { + return ERROR(CompilationError); + } + + uint64_t arity = 0; + while (!param.is()) { + if (!param.is()) { + return ERROR(CompilationError); + } + Pair& param_pair = *param.to(); + + auto param_first = TRY(param_pair.first()); + + if (!param_first.is()) { + return ERROR(CompilationError); + } + + int64_t reg = TRY(ctx.add_var(param_first)); + std::cout << "reg: " << reg << "\n"; + + param = TRY(param_pair.rest()); + arity++; + } + + Value second = TRY(first_pair.rest()); + + if (second.is() || !second.is()) { + return ERROR(CompilationError); + } + + Pair& second_pair = *second.to(); + + auto ex = TRY(compile_body(ctx, second_pair)); + + Value name = TRY(Nil::create()); + auto fun = TRY(Function::create(name, arity, context.constants, ex.code)); + + Expression ex_res = TRY(Expression::create()); + + int64_t c = TRY(context.add_const(TRY(fun.copy_value()))); + + uint64_t reg = context.alloc_reg(); + TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c})); + + ex.reg = reg; + + return ex_res; +} + +Result Compiler::compile_body(Context& context, Pair& expr) { + auto cur = TRY(expr.copy_value()); + + Expression ex_res = TRY(Expression::create()); + + int64_t maxreg = context.maxreg; + + while (!cur.is()) { + if (!cur.is()) { + return ERROR(CompilationError); + } + Pair& cur_pair = *cur.to(); + auto expr_val = TRY(cur_pair.first()); + + debug_print(expr_val); + + auto expr = TRY(compile_expr(context, expr_val)); + + TRY(ex_res.add_code(expr.code)); + + cur = TRY(cur_pair.rest()); + + if (cur.is()) { + ex_res.reg = expr.reg; + } else { + context.maxreg = maxreg; + } + } + + return ex_res; +} + Result Compiler::compile_list(Context& context, Pair& expr) { auto first = TRY(expr.first()); @@ -238,6 +364,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { return compile_primop(context, sym, expr); } else if (TRY(sym.cmp("if")) == 0) { return compile_if(context, sym, expr); + } else if (TRY(sym.cmp("lambda")) == 0) { + return compile_lambda(context, sym, expr); } } return ERROR(TypeMismatch); @@ -255,6 +383,24 @@ Result Compiler::compile_int64(Context& context, Int64& value) { return std::move(ex); } +Result Compiler::compile_symbol(Context& context, Symbol& value) { + Expression ex = TRY(Expression::create()); + + auto maybe_reg = context.get_var(TRY(value.copy_value())); + + if (maybe_reg.has_error()) { + return ERROR(CompilationError); + } + + auto var_reg = maybe_reg.value(); + + uint64_t reg = context.alloc_reg(); + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)var_reg})); + + ex.reg = reg; + return std::move(ex); +} + Result Compiler::compile_bool(Context& context, Bool& value) { Expression ex = TRY(Expression::create()); uint64_t reg = context.alloc_reg(); diff --git a/src/compiler.hpp b/src/compiler.hpp index d09d698..2878a41 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -33,8 +33,11 @@ class Compiler { Result compile_list(Context& context, Pair& expr); Result compile_primop(Context& context, Symbol& op, Pair& expr); Result compile_int64(Context& context, Int64& value); + Result compile_symbol(Context& context, Symbol& value); Result compile_bool(Context& context, Bool& value); Result compile_if(Context& context, Symbol& op, Pair& expr); + Result compile_lambda(Context& context, Symbol& op, Pair& expr); + Result compile_body(Context& context, Pair& expr); }; Result compile(Value& expr); diff --git a/src/vli.cpp b/src/vli.cpp index c67310a..b95e124 100644 --- a/src/vli.cpp +++ b/src/vli.cpp @@ -11,7 +11,7 @@ StaticArena<64 * 1024 * 1024> arena; Result run() { // auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); - auto code_str = TRY(String::create("(if true 1 2)")); + auto code_str = TRY(String::create("(lambda (x) (* x x))")); auto reader = Reader(code_str); auto parsed = TRY(reader.read_one());