Skip to content

Commit

Permalink
move metric tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yotam-Perlitz <[email protected]>
  • Loading branch information
perlitz committed Jan 21, 2025
1 parent be557ff commit 9ae0061
Showing 1 changed file with 180 additions and 180 deletions.
360 changes: 180 additions & 180 deletions tests/library/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,186 @@ def test_perplexity_with_prefix(self):
instance_outputs=[outputs[0]["score"]["instance"]],
)

def test_execution_accuracy_correct_query_mock_db(self):
metric = ExecutionAccuracy()
predictions = ["SELECT name FROM employees WHERE department = 'Sales'"]
references = ["SELECT name FROM employees WHERE department = 'Sales';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_different_db_schema(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT product_name, price FROM products WHERE category = 'Electronics'"
]
references = [
"SELECT product_name, price FROM products WHERE category = 'Electronics';"
]
task_data = [
{
"db": {
"db_id": "products_db",
"db_type": "in_memory",
"data": {
"products": {
"columns": [
"product_id",
"product_name",
"category",
"price",
],
"rows": [
(1, "Laptop", "Electronics", 1200),
(2, "Mouse", "Electronics", 25),
(3, "Shirt", "Clothing", 50),
(4, "Monitor", "Electronics", 300),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_multiple_tables(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT o.order_id, c.name FROM orders AS o JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped'"
]
references = [
"SELECT o.order_id, c.name FROM orders AS o INNER JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped';"
]
task_data = [
{
"db": {
"db_id": "sales_db",
"db_type": "in_memory",
"data": {
"customers": {
"columns": ["customer_id", "name", "city"],
"rows": [
(1, "John Doe", "New York"),
(2, "Jane Smith", "Los Angeles"),
(3, "David Lee", "Chicago"),
],
},
"orders": {
"columns": ["order_id", "customer_id", "status"],
"rows": [
(101, 1, "Shipped"),
(102, 2, "Pending"),
(103, 1, "Shipped"),
],
},
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_empty_result(self):
metric = ExecutionAccuracy()
predictions = ["SELECT name FROM employees WHERE department = 'HR'"]
references = ["SELECT name FROM employees WHERE department = 'HR';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_aggregation_query(self):
metric = ExecutionAccuracy()
predictions = ["SELECT AVG(salary) FROM employees"]
references = ["SELECT AVG(salary) FROM employees;"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_incorrect_query(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT nme FROM employees WHERE department = 'Sales'"
] # Incorrect column name 'nme'
references = ["SELECT name FROM employees WHERE department = 'Sales';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(0.0, outputs["score"])


class TestConfidenceIntervals(UnitxtTestCase):
def test_confidence_interval_off(self):
Expand Down Expand Up @@ -2114,183 +2294,3 @@ def test_metrics_ensemble(self):
instance_targets=instance_targets,
global_target=global_target,
)

def test_execution_accuracy_correct_query_mock_db(self):
metric = ExecutionAccuracy()
predictions = ["SELECT name FROM employees WHERE department = 'Sales'"]
references = ["SELECT name FROM employees WHERE department = 'Sales';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_different_db_schema(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT product_name, price FROM products WHERE category = 'Electronics'"
]
references = [
"SELECT product_name, price FROM products WHERE category = 'Electronics';"
]
task_data = [
{
"db": {
"db_id": "products_db",
"db_type": "in_memory",
"data": {
"products": {
"columns": [
"product_id",
"product_name",
"category",
"price",
],
"rows": [
(1, "Laptop", "Electronics", 1200),
(2, "Mouse", "Electronics", 25),
(3, "Shirt", "Clothing", 50),
(4, "Monitor", "Electronics", 300),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_multiple_tables(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT o.order_id, c.name FROM orders AS o JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped'"
]
references = [
"SELECT o.order_id, c.name FROM orders AS o INNER JOIN customers AS c ON o.customer_id = c.customer_id WHERE o.status = 'Shipped';"
]
task_data = [
{
"db": {
"db_id": "sales_db",
"db_type": "in_memory",
"data": {
"customers": {
"columns": ["customer_id", "name", "city"],
"rows": [
(1, "John Doe", "New York"),
(2, "Jane Smith", "Los Angeles"),
(3, "David Lee", "Chicago"),
],
},
"orders": {
"columns": ["order_id", "customer_id", "status"],
"rows": [
(101, 1, "Shipped"),
(102, 2, "Pending"),
(103, 1, "Shipped"),
],
},
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_empty_result(self):
metric = ExecutionAccuracy()
predictions = ["SELECT name FROM employees WHERE department = 'HR'"]
references = ["SELECT name FROM employees WHERE department = 'HR';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_aggregation_query(self):
metric = ExecutionAccuracy()
predictions = ["SELECT AVG(salary) FROM employees"]
references = ["SELECT AVG(salary) FROM employees;"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(1.0, outputs["score"])

def test_execution_accuracy_incorrect_query(self):
metric = ExecutionAccuracy()
predictions = [
"SELECT nme FROM employees WHERE department = 'Sales'"
] # Incorrect column name 'nme'
references = ["SELECT name FROM employees WHERE department = 'Sales';"]
task_data = [
{
"db": {
"db_id": "mock_db",
"db_type": "in_memory",
"data": {
"employees": {
"columns": ["id", "name", "department", "salary"],
"rows": [
(1, "Alice", "Sales", 50000),
(2, "Bob", "Engineering", 60000),
(3, "Charlie", "Sales", 55000),
],
}
},
}
}
]

outputs = metric.compute(references, predictions[0], task_data[0])
self.assertEqual(0.0, outputs["score"])

0 comments on commit 9ae0061

Please sign in to comment.