Skip to content
This repository has been archived by the owner on May 6, 2024. It is now read-only.

[WIP] Libnethack shared #140

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ endif()

message(STATUS "Building nle backend version: ${NLE_VERSION}")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

# We use this to decide where the root of the nle/ package is. Normally it
Expand Down Expand Up @@ -99,7 +100,7 @@ target_link_directories(nethack PUBLIC /usr/local/lib)
target_link_libraries(nethack PUBLIC m fcontext bz2)

# dlopen wrapper library
add_library(nethackdl STATIC "sys/unix/nledl.c")
add_library(nethackdl STATIC "sys/unix/nledl.c" "sys/unix/nleshared.cc")
target_include_directories(nethackdl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(nethackdl PUBLIC dl)

Expand Down
14 changes: 5 additions & 9 deletions include/nledl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,9 @@
#include "nleobs.h"

/* TODO: Don't call this nle_ctx_t as well. */
typedef struct nledl_ctx {
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*step)(void *, nle_obs *);
FILE *ttyrec;
} nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *);
typedef struct nledl_ctx nle_ctx_t;

nle_ctx_t *nle_start(const char *, nle_obs *, FILE *, nle_seeds_init_t *, int shared);
nle_ctx_t *nle_step(nle_ctx_t *, nle_obs *);

void nle_reset(nle_ctx_t *, nle_obs *, FILE *, nle_seeds_init_t *);
Expand All @@ -27,4 +21,6 @@ void nle_end(nle_ctx_t *);
void nle_set_seed(nle_ctx_t *, unsigned long, unsigned long, char);
void nle_get_seed(nle_ctx_t *, unsigned long *, unsigned long *, char *);

int nle_supports_shared(void);

#endif /* NLEDL_H */
72 changes: 45 additions & 27 deletions nle/nethack/nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def _set_env_vars(options, hackdir, wizkit=None):
# which should allow several instances of this. On MacOS, that seems
# a tough call.
class Nethack:
_instances = 0

def __init__(
self,
observation_keys=OBSERVATION_DESC.keys(),
Expand All @@ -102,25 +100,38 @@ def __init__(
"Couldn't find NetHack installation at '%s'." % hackdir
)

# Create a HACKDIR for us.
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name
self.shared = False

if _pynethack.supports_shared():
# "shared" mode does some hacky things to enable using a
# shared libnethack.so, prevents writing to any files, and does
# not chdir.
self.shared = True
dlpath = DLPATH
self._hackdir = hackdir
else:

# Create a HACKDIR for us.
self._tempdir = tempfile.TemporaryDirectory(prefix="nle")
self._vardir = self._tempdir.name

self._hackdir = self._vardir

# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()
# Save cwd and restore later. Currently libnethack changes
# directory on loading.
self._oldcwd = os.getcwd()

# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))
# Symlink a few files.
for fn in ["nhdat", "sysconf"]:
os.symlink(os.path.join(hackdir, fn), os.path.join(self._vardir, fn))
# Touch a few files.
for fn in ["perm", "logfile", "xlogfile"]:
os.close(os.open(os.path.join(self._vardir, fn), os.O_CREAT))
os.mkdir(os.path.join(self._vardir, "save"))

# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)
# Hacky AF: Copy our so into this directory to load several copies ...
dlpath = os.path.join(self._vardir, "libnethack.so")
shutil.copyfile(DLPATH, dlpath)

if options is None:
options = NETHACKOPTIONS
Expand All @@ -129,10 +140,10 @@ def __init__(
self._options.append("playmode:debug")
self._wizard = wizard

_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
self._ttyrec = ttyrec

self._pynethack = _pynethack.Nethack(dlpath, ttyrec)
self._pynethack = _pynethack.Nethack(dlpath, ttyrec, self.shared)

self._obs_buffers = {}

Expand All @@ -154,6 +165,11 @@ def step(self, action):
return self._step_return(), self._pynethack.done()

def _write_wizkit_file(self, wizkit_items):
if self._vardir is None:
raise RuntimeError(
"FIXME: shared wizkit: can't write to HACKDIR as "
"it is a shared directory"
)
# TODO ideally we need to check the validity of the requested items
with open(os.path.join(self._vardir, WIZKIT_FNAME), "w") as f:
for item in wizkit_items:
Expand All @@ -164,9 +180,9 @@ def reset(self, new_ttyrec=None, wizkit_items=None):
if not self._wizard:
raise ValueError("Set wizard=True to use the wizkit option.")
self._write_wizkit_file(wizkit_items)
_set_env_vars(self._options, self._vardir, wizkit=WIZKIT_FNAME)
_set_env_vars(self._options, self._hackdir, wizkit=WIZKIT_FNAME)
else:
_set_env_vars(self._options, self._vardir)
_set_env_vars(self._options, self._hackdir)
if new_ttyrec is None:
self._pynethack.reset()
else:
Expand All @@ -178,11 +194,13 @@ def reset(self, new_ttyrec=None, wizkit_items=None):

def close(self):
self._pynethack.close()
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
self._tempdir.cleanup()
if not self.shared:
try:
os.chdir(self._oldcwd)
except IOError:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
if self._tempdir is not None:
self._tempdir.cleanup()

def set_initial_seeds(self, core, disp, reseed=False):
self._pynethack.set_initial_seeds(core, disp, reseed)
Expand Down
3 changes: 2 additions & 1 deletion nle/tests/test_nethack.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def test_run_n_episodes(self, tmpdir, game, episodes=3):

nethackdir = tmpdir.chdir()

assert nethackdir.fnmatch("nle*")
if not game.shared:
assert nethackdir.fnmatch("nle*")
assert tmpdir.ensure("nle.ttyrec")

if mean_sps < 15000:
Expand Down
156 changes: 100 additions & 56 deletions sys/unix/nledl.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,74 +6,113 @@

#include "nledl.h"

void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init)
{
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);
void *nleshared_open(const char *dlpath);
void nleshared_close(void *handle);
void nleshared_reset(void *handle);
void *nleshared_sym(void *handle, const char *symname);
void nleshared_set_current(void *handle);
int nleshared_supported(void);

typedef struct nledl_ctx {
void *shared;
char dlpath[1024];
void *dlhandle;
void *nle_ctx;
void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *);
void *(*step)(void *, nle_obs *);
void (*end)(void *);
FILE *ttyrec;
} nle_ctx_t;

if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
static void *
sym(nle_ctx_t *nledl, const char *name)
{
if (nledl->shared) {
return nleshared_sym(nledl->shared, name);
} else {
dlerror(); /* Clear any existing error */
void *r = dlsym(nledl->dlhandle, name);
char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
return r;
}
}

dlerror(); /* Clear any existing error */

void *(*start)(nle_obs *, FILE *, nle_seeds_init_t *);
start = dlsym(nledl->dlhandle, "nle_start");
nledl->nle_ctx = start(obs, nledl->ttyrec, seed_init);

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
void
nledl_init(nle_ctx_t *nledl, nle_obs *obs, nle_seeds_init_t *seed_init,
int shared)
{
nledl->shared = NULL;
if (shared) {
if (nleshared_supported()) {
nledl->shared = nleshared_open(nledl->dlpath);
nleshared_set_current(nledl->shared);
} else {
fprintf(stderr, "Shared mode not supported on this system!\n");
exit(EXIT_FAILURE);
}
} else {
nledl->dlhandle = dlopen(nledl->dlpath, RTLD_LAZY);
if (!nledl->dlhandle) {
fprintf(stderr, "%s\n", dlerror());
exit(EXIT_FAILURE);
}
}

nledl->step = dlsym(nledl->dlhandle, "nle_step");
nledl->start = sym(nledl, "nle_start");
nledl->step = sym(nledl, "nle_step");
nledl->end = sym(nledl, "nle_end");

error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
nledl->nle_ctx = nledl->start(obs, nledl->ttyrec, seed_init);
}

void
nledl_close(nle_ctx_t *nledl)
{
void (*end)(void *);
if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->end(nledl->nle_ctx);

end = dlsym(nledl->dlhandle, "nle_end");
end(nledl->nle_ctx);
if (nledl->shared) {
nleshared_close(nledl->shared);
} else {
if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
}

if (dlclose(nledl->dlhandle)) {
fprintf(stderr, "Error in dlclose: %s\n", dlerror());
exit(EXIT_FAILURE);
dlerror();
}

dlerror();
}

nle_ctx_t *
nle_start(const char *dlpath, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
nle_seeds_init_t *seed_init, int shared)
{
/* TODO: Consider getting ttyrec path from caller? */
struct nledl_ctx *nledl = malloc(sizeof(struct nledl_ctx));
nledl->ttyrec = ttyrec;
strncpy(nledl->dlpath, dlpath, sizeof(nledl->dlpath));

nledl_init(nledl, obs, seed_init);
nledl_init(nledl, obs, seed_init, shared);
return nledl;
};

nle_ctx_t *
nle_step(nle_ctx_t *nledl, nle_obs *obs)
{
if (!nledl || !nledl->dlhandle || !nledl->nle_ctx) {
if (!nledl || (!nledl->dlhandle && !nledl->shared) || !nledl->nle_ctx) {
fprintf(stderr, "Illegal nledl_ctx\n");
exit(EXIT_FAILURE);
}

if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->step(nledl->nle_ctx, obs);

return nledl;
Expand All @@ -85,14 +124,25 @@ void
nle_reset(nle_ctx_t *nledl, nle_obs *obs, FILE *ttyrec,
nle_seeds_init_t *seed_init)
{
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init);
if (nledl->shared) {
if (nledl->shared) {
nleshared_set_current(nledl->shared);
}
nledl->end(nledl->nle_ctx);
nleshared_reset(nledl->shared);
if (ttyrec)
nledl->ttyrec = ttyrec;
nledl->nle_ctx = nledl->start(obs, ttyrec, seed_init);
} else {
nledl_close(nledl);
/* Reset file only if not-NULL. */
if (ttyrec)
nledl->ttyrec = ttyrec;

// TODO: Consider refactoring nledl.h such that we expose this init
// function but drop reset.
nledl_init(nledl, obs, seed_init, 0);
}
}

void
Expand All @@ -108,13 +158,7 @@ nle_set_seed(nle_ctx_t *nledl, unsigned long core, unsigned long disp,
{
void (*set_seed)(void *, unsigned long, unsigned long, char);

set_seed = dlsym(nledl->dlhandle, "nle_set_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
set_seed = sym(nledl, "nle_set_seed");

set_seed(nledl->nle_ctx, core, disp, reseed);
}
Expand All @@ -125,16 +169,16 @@ nle_get_seed(nle_ctx_t *nledl, unsigned long *core, unsigned long *disp,
{
void (*get_seed)(void *, unsigned long *, unsigned long *, char *);

get_seed = dlsym(nledl->dlhandle, "nle_get_seed");

char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "%s\n", error);
exit(EXIT_FAILURE);
}
get_seed = sym(nledl, "nle_get_seed");

/* Careful here. NetHack has different ideas of what a boolean is
* than C++ (see global.h and SKIP_BOOLEAN). But one byte should be fine.
*/
get_seed(nledl->nle_ctx, core, disp, reseed);
}

int
nle_supports_shared(void)
{
return nleshared_supported();
}
Loading