Initial implementation of lambda compilation
This commit is contained in:
parent
81eb5e43ce
commit
9a0f6d2264
3 changed files with 153 additions and 4 deletions
152
src/compiler.cpp
152
src/compiler.cpp
|
@ -5,19 +5,22 @@
|
||||||
struct Context {
|
struct Context {
|
||||||
Context() {}
|
Context() {}
|
||||||
|
|
||||||
Context(Value&& env, Array&& constants, Dict&& constants_dict)
|
Context(Value&& env, Array&& constants, Dict&& constants_dict,
|
||||||
|
Dict&& variables_dict)
|
||||||
: env(std::move(env)),
|
: env(std::move(env)),
|
||||||
constants(std::move(constants)),
|
constants(std::move(constants)),
|
||||||
constants_dict(std::move(constants_dict)),
|
constants_dict(std::move(constants_dict)),
|
||||||
|
variables_dict(std::move(variables_dict)),
|
||||||
maxreg(0) {}
|
maxreg(0) {}
|
||||||
|
|
||||||
static Result<Context> create() {
|
static Result<Context> create() {
|
||||||
auto env = TRY(Nil::create());
|
auto env = TRY(Nil::create());
|
||||||
auto constants = TRY(Array::create());
|
auto constants = TRY(Array::create());
|
||||||
auto constants_dict = TRY(Dict::create());
|
auto constants_dict = TRY(Dict::create());
|
||||||
|
auto variables_dict = TRY(Dict::create());
|
||||||
|
|
||||||
return Context(std::move(env), std::move(constants),
|
return Context(std::move(env), std::move(constants),
|
||||||
std::move(constants_dict));
|
std::move(constants_dict), std::move(variables_dict));
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t alloc_reg() {
|
uint64_t alloc_reg() {
|
||||||
|
@ -42,9 +45,37 @@ struct Context {
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result<int64_t> add_var(const Value& sym) {
|
||||||
|
auto idx = variables_dict.get(sym);
|
||||||
|
if (idx.has_value()) {
|
||||||
|
if (!idx.value().is<Int64>()) return ERROR(TypeMismatch);
|
||||||
|
|
||||||
|
return idx.value().to<Int64>()->value();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t i = maxreg;
|
||||||
|
|
||||||
|
variables_dict = TRY(variables_dict.insert(sym, TRY(Value::create(i))));
|
||||||
|
|
||||||
|
maxreg++;
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
Result<int64_t> get_var(const Value& sym) {
|
||||||
|
auto idx = variables_dict.get(sym);
|
||||||
|
if (idx.has_value()) {
|
||||||
|
if (!idx.value().is<Int64>()) return ERROR(TypeMismatch);
|
||||||
|
|
||||||
|
return idx.value().to<Int64>()->value();
|
||||||
|
}
|
||||||
|
|
||||||
|
return ERROR(KeyError);
|
||||||
|
}
|
||||||
|
|
||||||
Value env;
|
Value env;
|
||||||
Array constants;
|
Array constants;
|
||||||
Dict constants_dict;
|
Dict constants_dict;
|
||||||
|
Dict variables_dict;
|
||||||
uint64_t maxreg;
|
uint64_t maxreg;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -83,10 +114,11 @@ Result<Expression> Compiler::compile_expr(Context& context, Value& expr) {
|
||||||
return compile_int64(context, *expr.to<Int64>());
|
return compile_int64(context, *expr.to<Int64>());
|
||||||
case Tag::Bool:
|
case Tag::Bool:
|
||||||
return compile_bool(context, *expr.to<Bool>());
|
return compile_bool(context, *expr.to<Bool>());
|
||||||
|
case Tag::Symbol:
|
||||||
|
return compile_symbol(context, *expr.to<Symbol>());
|
||||||
case Tag::Nil:
|
case Tag::Nil:
|
||||||
case Tag::Float:
|
case Tag::Float:
|
||||||
case Tag::String:
|
case Tag::String:
|
||||||
case Tag::Symbol:
|
|
||||||
case Tag::Syntax:
|
case Tag::Syntax:
|
||||||
case Tag::Array:
|
case Tag::Array:
|
||||||
case Tag::ByteArray:
|
case Tag::ByteArray:
|
||||||
|
@ -229,6 +261,100 @@ Result<Expression> Compiler::compile_if(Context& context, Symbol& op,
|
||||||
return std::move(ex);
|
return std::move(ex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
|
||||||
|
Pair& expr) {
|
||||||
|
Context ctx = TRY(Context::create());
|
||||||
|
|
||||||
|
auto first = TRY(expr.rest());
|
||||||
|
|
||||||
|
if (first.is<Nil>() || !first.is<Pair>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
|
||||||
|
Pair& first_pair = *first.to<Pair>();
|
||||||
|
|
||||||
|
auto param = TRY(first_pair.first());
|
||||||
|
if (param.is<Nil>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t arity = 0;
|
||||||
|
while (!param.is<Nil>()) {
|
||||||
|
if (!param.is<Pair>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
Pair& param_pair = *param.to<Pair>();
|
||||||
|
|
||||||
|
auto param_first = TRY(param_pair.first());
|
||||||
|
|
||||||
|
if (!param_first.is<Symbol>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t reg = TRY(ctx.add_var(param_first));
|
||||||
|
std::cout << "reg: " << reg << "\n";
|
||||||
|
|
||||||
|
param = TRY(param_pair.rest());
|
||||||
|
arity++;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value second = TRY(first_pair.rest());
|
||||||
|
|
||||||
|
if (second.is<Nil>() || !second.is<Pair>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
|
||||||
|
Pair& second_pair = *second.to<Pair>();
|
||||||
|
|
||||||
|
auto ex = TRY(compile_body(ctx, second_pair));
|
||||||
|
|
||||||
|
Value name = TRY(Nil::create());
|
||||||
|
auto fun = TRY(Function::create(name, arity, context.constants, ex.code));
|
||||||
|
|
||||||
|
Expression ex_res = TRY(Expression::create());
|
||||||
|
|
||||||
|
int64_t c = TRY(context.add_const(TRY(fun.copy_value())));
|
||||||
|
|
||||||
|
uint64_t reg = context.alloc_reg();
|
||||||
|
TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)c}));
|
||||||
|
|
||||||
|
ex.reg = reg;
|
||||||
|
|
||||||
|
return ex_res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Result<Expression> Compiler::compile_body(Context& context, Pair& expr) {
|
||||||
|
auto cur = TRY(expr.copy_value());
|
||||||
|
|
||||||
|
Expression ex_res = TRY(Expression::create());
|
||||||
|
|
||||||
|
int64_t maxreg = context.maxreg;
|
||||||
|
|
||||||
|
while (!cur.is<Nil>()) {
|
||||||
|
if (!cur.is<Pair>()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
Pair& cur_pair = *cur.to<Pair>();
|
||||||
|
auto expr_val = TRY(cur_pair.first());
|
||||||
|
|
||||||
|
debug_print(expr_val);
|
||||||
|
|
||||||
|
auto expr = TRY(compile_expr(context, expr_val));
|
||||||
|
|
||||||
|
TRY(ex_res.add_code(expr.code));
|
||||||
|
|
||||||
|
cur = TRY(cur_pair.rest());
|
||||||
|
|
||||||
|
if (cur.is<Nil>()) {
|
||||||
|
ex_res.reg = expr.reg;
|
||||||
|
} else {
|
||||||
|
context.maxreg = maxreg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex_res;
|
||||||
|
}
|
||||||
|
|
||||||
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
|
Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
|
||||||
auto first = TRY(expr.first());
|
auto first = TRY(expr.first());
|
||||||
|
|
||||||
|
@ -238,6 +364,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
|
||||||
return compile_primop(context, sym, expr);
|
return compile_primop(context, sym, expr);
|
||||||
} else if (TRY(sym.cmp("if")) == 0) {
|
} else if (TRY(sym.cmp("if")) == 0) {
|
||||||
return compile_if(context, sym, expr);
|
return compile_if(context, sym, expr);
|
||||||
|
} else if (TRY(sym.cmp("lambda")) == 0) {
|
||||||
|
return compile_lambda(context, sym, expr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ERROR(TypeMismatch);
|
return ERROR(TypeMismatch);
|
||||||
|
@ -255,6 +383,24 @@ Result<Expression> Compiler::compile_int64(Context& context, Int64& value) {
|
||||||
return std::move(ex);
|
return std::move(ex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result<Expression> Compiler::compile_symbol(Context& context, Symbol& value) {
|
||||||
|
Expression ex = TRY(Expression::create());
|
||||||
|
|
||||||
|
auto maybe_reg = context.get_var(TRY(value.copy_value()));
|
||||||
|
|
||||||
|
if (maybe_reg.has_error()) {
|
||||||
|
return ERROR(CompilationError);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto var_reg = maybe_reg.value();
|
||||||
|
|
||||||
|
uint64_t reg = context.alloc_reg();
|
||||||
|
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {1, (int64_t)var_reg}));
|
||||||
|
|
||||||
|
ex.reg = reg;
|
||||||
|
return std::move(ex);
|
||||||
|
}
|
||||||
|
|
||||||
Result<Expression> Compiler::compile_bool(Context& context, Bool& value) {
|
Result<Expression> Compiler::compile_bool(Context& context, Bool& value) {
|
||||||
Expression ex = TRY(Expression::create());
|
Expression ex = TRY(Expression::create());
|
||||||
uint64_t reg = context.alloc_reg();
|
uint64_t reg = context.alloc_reg();
|
||||||
|
|
|
@ -33,8 +33,11 @@ class Compiler {
|
||||||
Result<Expression> compile_list(Context& context, Pair& expr);
|
Result<Expression> compile_list(Context& context, Pair& expr);
|
||||||
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
|
Result<Expression> compile_primop(Context& context, Symbol& op, Pair& expr);
|
||||||
Result<Expression> compile_int64(Context& context, Int64& value);
|
Result<Expression> compile_int64(Context& context, Int64& value);
|
||||||
|
Result<Expression> compile_symbol(Context& context, Symbol& value);
|
||||||
Result<Expression> compile_bool(Context& context, Bool& value);
|
Result<Expression> compile_bool(Context& context, Bool& value);
|
||||||
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
|
Result<Expression> compile_if(Context& context, Symbol& op, Pair& expr);
|
||||||
|
Result<Expression> compile_lambda(Context& context, Symbol& op, Pair& expr);
|
||||||
|
Result<Expression> compile_body(Context& context, Pair& expr);
|
||||||
};
|
};
|
||||||
|
|
||||||
Result<Value> compile(Value& expr);
|
Result<Value> compile(Value& expr);
|
||||||
|
|
|
@ -11,7 +11,7 @@ StaticArena<64 * 1024 * 1024> arena;
|
||||||
|
|
||||||
Result<void> run() {
|
Result<void> run() {
|
||||||
// auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
|
// auto code_str = TRY(String::create("(* (+ 1 2 3) (/ 4 2))"));
|
||||||
auto code_str = TRY(String::create("(if true 1 2)"));
|
auto code_str = TRY(String::create("(lambda (x) (* x x))"));
|
||||||
auto reader = Reader(code_str);
|
auto reader = Reader(code_str);
|
||||||
|
|
||||||
auto parsed = TRY(reader.read_one());
|
auto parsed = TRY(reader.read_one());
|
||||||
|
|
Loading…
Reference in a new issue