Implement tail calls
This commit is contained in:
parent
ab9377fe33
commit
d017884c5b
8 changed files with 83 additions and 23 deletions
|
@ -29,6 +29,9 @@ Result<void> Arena::gc_roots() {
|
|||
node = TRY(gc_root(node));
|
||||
node = node->next();
|
||||
}
|
||||
if (_gc_hold != 0) {
|
||||
_gc_hold = TRY(gc_pod(_gc_hold));
|
||||
}
|
||||
return Result<void>();
|
||||
}
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ class ArenaHeap {
|
|||
class Arena {
|
||||
public:
|
||||
Arena(ArenaHeap* first, ArenaHeap* second)
|
||||
: _first(first), _second(second), _gcroot() {}
|
||||
: _first(first), _second(second), _gcroot(), _gc_hold(0) {}
|
||||
|
||||
template <class T>
|
||||
Result<T*> alloc(uint64_t extra = 0) {
|
||||
|
@ -197,10 +197,15 @@ class Arena {
|
|||
return res;
|
||||
};
|
||||
|
||||
// TODO: this should probably check for _gc_hold being 0
|
||||
void sethold(PodObject* ptr) { _gc_hold = ptr; }
|
||||
PodObject* gethold() { return _gc_hold; }
|
||||
|
||||
private:
|
||||
ArenaHeap* _first;
|
||||
ArenaHeap* _second;
|
||||
GcRootList _gcroot;
|
||||
PodObject* _gc_hold;
|
||||
};
|
||||
|
||||
template <uint64_t size>
|
||||
|
@ -237,9 +242,10 @@ Result<void> arena_gc();
|
|||
template <class T>
|
||||
requires std::derived_from<T, PodObject>
|
||||
Result<GcRoot<T>> GcRoot<T>::create(T* ptr) {
|
||||
get_arena().sethold(ptr);
|
||||
auto lst = TRY(arena_alloc<GcRootList>());
|
||||
get_arena().add_root(lst);
|
||||
return std::move(GcRoot<T>(ptr, lst));
|
||||
return std::move(GcRoot<T>((T*)get_arena().gethold(), lst));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
|
|
|
@ -514,6 +514,7 @@ Result<Value> StackFrame::get(uint64_t idx) const {
|
|||
|
||||
Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
|
||||
uint64_t size = std::max(_value->size, idx + 1);
|
||||
auto nil = Value(TRY(Nil::create()));
|
||||
auto pod = TRY(arena_alloc<PodStackFrame>(sizeof(OffPtr<PodObject>) * size));
|
||||
|
||||
pod->parent = _value->parent;
|
||||
|
@ -526,7 +527,6 @@ Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
|
|||
pod->data[i] = _value->data[i];
|
||||
}
|
||||
if (size > _value->size) {
|
||||
auto nil = Value(TRY(Nil::create()));
|
||||
for (uint64_t i = _value->size; i < size; i++) {
|
||||
pod->data[i] = nil.pod();
|
||||
}
|
||||
|
@ -537,6 +537,7 @@ Result<StackFrame> StackFrame::set(uint64_t idx, const Value& val) const {
|
|||
}
|
||||
|
||||
Result<StackFrame> StackFrame::settop(uint64_t idx) const {
|
||||
auto nil = Value(TRY(Nil::create()));
|
||||
auto pod = TRY(arena_alloc<PodStackFrame>(sizeof(OffPtr<PodObject>) * idx));
|
||||
|
||||
pod->parent = _value->parent;
|
||||
|
@ -551,7 +552,6 @@ Result<StackFrame> StackFrame::settop(uint64_t idx) const {
|
|||
}
|
||||
|
||||
if (min_idx < idx) {
|
||||
auto nil = Value(TRY(Nil::create()));
|
||||
for (uint64_t i = min_idx; i < idx; i++) {
|
||||
pod->data[i] = nil.pod();
|
||||
}
|
||||
|
@ -588,9 +588,16 @@ Result<StackFrame> StackFrame::setpc(uint64_t pc) {
|
|||
Result<StackFrame> StackFrame::incpc() { return setpc(pc() + 1); }
|
||||
|
||||
Result<StackFrame> StackFrame::call(const Value& fun, uint64_t start,
|
||||
uint64_t end) const {
|
||||
auto trunc_stack = TRY(settop(start));
|
||||
auto new_stack = TRY(StackFrame::create(TRY(trunc_stack.copy_value()), fun));
|
||||
uint64_t end, bool tail) const {
|
||||
StackFrame new_stack;
|
||||
if (!tail) {
|
||||
auto trunc_stack = TRY(settop(start));
|
||||
new_stack = TRY(StackFrame::create(TRY(trunc_stack.copy_value()), fun));
|
||||
} else {
|
||||
auto par = TRY(this->parent());
|
||||
|
||||
new_stack = TRY(StackFrame::create(par, fun));
|
||||
}
|
||||
|
||||
if (fun.is<Function>()) {
|
||||
uint64_t arity = fun.to<Function>()->arity();
|
||||
|
|
|
@ -1162,7 +1162,8 @@ class StackFrame : public Object {
|
|||
return Array(TRY(MkGcRoot(pod)));
|
||||
}
|
||||
|
||||
Result<StackFrame> call(const Value& fun, uint64_t start, uint64_t end) const;
|
||||
Result<StackFrame> call(const Value& fun, uint64_t start, uint64_t end,
|
||||
bool tail = false) const;
|
||||
Result<StackFrame> ret(uint64_t regno) const;
|
||||
Result<StackFrame> ret(const Value& val) const;
|
||||
|
||||
|
|
|
@ -996,6 +996,42 @@ Result<Expression> Compiler::compile_function_call(Context& context,
|
|||
return ex;
|
||||
}
|
||||
|
||||
Result<Expression> Compiler::compile_tail_call(Context& context,
|
||||
const Value& expr) {
|
||||
auto num_params = TRY(expr.size());
|
||||
if (num_params < 2) {
|
||||
return syntax_error(TRY(expr.first()),
|
||||
"\"tailcall\" form must have at least 1 argument");
|
||||
}
|
||||
|
||||
auto ex = TRY(Expression::create());
|
||||
|
||||
auto second = TRY(expr.second());
|
||||
auto second_unwrapped = TRY(syntax_unwrap(second));
|
||||
auto param = TRY(TRY(expr.rest()).rest());
|
||||
|
||||
uint64_t firstreg = context.maxreg;
|
||||
auto fun_ex = TRY(compile_expr(context, second));
|
||||
TRY(ex.add_code(fun_ex.code));
|
||||
|
||||
int64_t maxreg = context.maxreg;
|
||||
while (!param.is<Nil>()) {
|
||||
Value param_val = TRY(param.first());
|
||||
|
||||
auto param_ex = TRY(compile_expr(context, param_val));
|
||||
TRY(ex.add_code(param_ex.code));
|
||||
|
||||
param = TRY(param.rest());
|
||||
}
|
||||
|
||||
TRY(ex.add_opcode(Oc::TailCall, {0, (int64_t)firstreg},
|
||||
{0, (int64_t)context.maxreg}));
|
||||
|
||||
ex.reg = firstreg;
|
||||
context.maxreg = maxreg;
|
||||
return ex;
|
||||
}
|
||||
|
||||
Result<Expression> Compiler::compile_list(Context& context, const Value& expr) {
|
||||
auto first = TRY(expr.first());
|
||||
auto unwrapped = TRY(syntax_unwrap(first));
|
||||
|
@ -1028,6 +1064,8 @@ Result<Expression> Compiler::compile_list(Context& context, const Value& expr) {
|
|||
return compile_quote(context, sym, expr);
|
||||
} else if (TRY(sym.cmp("syntax")) == 0) {
|
||||
return compile_syntax(context, sym, expr);
|
||||
} else if (TRY(sym.cmp("tailcall")) == 0) {
|
||||
return compile_tail_call(context, expr);
|
||||
} else {
|
||||
return compile_function_call(context, expr);
|
||||
}
|
||||
|
|
|
@ -61,9 +61,10 @@ class Compiler {
|
|||
Result<Expression> compile_let(Context& context, Symbol& op,
|
||||
const Value& expr);
|
||||
Result<Expression> compile_do(Context& context, Symbol& op,
|
||||
const Value& expr);
|
||||
const Value& expr);
|
||||
Result<Expression> compile_body(Context& context, const Value& expr);
|
||||
Result<Expression> compile_function_call(Context& context, const Value& expr);
|
||||
Result<Expression> compile_tail_call(Context& context, const Value& expr);
|
||||
|
||||
Result<Expression> syntax_error(const Value& expr, const char* message);
|
||||
|
||||
|
|
24
src/vm.cpp
24
src/vm.cpp
|
@ -101,17 +101,17 @@ Result<void> VM::vm_less_equal(Opcode& oc) {
|
|||
return Result<void>();
|
||||
}
|
||||
|
||||
Result<void> VM::vm_call_lisp(Opcode& oc, Function& fun) {
|
||||
Result<void> VM::vm_call_lisp(Opcode& oc, Function& fun, bool tail) {
|
||||
uint64_t reg_start = (uint64_t)oc.arg1().arg;
|
||||
uint64_t reg_end = (uint64_t)oc.arg2().arg;
|
||||
|
||||
_stack = TRY(_stack.incpc());
|
||||
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_end));
|
||||
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_end, tail));
|
||||
|
||||
return Result<void>();
|
||||
}
|
||||
|
||||
Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun) {
|
||||
Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun, bool tail) {
|
||||
uint64_t reg_start = (uint64_t)oc.arg1().arg;
|
||||
uint64_t reg_end = (uint64_t)oc.arg2().arg;
|
||||
|
||||
|
@ -119,12 +119,13 @@ Result<void> VM::vm_call_stdlib(Opcode& oc, StdlibFunction& fun) {
|
|||
_stack = TRY(_stack.set(reg_start + 1, TRY(params.copy_value())));
|
||||
_stack = TRY(_stack.settop(reg_start + 2));
|
||||
_stack = TRY(_stack.incpc());
|
||||
_stack = TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_start + 2));
|
||||
_stack =
|
||||
TRY(_stack.call(TRY(fun.copy_value()), reg_start, reg_start + 2, tail));
|
||||
|
||||
return Result<void>();
|
||||
}
|
||||
|
||||
Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun) {
|
||||
Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun, bool tail) {
|
||||
uint64_t reg_start = (uint64_t)oc.arg1().arg;
|
||||
uint64_t reg_end = (uint64_t)oc.arg2().arg;
|
||||
|
||||
|
@ -141,21 +142,21 @@ Result<void> VM::vm_call_cont(Opcode& oc, Continuation& fun) {
|
|||
return Result<void>();
|
||||
}
|
||||
|
||||
Result<void> VM::vm_call(Opcode& oc) {
|
||||
Result<void> VM::vm_call(Opcode& oc, bool tail) {
|
||||
Value fun_value = TRY(get(oc.arg1().is_const, oc.arg1().arg));
|
||||
if (fun_value.is<Function>()) {
|
||||
Function& fun = *fun_value.to<Function>();
|
||||
return vm_call_lisp(oc, fun);
|
||||
return vm_call_lisp(oc, fun, tail);
|
||||
}
|
||||
|
||||
if (fun_value.is<StdlibFunction>()) {
|
||||
StdlibFunction& fun = *fun_value.to<StdlibFunction>();
|
||||
return vm_call_stdlib(oc, fun);
|
||||
return vm_call_stdlib(oc, fun, tail);
|
||||
}
|
||||
|
||||
if (fun_value.is<Continuation>()) {
|
||||
Continuation& fun = *fun_value.to<Continuation>();
|
||||
return vm_call_cont(oc, fun);
|
||||
return vm_call_cont(oc, fun, tail);
|
||||
}
|
||||
|
||||
return ERROR(TypeMismatch);
|
||||
|
@ -275,7 +276,10 @@ Result<void> VM::step_bytecode() {
|
|||
TRY(vm_jump(oc));
|
||||
break;
|
||||
case Oc::Call:
|
||||
TRY(vm_call(oc));
|
||||
TRY(vm_call(oc, false));
|
||||
break;
|
||||
case Oc::TailCall:
|
||||
TRY(vm_call(oc, true));
|
||||
break;
|
||||
case Oc::Self:
|
||||
TRY(vm_self(oc));
|
||||
|
|
|
@ -45,11 +45,11 @@ class VM {
|
|||
Result<void> vm_mul(Opcode& oc);
|
||||
Result<void> vm_sub(Opcode& oc);
|
||||
Result<void> vm_div(Opcode& oc);
|
||||
Result<void> vm_call(Opcode& oc);
|
||||
Result<void> vm_call(Opcode& oc, bool tail);
|
||||
Result<void> vm_self(Opcode& oc);
|
||||
Result<void> vm_call_lisp(Opcode& oc, Function& fun);
|
||||
Result<void> vm_call_stdlib(Opcode& oc, StdlibFunction& fun);
|
||||
Result<void> vm_call_cont(Opcode& oc, Continuation& fun);
|
||||
Result<void> vm_call_lisp(Opcode& oc, Function& fun, bool tail);
|
||||
Result<void> vm_call_stdlib(Opcode& oc, StdlibFunction& fun, bool tail);
|
||||
Result<void> vm_call_cont(Opcode& oc, Continuation& fun, bool tail);
|
||||
Result<void> vm_ret(Opcode& oc);
|
||||
|
||||
Result<void> vm_equal(Opcode& oc);
|
||||
|
|
Loading…
Reference in a new issue