Skip to content

Commit

Permalink
take schema and catalog off the DatabricksConnection class
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanSherwin committed Oct 10, 2024
1 parent 5c0dc4a commit 8b0be62
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 90 deletions.
29 changes: 8 additions & 21 deletions raster_loader/cli/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,14 @@ def databricks(args=None):
@click.option("--host", help="The Databricks host URL.", required=True)
@click.option("--token", help="The Databricks access token.", required=True)
@click.option("--cluster-id", help="The Databricks cluster ID.", required=True) # New option
@click.option(
"--file_path", help="The path to the raster file.", required=False, default=None
)
@click.option(
"--file_url", help="The URL to the raster file.", required=False, default=None
)
@click.option("--file_path", help="The path to the raster file.", required=False, default=None)
@click.option("--file_url", help="The URL to the raster file.", required=False, default=None)
@click.option("--catalog", help="The name of the catalog.", required=True)
@click.option("--schema", help="The name of the schema.", required=True)
@click.option("--table", help="The name of the table.", default=None)
@click.option(
"--band",
help="Band(s) within raster to upload. "
"Could repeat --band to specify multiple bands.",
help="Band(s) within raster to upload. " "Could repeat --band to specify multiple bands.",
default=[1],
multiple=True,
)
Expand All @@ -55,9 +50,7 @@ def databricks(args=None):
default=[None],
multiple=True,
)
@click.option(
"--chunk_size", help="The number of blocks to upload in each chunk.", default=10000
)
@click.option("--chunk_size", help="The number of blocks to upload in each chunk.", default=10000)
@click.option(
"--overwrite",
help="Overwrite existing data in the table if it already exists.",
Expand Down Expand Up @@ -121,16 +114,10 @@ def upload(

# Create default table name if not provided
if table is None:
table = get_default_table_name(
file_path if is_local_file else urlparse(file_url).path, band
)
table = get_default_table_name(file_path if is_local_file else urlparse(file_url).path, band)

connector = DatabricksConnection(
host=host,
token=token,
cluster_id=cluster_id, # Pass cluster_id to DatabricksConnection
catalog=catalog,
schema=schema,
host=host, token=token, cluster_id=cluster_id # Pass cluster_id to DatabricksConnection
)

source = file_path if is_local_file else file_url
Expand All @@ -156,9 +143,10 @@ def upload(

click.echo("Uploading Raster to Databricks")

fqn = f"`{catalog}`.{schema}.{table}"
connector.upload_raster(
source,
table,
fqn,
bands_info,
chunk_size,
overwrite=overwrite,
Expand All @@ -168,4 +156,3 @@ def upload(

click.echo("Raster file uploaded to Databricks")
return 0

105 changes: 36 additions & 69 deletions raster_loader/io/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,41 @@
else:
_has_databricks = True


class DatabricksConnection(DataWarehouseConnection):
def __init__(self, host, token, cluster_id, catalog, schema):
def __init__(self, host, token, cluster_id):
if not _has_databricks:
import_error_databricks()

self.host = host
self.token = token
self.cluster_id = cluster_id
self.catalog = catalog
self.schema = schema

self.client = self.get_connection()

def get_connection(self):
# Initialize DatabricksSession
session = DatabricksSession.builder.remote(host=self.host, token=self.token, cluster_id=self.cluster_id).getOrCreate()
session = DatabricksSession.builder.remote(
host=self.host, token=self.token, cluster_id=self.cluster_id
).getOrCreate()
session.conf.set("spark.databricks.session.timeout", "6h")
return session

def get_table_fqn(self, table):
return f"`{self.catalog}`.{self.schema}.{table}"

def execute(self, sql):
# NOTE: if you get empty sql statement errors check runtime v databricks-connect version
# https://community.databricks.com/t5/data-engineering/parse-empty-statement-error-when-trying-to-use-spark-sql-via/td-p/80770
return self.client.sql(sql)

def execute_to_dataframe(self, sql):
df = self.execute(sql)
return df.toPandas()

def create_schema_if_not_exists(self):
self.execute(f"CREATE SCHEMA IF NOT EXISTS `{self.catalog}`.{self.schema}")
def create_schema_if_not_exists(self, fqn):
schema_name = fqn.split(".")[1] # Extract schema from FQN
self.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}")

def create_table_if_not_exists(self, table):
def create_table_if_not_exists(self, fqn):
self.execute(
f"""
CREATE TABLE IF NOT EXISTS `{self.catalog}`.{self.schema}.{table} (
CREATE TABLE IF NOT EXISTS {fqn} (
BLOCK BIGINT,
METADATA STRING,
{self.band_columns}
Expand All @@ -86,9 +83,8 @@ def write_metadata(
self,
metadata,
append_records,
table,
fqn,
):
# Create a DataFrame with the metadata
schema = StructType(
[
StructField("BLOCK", LongType(), True),
Expand All @@ -101,13 +97,11 @@ def write_metadata(
metadata_df = self.client.createDataFrame(data, schema)

# Write to table
fqn = self.get_table_fqn(table)
metadata_df.write.format("delta").mode("append").saveAsTable(fqn)

return True

def get_metadata(self, table):
fqn = self.get_table_fqn(table)
def get_metadata(self, fqn):
query = f"""
SELECT METADATA
FROM {fqn}
Expand All @@ -118,52 +112,44 @@ def get_metadata(self, table):
return None
return json.loads(result.iloc[0]["METADATA"])

def check_if_table_exists(self, table):
def check_if_table_exists(self, fqn):
schema_name, table_name = fqn.split(".")[1:3] # Extract schema and table
sql = f"""
SELECT *
FROM `{self.catalog}`.INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = '{self.schema}'
AND TABLE_NAME = '{table}'
FROM {schema_name}.INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = '{table_name}'
"""
df = self.execute(sql)
# If the count is greater than 0, the table exists
return df.count() > 0

def check_if_table_is_empty(self, table):
fqn = self.get_table_fqn(table)
def check_if_table_is_empty(self, fqn):
df = self.client.table(fqn)
return df.count() == 0

def upload_records(
self,
records: Iterable,
table: str,
fqn: str,
overwrite: bool,
):
fqn = self.get_table_fqn(table)
records_list = []
for record in records:
# Remove 'METADATA' from records, as it's handled separately
if "METADATA" in record:
del record["METADATA"]
records_list.append(record)

data_df = pd.DataFrame(records_list)
spark_df = self.client.createDataFrame(data_df)

if overwrite:
mode = "overwrite"
else:
mode = "append"

mode = "overwrite" if overwrite else "append"
spark_df.write.format("delta").mode(mode).saveAsTable(fqn)

return True

def upload_raster(
self,
file_path: str,
table: str,
fqn: str,
bands_info: List[Tuple[int, str]] = None,
chunk_size: int = None,
overwrite: bool = False,
Expand All @@ -173,34 +159,25 @@ def upload_raster(
print("Loading raster file to Databricks...")

bands_info = bands_info or [(1, None)]

append_records = False

try:
if (
self.check_if_table_exists(table)
and not self.check_if_table_is_empty(table)
and not overwrite
):
if self.check_if_table_exists(fqn) and not self.check_if_table_is_empty(fqn) and not overwrite:
append_records = append or ask_yes_no_question(
f"Table `{self.catalog}`.{self.schema}.{table} already exists "
"and is not empty. Append records? [yes/no] "
f"Table {fqn} already exists and is not empty. Append records? [yes/no] "
)

if not append_records:
exit()

# Prepare band columns
self.band_columns = ", ".join(
[
f"{self.band_rename_function(band_name or f'band_{band}')} BINARY"
for band, band_name in bands_info
]
[f"{self.band_rename_function(band_name or f'band_{band}')} BINARY" for band, band_name in bands_info]
)

# Create schema and table if not exists
self.create_schema_if_not_exists()
self.create_table_if_not_exists(table)
self.create_schema_if_not_exists(fqn)
self.create_table_if_not_exists(fqn)

metadata = rasterio_metadata(file_path, bands_info, self.band_rename_function)

Expand All @@ -213,7 +190,7 @@ def upload_raster(
total_blocks = get_number_of_blocks(file_path)

if chunk_size is None:
ret = self.upload_records(records_gen, table, overwrite)
ret = self.upload_records(records_gen, fqn, overwrite)
if not ret:
raise IOError("Error uploading to Databricks.")
else:
Expand All @@ -225,21 +202,19 @@ def upload_raster(
chunk_size = total_blocks
isFirstBatch = True
for records in batched(records_gen, chunk_size):
ret = self.upload_records(
records, table, overwrite and isFirstBatch
)
ret = self.upload_records(records, fqn, overwrite and isFirstBatch)
pbar.update(len(records))
if not ret:
raise IOError("Error uploading to Databricks.")
isFirstBatch = False

print("Writing metadata to Databricks...")
if append_records:
old_metadata = self.get_metadata(table)
old_metadata = self.get_metadata(fqn)
check_metadata_is_compatible(metadata, old_metadata)
update_metadata(metadata, old_metadata)

self.write_metadata(metadata, append_records, table)
self.write_metadata(metadata, append_records, fqn)

except IncompatibleRasterException as e:
raise IOError(f"Error uploading to Databricks: {e.message}")
Expand All @@ -250,48 +225,40 @@ def upload_raster(
)

if delete:
self.delete_table(table)
self.delete_table(fqn)

raise KeyboardInterrupt

except Exception as e:
delete = cleanup_on_failure or ask_yes_no_question(
(
"Error uploading to Databricks. "
"Would you like to delete the partially uploaded table? [yes/no] "
)
("Error uploading to Databricks. " "Would you like to delete the partially uploaded table? [yes/no] ")
)

if delete:
self.delete_table(table)
self.delete_table(fqn)

raise IOError(f"Error uploading to Databricks: {e}")

print("Done.")
return True

def delete_table(self, table):
fqn = self.get_table_fqn(table)
def delete_table(self, fqn):
self.execute(f"DROP TABLE IF EXISTS {fqn}")

def get_records(self, table: str, limit=10) -> pd.DataFrame:
fqn = self.get_table_fqn(table)
def get_records(self, fqn: str, limit=10) -> pd.DataFrame:
query = f"SELECT * FROM {fqn} LIMIT {limit}"
df = self.execute_to_dataframe(query)
return df

def insert_in_table(
self,
rows: List[dict],
table: str,
fqn: str,
) -> bool:
fqn = self.get_table_fqn(table)
data_df = pd.DataFrame(rows)
spark_df = self.client.createDataFrame(data_df)
spark_df.write.format("delta").mode("append").saveAsTable(fqn)
return True


def quote_name(self, name):
return f"`{name}`"

0 comments on commit 8b0be62

Please sign in to comment.