From 6fcd2316949cf687717cb7878506768ce8aa9e9f Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Wed, 21 Aug 2024 23:45:19 +0100 Subject: [PATCH] Implement VM opcode for closure capture --- src/arena.cpp | 1 + src/common.cpp | 26 +++++++++++++++++++++++++- src/common.hpp | 19 ++++++++++++++++++- src/compiler.cpp | 24 ++++++++++++++++-------- src/pod.hpp | 31 ++++++++++++++++--------------- src/vm.cpp | 21 +++++++++++++++++++++ src/vm.hpp | 2 ++ 7 files changed, 99 insertions(+), 25 deletions(-) diff --git a/src/arena.cpp b/src/arena.cpp index 52860b5..6000574 100644 --- a/src/arena.cpp +++ b/src/arena.cpp @@ -203,6 +203,7 @@ Result Arena::gc_function(PodFunction* obj) { nobj->name = TRY(gc_pod(obj->name.get())); nobj->constants = TRY(gc_pod(obj->constants.get())); nobj->code = TRY(gc_pod(obj->code.get())); + nobj->closure = TRY(gc_pod(obj->closure.get())); return nobj; } diff --git a/src/common.cpp b/src/common.cpp index 7ce007a..f64b55d 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -149,6 +149,11 @@ Result Function::cmp(const Function& rhs) const { res = TRY(lhs_code.cmp(rhs_code)); if (res != 0) return res; + auto lhs_closure = TRY(closure()); + auto rhs_closure = TRY(rhs.closure()); + res = TRY(lhs_closure.cmp(rhs_closure)); + if (res != 0) return res; + auto lhs_constants = TRY(constants()); auto rhs_constants = TRY(rhs.constants()); res = TRY(lhs_constants.cmp(rhs_constants)); @@ -205,13 +210,28 @@ Result Pair::rest() { } Result Function::create(const Value& name, uint64_t arity, - const Array& constants, const Array& code) { + const Array& constants, const Array& code, + const Array& closure) { auto pod = TRY(arena_alloc()); pod->header.tag = Tag::Function; pod->arity = arity; pod->name = name.pod(); pod->constants = constants.pod(); pod->code = code.pod(); + pod->closure = closure.pod(); + + return Function(TRY(MkGcRoot(pod))); +} + +Result Function::create(const Function& prototype, + const Array& closure) { + auto pod = TRY(arena_alloc()); + pod->header.tag = Tag::Function; + pod->arity = prototype._value->arity; + pod->name = prototype._value->name; + pod->constants = prototype._value->constants; + pod->code = prototype._value->code; + pod->closure = closure.pod(); return Function(TRY(MkGcRoot(pod))); } @@ -228,6 +248,10 @@ Result Function::constants() const { return Array::create((PodArray*)_value->constants.get()); } +Result Function::closure() const { + return Array::create((PodArray*)_value->closure.get()); +} + Result reverse(Value& val) { if (val.is()) return Value(TRY(Nil::create())); if (!val.is()) return ERROR(TypeMismatch); diff --git a/src/common.hpp b/src/common.hpp index a591b62..a8710b7 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -867,14 +867,19 @@ class Function : public Object { return Function(TRY(MkGcRoot(obj))); } + static Result create(const Function& prototype, + const Array& closure); + static Result create(const Value& name, uint64_t arity, - const Array& constants, const Array& code); + const Array& constants, const Array& code, + const Array& closure); uint64_t arity() const { return _value->arity; } Result name() const; Result code() const; Result constants() const; + Result closure() const; virtual Result copy_value() const final; Result copy() const; @@ -922,6 +927,18 @@ class Stack : public Object { return Stack(TRY(MkGcRoot(pod))); } + Result slice(uint64_t start, uint64_t end) { + if (start > end || end > gettop()) return ERROR(IndexOutOfRange); + uint64_t res_size = end - start; + auto pod = TRY(arena_alloc(res_size * sizeof(PodObject*))); + pod->size = res_size; + for (uint64_t i = 0; i < end - start; i++) { + pod->data[i] = _value->data[start + i]; + } + + return Array(TRY(MkGcRoot(pod))); + } + virtual Result copy_value() const final; Result copy() const; diff --git a/src/compiler.cpp b/src/compiler.cpp index a963276..757bc5a 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -6,14 +6,16 @@ struct Context { Context() {} Context(Value&& env, Array&& constants, Dict&& constants_dict, - Dict&& variables_dict, Array&& closures, Dict&& closures_dict, Context* parent) + Dict&& variables_dict, Array&& closures, Dict&& closures_dict, + Context* parent) : env(std::move(env)), constants(std::move(constants)), constants_dict(std::move(constants_dict)), variables_dict(std::move(variables_dict)), closures(std::move(closures)), closures_dict(std::move(closures_dict)), - maxreg(0), parent(parent) {} + maxreg(0), + parent(parent) {} static Result create() { auto env = TRY(Nil::create()); @@ -24,7 +26,8 @@ struct Context { auto closures_dict = TRY(Dict::create()); return Context(std::move(env), std::move(constants), - std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), 0); + std::move(constants_dict), std::move(variables_dict), + std::move(closures), std::move(closures_dict), 0); } static Result create(Context& parent) { @@ -36,7 +39,8 @@ struct Context { auto closures_dict = TRY(Dict::create()); return Context(std::move(env), std::move(constants), - std::move(constants_dict), std::move(variables_dict), std::move(closures), std::move(closures_dict), &parent); + std::move(constants_dict), std::move(variables_dict), + std::move(closures), std::move(closures_dict), &parent); } uint64_t alloc_reg() { @@ -146,7 +150,8 @@ Result Compiler::compile(Value& expr) { // TRY(debug_print(context.constants)); // TRY(debug_print(ex.code)); - auto fun = TRY(Function::create(name, 0, context.constants, ex.code)); + auto fun = TRY(Function::create(name, 0, context.constants, ex.code, + TRY(Array::create()))); return Value(std::move(fun)); } @@ -353,7 +358,8 @@ Result Compiler::compile_lambda(Context& context, Symbol& op, TRY(ex.add_opcode(Oc::Ret, {0, (int64_t)0})); Value name = TRY(Nil::create()); - auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code)); + auto fun = TRY(Function::create(name, arity, ctx.constants, ex.code, + TRY(Array::create()))); // std::cout << "--------------- LAMBDA " << arity << "\n"; // TRY(debug_print(TRY(expr.copy_value()))); @@ -384,7 +390,8 @@ Result Compiler::compile_lambda(Context& context, Symbol& op, TRY(ex_res.add_opcode(Oc::Mov, {0, (int64_t)vr}, {0, (int64_t)var_reg})); } - TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg}, {0, (int64_t)reg + (int64_t)ctx.closures.size() + 1})); + TRY(ex_res.add_opcode(Oc::MakeClosure, {0, (int64_t)reg}, + {0, (int64_t)reg + (int64_t)ctx.closures.size() + 1})); context.maxreg = reg; ex_res.reg = reg; @@ -512,7 +519,8 @@ Result Compiler::compile_symbol(Context& context, Symbol& value) { auto var_closure = maybe_closure.value(); uint64_t reg = context.alloc_reg(); - TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg}, {0, (int64_t)var_closure})); + TRY(ex.add_opcode(Oc::ClosureLoad, {0, (int64_t)reg}, + {0, (int64_t)var_closure})); ex.reg = reg; return std::move(ex); diff --git a/src/pod.hpp b/src/pod.hpp index c9f0257..c1450b9 100644 --- a/src/pod.hpp +++ b/src/pod.hpp @@ -66,33 +66,33 @@ struct PodObject { class PodNil final : public PodObject { public: - PodNil() : PodObject(Tag::Nil){}; + PodNil() : PodObject(Tag::Nil) {}; }; class PodInt64 final : public PodObject { public: - PodInt64() : PodObject(Tag::Int64){}; + PodInt64() : PodObject(Tag::Int64) {}; int64_t value; }; class PodFloat final : public PodObject { public: - PodFloat() : PodObject(Tag::Float){}; + PodFloat() : PodObject(Tag::Float) {}; double value; }; class PodBool final : public PodObject { public: - PodBool() : PodObject(Tag::Bool){}; + PodBool() : PodObject(Tag::Bool) {}; bool value; }; class PodArray final : public PodObject { public: - PodArray() : PodObject(Tag::Array){}; + PodArray() : PodObject(Tag::Array) {}; uint64_t size; OffPtr data[]; @@ -100,7 +100,7 @@ class PodArray final : public PodObject { class PodByteArray final : public PodObject { public: - PodByteArray() : PodObject(Tag::ByteArray){}; + PodByteArray() : PodObject(Tag::ByteArray) {}; uint64_t size; char data[]; @@ -108,7 +108,7 @@ class PodByteArray final : public PodObject { class PodDict final : public PodObject { public: - PodDict() : PodObject(Tag::Dict){}; + PodDict() : PodObject(Tag::Dict) {}; uint64_t size; OffPtr data[]; @@ -116,7 +116,7 @@ class PodDict final : public PodObject { class PodString final : public PodObject { public: - PodString() : PodObject(Tag::String){}; + PodString() : PodObject(Tag::String) {}; uint64_t size; char32_t data[]; @@ -124,7 +124,7 @@ class PodString final : public PodObject { class PodSymbol final : public PodObject { public: - PodSymbol() : PodObject(Tag::Symbol){}; + PodSymbol() : PodObject(Tag::Symbol) {}; uint64_t size; char32_t data[]; @@ -132,7 +132,7 @@ class PodSymbol final : public PodObject { class PodSyntax : public PodObject { public: - PodSyntax() : PodObject(Tag::Syntax){}; + PodSyntax() : PodObject(Tag::Syntax) {}; OffPtr filename; OffPtr modulename; OffPtr expression; @@ -141,14 +141,14 @@ class PodSyntax : public PodObject { class PodPair : public PodObject { public: - PodPair() : PodObject(Tag::Pair){}; + PodPair() : PodObject(Tag::Pair) {}; OffPtr first; OffPtr rest; }; class PodOpcode final : public PodObject { public: - PodOpcode() : PodObject(Tag::Opcode){}; + PodOpcode() : PodObject(Tag::Opcode) {}; Oc opcode; OpArg arg1; @@ -159,17 +159,18 @@ class PodOpcode final : public PodObject { class PodFunction final : public PodObject { public: - PodFunction() : PodObject(Tag::Function){}; + PodFunction() : PodObject(Tag::Function) {}; uint64_t arity; OffPtr name; OffPtr constants; OffPtr code; + OffPtr closure; }; class PodStack final : public PodObject { public: - PodStack() : PodObject(Tag::Stack){}; + PodStack() : PodObject(Tag::Stack) {}; uint64_t size; uint64_t top; @@ -180,7 +181,7 @@ template class Ptr { public: Ptr() : _ptr(0) {} - Ptr(T* ptr) : _ptr(ptr){}; + Ptr(T* ptr) : _ptr(ptr) {}; T* get() { return _ptr; } T& operator*() { return *_ptr; } diff --git a/src/vm.cpp b/src/vm.cpp index bd6ad67..6a9b38f 100644 --- a/src/vm.cpp +++ b/src/vm.cpp @@ -137,6 +137,24 @@ Result VM::vm_jump(Opcode& oc) { return Result(); } +Result VM::vm_make_closure(Opcode& oc) { + uint64_t begin = (uint64_t)_base + oc.arg1().arg + 1; + uint64_t end = (uint64_t)_base + oc.arg2().arg; + + Value fun_proto = TRY(get(oc.arg1().is_const, (uint64_t)oc.arg1().arg)); + if (!fun_proto.is()) return ERROR(TypeMismatch); + + auto closure = TRY(_stack.slice(begin, end)); + + Value fun = TRY(Function::create(*fun_proto.to(), closure)); + + setreg(oc.arg1().arg, fun); + + _pc++; + + return Result(); +} + Result VM::step() { auto opcode = TRY(_code.get(_pc)); if (!opcode.is()) return ERROR(TypeMismatch); @@ -170,6 +188,9 @@ Result VM::step() { case Oc::Call: TRY(vm_call(oc)); break; + case Oc::MakeClosure: + TRY(vm_make_closure(oc)); + break; default: return ERROR(NotImplemented); } diff --git a/src/vm.hpp b/src/vm.hpp index 4d35f6e..6d78b8e 100644 --- a/src/vm.hpp +++ b/src/vm.hpp @@ -31,6 +31,8 @@ class VM { Result vm_jump_not_equal(Opcode& oc); Result vm_jump(Opcode& oc); + Result vm_make_closure(Opcode& oc); + Result get(bool is_const, uint64_t idx); Result getconst(uint64_t idx); Result getreg(uint64_t idx);