Skip to content

Commit

Permalink
feat: add column update
Browse files Browse the repository at this point in the history
  • Loading branch information
Eli Yarson committed Apr 18, 2024
1 parent 94f7bc0 commit b7b0d29
Show file tree
Hide file tree
Showing 5 changed files with 753 additions and 4 deletions.
31 changes: 30 additions & 1 deletion snowflake_utils/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import typer
from typing_extensions import Annotated

from .models import FileFormat, InlineFileFormat, Table
from .models import FileFormat, InlineFileFormat, Table, Schema, Column
from .queries import connect
import logging

app = typer.Typer()

Expand All @@ -27,5 +29,32 @@ def copy(
)


@app.command()
def mass_single_column_update(
schema: Annotated[str, typer.Argument()],
target_column: Annotated[str, typer.Argument()],
new_column: Annotated[str, typer.Argument()],
data_type: Annotated[str, typer.Argument()],
) -> None:
db_schema = Schema(name=schema)
target_column = Column(name=target_column, data_type=data_type)
new_column = Column(name=new_column, data_type=data_type)
logging.getLogger().setLevel("DEBUG")
with connect() as conn, conn.cursor() as cursor:
tables = db_schema.get_tables(cursor=cursor)
for table in tables:
columns = table.get_columns(cursor=cursor)
column_names = [str.upper(column.name) for column in columns]
if (
str.upper(target_column.name) in column_names
and str.upper(new_column.name) in column_names
):
table.single_column_update(
cursor=cursor, target_column=target_column, new_column=new_column
)
else:
logging.debug("One or both of the columns don't exist in the table")


if __name__ == "__main__":
app()
33 changes: 33 additions & 0 deletions snowflake_utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,28 @@ class Column(BaseModel):
data_type: str


class Schema(BaseModel):
name: str
database: str | None = None

@property
def fully_qualified_name(self):
if self.database:
return f"{self.database}.{self.name}"
else:
return self.name

def get_tables(self, cursor: SnowflakeCursor):
cursor.execute(f"show tables in schema {self.fully_qualified_name};")
data = cursor.execute(
'select "name", "database_name", "schema_name" FROM TABLE(RESULT_SCAN(LAST_QUERY_ID()));'
).fetchall()
return [
Table(name=name, schema_=schema, database=database)
for (name, database, schema, *_) in data
]


class Table(BaseModel):
name: str
schema_: str
Expand Down Expand Up @@ -300,6 +322,17 @@ def drop(self, cursor: SnowflakeCursor) -> None:
logging.debug(f"Dropping table:{self.schema_}.{self.name}")
cursor.execute(f"drop table {self.schema_}.{self.name}")

def single_column_update(
self, cursor: SnowflakeCursor, target_column: Column, new_column: Column
):
"""Updates the value of one column with the value of another column in the same table."""
logging.debug(
f"Swapping the value of {target_column.name} with {new_column.name} in the table {self.name}"
)
cursor.execute(
f"UPDATE {self.schema_}.{self.name} SET {target_column.name} = {new_column.name};"
)


def _possibly_cast(s: str, old_column_type: str, new_column_type: str) -> str:
if old_column_type == "VARIANT" and new_column_type != "VARIANT":
Expand Down
Loading

0 comments on commit b7b0d29

Please sign in to comment.