Implement VM opcode for closure capture

This commit is contained in:
Konstantin Nazarov 2024-08-21 23:45:19 +01:00
parent edc0a89ed9
commit 6fcd231694
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
7 changed files with 99 additions and 25 deletions

View file

@ -203,6 +203,7 @@ Result<PodObject*> Arena::gc_function(PodFunction* obj) {
nobj->name = TRY(gc_pod(obj->name.get())); nobj->name = TRY(gc_pod(obj->name.get()));
nobj->constants = TRY(gc_pod(obj->constants.get())); nobj->constants = TRY(gc_pod(obj->constants.get()));
nobj->code = TRY(gc_pod(obj->code.get())); nobj->code = TRY(gc_pod(obj->code.get()));
nobj->closure = TRY(gc_pod(obj->closure.get()));
return nobj; return nobj;
} }

View file

@ -149,6 +149,11 @@ Result<short> Function::cmp(const Function& rhs) const {
res = TRY(lhs_code.cmp(rhs_code)); res = TRY(lhs_code.cmp(rhs_code));
if (res != 0) return res; if (res != 0) return res;
auto lhs_closure = TRY(closure());
auto rhs_closure = TRY(rhs.closure());
res = TRY(lhs_closure.cmp(rhs_closure));
if (res != 0) return res;
auto lhs_constants = TRY(constants()); auto lhs_constants = TRY(constants());
auto rhs_constants = TRY(rhs.constants()); auto rhs_constants = TRY(rhs.constants());
res = TRY(lhs_constants.cmp(rhs_constants)); res = TRY(lhs_constants.cmp(rhs_constants));
@ -205,13 +210,28 @@ Result<Value> Pair::rest() {
} }
Result<Function> Function::create(const Value& name, uint64_t arity, Result<Function> Function::create(const Value& name, uint64_t arity,
const Array& constants, const Array& code) { const Array& constants, const Array& code,
const Array& closure) {
auto pod = TRY(arena_alloc<PodFunction>()); auto pod = TRY(arena_alloc<PodFunction>());
pod->header.tag = Tag::Function; pod->header.tag = Tag::Function;
pod->arity = arity; pod->arity = arity;
pod->name = name.pod(); pod->name = name.pod();
pod->constants = constants.pod(); pod->constants = constants.pod();
pod->code = code.pod(); pod->code = code.pod();
pod->closure = closure.pod();
return Function(TRY(MkGcRoot(pod)));
}
Result<Function> Function::create(const Function& prototype,
const Array& closure) {
auto pod = TRY(arena_alloc<PodFunction>());
pod->header.tag = Tag::Function;
pod->arity = prototype._value->arity;
pod->name = prototype._value->name;
pod->constants = prototype._value->constants;
pod->code = prototype._value->code;
pod->closure = closure.pod();
return Function(TRY(MkGcRoot(pod))); return Function(TRY(MkGcRoot(pod)));
} }
@ -228,6 +248,10 @@ Result<Array> Function::constants() const {
return Array::create((PodArray*)_value->constants.get()); return Array::create((PodArray*)_value->constants.get());
} }
Result<Array> Function::closure() const {
return Array::create((PodArray*)_value->closure.get());
}
Result<Value> reverse(Value& val) { Result<Value> reverse(Value& val) {
if (val.is<Nil>()) return Value(TRY(Nil::create())); if (val.is<Nil>()) return Value(TRY(Nil::create()));
if (!val.is<Pair>()) return ERROR(TypeMismatch); if (!val.is<Pair>()) return ERROR(TypeMismatch);

View file

@ -867,14 +867,19 @@ class Function : public Object {
return Function(TRY(MkGcRoot(obj))); return Function(TRY(MkGcRoot(obj)));
} }
static Result<Function> create(const Function& prototype,
const Array& closure);
static Result<Function> create(const Value& name, uint64_t arity, static Result<Function> create(const Value& name, uint64_t arity,
const Array& constants, const Array& code); const Array& constants, const Array& code,
const Array& closure);
uint64_t arity() const { return _value->arity; } uint64_t arity() const { return _value->arity; }
Result<Value> name() const; Result<Value> name() const;
Result<Array> code() const; Result<Array> code() const;
Result<Array> constants() const; Result<Array> constants() const;
Result<Array> closure() const;
virtual Result<Value> copy_value() const final; virtual Result<Value> copy_value() const final;
Result<Function> copy() const; Result<Function> copy() const;
@ -922,6 +927,18 @@ class Stack : public Object {
return Stack(TRY(MkGcRoot(pod))); return Stack(TRY(MkGcRoot(pod)));
} }
Result<Array> slice(uint64_t start, uint64_t end) {
if (start > end || end > gettop()) return ERROR(IndexOutOfRange);
uint64_t res_size = end - start;
auto pod = TRY(arena_alloc<PodArray>(res_size * sizeof(PodObject*)));
pod->size = res_size;
for (uint64_t i = 0; i < end - start; i++) {
pod->data[i] = _value->data[start + i];
}
return Array(TRY(MkGcRoot(pod)));
}
virtual Result<Value> copy_value() const final; virtual Result<Value> copy_value() const final;
Result<Stack> copy() const; Result<Stack> copy() const;

View file

@ -6,14 +6,16 @@ struct Context {
Context() {} Context() {}
Context(Value&& env, Array&& constants, Dict&& constants_dict, Context(Value&& env, Array&& constants, Dict&& constants_dict,
Dict&& variables_dict, Array&& closures, Dict&& closures_dict, Context* parent) Dict&& variables_dict, Array&& closures, Dict&& closures_dict,
Context* parent)
: env(std::move(env)), : env(std::move(env)),
constants(std::move(constants)), constants(std::move(constants)),
constants_dict(std::move(constants_dict)), constants_dict(std::move(constants_dict)),
variables_dict(std::move(variables_dict)), variables_dict(std::move(variables_dict)),
closures(std::move(closures)), closures(std::move(closures)),
closures_dict(std::move(closures_dict)), closures_dict(std::move(closures_dict)),
maxreg(0), parent(parent) {} maxreg(0),
parent(parent) {}
static Result<Context> create() { static Result<Context> create() {
auto env = TRY(Nil::create()); auto env = TRY(Nil::create());
@ -24,7 +26,8 @@ struct Context {
auto closures_dict = TRY(Dict::create()); auto closures_dict = TRY(Dict::create());
return Context(std::move(env), std::move(constants), return Context(std::move(env), std::move(constants),
std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), 0); std::move(constants_dict), std::move(variables_dict),
std::move(closures), std::move(closures_dict), 0);
} }
static Result<Context> create(Context& parent) { static Result<Context> create(Context& parent) {
@ -36,7 +39,8 @@ struct Context {
auto closures_dict = TRY(Dict::create()); auto closures_dict = TRY(Dict::create());
return Context(std::move(env), std::move(constants), return Context(std::move(env), std::move(constants),
std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), &parent); std::move(constants_dict), std::move(variables_dict),
std::move(closures), std::move(closures_dict), &parent);
} }
uint64_t alloc_reg() { uint64_t alloc_reg() {
@ -146,7 +150,8 @@ Result<Value> Compiler::compile(Value& expr) {
// TRY(debug_print(context.constants)); // TRY(debug_print(context.constants));
// TRY(debug_print(ex.code)); // TRY(debug_print(ex.code));
auto fun = TRY(Function::create(name, 0, context.constants, ex.code)); auto fun = TRY(Function::create(name, 0, context.constants, ex.code,
TRY(Array::create())));
return Value(std::move(fun)); return Value(std::move(fun));
} }
@ -353,7 +358,8 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0})); TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0}));
Value name = TRY(Nil::create()); Value name = TRY(Nil::create());
auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code)); auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code,
TRY(Array::create())));
// std::cout << "--------------- LAMBDA " << arity << "\n"; // std::cout << "--------------- LAMBDA " << arity << "\n";
// TRY(debug_print(TRY(expr.copy_value()))); // TRY(debug_print(TRY(expr.copy_value())));
@ -384,7 +390,8 @@ Result<Expression> Compiler::compile_lambda(Context& context, Symbol& op,
TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)vr}, {0, (int64_t)var_reg})); TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)vr}, {0, (int64_t)var_reg}));
} }
TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg}, {0, (int64_t)reg + (int64_t)ctx.closures.size() + 1})); TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg},
{0, (int64_t)reg + (int64_t)ctx.closures.size() + 1}));
context.maxreg = reg; context.maxreg = reg;
ex_res.reg = reg; ex_res.reg = reg;
@ -512,7 +519,8 @@ Result<Expression> Compiler::compile_symbol(Context& context, Symbol& value) {
auto var_closure = maybe_closure.value(); auto var_closure = maybe_closure.value();
uint64_t reg = context.alloc_reg(); uint64_t reg = context.alloc_reg();
TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg}, {0, (int64_t)var_closure})); TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg},
{0, (int64_t)var_closure}));
ex.reg = reg; ex.reg = reg;
return std::move(ex); return std::move(ex);

View file

@ -66,33 +66,33 @@ struct PodObject {
class PodNil final : public PodObject { class PodNil final : public PodObject {
public: public:
PodNil() : PodObject(Tag::Nil){}; PodNil() : PodObject(Tag::Nil) {};
}; };
class PodInt64 final : public PodObject { class PodInt64 final : public PodObject {
public: public:
PodInt64() : PodObject(Tag::Int64){}; PodInt64() : PodObject(Tag::Int64) {};
int64_t value; int64_t value;
}; };
class PodFloat final : public PodObject { class PodFloat final : public PodObject {
public: public:
PodFloat() : PodObject(Tag::Float){}; PodFloat() : PodObject(Tag::Float) {};
double value; double value;
}; };
class PodBool final : public PodObject { class PodBool final : public PodObject {
public: public:
PodBool() : PodObject(Tag::Bool){}; PodBool() : PodObject(Tag::Bool) {};
bool value; bool value;
}; };
class PodArray final : public PodObject { class PodArray final : public PodObject {
public: public:
PodArray() : PodObject(Tag::Array){}; PodArray() : PodObject(Tag::Array) {};
uint64_t size; uint64_t size;
OffPtr<PodObject> data[]; OffPtr<PodObject> data[];
@ -100,7 +100,7 @@ class PodArray final : public PodObject {
class PodByteArray final : public PodObject { class PodByteArray final : public PodObject {
public: public:
PodByteArray() : PodObject(Tag::ByteArray){}; PodByteArray() : PodObject(Tag::ByteArray) {};
uint64_t size; uint64_t size;
char data[]; char data[];
@ -108,7 +108,7 @@ class PodByteArray final : public PodObject {
class PodDict final : public PodObject { class PodDict final : public PodObject {
public: public:
PodDict() : PodObject(Tag::Dict){}; PodDict() : PodObject(Tag::Dict) {};
uint64_t size; uint64_t size;
OffPtr<PodObject> data[]; OffPtr<PodObject> data[];
@ -116,7 +116,7 @@ class PodDict final : public PodObject {
class PodString final : public PodObject { class PodString final : public PodObject {
public: public:
PodString() : PodObject(Tag::String){}; PodString() : PodObject(Tag::String) {};
uint64_t size; uint64_t size;
char32_t data[]; char32_t data[];
@ -124,7 +124,7 @@ class PodString final : public PodObject {
class PodSymbol final : public PodObject { class PodSymbol final : public PodObject {
public: public:
PodSymbol() : PodObject(Tag::Symbol){}; PodSymbol() : PodObject(Tag::Symbol) {};
uint64_t size; uint64_t size;
char32_t data[]; char32_t data[];
@ -132,7 +132,7 @@ class PodSymbol final : public PodObject {
class PodSyntax : public PodObject { class PodSyntax : public PodObject {
public: public:
PodSyntax() : PodObject(Tag::Syntax){}; PodSyntax() : PodObject(Tag::Syntax) {};
OffPtr<PodString> filename; OffPtr<PodString> filename;
OffPtr<PodString> modulename; OffPtr<PodString> modulename;
OffPtr<PodObject> expression; OffPtr<PodObject> expression;
@ -141,14 +141,14 @@ class PodSyntax : public PodObject {
class PodPair : public PodObject { class PodPair : public PodObject {
public: public:
PodPair() : PodObject(Tag::Pair){}; PodPair() : PodObject(Tag::Pair) {};
OffPtr<PodObject> first; OffPtr<PodObject> first;
OffPtr<PodObject> rest; OffPtr<PodObject> rest;
}; };
class PodOpcode final : public PodObject { class PodOpcode final : public PodObject {
public: public:
PodOpcode() : PodObject(Tag::Opcode){}; PodOpcode() : PodObject(Tag::Opcode) {};
Oc opcode; Oc opcode;
OpArg arg1; OpArg arg1;
@ -159,17 +159,18 @@ class PodOpcode final : public PodObject {
class PodFunction final : public PodObject { class PodFunction final : public PodObject {
public: public:
PodFunction() : PodObject(Tag::Function){}; PodFunction() : PodObject(Tag::Function) {};
uint64_t arity; uint64_t arity;
OffPtr<PodObject> name; OffPtr<PodObject> name;
OffPtr<PodObject> constants; OffPtr<PodObject> constants;
OffPtr<PodObject> code; OffPtr<PodObject> code;
OffPtr<PodObject> closure;
}; };
class PodStack final : public PodObject { class PodStack final : public PodObject {
public: public:
PodStack() : PodObject(Tag::Stack){}; PodStack() : PodObject(Tag::Stack) {};
uint64_t size; uint64_t size;
uint64_t top; uint64_t top;
@ -180,7 +181,7 @@ template <class T>
class Ptr { class Ptr {
public: public:
Ptr() : _ptr(0) {} Ptr() : _ptr(0) {}
Ptr(T* ptr) : _ptr(ptr){}; Ptr(T* ptr) : _ptr(ptr) {};
T* get() { return _ptr; } T* get() { return _ptr; }
T& operator*() { return *_ptr; } T& operator*() { return *_ptr; }

View file

@ -137,6 +137,24 @@ Result<void> VM::vm_jump(Opcode& oc) {
return Result<void>(); return Result<void>();
} }
Result<void> VM::vm_make_closure(Opcode& oc) {
uint64_t begin = (uint64_t)_base + oc.arg1().arg + 1;
uint64_t end = (uint64_t)_base + oc.arg2().arg;
Value fun_proto = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg));
if (!fun_proto.is<Function>()) return ERROR(TypeMismatch);
auto closure = TRY(_stack.slice(begin, end));
Value fun = TRY(Function::create(*fun_proto.to<Function>(), closure));
setreg(oc.arg1().arg, fun);
_pc++;
return Result<void>();
}
Result<void> VM::step() { Result<void> VM::step() {
auto opcode = TRY(_code.get(_pc)); auto opcode = TRY(_code.get(_pc));
if (!opcode.is<Opcode>()) return ERROR(TypeMismatch); if (!opcode.is<Opcode>()) return ERROR(TypeMismatch);
@ -170,6 +188,9 @@ Result<void> VM::step() {
case Oc::Call: case Oc::Call:
TRY(vm_call(oc)); TRY(vm_call(oc));
break; break;
case Oc::MakeClosure:
TRY(vm_make_closure(oc));
break;
default: default:
return ERROR(NotImplemented); return ERROR(NotImplemented);
} }

View file

@ -31,6 +31,8 @@ class VM {
Result<void> vm_jump_not_equal(Opcode& oc); Result<void> vm_jump_not_equal(Opcode& oc);
Result<void> vm_jump(Opcode& oc); Result<void> vm_jump(Opcode& oc);
Result<void> vm_make_closure(Opcode& oc);
Result<Value> get(bool is_const, uint64_t idx); Result<Value> get(bool is_const, uint64_t idx);
Result<Value> getconst(uint64_t idx); Result<Value> getconst(uint64_t idx);
Result<Value> getreg(uint64_t idx); Result<Value> getreg(uint64_t idx);