-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_dataset_custom.py
68 lines (63 loc) · 2.1 KB
/
create_dataset_custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pygame
import os
import cv2
from pygame.locals import *
from ImageEnvironment import GridWorldEnv
from MooreMachine import MinecraftMoore
EXECUTIONS = 20
MAX_LENGTH = 30
dataset_path = "dataset_whole"
minecraft_machine = MinecraftMoore
Env = GridWorldEnv(minecraft_machine, "human", train=False)
for i in range(EXECUTIONS):
obs, reward, info = Env.reset()
print("Execution "+str(i))
path=os.path.join(dataset_path, "episode_"+str(EXECUTIONS+i)+"/")
try:
os.mkdir(path)
except OSError as error:
print(error)
terminated = False
obs_list = []
reward_list = []
action_record = []
info_record = []
j=0
while(True and j<MAX_LENGTH):
for e in pygame.event.get():
if e.type == QUIT:
Env.close()
elif e.type == KEYDOWN:
if e.key == K_ESCAPE:
Env.close()
elif e.key == K_s:
obs, reward, info, terminated = Env.step(0)
obs_list.append(obs)
reward_list.append(reward)
info_record.append(info)
j+=1
elif e.key == K_d:
obs, reward, info, terminated = Env.step(1)
obs_list.append(obs)
reward_list.append(reward)
info_record.append(info)
j+=1
elif e.key == K_w:
obs, reward, info, terminated = Env.step(2)
obs_list.append(obs)
reward_list.append(reward)
info_record.append(info)
j+=1
elif e.key == K_a:
obs, reward, info, terminated = Env.step(3)
obs_list.append(obs)
reward_list.append(reward)
info_record.append(info)
j+=1
path_ = path
try:
os.mkdir(path_)
except OSError as error:
print(error)
for o in range(len(obs_list)):
cv2.imwrite(path_+"img"+str(o)+"_"+str(reward_list[o])+".jpg", obs_list[o])