From 8cb0a4f24fbacef67960baadb9e081d2e1988e7d Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Wed, 30 Oct 2024 22:48:57 -0400 Subject: [PATCH] added test_convert_code_to_omop_concept --- tests/unit_tests/test_ehrshot_to_omop.py | 59 +++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_ehrshot_to_omop.py b/tests/unit_tests/test_ehrshot_to_omop.py index 1f6217b..ef7a119 100644 --- a/tests/unit_tests/test_ehrshot_to_omop.py +++ b/tests/unit_tests/test_ehrshot_to_omop.py @@ -2,7 +2,9 @@ from datetime import datetime from pyspark.sql import SparkSession from pyspark.sql.types import StructType, StructField, StringType, IntegerType, TimestampType -from cehrbert_data.tools.ehrshot_to_omop import map_unit, map_answer, create_omop_person +from cehrbert_data.tools.ehrshot_to_omop import ( + map_unit, map_answer, create_omop_person, convert_code_to_omop_concept +) # Define the test case @@ -39,6 +41,61 @@ def setUp(self): StructField("start", TimestampType(), True) ]) + def test_convert_code_to_omop_concept(self): + # Define schemas for input DataFrames + data_schema = StructType([ + StructField("patient_id", IntegerType(), True), + StructField("code", StringType(), True) + ]) + + concept_schema = StructType([ + StructField("vocabulary_id", StringType(), True), + StructField("concept_code", StringType(), True), + StructField("concept_id", IntegerType(), True) + ]) + + # Sample data for testing + data = [ + (1, "ICD10/1234"), + (2, "SNOMED/5678"), + (3, "ICD10/0000") # No matching concept + ] + + concept_data = [ + ("ICD10", "1234", 1001), + ("SNOMED", "5678", 1002) + ] + + # Create DataFrames + data = self.spark.createDataFrame(data, schema=data_schema) + concept = self.spark.createDataFrame(concept_data, schema=concept_schema) + # Run function + actual_df = convert_code_to_omop_concept(data, concept, "code") + + # Define expected data and schema + expected_data = [ + (1, "ICD10/1234", "ICD10", "1234", 1001), # Match with concept_id 1001 + (2, "SNOMED/5678", "SNOMED", "5678", 1002), # Match with concept_id 1002 + (3, "ICD10/0000", "ICD10", "0000", 0) # No match, default concept_id 0 + ] + + expected_schema = StructType([ + StructField("patient_id", IntegerType(), True), + StructField("code", StringType(), True), + StructField("vocabulary_id", StringType(), True), + StructField("concept_code", StringType(), True), + StructField("concept_id", IntegerType(), True) + ]) + + expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema) + + # Collect results for comparison + actual_data = actual_df.sort("patient_id").collect() + expected_data = expected_df.sort("patient_id").collect() + + # Compare results + self.assertEqual(actual_data, expected_data) + def test_create_omop_person(self): # Sample concept data for mapping demographic codes to concept_ids concept_data = [