From 2293aaa881575e367e6f7cdf8b15424231beb05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Albuquerque?= Date: Fri, 8 Mar 2024 11:21:08 -0300 Subject: [PATCH] feat(updating-auto-fs-creation): add docstring and auto-infer by df --- butterfree/automated/feature_set_creation.py | 34 ++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/butterfree/automated/feature_set_creation.py b/butterfree/automated/feature_set_creation.py index a56ad30d..bc8580ba 100644 --- a/butterfree/automated/feature_set_creation.py +++ b/butterfree/automated/feature_set_creation.py @@ -111,6 +111,15 @@ def _get_tables_with_regex(self, sql_query: str) -> Tuple[List[Table], str]: return tables, modified_sql_query def get_readers(self, sql_query: str) -> str: + """ + Extracts table readers from a SQL query and formats them as a string. + + Args: + sql_query (str): The SQL query from which to extract table readers. + + Returns: + str: A formatted string containing the table readers. + """ tables, modified_sql_query = self._get_tables_with_regex(sql_query.lower()) readers = [] for table in tables: @@ -122,6 +131,7 @@ def get_readers(self, sql_query: str) -> str: ), """ readers.append(table_reader_string) + final_string = """ source=Source( readers=[ @@ -139,7 +149,23 @@ def get_readers(self, sql_query: str) -> str: return final_string - def get_features(self, sql_query: str, df: Optional[DataFrame]) -> str: + def get_features(self, sql_query: str, df: Optional[DataFrame] = None) -> str: + """ + Extract features from a SQL query and return them formatted as a string. + + Args: + sql_query (str): The SQL query used to extract features. + df (Optional[DataFrame], optional): Optional DataFrame used to infer data types. Defaults to None. + + Returns: + str: A formatted string containing the extracted features. + + This sould be used on Databricks. + + Especially if you want automatic type inference without passing a reference dataframe. + The utility will only work in an environment where a spark session is available in the environment + """ + features = self._get_features_with_regex(sql_query) features_formatted = [] for feature in features: @@ -147,8 +173,10 @@ def get_features(self, sql_query: str, df: Optional[DataFrame]) -> str: data_type = "." - if df and isinstance(df, DataFrame): - data_type = self._get_data_type(feature, df) + if df is None: + df = spark.sql(sql_query) + + data_type = self._get_data_type(feature, df) feature_string = f""" Feature(