-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
156 lines (125 loc) · 5.01 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from itertools import permutations, product
import logging
from math import factorial
import os
import numpy as np
import torch
###
# some utility functions
###
def isPrime(n):
"""
Checks whether n is a prime number
"""
if n & 1 == 0:
return False
d = 3
while d * d <= n:
if n % d == 0:
return False
d = d + 2
return True
def get_inverse_perm(perm):
"""
Computes inverse of a given permutation.
"""
perm = np.array(perm)
inv = np.empty_like(perm)
inv[perm] = np.arange(len(perm), dtype=perm.dtype)
return list(inv)
def compose_perms(perm1, perm2):
"""
Computes perm1(perm2)
"""
perm1 = np.array(perm1)
perm2 = np.array(perm2)
return tuple(perm1[perm2])
def get_dataset(descr, num_elements, data_dir=None, force_data=False):
if not descr.startswith('perm'):
return ArithmeticData(data_dir, force_data, num_elements, descr)
else:
return PermData(data_dir, force_data, num_elements, descr)
def get_arithmetic_func(func_name):
return {
'plus': lambda x,y,p: (x, y, (x + y) % p),
'minus': lambda x,y,p: (x, y, (x - y) % p),
'div': lambda x,y,p: ((x * y) % p, y, x),
'div_odd': lambda x,y,p: (x, y, (x // y) % p if y % 2 == 1 else (x - y) % p),
'x2y2': lambda x,y,p: (x, y, (x ** 2 + y ** 2) % p),
'x2xyy2': lambda x,y,p: (x, y, (x ** 2 + x * y + y ** 2) % p),
'x2xyy2x': lambda x,y,p: (x, y, (x ** 2 + x * y + y ** 2 + x) % p),
'x3xy': lambda x,y,p: (x, y, (x ** 3 + x * y) % p),
'x3xy2y': lambda x,y,p: (x, y, (x ** 3 + x * y ** 2 + y) % p)
}[func_name]
###
# Dataset classes
###
class ArithmeticData(torch.utils.data.Dataset):
def __init__(self, data_dir=None, force_data=False, prime=97, func_name="plus"):
assert data_dir is not None, "data_dir is None"
assert isPrime(prime), "prime is not prime"
if force_data:
logging.info(f"Creating data and saving to {data_dir}")
self.generate_data(data_dir, func_name, prime)
logging.info(f"Loading data from {data_dir}")
try:
self.data = np.load(os.path.join(data_dir, f'{func_name}_{prime}.npy'))
except FileNotFoundError:
path = os.path.join(data_dir, f'{func_name}_{prime}.npy')
raise FileNotFoundError(f"Could not find {path}. Run with force_data=True to generate data")
def __getitem__(self, index):
return np.array(self.data[index])
def __len__(self):
return len(self.data)
@staticmethod
def generate_data(data_dir, func_name, prime=97):
data = []
func = get_arithmetic_func(func_name)
op = prime
eq = prime + 1
if func_name == 'div': # avoid dividing by zero
y_range = range(1, prime)
else:
y_range = range(prime)
for x, y in product(range(prime), y_range):
x, y, res = func(x, y, prime)
data.append([x, op, y, eq, res])
# save data
np.save(os.path.join(data_dir, f'{func_name}_{prime}.npy'), data)
class PermData(torch.utils.data.Dataset):
def __init__(self, data_dir=None, force_data=False, group_size=5, func_name="perm_xy"):
assert data_dir is not None, "data_dir is None"
assert group_size <= 10, "group_size should not be > 10, otherwise you will run out of RAM"
if force_data:
logging.info(f"Creating data and saving to {data_dir}")
self.generate_data(data_dir, group_size, func_name)
logging.info(f"Loading data from {data_dir}")
try:
self.data = np.load(os.path.join(data_dir, f'{func_name}_{group_size}.npy'))
except FileNotFoundError:
path = os.path.join(data_dir, f'{func_name}_{group_size}.npy')
raise FileNotFoundError(f"Could not find {path}. Run with force_data=True to generate data")
def __getitem__(self, index):
return np.array(self.data[index])
def __len__(self):
return len(self.data)
@staticmethod
def generate_data(data_dir, group_size=5, func_name="perm_xy"):
data = []
num_tokens = factorial(group_size)
op = num_tokens
eq = num_tokens + 1
all_permutations = list(permutations(range(group_size)))
for i, j in product(range(num_tokens), repeat=2):
# compose permutations
perm1, perm2 = all_permutations[i], all_permutations[j]
combined = compose_perms(perm1, perm2)
if func_name == "perm_xyx":
combined = compose_perms(combined, perm1)
elif func_name == "perm_xyx1":
combined = compose_perms(combined, get_inverse_perm(perm1))
# get resulting index and save data
res = all_permutations.index(combined)
data.append([i, op, j, eq, res])
# save data
np.save(os.path.join(data_dir, f'{func_name}_{group_size}.npy'), data)