Skip to content

Commit

Permalink
Merge pull request #35 from OpenFreeEnergy/issue_33_vis_refactor
Browse files Browse the repository at this point in the history
visualisation review/refactor
  • Loading branch information
RiesBen authored May 6, 2024
2 parents f9a60a9 + 2b394b3 commit 6da9ed2
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/konnektor/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .color_schemes import OFE_COLORS
from .visualization import draw_ligand_network
from .widget import draw_network_widget
9 changes: 9 additions & 0 deletions src/konnektor/visualization/color_schemes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
OFE_COLORS = (
(49 / 256, 57 / 256, 77 / 256), # Badass Blue
(184 / 256, 87 / 256, 65 / 256), # Feeling spicy
(0, 147 / 256, 132 / 256), # Feeling sick
(217 / 256, 196 / 256, 177 / 256), # Beastly grey
(217 / 256, 196 / 256, 177 / 256), # Sandy Sergio
(238 / 256, 192 / 256, 68 / 256), # Gold
(0 / 256, 47 / 256, 74 / 256), # otherBlue
)
48 changes: 23 additions & 25 deletions src/konnektor/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,59 @@
import networkx as nx
from matplotlib import pyplot as plt

ofe_colors = [(49 / 256, 57 / 256, 77 / 256), # Badass Blue
(184 / 256, 87 / 256, 65 / 256), # Feeling spicy
(0, 147 / 256, 132 / 256), # Feeling sick
(217 / 256, 196 / 256, 177 / 256), # Beastlygrey
(217 / 256, 196 / 256, 177 / 256), # Sandy Sergio
(238/256, 192/256, 68/256), #GOld
(0 / 256, 47 / 256, 74 / 256), ] # otherBlue]

def color_gradient(c1=ofe_colors[1], c2=ofe_colors[2], c3=ofe_colors[1], mix=0):
c1=np.array(c1)
c2=np.array(c2)
c3=np.array(c3)
from . import OFE_COLORS


def color_gradient(c1=OFE_COLORS[1], c2=OFE_COLORS[2], c3=OFE_COLORS[1], mix=0):
c1 = np.array(c1)
c2 = np.array(c2)
c3 = np.array(c3)
mix = np.array(mix, ndmin=1)

if(mix > 0.5):
if mix > 0.5:
m = mix-0.5
c = (0.5-m)*c2 + m*c3
else:
m = mix
c = (0.5-m)*c1 + m*c2
return c

def get_node_connectivities(cg):
return [sum([n in e for e in cg.edges]) for n in cg.nodes]

def get_node_connectivities(cg) -> list[int]:
"""The connectivity of each node"""
# TODO: why is this summing n?
# shouldn't it be [sum(1 for e in edges_of(n)) for n in cg.nodes]
return [sum([n in e for e in cg.edges]) for n in cg.nodes]


def draw_ligand_network(network, title="", ax=None, node_size=2050, edge_width=3, fontsize=18):

ligands = list(network.nodes)
edge_map = {(m.componentA.name, m.componentB.name): m for m in network.edges}
edges = list(sorted(edge_map.keys()))
weights = [edge_map[k].annotations['score'] for k in edges]

#g = network.graph
g = nx.Graph()
[g.add_node(n.name) for n in ligands]
g.add_weighted_edges_from(ebunch_to_add=[(e[0], e[1], w) for e,w in zip(
edges,weights)])
for n in ligands:
g.add_node(n.name)
g.add_weighted_edges_from(ebunch_to_add=[(e[0], e[1], w)

# graph vis layout
pos = nx.spring_layout(g, weight=1)
# graph vis layout
pos = nx.spring_layout(g, weight=1)
for e, w in zip(edges, weights)])

if(ax is None):
fig, ax = plt.subplots(figsize=[16,9])
if ax is None:
fig, ax = plt.subplots(figsize=[16, 9])
else:
fig=None
fig = None

connectivities = np.array(get_node_connectivities(network))
mixins = np.clip(connectivities / (sum(connectivities)/len(connectivities)), a_min=0, a_max=2)/2

cs = list(map(lambda x: color_gradient(mix=x), mixins))

nx.draw_networkx(g, pos=pos, with_labels=True, ax=ax, node_size=node_size, width=edge_width,
node_color=cs, edge_color=ofe_colors[3], font_color=[1,1,1])
node_color=cs, edge_color=OFE_COLORS[3], font_color=[1, 1, 1])
ax.set_title(title, fontsize=fontsize) #+" #edges "+str(len(g.edges))

return fig
55 changes: 37 additions & 18 deletions src/konnektor/visualization/widget.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gufe
import networkx as nx
import numpy as np
from urllib import parse
Expand All @@ -9,18 +10,12 @@

from gufe.visualization.mapping_visualization import draw_mapping

ofe_colors = [(49 / 256, 57 / 256, 77 / 256), # Badass Blue
(184 / 256, 87 / 256, 65 / 256), # Feeling spicy
(0, 147 / 256, 132 / 256), # Feeling sick
(217 / 256, 196 / 256, 177 / 256), # Beastlygrey
(217 / 256, 196 / 256, 177 / 256), # Sandy Sergio
(238 / 256, 192 / 256, 68 / 256), # GOld
(0 / 256, 47 / 256, 74 / 256), ] # otherBlue]
from . import OFE_COLORS

rgb2hex = lambda r, g, b: '#%02x%02x%02x' % (int(r * 256), int(g * 256), int(b * 256))


def color_gradient(c1=ofe_colors[1], c2=ofe_colors[2], c3=ofe_colors[1], mix=0):
def color_gradient(c1=OFE_COLORS[1], c2=OFE_COLORS[2], c3=OFE_COLORS[1], mix=0):
c1 = np.array(c1)
c2 = np.array(c2)
c3 = np.array(c3)
Expand Down Expand Up @@ -72,27 +67,37 @@ def map2svg(mapping):
return impath


def build_cytoscape(network, layout="concentric", show_molecules=True, show_mappings=False):
def build_cytoscape(network: gufe.LigandNetwork, layout="concentric", show_molecules=True, show_mappings=False):
ligands = list(network.nodes)
edge_map = {(m.componentA.name, m.componentB.name): m for m in network.edges}
edges = list(sorted(edge_map.keys()))
weights = [edge_map[k].annotations['score'] for k in edges]

connectivities = np.array(get_node_connectivities(network))
if(len(connectivities) == 0):
mixins=np.array([0])
if not len(connectivities):
mixins = np.array([0])
cs = list(map(lambda x: color_gradient(mix=x), mixins))

else:
mixins = np.clip(connectivities / (sum(connectivities) / len(connectivities)), a_min=0, a_max=2) / 2
cs = list(map(lambda x: color_gradient(mix=x), mixins))

# build a graph
g = nx.Graph()

[g.add_node(n.name, name=n.name, classes="ligand", img=mol2svg(n.to_rdkit()), col=c) for n, c in zip(ligands, cs)]
[g.add_node(f"{e[0]}-{e[1]}", classes="mapping", name=f"{e[0]}-{e[1]}", lab=f"{e[0]} - {e[1]}\nscore: {w:2.2F}",
weight="{:2.2F}".format(w), img=map2svg(edge_map[e])) for e, w in zip(edges, weights)]
g.add_nodes_from(
(n.name,
{'name': n.name, 'classes': 'ligand',
'img': mol2svg(n.to_rdkit()), 'col': c},
)
for n, c in zip(ligands, cs)
)
g.add_nodes_from(
(f'{e[0]}-{e[1]}',
{'name': f'{e[0]}-{e[1]}', 'classes': 'mapping',
'lab': f'{e[0]} - {e[1]}\nscore: {w:2.2F}',
'weight': f'{w:2.2F}', 'img': map2svg(edge_map[e])},
)
for e, w in zip(edges, weights)
)

for e, w in zip(edges, weights):
g.add_edge(e[0], f"{e[0]}-{e[1]}")
Expand Down Expand Up @@ -187,9 +192,23 @@ def build_cytoscape(network, layout="concentric", show_molecules=True, show_mapp
return undirected


def draw_network_widget(network, layout="cose", show_molecules=True, show_mappings=False, ):
def draw_network_widget(network: gufe.LigandNetwork, layout="cose", show_molecules=True, show_mappings=False):
"""For use in a jupyter noterbook, visualise a LigandNetwork
Parameters
----------
network: gufe.LigandNetwork
the network to visualise
layout : str, optional
how to initially layout the nodes, can be one of X/Y/Z
defaults to 'cose'
show_molecule: bool, optional
if to show molecule images on the representation, default True
show_mappings: bool, optional
if to show mapping images on the representation, default False
"""
@interact(network=fixed(network), layout=['dagre', 'cola', 'breadthfirst',
'circular', 'preset', 'concentric', 'cose'])
'concentric', 'cose'])
def interactive_widget(network=network, layout=layout, show_molecules=show_molecules, show_mappings=show_mappings):
v = build_cytoscape(network=network, layout=layout, show_molecules=show_molecules, show_mappings=show_mappings)
return v
Expand Down

0 comments on commit 6da9ed2

Please sign in to comment.