Skip to content

Commit

Permalink
added the clean_up_unit function for cleaning up the unit string
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 3, 2024
1 parent 4b2c630 commit 84e1075
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/cehrbert_data/utils/spark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame:
datetime_field,
table_domain_field
) in get_key_fields(domain_table):
domain_id_field = get_domain_id_field(domain_table)
# Remove records that don't have a date or standard_concept_id
filtered_domain_table = domain_table.where(F.col(date_field).isNotNull()).where(
F.col(concept_id_field).isNotNull()
Expand All @@ -165,7 +164,6 @@ def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame:
filtered_domain_table.where(F.col(concept_id_field).cast("string") != "0")
.withColumn("date", F.to_date(F.col(date_field)))
.withColumn("datetime", datetime_field_udf)
.withColumn("domain_id", F.col(domain_id_field).cast("string"))
)

unit_udf = F.col("unit") if domain_has_unit(filtered_domain_table) else F.lit(None).cast("string")
Expand All @@ -176,7 +174,7 @@ def join_domain_tables(domain_tables: List[DataFrame]) -> DataFrame:
filtered_domain_table["datetime"],
filtered_domain_table["visit_occurrence_id"],
F.lit(table_domain_field).alias("domain"),
F.concat(F.lit(table_domain_field), F.lit("-"), F.col("domain_id")).alias("event_group_id"),
F.lit(None).cast("string").alias("event_group_id"),
F.lit(0.0).alias("concept_value"),
unit_udf.alias("unit"),
).distinct()
Expand Down Expand Up @@ -1349,6 +1347,16 @@ def pandas_udf_to_att(time_intervals: pd.Series) -> pd.Series:
return visit_occurrence.join(person, "person_id")


def clean_up_unit(dataframe: DataFrame) -> DataFrame:
return dataframe.withColumn(
"unit",
F.regexp_replace(F.col("unit"), r"\{.*?\}", "")
).withColumn(
"unit",
F.regexp_replace(F.col("unit"), r"^/", "")
)


def process_measurement(
spark,
measurement: DataFrame,
Expand All @@ -1369,9 +1377,8 @@ def process_measurement(
# Get the standard units from the concept_name
measurement = measurement.join(
concept.select("concept_id", "concept_code"), measurement.unit_concept_id == concept.concept_id, "left"
).withColumn(
"unit", F.coalesce(F.col("concept_code"), F.lit(None).cast("string"))
).drop("concept_id", "concept_name")
measurement = clean_up_unit(measurement)

# Register the tables in spark context
measurement.createOrReplaceTempView(MEASUREMENT)
Expand All @@ -1388,7 +1395,7 @@ def process_measurement(
'measurement' AS domain,
m.unit,
m.value_as_number AS concept_value,
CONCAT('measurement-', CAST(m.measurement_id AS STRING)) AS event_group_id
CAST(NULL AS STRING) AS event_group_id
FROM measurement AS m
WHERE m.visit_occurrence_id IS NOT NULL
AND m.value_as_number IS NOT NULL
Expand All @@ -1412,7 +1419,7 @@ def process_measurement(
'categorical_measurement' AS domain,
CAST(NULL AS STRING) AS unit,
0.0 AS concept_value,
CONCAT('measurement-', CAST(m.measurement_id AS STRING)) AS event_group_id
CONCAT('mea-', CAST(m.measurement_id AS STRING)) AS event_group_id
FROM measurement AS m
WHERE EXISTS (
SELECT
Expand Down
Empty file added tests/unit_tests/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions tests/unit_tests/test_spark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
from pyspark.sql import SparkSession
from cehrbert_data.utils.spark_utils import clean_up_unit


# Define the test case
class CleanUpUnitTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Initialize the Spark session for testing
cls.spark = SparkSession.builder.appName("UnitTest").getOrCreate()

@classmethod
def tearDownClass(cls):
# Stop the Spark session after tests are done
cls.spark.stop()

def test_clean_up_unit(self):
# Create a sample DataFrame
test_data = [
("/mg/dL{adult}",), # Contains both a leading / and a curly bracket
("kg/m2{child}",), # Contains only a curly bracket
("{adult}/mmHg",), # Contains only a leading /
("g/L",), # Contains neither
("/min",),
]
df = self.spark.createDataFrame(test_data, ["unit"])

# Call the function to clean up the units
cleaned_df = clean_up_unit(df)

# Expected results after cleaning
expected_data = [
("mg/dL",), # Removed both curly bracket content and leading /
("kg/m2",), # Removed curly bracket content
("mmHg",), # Removed leading /
("g/L",), # No change
("min",),
]
expected_df = self.spark.createDataFrame(expected_data, ["unit"])

# Collect the actual and expected results
actual_result = cleaned_df.collect()
expected_result = expected_df.collect()

# Compare the actual and expected results
self.assertEqual(actual_result, expected_result)


# Entry point for running the tests
if __name__ == "__main__":
unittest.main()

0 comments on commit 84e1075

Please sign in to comment.