Implement tail calls

This commit is contained in:
Konstantin Nazarov 2024-11-03 18:18:15 +00:00
parent ab9377fe33
commit d017884c5b
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
8 changed files with 83 additions and 23 deletions

View file

@ -29,6 +29,9 @@ Result<void> Arena::gc_roots() {
node = TRY(gc_root(node));
node = node->next();
}
if (_gc_hold != 0) {
_gc_hold = TRY(gc_pod(_gc_hold));
}
return Result<void>();
}

View file

@ -128,7 +128,7 @@ class ArenaHeap {
class Arena {
public:
Arena(ArenaHeap* first, ArenaHeap* second)
: _first(first), _second(second), _gcroot() {}
: _first(first), _second(second), _gcroot(), _gc_hold(0) {}
template <class T>
Result<T*> alloc(uint64_t extra = 0) {
@ -197,10 +197,15 @@ class Arena {
return res;
};
// TODO: this should probably check for _gc_hold being 0
void sethold(PodObject* ptr) { _gc_hold = ptr; }
PodObject* gethold() { return _gc_hold; }
private:
ArenaHeap* _first;
ArenaHeap* _second;
GcRootList _gcroot;
PodObject* _gc_hold;
};
template <uint64_t size>
@ -237,9 +242,10 @@ Result<void> arena_gc();
template <class T>
requires std::derived_from<T, PodObject>
Result<GcRoot<T>> GcRoot<T>::create(T* ptr) {
get_arena().sethold(ptr);
auto lst = TRY(arena_alloc<GcRootList>());
get_arena().add_root(lst);
return std::move(GcRoot<T>(ptr, lst));
return std::move(GcRoot<T>((T*)get_arena().gethold(), lst));
}
template <class T>

View file

@ -514,6 +514,7 @@ Result<Value> StackFrame::get(uint64_t idx) const {
Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
uint64_t size = std::max(_value->size, idx + 1);
auto nil = Value(TRY(Nil::create()));
auto pod = TRY(arena_alloc<PodStackFrame>(sizeof(OffPtr<PodObject>) * size));
pod->parent = _value->parent;
@ -526,7 +527,6 @@ Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
pod->data[i] = _value->data[i];
}
if (size > _value->size) {
auto nil = Value(TRY(Nil::create()));
for (uint64_t i = _value->size; i < size; i++) {
pod->data[i] = nil.pod();
}
@ -537,6 +537,7 @@ Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
}
Result<StackFrame> StackFrame::settop(uint64_t idx) const {
auto nil = Value(TRY(Nil::create()));
auto pod = TRY(arena_alloc<PodStackFrame>(sizeof(OffPtr<PodObject>) * idx));
pod->parent = _value->parent;
@ -551,7 +552,6 @@ Result<StackFrame> StackFrame::settop(uint64_t idx) const {
}
if (min_idx < idx) {
auto nil = Value(TRY(Nil::create()));
for (uint64_t i = min_idx; i < idx; i++) {
pod->data[i] = nil.pod();
}
@ -588,9 +588,16 @@ Result<StackFrame> StackFrame::setpc(uint64_t pc) {
Result<StackFrame> StackFrame::incpc() { return setpc(pc() + 1); }
Result<StackFrame> StackFrame::call(const Value& fun, uint64_t start,
uint64_t end) const {
auto trunc_stack = TRY(settop(start));
auto new_stack = TRY(StackFrame::create(TRY(trunc_stack.copy_value()), fun));
uint64_t end, bool tail) const {
StackFrame new_stack;
if (!tail) {
auto trunc_stack = TRY(settop(start));
new_stack = TRY(StackFrame::create(TRY(trunc_stack.copy_value()), fun));
} else {
auto par = TRY(this->parent());
new_stack = TRY(StackFrame::create(par, fun));
}
if (fun.is<Function>()) {
uint64_t arity = fun.to<Function>()->arity();

View file

@ -1162,7 +1162,8 @@ class StackFrame : public Object {
return Array(TRY(MkGcRoot(pod)));
}
Result<StackFrame> call(const Value& fun, uint64_t start, uint64_t end) const;
Result<StackFrame> call(const Value& fun, uint64_t start, uint64_t end,
bool tail = false) const;
Result<StackFrame> ret(uint64_t regno) const;
Result<StackFrame> ret(const Value& val) const;

View file

@ -996,6 +996,42 @@ Result<Expression> Compiler::compile_function_call(Context& context,
return ex;
}
Result<Expression> Compiler::compile_tail_call(Context& context,
const Value& expr) {
auto num_params = TRY(expr.size());
if (num_params < 2) {
return syntax_error(TRY(expr.first()),
"\"tailcall\" form must have at least 1 argument");
}
auto ex = TRY(Expression::create());
auto second = TRY(expr.second());
auto second_unwrapped = TRY(syntax_unwrap(second));
auto param = TRY(TRY(expr.rest()).rest());
uint64_t firstreg = context.maxreg;
auto fun_ex = TRY(compile_expr(context, second));
TRY(ex.add_code(fun_ex.code));
int64_t maxreg = context.maxreg;
while (!param.is<Nil>()) {
Value param_val = TRY(param.first());
auto param_ex = TRY(compile_expr(context, param_val));
TRY(ex.add_code(param_ex.code));
param = TRY(param.rest());
}
TRY(ex.add_opcode(Oc::TailCall, {0, (int64_t)firstreg},
{0, (int64_t)context.maxreg}));
ex.reg = firstreg;
context.maxreg = maxreg;
return ex;
}
Result<Expression> Compiler::compile_list(Context& context, const Value& expr) {
auto first = TRY(expr.first());
auto unwrapped = TRY(syntax_unwrap(first));
@ -1028,6 +1064,8 @@ Result<Expression> Compiler::compile_list(Context& context, const Value& expr) {
return compile_quote(context, sym, expr);
} else if (TRY(sym.cmp("syntax")) == 0) {
return compile_syntax(context, sym, expr);
} else if (TRY(sym.cmp("tailcall")) == 0) {
return compile_tail_call(context, expr);
} else {
return compile_function_call(context, expr);
}

View file

@ -61,9 +61,10 @@ class Compiler {
Result<Expression> compile_let(Context& context, Symbol& op,
const Value& expr);
Result<Expression> compile_do(Context& context, Symbol& op,
const Value& expr);
const Value& expr);
Result<Expression> compile_body(Context& context, const Value& expr);
Result<Expression> compile_function_call(Context& context, const Value& expr);
Result<Expression> compile_tail_call(Context& context, const Value& expr);
Result<Expression> syntax_error(const Value& expr, const char* message);

View file

@ -101,17 +101,17 @@ Result<void> VM::vm_less_equal(Opcode& oc) {
return Result<void>();
}
Result<void> VM::vm_call_lisp(Opcode& oc, Function& fun) {
Result<void> VM::vm_call_lisp(Opcode& oc, Function& fun, bool tail) {
uint64_t reg_start = (uint64_t)oc.arg1().arg;
uint64_t reg_end = (uint64_t)oc.arg2().arg;
_stack = TRY(_stack.incpc());
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_end));
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_end, tail));
return Result<void>();
}
Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun) {
Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun, bool tail) {
uint64_t reg_start = (uint64_t)oc.arg1().arg;
uint64_t reg_end = (uint64_t)oc.arg2().arg;
@ -119,12 +119,13 @@ Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun) {
_stack = TRY(_stack.set(reg_start + 1, TRY(params.copy_value())));
_stack = TRY(_stack.settop(reg_start + 2));
_stack = TRY(_stack.incpc());
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_start + 2));
_stack =
TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_start + 2, tail));
return Result<void>();
}
Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun) {
Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun, bool tail) {
uint64_t reg_start = (uint64_t)oc.arg1().arg;
uint64_t reg_end = (uint64_t)oc.arg2().arg;
@ -141,21 +142,21 @@ Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun) {
return Result<void>();
}
Result<void> VM::vm_call(Opcode& oc) {
Result<void> VM::vm_call(Opcode& oc, bool tail) {
Value fun_value = TRY(get(oc.arg1().is_const, oc.arg1().arg));
if (fun_value.is<Function>()) {
Function& fun = *fun_value.to<Function>();
return vm_call_lisp(oc, fun);
return vm_call_lisp(oc, fun, tail);
}
if (fun_value.is<StdlibFunction>()) {
StdlibFunction& fun = *fun_value.to<StdlibFunction>();
return vm_call_stdlib(oc, fun);
return vm_call_stdlib(oc, fun, tail);
}
if (fun_value.is<Continuation>()) {
Continuation& fun = *fun_value.to<Continuation>();
return vm_call_cont(oc, fun);
return vm_call_cont(oc, fun, tail);
}
return ERROR(TypeMismatch);
@ -275,7 +276,10 @@ Result<void> VM::step_bytecode() {
TRY(vm_jump(oc));
break;
case Oc::Call:
TRY(vm_call(oc));
TRY(vm_call(oc, false));
break;
case Oc::TailCall:
TRY(vm_call(oc, true));
break;
case Oc::Self:
TRY(vm_self(oc));

View file

@ -45,11 +45,11 @@ class VM {
Result<void> vm_mul(Opcode& oc);
Result<void> vm_sub(Opcode& oc);
Result<void> vm_div(Opcode& oc);
Result<void> vm_call(Opcode& oc);
Result<void> vm_call(Opcode& oc, bool tail);
Result<void> vm_self(Opcode& oc);
Result<void> vm_call_lisp(Opcode& oc, Function& fun);
Result<void> vm_call_stdlib(Opcode& oc, StdlibFunction& fun);
Result<void> vm_call_cont(Opcode& oc, Continuation& fun);
Result<void> vm_call_lisp(Opcode& oc, Function& fun, bool tail);
Result<void> vm_call_stdlib(Opcode& oc, StdlibFunction& fun, bool tail);
Result<void> vm_call_cont(Opcode& oc, Continuation& fun, bool tail);
Result<void> vm_ret(Opcode& oc);
Result<void> vm_equal(Opcode& oc);