Skip to content

Commit

Permalink
improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ross96D committed Dec 22, 2024
1 parent 4a412e7 commit 302281b
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 40 deletions.
79 changes: 70 additions & 9 deletions bindings/go/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ extern bool rdb_go_callback(uintptr_t, struct Bytes, struct Bytes);
import "C"
import (
"errors"
"fmt"
"runtime/cgo"
"unsafe"
)
Expand Down Expand Up @@ -49,33 +50,90 @@ func fromCBytes(b C.struct_Bytes) []byte {
return C.GoBytes(unsafe.Pointer(b.ptr), C.int(b.len))
}

var (
ErrCloseDB = errors.New("database is closed")
ErrUnexpectedEOF = errors.New("unexpected end of file")
ErrUnseekable = errors.New("file is unseekable")
ErrBrokenPipe = errors.New("broken pipe")
ErrNotOpenForWriting = errors.New("file not open for writing")
ErrLockViolation = errors.New("lock violation in file")
ErrIsDir = errors.New("database path is directory")
ErrOutOfMemory = errors.New("out of memory")
ErrAccessDenied = errors.New("access denied")
ErrUnexpected = errors.New("unexpected")
ErrNotDocumented = errors.New("error code not documented")
)

func rdb_error() error {
errcode := C.rdb_error_code()
switch errcode {
case 0:
return nil
case 1:
return ErrCloseDB
case 2:
return ErrUnexpectedEOF
case 50:
return ErrUnseekable
case 51:
return ErrBrokenPipe
case 52:
return ErrNotOpenForWriting
case 53:
return ErrLockViolation
case 54:
return ErrIsDir
case 55:
return ErrOutOfMemory
case 56:
return ErrAccessDenied
case 99:
return ErrUnexpected
case 100:
return ErrNotDocumented
default:
return fmt.Errorf("error code %d %w", errcode, ErrNotDocumented)
}
}

type Database struct {
pointer unsafe.Pointer
}

func New(path []byte) (Database, error) {
r := C.rdb_open(toCBytes(path))
if r.error != nil {
return Database{}, errors.New(C.GoString(r.error))
if r.database == nil {
return Database{}, rdb_error()
}
return Database{pointer: r.database}, nil

}

func (db Database) Set(key []byte, value []byte) bool {
return bool(C.rdb_set(db.pointer, toCBytes(key), toCBytes(value)))
func (db Database) Set(key []byte, value []byte) error {
if !bool(C.rdb_set(db.pointer, toCBytes(key), toCBytes(value))) {
return rdb_error()
}
return nil
}

func (db Database) Get(key []byte) (AllocatedBytes, error) {
ret := C.rdb_get(db.pointer, toCBytes(key))
if !ret.valid {
return AllocatedBytes{}, ErrNotFound{}
err := rdb_error()
if err != nil {
return AllocatedBytes{}, err
} else {
return AllocatedBytes{}, ErrNotFound{}
}
}
return AllocatedBytes{Bytes: fromCBytes(ret.bytes)}, nil
}

func (db Database) Remove(key []byte) bool {
return bool(C.rdb_remove(db.pointer, toCBytes(key)))
func (db Database) Remove(key []byte) error {
if !bool(C.rdb_remove(db.pointer, toCBytes(key))) {
return rdb_error()
}
return nil
}

type GoCallback = func([]byte, []byte) bool
Expand All @@ -92,11 +150,14 @@ func rdb_go_callback(handle C.uintptr_t, key C.struct_Bytes, value C.struct_Byte
// golang object
//
// calling a [Database] function inside the body is ilegal behaviour
func (db Database) ForEach(fn GoCallback) {
func (db Database) ForEach(fn GoCallback) error {
handle := cgo.NewHandle(fn)
defer handle.Delete()

C.rdb_foreach(db.pointer, (unsafe.Pointer)(handle), C.Callback(C.rdb_go_callback))
if !bool(C.rdb_foreach(db.pointer, (unsafe.Pointer)(handle), C.Callback(C.rdb_go_callback))) {
return rdb_error()
}
return nil
}

func (db Database) Close() {
Expand Down
46 changes: 30 additions & 16 deletions bindings/go/rdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,22 @@ func Get(db rdb.Database, t *testing.T, key string) []byte {
func TestDatabase(t *testing.T) {
db, err := rdb.New([]byte("test_db"))
require.NoError(t, err)
defer db.Close()
defer os.Remove("test_db")

db.Set([]byte("key1"), []byte("value1"))
db.Set([]byte("key2"), []byte("value2"))
db.Set([]byte("key3"), []byte("value3"))
db.Set([]byte("key4"), []byte("value4"))
require.NoError(t, db.Set([]byte("key1"), []byte("value1")))
require.NoError(t, db.Set([]byte("key2"), []byte("value2")))
require.NoError(t, db.Set([]byte("key3"), []byte("value3")))
require.NoError(t, db.Set([]byte("key4"), []byte("value4")))

assert.Equal(t, []byte("value1"), Get(db, t, "key1"))
assert.Equal(t, []byte("value2"), Get(db, t, "key2"))
assert.Equal(t, []byte("value3"), Get(db, t, "key3"))
assert.Equal(t, []byte("value4"), Get(db, t, "key4"))

db.Set([]byte("key1"), []byte("val1"))
db.Set([]byte("key2"), []byte("val2"))
db.Set([]byte("key3"), []byte("val3"))
db.Set([]byte("key4"), []byte("val4"))
require.NoError(t, db.Set([]byte("key1"), []byte("val1")))
require.NoError(t, db.Set([]byte("key2"), []byte("val2")))
require.NoError(t, db.Set([]byte("key3"), []byte("val3")))
require.NoError(t, db.Set([]byte("key4"), []byte("val4")))

assert.Equal(t, []byte("val1"), Get(db, t, "key1"))
assert.Equal(t, []byte("val2"), Get(db, t, "key2"))
Expand All @@ -52,26 +51,41 @@ func TestDatabase(t *testing.T) {
var count int = 0
var keys [][]byte = [][]byte{}
var values [][]byte = [][]byte{}
db.ForEach(func(b1, b2 []byte) bool {
require.NoError(t, db.ForEach(func(b1, b2 []byte) bool {
count += 1
keys = append(keys, b1)
values = append(values, b2)
return true
})
}))
assert.Equal(t, 4, count)

for i, s := range keys {
db.Remove(s)
require.NoError(t, db.Remove(s))
assert.True(t, strings.HasPrefix(string(values[i]), "val"))
assert.True(t, strings.HasPrefix(string(s), "key"))
}

_, err = db.Get([]byte("key1"))
assert.Error(t, err)
assert.ErrorAs(t, err, &rdb.ErrNotFound{})
_, err = db.Get([]byte("key2"))
assert.Error(t, err)
assert.ErrorAs(t, err, &rdb.ErrNotFound{})
_, err = db.Get([]byte("key3"))
assert.Error(t, err)
assert.ErrorAs(t, err, &rdb.ErrNotFound{})
_, err = db.Get([]byte("key4"))
assert.Error(t, err)
assert.ErrorAs(t, err, &rdb.ErrNotFound{})

db.Close()
_, err = db.Get([]byte("key1"))
assert.ErrorIs(t, err, rdb.ErrCloseDB)
err = db.Remove([]byte("key1"))
assert.ErrorIs(t, err, rdb.ErrCloseDB)
err = db.Set([]byte("key1"), []byte("value1"))
assert.ErrorIs(t, err, rdb.ErrCloseDB)
assert.ErrorIs(t, db.ForEach(func(b1, b2 []byte) bool {
count += 1
keys = append(keys, b1)
values = append(values, b2)
return true
}), rdb.ErrCloseDB)

}
2 changes: 2 additions & 0 deletions rdb.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct OptionalBytes
bool valid;
};

uint64_t rdb_error_code();

struct Result rdb_open(struct Bytes path);

void rdb_close(void* db);
Expand Down
14 changes: 9 additions & 5 deletions src/db.zig
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ pub const Scope = enum {

const METADATA_SIZE = 1 << 12;

pub const RdbError = std.mem.Allocator.Error ||
std.fs.File.GetSeekPosError || std.fs.File.OpenError ||
std.fs.File.WriteError || std.fs.File.ReadError ||
std.fs.File.SeekError || error{ ClosedDB, UnexpectedEOF };

pub const DB = struct {
path: bytes,
allocator: std.mem.Allocator,
Expand Down Expand Up @@ -169,12 +174,11 @@ pub const DB = struct {

fn _create_tree(allocator: std.mem.Allocator, file: std.fs.File, tree: *zart.Tree(DataPtr)) !void {
try file.seekTo(METADATA_SIZE);
while (true) {
loop: while (true) {
const entry = _read_key_value(allocator, file, false) catch |err| {
if (err == error.EOF) {
break;
} else {
return err;
switch (err) {
error.EOF => break :loop,
else => |v| return v,
}
};
// TODO this should be done inside _read_key_value to avoid allocations
Expand Down
56 changes: 46 additions & 10 deletions src/root.zig
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// TODO define how to destroy the memory of the database after deinit. If we destroy on close than
// TODO code that is executing concurrently would get SIGSEGV
const _root = @This();

const std = @import("std");
Expand All @@ -18,7 +20,6 @@ pub const DB = db.DB;

pub const Result = extern struct {
database: ?*db.DB = null,
err: ?[*:0]const u8 = null,
};

pub const Bytes = extern struct {
Expand All @@ -38,23 +39,57 @@ inline fn toSlice(ptr: [*]const u8, len: u64) []const u8 {
return resp;
}

fn handle(err: db.RdbError) void {
rdb_error = switch (err) {
// call a function on a closed database
error.ClosedDB => 1,
// file got an unexpected end, fail to parse file
error.UnexpectedEOF => 2,
// error returned on file.seek functions
error.Unseekable => 50,
// file operation related error
error.BrokenPipe => 51,
// the file was not opened for writing but a set function was called
error.NotOpenForWriting => 52,
// Windows-only
error.LockViolation => 53,
// database path is a directory instead of a file
error.IsDir => 54,
// allocation error
error.OutOfMemory => 55,
// the doc on the error says this is caused by wsam
error.AccessDenied => 56,
// unknown error returned from the zig std library
error.Unexpected => 99,
// unexpected error returned
else => 100,
};
}

threadlocal var rdb_error: u64 = 0;

pub export fn rdb_error_code() u64 {
return rdb_error;
}

pub export fn rdb_open(path: Bytes) Result {
const database = global_allocator.create(db.DB) catch unreachable;

database.* = db.DB.init(global_allocator, toSlice(path.ptr, path.len)) catch |err| {
const error_str = std.fmt.allocPrintZ(global_allocator, "{}", .{err}) catch unreachable;
return Result{ .err = error_str.ptr };
handle(err);
return Result{};
};
return Result{ .database = database };
}

pub export fn rdb_close(database: *db.DB) void {
database.deinit();
global_allocator.destroy(database);
}

pub export fn rdb_get(database: *db.DB, key: Bytes) OptionalBytes {
const val = database.search(toSlice(key.ptr, key.len)) catch {
rdb_error = 0;
const val = database.search(toSlice(key.ptr, key.len)) catch |err| {
handle(err);
return OptionalBytes{};
};
if (val) |v| {
Expand All @@ -71,16 +106,16 @@ pub export fn rdb_get(database: *db.DB, key: Bytes) OptionalBytes {
}

pub export fn rdb_set(database: *db.DB, key: Bytes, value: Bytes) bool {
database.set(toSlice(key.ptr, key.len), toSlice(value.ptr, value.len), .{ .own = true }) catch {
// TODO handle error
database.set(toSlice(key.ptr, key.len), toSlice(value.ptr, value.len), .{ .own = true }) catch |err| {
handle(err);
return false;
};
return true;
}

pub export fn rdb_remove(database: *db.DB, key: Bytes) bool {
database.delete(toSlice(key.ptr, key.len)) catch {
// TODO handle error
database.delete(toSlice(key.ptr, key.len)) catch |err| {
handle(err);
return false;
};
return true;
Expand All @@ -103,7 +138,8 @@ pub export fn rdb_foreach(database: *db.DB, caller_ctx_: *anyopaque, cfun: *cons
}
}.f;
const ctx = T{ .fun = cfun, .caller_ctx = caller_ctx_ };
database.for_each(arena.allocator(), T, ctx, fun) catch {
database.for_each(arena.allocator(), T, ctx, fun) catch |err| {
handle(err);
return false;
};
return true;
Expand Down

0 comments on commit 302281b

Please sign in to comment.