Skip to content

Commit

Permalink
correct version of pearson_correlation_coef and is_strictly_increasin…
Browse files Browse the repository at this point in the history
…g + update make_ascending
  • Loading branch information
lhaibach committed Jan 31, 2025
1 parent 5faf79d commit 5b60643
Showing 1 changed file with 19 additions and 20 deletions.
39 changes: 19 additions & 20 deletions src/stratigraphy/sidebar/a_above_b_sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def strictly_contains(self, other: AAboveBSidebar) -> bool:
)

def is_strictly_increasing(self) -> bool:
return all(i.value < j.value for i, j in zip(self.entries, self.entries[1:], strict=False))
return all(self.entries[i].value < self.entries[i + 1].value for i in range(len(self.entries) - 1))

def depth_intervals(self) -> list[AAboveBInterval]:
"""Creates a list of depth intervals from the depth column entries.
Expand Down Expand Up @@ -95,11 +95,15 @@ def pearson_correlation_coef(self) -> float:
positions = np.array([entry.rect.y1 for entry in self.entries])
entries = np.array([entry.value for entry in self.entries])

# Avoid warnings in the np.corrcoef call, as the correlation coef is undefined if the standard deviation is 0.
if np.std(entries) == 0 or np.std(positions) == 0:
std_positions = np.std(positions)
std_entries = np.std(entries)
if std_positions == 0 or std_entries == 0:
return 0

return np.corrcoef(positions, entries)[0, 1].item()
# We calculate the Pearson correlation coefficient manually
# to avoid redundant standard deviation calculations that would occur with np.corrcoef.
covariance = np.mean((positions - np.mean(positions)) * (entries - np.mean(entries)))
return covariance / (std_positions * std_entries)

def remove_entry_by_correlation_gradient(self) -> AAboveBSidebar | None:
if len(self.entries) < 3:
Expand All @@ -113,33 +117,28 @@ def remove_entry_by_correlation_gradient(self) -> AAboveBSidebar | None:

def make_ascending(self):
median_value = np.median(np.array([entry.value for entry in self.entries]))

for i, entry in enumerate(self.entries):
new_values = []

if entry.value.is_integer() and entry.value > median_value:
new_values.extend([entry.value / 100, entry.value / 10])
for new_value in new_values:
if self._valid_value(i, new_value):
# Create a new entry instead of modifying the value of the current one, as this entry might be
# used in different sidebars as well.
self.entries[i] = DepthColumnEntry(rect=entry.rect, value=new_value)
break

# Correct common OCR mistakes where "4" is recognized instead of "1"
# We don't control for OCR mistakes recognizing "9" as "3" (example zurich/680244005-bp.pdf)
if "4" in str(entry.value) and not self._valid_value(i, entry.value):
# Correct common OCR mistakes where "4" is recognized instead of "1"
# Edge case: OCR also can also replace "3" with "9"
alternative_values = generate_alternatives(entry.value)
for alternative_value in alternative_values:
if self._valid_value(i, alternative_value):
self.entries[i] = DepthColumnEntry(rect=entry.rect, value=alternative_value)
break
new_values.extend(generate_alternatives(entry.value))

# Assign the first valid correction
for new_value in new_values:
if self._valid_value(i, new_value):
self.entries[i] = DepthColumnEntry(rect=entry.rect, value=new_value)
break
return self

def _valid_value(self, index: int, new_value: float) -> bool:
"""Check if new value at given index is maintaining ascending order."""
previous_ok = index == 0 or all(
other_entry.value < new_value for other_entry in self.entries[:index]
) ## too strict?
previous_ok = index == 0 or all(other_entry.value < new_value for other_entry in self.entries[:index])
next_ok = index + 1 == len(self.entries) or new_value < self.entries[index + 1].value
return previous_ok and next_ok

Expand Down

0 comments on commit 5b60643

Please sign in to comment.