Skip to content

Commit

Permalink
added a test for drop_duplicate_visits
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 31, 2024
1 parent 6770b26 commit 07930d7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 15 deletions.
58 changes: 44 additions & 14 deletions src/cehrbert_data/tools/ehrshot_to_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,49 @@ def generate_visit_id(data: DataFrame, time_interval: int = 12) -> DataFrame:
return data.join(visit, on=["patient_id", "visit_order"]).drop("visit_order", "patient_event_order")


def drop_duplicate_visits(data: DataFrame) -> DataFrame:
"""
Removes duplicate visits based on visit priority, retaining a single record per `visit_id`.
This function identifies duplicate visits by `visit_id` and assigns a priority to each visit type.
Visits with the highest priority (lowest priority value) are retained, while others are dropped.
Priority is assigned based on the `code` column:
- "Visit/IP" and "Visit/ERIP" have the highest priority (1),
- "Visit/ER" has medium priority (2),
- All other visit types have the lowest priority (3).
The function returns a DataFrame with only the highest-priority visit per `visit_id`.
Parameters
----------
data : DataFrame
A PySpark DataFrame containing the following columns:
- `visit_id`: Unique identifier for each visit.
- `code`: String code indicating the type of visit, which determines visit priority.
Returns
-------
DataFrame
The input DataFrame with duplicates removed based on `visit_id` and `code` priority.
Only the highest-priority visit is retained for each `visit_id`.
"""
data = data.withColumn(
"priority",
f.when(f.col("code").isin(["Visit/IP", "Visit/ERIP"]), 1).otherwise(
f.when(f.col("code") == "Visit/ER", 2).otherwise(3)
)
).withColumn(
"visit_rank",
f.row_number().over(Window.partitionBy("visit_id").orderBy(f.col("priority")))
).where(
f.col("visit_rank") == 1
).drop(
"visit_rank",
"priority"
)
return data


def main(args):
spark = SparkSession.builder.appName("Convert EHRShot Data").getOrCreate()

Expand Down Expand Up @@ -528,20 +571,7 @@ def main(args):

# There could be multiple visit
if domain_table_name == "visit_occurrence":
domain_table = domain_table.withColumn(
"priority",
f.when(f.col("code").isin(["Visit/IP", "Visit/ERIP"]), 1).otherwise(
f.when(f.col("code") == "Visit/ER", 2).otherwise(3)
)
).withColumn(
"visit_rank",
f.row_number().over(Window.partitionBy("visit_id").orderBy(f.col("priority")))
).where(
f.col("visit_rank") == 1
).drop(
"visit_rank",
"priority"
)
domain_table = drop_duplicate_visits(domain_table)

domain_table.drop(
*original_columns
Expand Down
49 changes: 48 additions & 1 deletion tests/unit_tests/test_ehrshot_to_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from pyspark.sql import functions as f
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType, FloatType
from cehrbert_data.tools.ehrshot_to_omop import (
map_unit, map_answer, create_omop_person, convert_code_to_omop_concept, extract_value, generate_visit_id
map_unit,
map_answer,
create_omop_person,
convert_code_to_omop_concept,
extract_value,
generate_visit_id,
drop_duplicate_visits
)


Expand Down Expand Up @@ -92,6 +98,47 @@ def test_generate_visit_id(self):
1, patient_3_visits, "Patient 2 should have one visit as events are within time interval."
)

def test_drop_duplicate_visits(self):
# Define schema for input DataFrame
schema = StructType([
StructField("visit_id", IntegerType(), True),
StructField("code", StringType(), True)
])

# Sample data with duplicate visit IDs and varying priorities
data = [
(1, "Visit/IP"), # Highest priority for visit_id 1
(1, "Visit/ER"), # Lower priority for visit_id 1
(2, "Visit/OP"), # Lowest priority for visit_id 2
(2, "Visit/ER"), # Medium priority for visit_id 2
(3, "Visit/ERIP"), # Highest priority for visit_id 3
(3, "Visit/OP"), # Lower priority for visit_id 3
(4, "Visit/OP") # Highest priority for visit_id 4
]

# Create DataFrame
data = self.spark.createDataFrame(data, schema=schema)

# Run the function to drop duplicates
result_df = drop_duplicate_visits(data)

# Define expected data and schema
expected_data = [
(1, "Visit/IP"), # Only highest priority Visit/IP retained for visit_id 1
(2, "Visit/ER"), # Only medium priority Visit/ER retained for visit_id 2
(3, "Visit/ERIP"), # Only highest priority Visit/ERIP retained for visit_id 3
(4, "Visit/OP") # Only highest priority Visit/OP retained for visit_id 3
]

expected_df = self.spark.createDataFrame(expected_data, schema=data.schema)

# Collect results for comparison
actual_data = result_df.sort("visit_id").collect()
expected_data = expected_df.sort("visit_id").collect()

# Check that the actual data matches the expected data
self.assertEqual(actual_data, expected_data, "The DataFrames do not match the expected result.")

def test_extract_value(self):
current_time = datetime.now()
# Create DataFrames
Expand Down

0 comments on commit 07930d7

Please sign in to comment.