Initial implementation of lambda compilation

This commit is contained in:
Konstantin Nazarov 2024-08-17 12:33:45 +01:00
parent 81eb5e43ce
commit 9a0f6d2264
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
3 changed files with 153 additions and 4 deletions

View file

@ -5,19 +5,22 @@
struct Context { struct Context {
Context() {} Context() {}
Context(Value&& env, Array&& constants, Dict&& constants_dict) Context(Value&& env, Array&& constants, Dict&& constants_dict,
Dict&& variables_dict)
: env(std::move(env)), : env(std::move(env)),
constants(std::move(constants)), constants(std::move(constants)),
constants_dict(std::move(constants_dict)), constants_dict(std::move(constants_dict)),
variables_dict(std::move(variables_dict)),
maxreg(0) {} maxreg(0) {}
static Result<Context> create() { static Result<Context> create() {
auto env = TRY(Nil::create()); auto env = TRY(Nil::create());
auto constants = TRY(Array::create()); auto constants = TRY(Array::create());
auto constants_dict = TRY(Dict::create()); auto constants_dict = TRY(Dict::create());
auto variables_dict = TRY(Dict::create());
return Context(std::move(env), std::move(constants), return Context(std::move(env), std::move(constants),
std::move(constants_dict)); std::move(constants_dict), std::move(variables_dict));
} }
uint64_t alloc_reg() { uint64_t alloc_reg() {
@ -42,9 +45,37 @@ struct Context {
return i; return i;
} }
Result<int64_t> add_var(const Value& sym) {
auto idx = variables_dict.get(sym);
if (idx.has_value()) {
if (!idx.value().is<Int64>()) return ERROR(TypeMismatch);
return idx.value().to<Int64>()->value();
}
int64_t i = maxreg;
variables_dict = TRY(variables_dict.insert(sym, TRY(Value::create(i))));
maxreg++;
return i;
}
Result<int64_t> get_var(const Value& sym) {
auto idx = variables_dict.get(sym);
if (idx.has_value()) {
if (!idx.value().is<Int64>()) return ERROR(TypeMismatch);
return idx.value().to<Int64>()->value();
}
return ERROR(KeyError);
}
Value env; Value env;
Array constants; Array constants;
Dict constants_dict; Dict constants_dict;
Dict variables_dict;
uint64_t maxreg; uint64_t maxreg;
}; };
@ -83,10 +114,11 @@ Result<Expression> Compiler::compile_expr(Context& context, Value& expr) {
return compile_int64(context, *expr.to<Int64>()); return compile_int64(context, *expr.to<Int64>());
case Tag::Bool: case Tag::Bool:
return compile_bool(context, *expr.to<Bool>()); return compile_bool(context, *expr.to<Bool>());
case Tag::Symbol:
return compile_symbol(context, *expr.to<Symbol>());
case Tag::Nil: case Tag::Nil:
case Tag::Float: case Tag::Float:
case Tag::String: case Tag::String:
case Tag::Symbol:
case Tag::Syntax: case Tag::Syntax:
case Tag::Array: case Tag::Array:
case Tag::ByteArray: case Tag::ByteArray:
@ -229,6 +261,100 @@ Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
return std::move(ex); return std::move(ex);
} }
Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
Pair& expr) {
Context ctx = TRY(Context::create());
auto first = TRY(expr.rest());
if (first.is<Nil>() || !first.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& first_pair = *first.to<Pair>();
auto param = TRY(first_pair.first());
if (param.is<Nil>()) {
return ERROR(CompilationError);
}
uint64_t arity = 0;
while (!param.is<Nil>()) {
if (!param.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& param_pair = *param.to<Pair>();
auto param_first = TRY(param_pair.first());
if (!param_first.is<Symbol>()) {
return ERROR(CompilationError);
}
int64_t reg = TRY(ctx.add_var(param_first));
std::cout << "reg: " << reg << "\n";
param = TRY(param_pair.rest());
arity++;
}
Value second = TRY(first_pair.rest());
if (second.is<Nil>() || !second.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& second_pair = *second.to<Pair>();
auto ex = TRY(compile_body(ctx, second_pair));
Value name = TRY(Nil::create());
auto fun = TRY(Function::create(name, arity, context.constants, ex.code));
Expression ex_res = TRY(Expression::create());
int64_t c = TRY(context.add_const(TRY(fun.copy_value())));
uint64_t reg = context.alloc_reg();
TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c}));
ex.reg = reg;
return ex_res;
}
Result<Expression> Compiler::compile_body(Context& context, Pair& expr) {
auto cur = TRY(expr.copy_value());
Expression ex_res = TRY(Expression::create());
int64_t maxreg = context.maxreg;
while (!cur.is<Nil>()) {
if (!cur.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& cur_pair = *cur.to<Pair>();
auto expr_val = TRY(cur_pair.first());
debug_print(expr_val);
auto expr = TRY(compile_expr(context, expr_val));
TRY(ex_res.add_code(expr.code));
cur = TRY(cur_pair.rest());
if (cur.is<Nil>()) {
ex_res.reg = expr.reg;
} else {
context.maxreg = maxreg;
}
}
return ex_res;
}
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) { Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
auto first = TRY(expr.first()); auto first = TRY(expr.first());
@ -238,6 +364,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
return compile_primop(context, sym, expr); return compile_primop(context, sym, expr);
} else if (TRY(sym.cmp("if")) == 0) { } else if (TRY(sym.cmp("if")) == 0) {
return compile_if(context, sym, expr); return compile_if(context, sym, expr);
} else if (TRY(sym.cmp("lambda")) == 0) {
return compile_lambda(context, sym, expr);
} }
} }
return ERROR(TypeMismatch); return ERROR(TypeMismatch);
@ -255,6 +383,24 @@ Result<Expression> Compiler::compile_int64(Context& context, Int64& value) {
return std::move(ex); return std::move(ex);
} }
Result<Expression> Compiler::compile_symbol(Context& context, Symbol& value) {
Expression ex = TRY(Expression::create());
auto maybe_reg = context.get_var(TRY(value.copy_value()));
if (maybe_reg.has_error()) {
return ERROR(CompilationError);
}
auto var_reg = maybe_reg.value();
uint64_t reg = context.alloc_reg();
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)var_reg}));
ex.reg = reg;
return std::move(ex);
}
Result<Expression> Compiler::compile_bool(Context& context, Bool& value) { Result<Expression> Compiler::compile_bool(Context& context, Bool& value) {
Expression ex = TRY(Expression::create()); Expression ex = TRY(Expression::create());
uint64_t reg = context.alloc_reg(); uint64_t reg = context.alloc_reg();

View file

@ -33,8 +33,11 @@ class Compiler {
Result<Expression> compile_list(Context& context, Pair& expr); Result<Expression> compile_list(Context& context, Pair& expr);
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr); Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_int64(Context& context, Int64& value); Result<Expression> compile_int64(Context& context, Int64& value);
Result<Expression> compile_symbol(Context& context, Symbol& value);
Result<Expression> compile_bool(Context& context, Bool& value); Result<Expression> compile_bool(Context& context, Bool& value);
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr); Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_lambda(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_body(Context& context, Pair& expr);
}; };
Result<Value> compile(Value& expr); Result<Value> compile(Value& expr);

View file

@ -11,7 +11,7 @@ StaticArena<64 * 1024 * 1024> arena;
Result<void> run() { Result<void> run() {
// auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))")); // auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
auto code_str = TRY(String::create("(if true 1 2)")); auto code_str = TRY(String::create("(lambda (x) (* x x))"));
auto reader = Reader(code_str); auto reader = Reader(code_str);
auto parsed = TRY(reader.read_one()); auto parsed = TRY(reader.read_one());