Skip to content

Commit

Permalink
Implement label to polyhedron
Browse files Browse the repository at this point in the history
  • Loading branch information
bobleesj committed Jun 18, 2024
1 parent 41338ca commit 64f53b8
Show file tree
Hide file tree
Showing 17 changed files with 280 additions and 22 deletions.
67 changes: 67 additions & 0 deletions example.ipynb

Large diffs are not rendered by default.

96 changes: 80 additions & 16 deletions src/cifkit/figures/polyhedron.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,85 @@
import os
import pyvista as pv
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull
from scipy.spatial import Delaunay


def plot(points, labels, file_path, is_displayed, output_dir=None):
def generate_contrasting_colors():
"""Return a list of manually selected contrasting colors."""
return [
"#00FFFF", # Cyan
"#0000FF", # Blue
"#FF0000", # Red
"#800080", # Purple
"#FF00FF", # Magenta
"#FFFF00", # Yellow
"#00FF00", # Lime
"#FF4500", # Orange Red
"#2E8B57", # Sea Green
"#1E90FF", # Dodger Blue
"#FF1493", # Deep Pink
"#FFD700", # Gold
]


def generate_color_mapping(labels):
"""Generate a dictionary mapping labels to contrasting colors."""
colors = generate_contrasting_colors()
color_map = {}
unique_labels = set(labels)
for i, label in enumerate(unique_labels):
color_map[label] = colors[i % len(colors)]
return color_map


def plot(points, vertex_labels, file_path, formula, is_displayed, output_dir=None):
"""
Generate and save a 3D plot of a molecular structure.
"""
plotter = pv.Plotter(off_screen = not is_displayed)

points = np.array(points)
central_atom_coord = points[-1]
central_atom_label = labels[-1]

# plotter = pv.Plotter()
plotter = pv.Plotter(off_screen=not is_displayed, window_size=(1600, 1200))
label_colors = generate_color_mapping(vertex_labels)

for point, label in zip(points, labels):
radius = 0.6 if np.array_equal(point, central_atom_coord) else 0.4 # Central atom larger
points = np.array(points)
central_atom_coord = points[-1]
central_atom_label = vertex_labels[-1]
# Coordination numbers
coordination_number = len(points) - 1

# Title
title = f"Formula: {formula}, Central atom: {central_atom_label}, CN: {coordination_number},\n{file_path}"
plotter.add_title(title, font="arial")
# Constructing title and subtitle

for idx, (point, label) in enumerate(zip(points, vertex_labels)):
radius = (
0.4 if np.array_equal(point, central_atom_coord) else 0.4
) # Central atom larger
sphere = pv.Sphere(radius=radius, center=point)
plotter.add_mesh(sphere, color='#D3D3D3') # Light grey color

plotter.add_mesh(sphere, color=label_colors[label])

# Add labels with index
indexed_label = (
f"{idx + 1}. {label}" # Creating a label with numbering
)
adjusted_point = point + [
0.3,
0.3,
0.3,
] # Offset to avoid overlapping with the sphere
if idx != len(points) - 1:
plotter.add_point_labels(
adjusted_point,
[indexed_label], # Use the indexed label
font_size=50,
text_color=label_colors[label],
always_visible=True,
shape=None,
margin=0,
reset_camera=False,
)

delaunay = Delaunay(points)
hull = ConvexHull(points)
Expand All @@ -43,23 +101,27 @@ def plot(points, labels, file_path, is_displayed, output_dir=None):
for edge in edges:
if edge in hull_edges:
start, end = points[edge[0]], points[edge[1]]
cylinder = pv.Cylinder(center=(start + end) / 2, direction=end - start, radius=0.05, height=np.linalg.norm(end - start))
plotter.add_mesh(cylinder, color='grey')
cylinder = pv.Cylinder(
center=(start + end) / 2,
direction=end - start,
radius=0.05,
height=np.linalg.norm(end - start),
)
plotter.add_mesh(cylinder, color="grey")

faces = []
for simplex in hull.simplices:
faces.append([3] + list(simplex))
poly_data = pv.PolyData(points, faces)

plotter.add_mesh(poly_data, color='aqua', opacity=0.5, show_edges=True)
plotter.add_mesh(poly_data, color="aqua", opacity=0.5, show_edges=True)

plotter.show()


"""
Output
"""

# Determine the output directory based on provided path
if not output_dir:
output_dir = os.path.join(os.path.dirname(file_path), "polyhedrons")
Expand All @@ -82,3 +144,5 @@ def plot(points, labels, file_path, is_displayed, output_dir=None):
"""
# Save the screenshot
plotter.screenshot(save_path)


11 changes: 9 additions & 2 deletions src/cifkit/models/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,14 @@ def get_polyhedron_labels_by_CN_best_methods(

@ensure_connections
def plot_polyhedron(self, site_label, is_displayed=False, output_dir=None):
coords, labels = get_polyhedron_coordinates_labels(
coords, vertex_labels = get_polyhedron_coordinates_labels(
self.CN_connections_by_best_methods, site_label
)
polyhedron.plot(coords, labels, self.file_path, is_displayed, output_dir)
polyhedron.plot(
coords,
vertex_labels,
self.file_path,
self.formula,
is_displayed,
output_dir,
)
114 changes: 114 additions & 0 deletions test_polyhedron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# points = np.array(
# [
# [0.0, 0.0, 3.881],
# [0.0, 0.0, 0.0],
# [3.738, 2.158, 1.94],
# [3.738, -2.158, 1.94],
# [4.43, 0.0, 0.0],
# [4.43, 0.0, 3.881],
# [-0.936, 1.622, 1.94],
# [-0.936, -1.622, 1.94],
# [1.523, 2.638, 0.0],
# [1.523, -2.638, 0.0],
# [1.523, -2.638, 3.881],
# [1.523, 2.638, 3.881],
# [1.873, 0.0, -1.94],
# [1.873, 0.0, 5.821],
# [1.873, 0.0, 1.94],
# ]
# )

# labels = [
# "Rh2",
# "Rh2",
# "Rh1",
# "Rh1",
# "U1",
# "U1",
# "In1",
# "In1",
# "U1",
# "U1",
# "U1",
# "U1",
# "In1",
# "In1",
# "In1",
# ]


# colors = (
# generate_contrasting_colors()
# ) # or generate_color_palette_from_colormaps(len(labels))

# label_colors = {label: color for label, color in zip(labels, colors)}


# plotter = pv.Plotter(window_size=(1600, 1200))

# central_atom_index = np.argmin(np.linalg.norm(points, axis=1))
# central_atom = points[central_atom_index]

# for idx, (point, label) in enumerate(zip(points, labels)):
# radius = (
# 0.4 if np.array_equal(point, central_atom) else 0.4
# ) # Central atom larger
# sphere = pv.Sphere(radius=radius, center=point)
# plotter.add_mesh(sphere, color=label_colors[label])

# # Add labels with index
# indexed_label = f"{idx + 1}. {label}" # Creating a label with numbering
# adjusted_point = point + [
# 0.3,
# 0.3,
# 0.3,
# ] # Offset to avoid overlapping with the sphere
# if idx != len(points) - 1:
# plotter.add_point_labels(
# adjusted_point,
# [indexed_label], # Use the indexed label
# font_size=50,
# text_color=label_colors[label],
# always_visible=True,
# shape=None,
# margin=0,
# reset_camera=False,
# )


# delaunay = Delaunay(points)
# hull = ConvexHull(points)

# edges = set()
# for simplex in delaunay.simplices:
# for i in range(4):
# for j in range(i + 1, 4):
# edge = tuple(sorted([simplex[i], simplex[j]]))
# edges.add(edge)

# hull_edges = set()
# for simplex in hull.simplices:
# for i in range(len(simplex)):
# for j in range(i + 1, len(simplex)):
# hull_edge = tuple(sorted([simplex[i], simplex[j]]))
# hull_edges.add(hull_edge)

# for edge in edges:
# if edge in hull_edges:
# start, end = points[edge[0]], points[edge[1]]
# cylinder = pv.Cylinder(
# center=(start + end) / 2,
# direction=end - start,
# radius=0.05,
# height=np.linalg.norm(end - start),
# )
# plotter.add_mesh(cylinder, color="grey")

# faces = []
# for simplex in hull.simplices:
# faces.append([3] + list(simplex))
# poly_data = pv.PolyData(points, faces)

# plotter.add_mesh(poly_data, color="aqua", opacity=0.3, show_edges=True)

# plotter.show()
14 changes: 10 additions & 4 deletions tests/core/models/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cifkit.utils.error_messages import CifParserError
from cifkit.utils import folder


@pytest.mark.fast
def test_cif_static_properties(cif_URhIn):
assert cif_URhIn.file_path == "tests/data/cif/URhIn.cif"
Expand Down Expand Up @@ -363,33 +364,38 @@ def test_plot_polyhedron_with_output_folder_given(cif_URhIn):
os.makedirs(expected_output_dir)

# Define the output file path
cif_URhIn.plot_polyhedron("In1", output_dir="tests/data/cif/polyhedrons_user")
cif_URhIn.plot_polyhedron(
"In1", is_displayed=True, output_dir="tests/data/cif/polyhedrons_user"
)

assert os.path.exists(output_file_path)
assert os.path.getsize(output_file_path) > 1024

shutil.rmtree(expected_output_dir)


def test_plot_polyhedrons(cif_ensemble_test):
# Define the directory to store the output
expected_output_dir = "tests/data/cif/ensemble_test/polyhedrons"
# Ensure the directory exists
if not os.path.exists(expected_output_dir):
os.makedirs(expected_output_dir)

cifs = cif_ensemble_test.cifs
for cif in cifs:
labels = cif.site_labels
print("Label from the cif object")
print(labels)
for label in labels:
cif.plot_polyhedron(label)


# Check the number of files
image_file_count = folder.get_file_count(expected_output_dir, ".png")
assert image_file_count == 12

shutil.rmtree(expected_output_dir)


"""
Test error during init
"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 64f53b8

Please sign in to comment.