Skip to content

Commit

Permalink
Add testing framework (#25)
Browse files Browse the repository at this point in the history
* add fixture

* add test files

* add pytest to dev env

* setup logging in main

* format

* format

* add debug statements

* setup for future testing

* setup for future testing

* test analyze functions in quantify

* move fixture

* import fixtures

* make tests a python package

* add workflow for tests

* fix workflow

* set logging level to INFO

* format
  • Loading branch information
eberrigan authored Jan 14, 2025
1 parent 9c728d2 commit b46fd64
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 32 deletions.
43 changes: 43 additions & 0 deletions .github/workflows/test-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: CI Workflow

on:
push:
branches:
- '**'
pull_request:
types: [opened, reopened, synchronize]

jobs:
tests:
name: Tests (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04", "windows-2022", "macos-14"]
include:
- env_file: environment.yaml

steps:
- name: Checkout repo
uses: actions/checkout@v4

- name: Setup Conda
uses: conda-incubator/[email protected]
with:
miniforge-version: latest
conda-solver: "libmamba"
environment-file: ${{ matrix.env_file }}
activate-environment: ariadne_dev

- name: Print environment info
shell: bash -l {0}
run: |
which python
conda info
conda list
pip freeze
- name: Run tests
shell: bash -l {0}
run: pytest tests/
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- numpy
- scipy
- matplotlib
- pytest # for testing
- pip
- pip:
- -e .
54 changes: 35 additions & 19 deletions src/ariadne_roots/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import networkx as nx
import json
import logging

from pathlib import Path
from queue import Queue
Expand All @@ -28,14 +29,20 @@
from ariadne_roots import quantify


# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)


class StartupUI:
"""Startup window interface."""

def __init__(self, base):
self.base = base
self.base.geometry("350x200")


# master frame
self.frame = tk.Frame(self.base)
self.frame.pack(side="top", fill="both", expand=True)
Expand Down Expand Up @@ -123,7 +130,7 @@ def __init__(self, base):
self.menu, text="Zoom In (+)", command=None, state="disabled"
)
self.button_zoom_out = tk.Button(
self.menu, text="Zoom Out (-)", command=None, state="disabled"
self.menu, text="Zoom Out (-)", command=None, state="disabled"
)
self.button_import.pack(fill="x", side="top")
self.button_prev.pack(fill="x", side="top")
Expand Down Expand Up @@ -208,11 +215,14 @@ def __init__(self, base):

def ask_zoom_factor(self):
"""Prompt user to select a zoom factor after importing the first image."""

def on_ok():
try:
# Retrieve the scale factor from the statusbar
zoom = self.scale_factor
print(f"Zoom factor selected: {zoom}") # Debug print, remove if unnecessary
print(
f"Zoom factor selected: {zoom}"
) # Debug print, remove if unnecessary
self.zoom_factor = zoom # Store zoom factor
zoom_popup.destroy()
except ValueError:
Expand All @@ -235,7 +245,6 @@ def on_cancel():
cancel_button = tk.Button(zoom_popup, text="Cancel", command=on_cancel)
cancel_button.pack(side="right", padx=20)


def click_info(self, event):
"""Show node metadata on right click (for debugging)."""
for n in self.tree.nodes: # check click proximity to existing points
Expand Down Expand Up @@ -298,7 +307,6 @@ def import_image(self):

self.frame_id = self.canvas.create_image(0, 0, image=self.img, anchor="nw")


# create gif iterator for pagination
self.iterframes = ImageSequence.Iterator(self.file)
self.frame_index = 0
Expand Down Expand Up @@ -343,9 +351,8 @@ def import_image(self):
self.canvas.bind("-", self.zoom_out)

# Prompt user to choose zoom factor after importing the first image
#if not hasattr(self, 'zoom_factor'):
#self.ask_zoom_factor()

# if not hasattr(self, 'zoom_factor'):
# self.ask_zoom_factor()

def change_frame(self, next_index):
"""Move frames in the GIF."""
Expand Down Expand Up @@ -494,8 +501,6 @@ def insert(self, event=None):
if not self.prox_override:
self.override()



def change_root(self, event=None):
"""Clear current tree, prompt for a new root, and reinitialize."""

Expand All @@ -517,19 +522,18 @@ def change_root(self, event=None):
# Prompt for a new plant ID assignment and create a new tree
self.tree.popup(self.base)


# Zoom function
#zoom in
# Zoom function
# zoom in
def zoom_in(self):
self.scale_factor *= 1.5 # Increase scale
self.update_image()
self.update_statusbar() # Update the status bar with zoom info
self.update_statusbar() # Update the status bar with zoom info

# Zoom out
def zoom_out(self):
self.scale_factor /= 1.5 # Decrease scale
self.update_image()
self.update_statusbar() # Update the status bar with zoom info
self.update_statusbar() # Update the status bar with zoom info

def update_image(self):
"""Update the image on the canvas based on the scale factor."""
Expand Down Expand Up @@ -557,8 +561,8 @@ def update_statusbar(self):
"""Update the status bar text with scale factor and other information."""
self.statusbar.config(
text=f"{self.canvas.curr_coords}, {self.day_indicator}, "
f"{self.override_indicator}, {self.inserting_indicator}, "
f"Zoom Scale: {self.scale_factor}"
f"{self.override_indicator}, {self.inserting_indicator}, "
f"Zoom Scale: {self.scale_factor}"
)

def draw_edge(self, parent_node, child_node):
Expand Down Expand Up @@ -808,6 +812,7 @@ def make_file(self, event=None):
json.dump(s, h)
print(f"wrote to output {output_name}")


class Node:
"""An (x,y,0) point along a root."""

Expand Down Expand Up @@ -918,7 +923,6 @@ def updater():

base.wait_window(top) # wait for a button to be pressed


##########################
def insert_child(self, current_node, new):
"""Assign child when using insertion mode."""
Expand Down Expand Up @@ -977,7 +981,9 @@ def index_LRs(self):
current_node = q.get()
# arbitrarily, we assign LR indices left-to-right
# sort by x-coordinate
current_node_children = sorted(current_node.children, key=lambda x: x.relcoords[0])
current_node_children = sorted(
current_node.children, key=lambda x: x.relcoords[0]
)

for n in current_node_children:
if n.root_degree is None: # only index nodes that haven't been already
Expand Down Expand Up @@ -1079,6 +1085,16 @@ def import_file(self):
w = csv.DictWriter(csvfile, fieldnames=results.keys())
w.writerow(results)

# debug
logging.debug(f"Total root length: {results['Total root length']}")
logging.debug(f"Travel distance: {results['Travel distance']}")
logging.debug(
f"Total root length (random): {results['Total root length (random)']}"
)
logging.debug(
f"Travel distance (random): {results['Travel distance (random)']}"
)

# make pareto plot and save
quantify.plot_all(
front,
Expand Down
34 changes: 21 additions & 13 deletions src/ariadne_roots/pareto_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def pareto_cost(total_root_length, total_travel_distance, alpha):
return cost


def pareto_cost_3d_path_tortuosity(total_root_length, total_travel_distance, total_path_coverage, alpha, beta):
def pareto_cost_3d_path_tortuosity(
total_root_length, total_travel_distance, total_path_coverage, alpha, beta
):
"""
Computes the pareto cost.
Expand All @@ -193,21 +195,25 @@ def pareto_cost_3d_path_tortuosity(total_root_length, total_travel_distance, tot
When alpha = gamma = 0, beta = 1 => cost = total_travel_distance will be minimized
When beta = gamma = 0, alpha = 1 => cost = total_root_length will be minimized
total_root_length: the sum of the lengths of the edges in the root network
total_root_length: the sum of the lengths of the edges in the root network
(a.k.a. material cost, wiring cost)
total_travel_distance: the sum of the lengths of the shortest paths from every
lateral root tip to the base node of the network. (a.k.a. the satellite cost,
lateral root tip to the base node of the network. (a.k.a. the satellite cost,
conduction delay)
total_path_coverage: the sum of the tortuosity of all the root paths. The tortuosity per
path is defined as the ratio of the actual path length to the shortest path
length between the root and the root tip. The total root coverage is the sum of
total_path_coverage: the sum of the tortuosity of all the root paths. The tortuosity per
path is defined as the ratio of the actual path length to the shortest path
length between the root and the root tip. The total root coverage is the sum of
the tortuosity of all the root paths.
"""
assert 0 <= alpha <= 1
assert 0 <= beta <= 1

gamma = 1 - alpha - beta
cost = alpha * total_root_length + beta * total_travel_distance - gamma * total_path_coverage
cost = (
alpha * total_root_length
+ beta * total_travel_distance
- gamma * total_path_coverage
)

return cost

Expand Down Expand Up @@ -501,14 +507,14 @@ def pareto_steiner_fast_3d_path_tortuosity(G, alpha, beta):
When alpha = gamma = 0, beta = 1 => cost = total_travel_distance will be minimized
When beta = gamma = 0, alpha = 1 => cost = total_root_length will be minimized
total_root_length: the sum of the lengths of the edges in the root network
total_root_length: the sum of the lengths of the edges in the root network
(a.k.a. material cost, wiring cost)
total_travel_distance: the sum of the lengths of the shortest paths from every
lateral root tip to the base node of the network. (a.k.a. the satellite cost,
lateral root tip to the base node of the network. (a.k.a. the satellite cost,
conduction delay)
total_path_coverage: the sum of the tortuosity of all the root paths. The tortuosity per
path is defined as the ratio of the actual path length to the shortest path
length between the root and the root tip. The total root coverage is the sum of
total_path_coverage: the sum of the tortuosity of all the root paths. The tortuosity per
path is defined as the ratio of the actual path length to the shortest path
length between the root and the root tip. The total root coverage is the sum of
the tortuosity of all the root paths.
The algorithm uses a greedy approach: always take the edge that will reduce the
Expand Down Expand Up @@ -703,7 +709,9 @@ def pareto_steiner_fast_3d_path_tortuosity(G, alpha, beta):
H.nodes[n2]["distance_to_base"] = (
node_dist(H, n2, u) + H.nodes[u]["distance_to_base"]
)
H.nodes[n2]["straight_distance_to_base"] = H.nodes[n2]["distance_to_base"] / node_dist(H, n2, base_node)
H.nodes[n2]["straight_distance_to_base"] = H.nodes[n2][
"distance_to_base"
] / node_dist(H, n2, base_node)

added_nodes += 1
return H
Expand Down
Empty file added tests/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from tests.fixtures import *

# Import the test data from the fixtures.py file
Loading

0 comments on commit b46fd64

Please sign in to comment.