From 2e62a67490b52d722bbb2739a114134cf542e6c8 Mon Sep 17 00:00:00 2001 From: Konstantin Nazarov Date: Wed, 28 Aug 2024 23:57:01 +0100 Subject: [PATCH] Implement "let" construct in the compiler --- CMakeLists.txt | 1 + src/common.cpp | 47 ++++++++++++++++++++++++++++++-- src/common.hpp | 14 ++++++++-- src/compiler.cpp | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ src/compiler.hpp | 1 + src/stdlib.cpp | 14 ++++++++++ src/stdlib.hpp | 2 ++ test/dict.vli | 6 +++++ 8 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 test/dict.vli diff --git a/CMakeLists.txt b/CMakeLists.txt index 12f8951..5cb717f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ set(CPP_TESTS set(LISP_TESTS numeric logic + dict ) diff --git a/src/common.cpp b/src/common.cpp index 33e82d1..3883360 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -111,6 +111,8 @@ Result Dict::copy_value() const { return Value(Dict(TRY(_value.copy()))); } +Result Dict::copy() const { return Dict(TRY(_value.copy())); } + Result String::copy_value() const { return Value(String(TRY(_value.copy()))); } @@ -127,6 +129,8 @@ Result Pair::copy_value() const { return Value(Pair(TRY(_value.copy()))); } +Result Pair::copy() const { return Pair(TRY(_value.copy())); } + Result Bool::copy_value() const { return Value(Bool(TRY(_value.copy()))); } @@ -239,15 +243,49 @@ Result Pair::create(Value& first, Value& rest) { return Pair(TRY(MkGcRoot(pod))); } -Result Pair::first() { +Result Pair::create(const Array& arr) { + Value cur = TRY(Nil::create()); + + for (uint64_t i = 0; i < arr.size(); i++) { + Value arr_elt = TRY(arr.get(arr.size() - i - 1)); + cur = TRY(Pair::create(arr_elt, cur)); + } + + Pair& res = *cur.to(); + return TRY(res.copy()); +} + +Result Pair::first() const { auto val = _value->first.get(); return Value::create(val); } -Result Pair::rest() { + +Result Pair::second() const { return TRY(TRY(rest()).first()); } + +Result Pair::rest() const { auto val = _value->rest.get(); return Value::create(val); } +Result Pair::get(const Value& key) const { + if (!key.is()) return ERROR(TypeMismatch); + + uint64_t i = key.to()->value(); + + if (i == 0) return first(); + i--; + + Value cur = TRY(rest()); + + while (!cur.is()) { + if (i == 0) return cur.first(); + i--; + cur = TRY(cur.rest()); + } + + return ERROR(KeyError); +} + Result Function::create(const Value& name, uint64_t arity, const Array& constants, const Array& code, const Array& closure) { @@ -630,3 +668,8 @@ Result Array::cmp(const Array& rhs) const { Result Object::get(const Value& key) const { return ERROR(TypeMismatch); } + +Result Object::first() const { return ERROR(TypeMismatch); } +Result Object::second() const { return ERROR(TypeMismatch); } + +Result Object::rest() const { return ERROR(TypeMismatch); } diff --git a/src/common.hpp b/src/common.hpp index e268729..e021eb9 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -110,6 +110,9 @@ class Object { virtual Result div_inv(const Float&) const; virtual Result get(const Value& key) const; + virtual Result first() const; + virtual Result second() const; + virtual Result rest() const; Object() = default; Object(const Object&) = delete; @@ -670,9 +673,13 @@ class Pair : public Object { static Result create(PodPair* obj) { return Pair(TRY(MkGcRoot(obj))); } static Result create(Value& first, Value& rest); + static Result create(const Array& arr); - Result first(); - Result rest(); + virtual Result get(const Value& key) const final; + + Result first() const final; + Result second() const final; + Result rest() const final; virtual Result copy_value() const final; Result copy() const; @@ -1135,6 +1142,9 @@ class Value { Value k = TRY(Int64::create(key)); return ((Object*)buf)->get(k); } + Result first() { return ((Object*)buf)->first(); } + Result second() { return ((Object*)buf)->second(); } + Result rest() { return ((Object*)buf)->rest(); } // TODO: cmp() probably doesn't need arena parameter // Result operator==(Value& rhs) { return TRY(cmp(rhs)) == 0; } diff --git a/src/compiler.cpp b/src/compiler.cpp index faa7738..95255d0 100644 --- a/src/compiler.cpp +++ b/src/compiler.cpp @@ -88,6 +88,15 @@ struct Context { return i; } + Result update_var(const Value& sym) { + int64_t i = maxreg; + + variables_dict = TRY(variables_dict.insert(sym, TRY(Value::create(i)))); + + maxreg++; + return i; + } + Result get_var(const Value& sym) { auto idx = variables_dict.get(sym); if (idx.has_value()) { @@ -668,6 +677,65 @@ Result Compiler::compile_fn(Context& context, Symbol& op, return ex_res; } +Result Compiler::compile_let(Context& context, Symbol& op, + Pair& expr) { + // Save the variable bindings to restore it later + Dict saved_vars = TRY(context.variables_dict.copy()); + + auto first = TRY(expr.rest()); + + if (first.is() || !first.is()) { + return ERROR(CompilationError); + } + + uint64_t maxreg = context.maxreg; + + auto bindings = TRY(first.first()); + + Expression ex_res = TRY(Expression::create()); + + while (!bindings.is()) { + auto binding = TRY(bindings.first()); + + if (!binding.is()) { + return ERROR(CompilationError); + } + + auto binding_name = TRY(binding.first()); + auto binding_expr = TRY(binding.second()); + + if (!binding_name.is()) return ERROR(CompilationError); + + int64_t reg = TRY(context.update_var(binding_name)); + auto ex = TRY(compile_expr(context, binding_expr)); + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {0, (int64_t)ex.reg})); + context.maxreg = reg + 1; + + ex_res.add_code(ex.code); + + bindings = TRY(bindings.rest()); + } + + Value second = TRY(first.rest()); + + if (second.is() || !second.is()) { + return ERROR(CompilationError); + } + + Pair& second_pair = *second.to(); + + auto ex = TRY(compile_body(context, second_pair)); + + TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)maxreg}, {0, (int64_t)ex.reg})); + ex_res.add_code(ex.code); + + // Restore the variables back + context.variables_dict = std::move(saved_vars); + ex_res.reg = maxreg; + + return std::move(ex_res); +} + Result Compiler::compile_body(Context& context, Pair& expr) { auto cur = TRY(expr.copy_value()); @@ -751,6 +819,8 @@ Result Compiler::compile_list(Context& context, Pair& expr) { return compile_not(context, sym, expr); } else if (TRY(sym.cmp("fn")) == 0) { return compile_fn(context, sym, expr); + } else if (TRY(sym.cmp("let")) == 0) { + return compile_let(context, sym, expr); } else { return compile_function_call(context, expr); } diff --git a/src/compiler.hpp b/src/compiler.hpp index bf39865..9069db0 100644 --- a/src/compiler.hpp +++ b/src/compiler.hpp @@ -42,6 +42,7 @@ class Compiler { Result compile_or(Context& context, Symbol& op, Pair& expr); Result compile_not(Context& context, Symbol& op, Pair& expr); Result compile_fn(Context& context, Symbol& op, Pair& expr); + Result compile_let(Context& context, Symbol& op, Pair& expr); Result compile_body(Context& context, Pair& expr); Result compile_function_call(Context& context, Pair& expr); }; diff --git a/src/stdlib.cpp b/src/stdlib.cpp index 17222e3..bee33ac 100644 --- a/src/stdlib.cpp +++ b/src/stdlib.cpp @@ -69,6 +69,18 @@ Result stdlib_dict(const Array& params) { return d; } +Result stdlib_list(const Array& params) { + Value d = TRY(Pair::create(params)); + return d; +} + +Result stdlib_get(const Array& params) { + if (params.size() != 2) return ERROR(ArgumentCountMismatch); + Value collection = TRY(params.get(0)); + Value key = TRY(params.get(1)); + return TRY(collection.get(key)); +} + #define STDLIB_FUNCTION(name, id) \ [(uint64_t)StdlibFunctionId::id] = {#name, StdlibFunctionId::id, \ stdlib_##name} @@ -80,6 +92,8 @@ static StdlibFunctionEntry function_entries[] = { STDLIB_FUNCTION(prn, Prn), STDLIB_FUNCTION(assert, Assert), STDLIB_FUNCTION(dict, Dict), + STDLIB_FUNCTION(list, List), + STDLIB_FUNCTION(get, Get), [(uint64_t)StdlibFunctionId::Max] = {0, StdlibFunctionId::Max, stdlib_unknown}, }; diff --git a/src/stdlib.hpp b/src/stdlib.hpp index 9403ecb..db1ceb6 100644 --- a/src/stdlib.hpp +++ b/src/stdlib.hpp @@ -11,6 +11,8 @@ enum class StdlibFunctionId : uint64_t { Prn, Assert, Dict, + List, + Get, Max, }; diff --git a/test/dict.vli b/test/dict.vli new file mode 100644 index 0000000..d64169f --- /dev/null +++ b/test/dict.vli @@ -0,0 +1,6 @@ +;; -*- mode: lisp; -*- + +(let ((d (dict 1 2 3 4))) + (assert (= (get d 1) 2)) + (assert (= (get d 3) 4)) + )