Implement VM opcode for closure capture

This commit is contained in:
Konstantin Nazarov 2024-08-21 23:45:19 +01:00
parent edc0a89ed9
commit 6fcd231694
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
7 changed files with 99 additions and 25 deletions

View file

@ -203,6 +203,7 @@ Result<PodObject*> Arena::gc_function(PodFunction* obj) {
nobj->name = TRY(gc_pod(obj->name.get())); nobj->name = TRY(gc_pod(obj->name.get()));
nobj->constants = TRY(gc_pod(obj->constants.get())); nobj->constants = TRY(gc_pod(obj->constants.get()));
nobj->code = TRY(gc_pod(obj->code.get())); nobj->code = TRY(gc_pod(obj->code.get()));
nobj->closure = TRY(gc_pod(obj->closure.get()));
return nobj; return nobj;
} }

View file

@ -149,6 +149,11 @@ Result<short> Function::cmp(const Function& rhs) const {
res = TRY(lhs_code.cmp(rhs_code)); res = TRY(lhs_code.cmp(rhs_code));
if (res != 0) return res; if (res != 0) return res;
auto lhs_closure = TRY(closure());
auto rhs_closure = TRY(rhs.closure());
res = TRY(lhs_closure.cmp(rhs_closure));
if (res != 0) return res;
auto lhs_constants = TRY(constants()); auto lhs_constants = TRY(constants());
auto rhs_constants = TRY(rhs.constants()); auto rhs_constants = TRY(rhs.constants());
res = TRY(lhs_constants.cmp(rhs_constants)); res = TRY(lhs_constants.cmp(rhs_constants));
@ -205,13 +210,28 @@ Result<Value> Pair::rest() {
} }
Result<Function> Function::create(const Value& name, uint64_t arity, Result<Function> Function::create(const Value& name, uint64_t arity,
const Array& constants, const Array& code) { const Array& constants, const Array& code,
const Array& closure) {
auto pod = TRY(arena_alloc<PodFunction>()); auto pod = TRY(arena_alloc<PodFunction>());
pod->header.tag = Tag::Function; pod->header.tag = Tag::Function;
pod->arity = arity; pod->arity = arity;
pod->name = name.pod(); pod->name = name.pod();
pod->constants = constants.pod(); pod->constants = constants.pod();
pod->code = code.pod(); pod->code = code.pod();
pod->closure = closure.pod();
return Function(TRY(MkGcRoot(pod)));
}
Result<Function> Function::create(const Function& prototype,
const Array& closure) {
auto pod = TRY(arena_alloc<PodFunction>());
pod->header.tag = Tag::Function;
pod->arity = prototype._value->arity;
pod->name = prototype._value->name;
pod->constants = prototype._value->constants;
pod->code = prototype._value->code;
pod->closure = closure.pod();
return Function(TRY(MkGcRoot(pod))); return Function(TRY(MkGcRoot(pod)));
} }
@ -228,6 +248,10 @@ Result<Array> Function::constants() const {
return Array::create((PodArray*)_value->constants.get()); return Array::create((PodArray*)_value->constants.get());
} }
Result<Array> Function::closure() const {
return Array::create((PodArray*)_value->closure.get());
}
Result<Value> reverse(Value& val) { Result<Value> reverse(Value& val) {
if (val.is<Nil>()) return Value(TRY(Nil::create())); if (val.is<Nil>()) return Value(TRY(Nil::create()));
if (!val.is<Pair>()) return ERROR(TypeMismatch); if (!val.is<Pair>()) return ERROR(TypeMismatch);

View file

@ -867,14 +867,19 @@ class Function : public Object {
return Function(TRY(MkGcRoot(obj))); return Function(TRY(MkGcRoot(obj)));
} }
static Result<Function> create(const Function& prototype,
const Array& closure);
static Result<Function> create(const Value& name, uint64_t arity, static Result<Function> create(const Value& name, uint64_t arity,
const Array& constants, const Array& code); const Array& constants, const Array& code,
const Array& closure);
uint64_t arity() const { return _value->arity; } uint64_t arity() const { return _value->arity; }
Result<Value> name() const; Result<Value> name() const;
Result<Array> code() const; Result<Array> code() const;
Result<Array> constants() const; Result<Array> constants() const;
Result<Array> closure() const;
virtual Result<Value> copy_value() const final; virtual Result<Value> copy_value() const final;
Result<Function> copy() const; Result<Function> copy() const;
@ -922,6 +927,18 @@ class Stack : public Object {
return Stack(TRY(MkGcRoot(pod))); return Stack(TRY(MkGcRoot(pod)));
} }
Result<Array> slice(uint64_t start, uint64_t end) {
if (start > end || end > gettop()) return ERROR(IndexOutOfRange);
uint64_t res_size = end - start;
auto pod = TRY(arena_alloc<PodArray>(res_size * sizeof(PodObject*)));
pod->size = res_size;
for (uint64_t i = 0; i < end - start; i++) {
pod->data[i] = _value->data[start + i];
}
return Array(TRY(MkGcRoot(pod)));
}
virtual Result<Value> copy_value() const final; virtual Result<Value> copy_value() const final;
Result<Stack> copy() const; Result<Stack> copy() const;

View file

@ -6,14 +6,16 @@ struct Context {
Context() {} Context() {}
Context(Value&& env, Array&& constants, Dict&& constants_dict, Context(Value&& env, Array&& constants, Dict&& constants_dict,
Dict&& variables_dict, Array&& closures, Dict&& closures_dict, Context* parent) Dict&& variables_dict, Array&& closures, Dict&& closures_dict,
Context* parent)
: env(std::move(env)), : env(std::move(env)),
constants(std::move(constants)), constants(std::move(constants)),
constants_dict(std::move(constants_dict)), constants_dict(std::move(constants_dict)),
variables_dict(std::move(variables_dict)), variables_dict(std::move(variables_dict)),
closures(std::move(closures)), closures(std::move(closures)),
closures_dict(std::move(closures_dict)), closures_dict(std::move(closures_dict)),
maxreg(0), parent(parent) {} maxreg(0),
parent(parent) {}
static Result<Context> create() { static Result<Context> create() {
auto env = TRY(Nil::create()); auto env = TRY(Nil::create());
@ -24,7 +26,8 @@ struct Context {
auto closures_dict = TRY(Dict::create()); auto closures_dict = TRY(Dict::create());
return Context(std::move(env), std::move(constants), return Context(std::move(env), std::move(constants),
std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), 0); std::move(constants_dict), std::move(variables_dict),
std::move(closures), std::move(closures_dict), 0);
} }
static Result<Context> create(Context& parent) { static Result<Context> create(Context& parent) {
@ -36,7 +39,8 @@ struct Context {
auto closures_dict = TRY(Dict::create()); auto closures_dict = TRY(Dict::create());
return Context(std::move(env), std::move(constants), return Context(std::move(env), std::move(constants),
std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), &parent); std::move(constants_dict), std::move(variables_dict),
std::move(closures), std::move(closures_dict), &parent);
} }
uint64_t alloc_reg() { uint64_t alloc_reg() {
@ -146,7 +150,8 @@ Result<Value> Compiler::compile(Value& expr) {
// TRY(debug_print(context.constants)); // TRY(debug_print(context.constants));
// TRY(debug_print(ex.code)); // TRY(debug_print(ex.code));
auto fun = TRY(Function::create(name, 0, context.constants, ex.code)); auto fun = TRY(Function::create(name, 0, context.constants, ex.code,
TRY(Array::create())));
return Value(std::move(fun)); return Value(std::move(fun));
} }
@ -353,7 +358,8 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0})); TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0}));
Value name = TRY(Nil::create()); Value name = TRY(Nil::create());
auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code)); auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code,
TRY(Array::create())));
// std::cout << "--------------- LAMBDA " << arity << "\n"; // std::cout << "--------------- LAMBDA " << arity << "\n";
// TRY(debug_print(TRY(expr.copy_value()))); // TRY(debug_print(TRY(expr.copy_value())));
@ -384,7 +390,8 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)vr}, {0, (int64_t)var_reg})); TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)vr}, {0, (int64_t)var_reg}));
} }
TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg}, {0, (int64_t)reg + (int64_t)ctx.closures.size() + 1})); TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg},
{0, (int64_t)reg + (int64_t)ctx.closures.size() + 1}));
context.maxreg = reg; context.maxreg = reg;
ex_res.reg = reg; ex_res.reg = reg;
@ -512,7 +519,8 @@ Result<Expression> Compiler::compile_symbol(Context& context, Symbol& value) {
auto var_closure = maybe_closure.value(); auto var_closure = maybe_closure.value();
uint64_t reg = context.alloc_reg(); uint64_t reg = context.alloc_reg();
TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg}, {0, (int64_t)var_closure})); TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg},
{0, (int64_t)var_closure}));
ex.reg = reg; ex.reg = reg;
return std::move(ex); return std::move(ex);

View file

@ -165,6 +165,7 @@ class PodFunction final : public PodObject {
OffPtr<PodObject> name; OffPtr<PodObject> name;
OffPtr<PodObject> constants; OffPtr<PodObject> constants;
OffPtr<PodObject> code; OffPtr<PodObject> code;
OffPtr<PodObject> closure;
}; };
class PodStack final : public PodObject { class PodStack final : public PodObject {

View file

@ -137,6 +137,24 @@ Result<void> VM::vm_jump(Opcode& oc) {
return Result<void>(); return Result<void>();
} }
Result<void> VM::vm_make_closure(Opcode& oc) {
uint64_t begin = (uint64_t)_base + oc.arg1().arg + 1;
uint64_t end = (uint64_t)_base + oc.arg2().arg;
Value fun_proto = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
if (!fun_proto.is<Function>()) return ERROR(TypeMismatch);
auto closure = TRY(_stack.slice(begin, end));
Value fun = TRY(Function::create(*fun_proto.to<Function>(), closure));
setreg(oc.arg1().arg, fun);
_pc++;
return Result<void>();
}
Result<void> VM::step() { Result<void> VM::step() {
auto opcode = TRY(_code.get(_pc)); auto opcode = TRY(_code.get(_pc));
if (!opcode.is<Opcode>()) return ERROR(TypeMismatch); if (!opcode.is<Opcode>()) return ERROR(TypeMismatch);
@ -170,6 +188,9 @@ Result<void> VM::step() {
case Oc::Call: case Oc::Call:
TRY(vm_call(oc)); TRY(vm_call(oc));
break; break;
case Oc::MakeClosure:
TRY(vm_make_closure(oc));
break;
default: default:
return ERROR(NotImplemented); return ERROR(NotImplemented);
} }

View file

@ -31,6 +31,8 @@ class VM {
Result<void> vm_jump_not_equal(Opcode& oc); Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc); Result<void> vm_jump(Opcode& oc);
Result<void> vm_make_closure(Opcode& oc);
Result<Value> get(bool is_const, uint64_t idx); Result<Value> get(bool is_const, uint64_t idx);
Result<Value> getconst(uint64_t idx); Result<Value> getconst(uint64_t idx);
Result<Value> getreg(uint64_t idx); Result<Value> getreg(uint64_t idx);