Skip to content

Commit

Permalink
issue with dmrg_interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
Ion committed Dec 14, 2024
1 parent 3556f9a commit 29e7ce9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:

# ADJUST THIS: install all dependencies
- run: pip install torch
- run: python setup.py install
- run: pip install -e .
- run: pip install sphinx sphinx_rtd_theme
# ADJUST THIS: build your documentation into docs/.
# We use a custom build script for pdoc itself, ideally you just run `pdoc -o docs/ ...` here.
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def test_dmrg_cross_interpolation():

assert err_rel(x.full(), x_ref) < 1e-6

def test_dmrg_cross_interpolation_nonvect():
"""
Test the DMRG cross interpolation method for non vectorized function.
"""
func1 = lambda I,J,K,L: 1 / (6 + I + J + K + L)
N = [20] * 4
x = tntt.interpolate.dmrg_cross(func1, N, eps=1e-7, eval_vect=False)
Is = tntt.meshgrid([tn.arange(0, n, dtype=tn.float64) for n in N])
x_ref = 1 / (2 + Is[0].full() + Is[1].full() + Is[2].full() + Is[3].full() + 4)

assert err_rel(x.full(), x_ref) < 1e-6

def test_function_interpolate_multivariable():
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def test_qtt(dtype):
xq = x.to_qtt()
xx = xq.qtt_to_tens(N)

assert np.abs((x-xx).norm(True)/x.norm(True)) < 1e-12, 'TT to QTT and back not working.'
assert tn.abs((x-xx).norm(True)/x.norm(True)) < 1e-12, 'TT to QTT and back not working.'

@pytest.mark.parametrize("dtype", parameters)
def test_reshape(dtype):
Expand Down
30 changes: 20 additions & 10 deletions torchtt/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _LU(M):
def _max_matrix(M):

values, indices = M.flatten().topk(1)
indices = [np.unravel_index(i, M.shape) for i in indices]
indices = [tn.unravel_index(i, M.shape) for i in indices]
return values, indices


Expand Down Expand Up @@ -166,7 +166,7 @@ def function_interpolate(function, x, eps=1e-9, start_tens=None, nswp=20, kick=2
rnew = min(N[k]*rank[k+1], rank[k])
Jk = _maxvol(core)
# print(Jk)
tmp = np.unravel_index(Jk[:rnew], (rank[k+1], N[k]))
tmp = tn.unravel_index(Jk[:rnew], (rank[k+1], N[k]))
# if k==d-1:
# idx_new = tn.tensor(tmp[1].reshape([1,-1]))
# else:
Expand Down Expand Up @@ -305,7 +305,7 @@ def function_interpolate(function, x, eps=1e-9, start_tens=None, nswp=20, kick=2
_, Ps[k+1] = QR(tn.reshape(tmp, [rank[k]*N[k], rank[k+1]]))

# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]], (rank[k], N[k]))
tmp = tn.unravel_index(idx[:rank[k+1]], (rank[k], N[k]))
idx_new = tn.tensor(
np.hstack((Idx[k][tmp[0], :], tmp[1].reshape([-1, 1]))))
Idx[k+1] = idx_new+0
Expand Down Expand Up @@ -423,7 +423,7 @@ def function_interpolate(function, x, eps=1e-9, start_tens=None, nswp=20, kick=2
_, tmp = QR(tn.reshape(tmp, [rank[k+1], -1]).t())
Ps[k+1] = tmp
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]], (N[k+1], rank[k+2]))
tmp = tn.unravel_index(idx[:rank[k+1]], (N[k+1], rank[k+2]))
idx_new = tn.tensor(
np.vstack((tmp[0].reshape([1, -1]), Idx[k+2][:, tmp[1]])))
Idx[k+1] = idx_new+0
Expand Down Expand Up @@ -511,7 +511,7 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
rnew = min(N[k]*rank[k+1], rank[k])
Jk = _maxvol(core)
# print(Jk)
tmp = np.unravel_index(Jk[:rnew], (rank[k+1], N[k]))
tmp = tn.unravel_index(Jk[:rnew], (rank[k+1], N[k]))
# if k==d-1:
# idx_new = tn.tensor(tmp[1].reshape([1,-1]))
# else:
Expand Down Expand Up @@ -568,6 +568,12 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
supercore = tn.reshape(function(eval_index), [
rank[k], N[k], N[k+1], rank[k+2]])
n_eval += eval_index.shape[0]
else:
supercore = tn.zeros(eval_index.shape[0], dtype=dtype, device=device)
for ind in range(eval_index.shape[0]):
supercore[ind] = function(*eval_index[ind,:])
supercore = tn.reshape(supercore, [rank[k], N[k], N[k+1], rank[k+2]])
n_eval += eval_index.shape[0]

# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',
Expand Down Expand Up @@ -633,7 +639,7 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
_, Ps[k+1] = QR(tn.reshape(tmp, [rank[k]*N[k], rank[k+1]]))

# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]], (rank[k], N[k]))
tmp = tn.unravel_index(idx[:rank[k+1]], (rank[k], N[k]))
idx_new = tn.tensor(
np.hstack((Idx[k][tmp[0], :], tmp[1].reshape([-1, 1]))))
Idx[k+1] = idx_new+0
Expand Down Expand Up @@ -662,6 +668,12 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
supercore = tn.reshape(function(eval_index).to(dtype=dtype), [
rank[k], N[k], N[k+1], rank[k+2]])
n_eval += eval_index.shape[0]
else:
supercore = tn.zeros(eval_index.shape[0], dtype=dtype, device=device)
for ind in range(eval_index.shape[0]):
supercore[ind] = function(*eval_index[ind,:])
supercore = tn.reshape(supercore, [rank[k], N[k], N[k+1], rank[k+2]])
n_eval += eval_index.shape[0]

# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',
Expand All @@ -680,20 +692,18 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
U = U[:, :rnew]
S = S[:rnew]
V = V[:rnew, :]
# print('kkt new',tn.linalg.norm([email protected](S)@V))

# kick the rank
U = U @ tn.diag(S)
VK = tn.randn((kick, V.shape[1]), dtype=dtype, device=device)
V, Rtemp = QR(tn.cat((V, VK), 0).t())
radd = V.shape[1] - rnew
radd = Rtemp.shape[1] - rnew
if radd > 0:
U = tn.cat(
(U, tn.zeros((U.shape[0], radd), dtype=dtype, device=device)), 1)
U = U @ Rtemp.T
V = V.t()

# print('kkt new',tn.linalg.norm(supercore-U@V))
# compute err (dx)
super_prev = tn.einsum('ijk,kmn->ijmn', cores[k], cores[k+1])
super_prev = tn.einsum(
Expand Down Expand Up @@ -731,7 +741,7 @@ def dmrg_cross(function, N, eps=1e-9, nswp=10, x_start=None, kick=2, dtype=tn.fl
_, tmp = QR(tn.reshape(tmp, [rank[k+1], -1]).t())
Ps[k+1] = tmp
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]], (N[k+1], rank[k+2]))
tmp = tn.unravel_index(idx[:rank[k+1]], (N[k+1], rank[k+2]))
idx_new = tn.tensor(
np.vstack((tmp[0].reshape([1, -1]), Idx[k+2][:, tmp[1]])))
Idx[k+1] = idx_new+0
Expand Down

0 comments on commit 29e7ce9

Please sign in to comment.