Implement "let" construct in the compiler

This commit is contained in:
Konstantin Nazarov 2024-08-28 23:57:01 +01:00
parent f78b6e67cd
commit 2e62a67490
Signed by: knazarov
GPG key ID: 4CFE0A42FA409C22
8 changed files with 151 additions and 4 deletions

View file

@ -63,6 +63,7 @@ set(CPP_TESTS
set(LISP_TESTS
numeric
logic
dict
)

View file

@ -111,6 +111,8 @@ Result<Value> Dict::copy_value() const {
return Value(Dict(TRY(_value.copy())));
}
Result<Dict> Dict::copy() const { return Dict(TRY(_value.copy())); }
Result<Value> String::copy_value() const {
return Value(String(TRY(_value.copy())));
}
@ -127,6 +129,8 @@ Result<Value> Pair::copy_value() const {
return Value(Pair(TRY(_value.copy())));
}
Result<Pair> Pair::copy() const { return Pair(TRY(_value.copy())); }
Result<Value> Bool::copy_value() const {
return Value(Bool(TRY(_value.copy())));
}
@ -239,15 +243,49 @@ Result<Pair> Pair::create(Value& first, Value& rest) {
return Pair(TRY(MkGcRoot(pod)));
}
Result<Value> Pair::first() {
Result<Pair> Pair::create(const Array& arr) {
Value cur = TRY(Nil::create());
for (uint64_t i = 0; i < arr.size(); i++) {
Value arr_elt = TRY(arr.get(arr.size() - i - 1));
cur = TRY(Pair::create(arr_elt, cur));
}
Pair& res = *cur.to<Pair>();
return TRY(res.copy());
}
Result<Value> Pair::first() const {
auto val = _value->first.get();
return Value::create(val);
}
Result<Value> Pair::rest() {
Result<Value> Pair::second() const { return TRY(TRY(rest()).first()); }
Result<Value> Pair::rest() const {
auto val = _value->rest.get();
return Value::create(val);
}
Result<Value> Pair::get(const Value& key) const {
if (!key.is<Int64>()) return ERROR(TypeMismatch);
uint64_t i = key.to<Int64>()->value();
if (i == 0) return first();
i--;
Value cur = TRY(rest());
while (!cur.is<Nil>()) {
if (i == 0) return cur.first();
i--;
cur = TRY(cur.rest());
}
return ERROR(KeyError);
}
Result<Function> Function::create(const Value& name, uint64_t arity,
const Array& constants, const Array& code,
const Array& closure) {
@ -630,3 +668,8 @@ Result<short> Array::cmp(const Array& rhs) const {
Result<Value> Object::get(const Value& key) const {
return ERROR(TypeMismatch);
}
Result<Value> Object::first() const { return ERROR(TypeMismatch); }
Result<Value> Object::second() const { return ERROR(TypeMismatch); }
Result<Value> Object::rest() const { return ERROR(TypeMismatch); }

View file

@ -110,6 +110,9 @@ class Object {
virtual Result<Value> div_inv(const Float&) const;
virtual Result<Value> get(const Value& key) const;
virtual Result<Value> first() const;
virtual Result<Value> second() const;
virtual Result<Value> rest() const;
Object() = default;
Object(const Object&) = delete;
@ -670,9 +673,13 @@ class Pair : public Object {
static Result<Pair> create(PodPair* obj) { return Pair(TRY(MkGcRoot(obj))); }
static Result<Pair> create(Value& first, Value& rest);
static Result<Pair> create(const Array& arr);
Result<Value> first();
Result<Value> rest();
virtual Result<Value> get(const Value& key) const final;
Result<Value> first() const final;
Result<Value> second() const final;
Result<Value> rest() const final;
virtual Result<Value> copy_value() const final;
Result<Pair> copy() const;
@ -1135,6 +1142,9 @@ class Value {
Value k = TRY(Int64::create(key));
return ((Object*)buf)->get(k);
}
Result<Value> first() { return ((Object*)buf)->first(); }
Result<Value> second() { return ((Object*)buf)->second(); }
Result<Value> rest() { return ((Object*)buf)->rest(); }
// TODO: cmp() probably doesn't need arena parameter
// Result<bool> operator==(Value& rhs) { return TRY(cmp(rhs)) == 0; }

View file

@ -88,6 +88,15 @@ struct Context {
return i;
}
Result<int64_t> update_var(const Value& sym) {
int64_t i = maxreg;
variables_dict = TRY(variables_dict.insert(sym, TRY(Value::create(i))));
maxreg++;
return i;
}
Result<int64_t> get_var(const Value& sym) {
auto idx = variables_dict.get(sym);
if (idx.has_value()) {
@ -668,6 +677,65 @@ Result<Expression> Compiler::compile_fn(Context& context, Symbol& op,
return ex_res;
}
Result<Expression> Compiler::compile_let(Context& context, Symbol& op,
Pair& expr) {
// Save the variable bindings to restore it later
Dict saved_vars = TRY(context.variables_dict.copy());
auto first = TRY(expr.rest());
if (first.is<Nil>() || !first.is<Pair>()) {
return ERROR(CompilationError);
}
uint64_t maxreg = context.maxreg;
auto bindings = TRY(first.first());
Expression ex_res = TRY(Expression::create());
while (!bindings.is<Nil>()) {
auto binding = TRY(bindings.first());
if (!binding.is<Pair>()) {
return ERROR(CompilationError);
}
auto binding_name = TRY(binding.first());
auto binding_expr = TRY(binding.second());
if (!binding_name.is<Symbol>()) return ERROR(CompilationError);
int64_t reg = TRY(context.update_var(binding_name));
auto ex = TRY(compile_expr(context, binding_expr));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)reg}, {0, (int64_t)ex.reg}));
context.maxreg = reg + 1;
ex_res.add_code(ex.code);
bindings = TRY(bindings.rest());
}
Value second = TRY(first.rest());
if (second.is<Nil>() || !second.is<Pair>()) {
return ERROR(CompilationError);
}
Pair& second_pair = *second.to<Pair>();
auto ex = TRY(compile_body(context, second_pair));
TRY(ex.add_opcode(Oc::Mov, {0, (int64_t)maxreg}, {0, (int64_t)ex.reg}));
ex_res.add_code(ex.code);
// Restore the variables back
context.variables_dict = std::move(saved_vars);
ex_res.reg = maxreg;
return std::move(ex_res);
}
Result<Expression> Compiler::compile_body(Context& context, Pair& expr) {
auto cur = TRY(expr.copy_value());
@ -751,6 +819,8 @@ Result<Expression> Compiler::compile_list(Context& context, Pair& expr) {
return compile_not(context, sym, expr);
} else if (TRY(sym.cmp("fn")) == 0) {
return compile_fn(context, sym, expr);
} else if (TRY(sym.cmp("let")) == 0) {
return compile_let(context, sym, expr);
} else {
return compile_function_call(context, expr);
}

View file

@ -42,6 +42,7 @@ class Compiler {
Result<Expression> compile_or(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_not(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_fn(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_let(Context& context, Symbol& op, Pair& expr);
Result<Expression> compile_body(Context& context, Pair& expr);
Result<Expression> compile_function_call(Context& context, Pair& expr);
};

View file

@ -69,6 +69,18 @@ Result<Value> stdlib_dict(const Array& params) {
return d;
}
Result<Value> stdlib_list(const Array& params) {
Value d = TRY(Pair::create(params));
return d;
}
Result<Value> stdlib_get(const Array& params) {
if (params.size() != 2) return ERROR(ArgumentCountMismatch);
Value collection = TRY(params.get(0));
Value key = TRY(params.get(1));
return TRY(collection.get(key));
}
#define STDLIB_FUNCTION(name, id) \
[(uint64_t)StdlibFunctionId::id] = {#name, StdlibFunctionId::id, \
stdlib_##name}
@ -80,6 +92,8 @@ static StdlibFunctionEntry function_entries[] = {
STDLIB_FUNCTION(prn, Prn),
STDLIB_FUNCTION(assert, Assert),
STDLIB_FUNCTION(dict, Dict),
STDLIB_FUNCTION(list, List),
STDLIB_FUNCTION(get, Get),
[(uint64_t)StdlibFunctionId::Max] = {0, StdlibFunctionId::Max,
stdlib_unknown},
};

View file

@ -11,6 +11,8 @@ enum class StdlibFunctionId : uint64_t {
Prn,
Assert,
Dict,
List,
Get,
Max,
};

6
test/dict.vli Normal file
View file

@ -0,0 +1,6 @@
;; -*- mode: lisp; -*-
(let ((d (dict 1 2 3 4)))
(assert (= (get d 1) 2))
(assert (= (get d 3) 4))
)