Skip to content

Commit

Permalink
Refactor to align with published works
Browse files Browse the repository at this point in the history
  • Loading branch information
RadostW committed Jan 29, 2025
1 parent acbba93 commit d677e01
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
20 changes: 10 additions & 10 deletions pygrpy/grpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def mu(centres,radii,blockmatrix = False):
#indicies: bead, bead, coord, coord
muTT = np.empty([n,n,3,3]) #translation-translation
muRR = np.empty([n,n,3,3]) #rotation-rotation
muRT = np.empty([n,n,3,3]) #rotation-translation coupling
muTR = np.empty([n,n,3,3]) #translation-rotation coupling

for i in range(0,n):
for j in range(0,n):
Expand All @@ -157,8 +157,8 @@ def mu(centres,radii,blockmatrix = False):
RRidentityScale = (1.0 / (8 * math.pi * (aSmall**3)))
RRrHatScale = 0.0

# Rotation-translation
RTScale = 0.0
# Translation-rotation
TRScale = 0.0

elif distances[i][j] > a[i]+a[j]: #Far apart
# Translation-translation
Expand All @@ -169,8 +169,8 @@ def mu(centres,radii,blockmatrix = False):
RRidentityScale = (-1.0 / (16.0 * math.pi * (distances[i][j]**3)))
RRrHatScale = (1.0 / (16.0 * math.pi * (distances[i][j]**3)))*3

# Rotation-translation
RTScale = (1.0 / (8 * math.pi * (distances[i][j]**2) ))
# Translation-rotation
TRScale = (1.0 / (8 * math.pi * (distances[i][j]**2) ))

elif distances[i][j] > aBig - aSmall and distances[i][j] <= a[i]+a[j]: #Close together
# Translation-translation
Expand All @@ -187,21 +187,21 @@ def mu(centres,radii,blockmatrix = False):
RRidentityScale = (1.0 / (8.0 * math.pi * (a[i]**3) * (a[j]**3))) * mathcalA
RRrHatScale = (1.0 / (8.0 * math.pi * (a[i]**3) * (a[j]**3))) * mathcalB

# Rotation-translation
RTScale = (1.0 / (16.0 * math.pi * (a[j]**3) * a[i])) * ( ( ((a[j] - a[i] + distances[i][j])**2)*(a[i]**2+2.0*a[i]*(a[j]+distances[i][j])-3.0*((a[j]-distances[i][j])**2)) ) / (8.0 * (distances[i][j]**2)))
# Translation-rotation
TRScale = (1.0 / (16.0 * math.pi * (a[j]**3) * a[i])) * ( ( ((a[j] - a[i] + distances[i][j])**2)*(a[i]**2+2.0*a[i]*(a[j]+distances[i][j])-3.0*((a[j]-distances[i][j])**2)) ) / (8.0 * (distances[i][j]**2)))

else:
raise NotImplementedError("One bead entirely inside another")

# GRPY approximation is of form scalar * matrix + scalar * matrix
muTT[i,j,:,:] = TTidentityScale * np.identity(3) + TTrHatScale * np.outer(rHatMatrix[i][j],rHatMatrix[i][j])
muRR[i,j,:,:] = RRidentityScale * np.identity(3) + RRrHatScale * np.outer(rHatMatrix[i][j],rHatMatrix[i][j])
muRT[i,j,:,:] = RTScale * _epsilonVec(rHatMatrix[i][j])
muTR[i,j,:,:] = TRScale * _epsilonVec(rHatMatrix[i][j])

if blockmatrix:
return np.array([[muTT,muRT],[_transTranspose(muRT),muRR]])
return np.array([[muTT,muTR],[_transTranspose(muTR),muRR]])
else:
return np.hstack(np.hstack(np.hstack(np.hstack(np.array([[muTT,muRT],[_transTranspose(muRT),muRR]])))))
return np.hstack(np.hstack(np.hstack(np.hstack(np.array([[muTT,muTR],[_transTranspose(muTR),muRR]])))))



Expand Down
14 changes: 7 additions & 7 deletions pygrpy/jax_grpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def mu(centres, radii):
muRRrHatScaleClose = (1.0 / (8.0 * math.pi * (ai ** 3) * (aj ** 3))) * mathcalB

# ### coupling matricies
muRTScaleDiag = 0.0
muTRScaleDiag = 0.0

muRTScaleFar = 1.0 / (8 * math.pi * (dist ** 2))
muTRScaleFar = 1.0 / (8 * math.pi * (dist ** 2))

muRTScaleClose = (1.0 / (16.0 * math.pi * (aj ** 3) * ai)) * ( ( ((aj - ai + dist) ** 2) * ( ai ** 2 + 2.0 * ai * (aj + dist) - 3.0 * ((aj - dist) ** 2))) / (8.0 * (dist ** 2)) )
muTRScaleClose = (1.0 / (16.0 * math.pi * (aj ** 3) * ai)) * ( ( ((aj - ai + dist) ** 2) * ( ai ** 2 + 2.0 * ai * (aj + dist) - 3.0 * ((aj - dist) ** 2))) / (8.0 * (dist ** 2)) )

# solution branch indicators
isFar = 1.0*(dist > ai + aj)
Expand All @@ -129,7 +129,7 @@ def mu(centres, radii):
muRRidentityScale = isDiag * muRRidentityScaleDiag + (1.0 - isDiag) * (isFar * muRRidentityScaleFar + (1.0 - isFar) * muRRidentityScaleClose)
muRRrHatScale = (1.0 - isDiag) * (isFar * muRRrHatScaleFar + (1.0 - isFar) * muRRrHatScaleClose)

muRTScale = (1.0 - isDiag) * (isFar * muRTScaleFar + (1.0 - isFar) * muRTScaleClose)
muTRScale = (1.0 - isDiag) * (isFar * muTRScaleFar + (1.0 - isFar) * muTRScaleClose)

# construct large matricies
muTT = (
Expand All @@ -140,13 +140,13 @@ def mu(centres, radii):
muRRidentityScale[:,:,jnp.newaxis,jnp.newaxis] * jnp.identity(3)[jnp.newaxis,jnp.newaxis,:,:]
+ muRRrHatScale[:,:,jnp.newaxis,jnp.newaxis] * rHatMatrix[:,:,jnp.newaxis,:] * rHatMatrix[:,:,:,jnp.newaxis]
)
muRT = (
muRTScale[:,:,jnp.newaxis,jnp.newaxis] * epsilonRHatMatrix[:,:,:,:]
muTR = (
muTRScale[:,:,jnp.newaxis,jnp.newaxis] * epsilonRHatMatrix[:,:,:,:]
)

# flatten (2,2,n,n,3,3) tensor in the correct order
return jax.lax.reshape(
jnp.array([[muTT,muRT],[_transTranspose(muRT),muRR]]),
jnp.array([[muTT,muTR],[_transTranspose(muTR),muRR]]),
(6*n,6*n),
dimensions = (0,2,4,1,3,5)
)
Expand Down

0 comments on commit d677e01

Please sign in to comment.