Skip to content

Commit

Permalink
run pre commit
Browse files Browse the repository at this point in the history
  • Loading branch information
blublinsky committed Aug 14, 2024
1 parent d33f3dc commit 625132c
Show file tree
Hide file tree
Showing 21 changed files with 163 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class TransformConfiguration(CLIArgumentProvider):
This is a base transform configuration class defining transform's input/output parameter
"""

def __init__(self, name: str, transform_class: type[AbstractBinaryTransform], remove_from_metadata: list[str] = []):
def __init__(
self, name: str, transform_class: type[AbstractBinaryTransform], remove_from_metadata: list[str] = []
):
"""
Initialization
:param name: transformer name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def add_input_params(self, parser: argparse.ArgumentParser) -> None:
typically determined based on the cluster configuration or the available resources
(number of workers).
"""
parser.add_argument(f"--{cli_prefix}parallelization", type=int,
default=-1,
help="parallelization.")
parser.add_argument(f"--{cli_prefix}parallelization", type=int, default=-1, help="parallelization.")
return TransformExecutionConfiguration.add_input_params(self, parser=parser)

def apply_input_params(self, args: argparse.Namespace) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from typing import Any

from data_processing.runtime import AbstractTransformFileProcessor
from data_processing.data_access import DataAccessFactoryBase
from data_processing.runtime import AbstractTransformFileProcessor
from data_processing_spark.transform import SparkTransformRuntimeConfiguration


Expand All @@ -23,15 +23,17 @@ class SparkTransformFileProcessor(AbstractTransformFileProcessor):
"""

def __init__(
self, data_access_factory: DataAccessFactoryBase,
runtime_configuration: SparkTransformRuntimeConfiguration,
statistics: dict[str, Any]
self,
data_access_factory: DataAccessFactoryBase,
runtime_configuration: SparkTransformRuntimeConfiguration,
statistics: dict[str, Any],
):
"""
Init method
"""
super().__init__(data_access_factory=data_access_factory,
transform_parameters=runtime_configuration.get_transform_params())
super().__init__(
data_access_factory=data_access_factory, transform_parameters=runtime_configuration.get_transform_params()
)
# Add data access ant statistics to the processor parameters
self.runtime_configuration = runtime_configuration
self.transform = None
Expand All @@ -46,8 +48,9 @@ def create_transform(self, partition: int):
:return: None
"""
# Create local processor
self.transform = (self.runtime_configuration.get_transform_class()
(self.transform_params | {"partition_index": partition}))
self.transform = self.runtime_configuration.get_transform_class()(
self.transform_params | {"partition_index": partition}
)

def _publish_stats(self, stats: dict[str, Any]) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
import time

from data_processing.data_access import DataAccessFactory, DataAccessFactoryBase
from data_processing_spark.runtime.spark import orchestrate, SparkTransformExecutionConfiguration
from data_processing_spark.transform import SparkTransformRuntimeConfiguration
from data_processing.runtime.transform_launcher import AbstractTransformLauncher
from data_processing.utils import get_logger
from data_processing_spark.runtime.spark import (
SparkTransformExecutionConfiguration,
orchestrate,
)
from data_processing_spark.transform import SparkTransformRuntimeConfiguration


logger = get_logger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
# limitations under the License.
################################################################################

import time
import traceback
from datetime import datetime
import time

from data_processing.data_access import DataAccessFactoryBase
from data_processing.runtime import TransformRuntimeConfiguration
from data_processing.utils import get_logger
from pyspark import SparkContext, SparkConf
from data_processing_spark.runtime.spark import SparkTransformFileProcessor, SparkTransformExecutionConfiguration
from data_processing_spark.runtime.spark import (
SparkTransformExecutionConfiguration,
SparkTransformFileProcessor,
)
from pyspark import SparkConf, SparkContext


logger = get_logger(__name__)
Expand All @@ -45,8 +48,7 @@ def orchestrate(
logger.error("No DataAccess instance provided - exiting")
return 1
# initialize Spark
conf = (SparkConf().setAppName(runtime_config.get_name())
.set('spark.driver.host', '127.0.0.1'))
conf = SparkConf().setAppName(runtime_config.get_name()).set("spark.driver.host", "127.0.0.1")
sc = SparkContext(conf=conf)
transform_config = sc.broadcast(runtime_config)
daf = sc.broadcast(data_access_factory)
Expand All @@ -60,9 +62,9 @@ def process_partition(iterator):
# local statistics dictionary
statistics = {}
# create file processor
file_processor = SparkTransformFileProcessor(data_access_factory=daf.value,
runtime_configuration=transform_config.value,
statistics=statistics)
file_processor = SparkTransformFileProcessor(
data_access_factory=daf.value, runtime_configuration=transform_config.value, statistics=statistics
)
first = True
for f in iterator:
# for every file
Expand Down Expand Up @@ -98,7 +100,7 @@ def process_partition(iterator):
logger.info(f"Parallelizing execution. Using {num_partitions} partitions")
stats_rdd = source_rdd.zipWithIndex().mapPartitions(process_partition)
# build overall statistics
stats = dict(stats_rdd.reduceByKey(lambda a, b: a+b).collect())
stats = dict(stats_rdd.reduceByKey(lambda a, b: a + b).collect())
return_code = 0
status = "success"
except Exception as e:
Expand All @@ -113,15 +115,18 @@ def process_partition(iterator):
input_params = runtime_config.get_transform_metadata() | execution_config.get_input_params()
metadata = {
"pipeline": execution_config.pipeline_id,
"job details": execution_config.job_details |
{
"start_time": start_ts,
"end_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"status": status,
},
"job details": execution_config.job_details
| {
"start_time": start_ts,
"end_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"status": status,
},
"code": execution_config.code_location,
"job_input_params": input_params | data_access_factory.get_input_params(),
"execution_stats": {"num partitions": num_partitions, "execution time, min": (time.time() - start_time) / 60},
"execution_stats": {
"num partitions": num_partitions,
"execution time, min": (time.time() - start_time) / 60,
},
"job_output_stats": stats,
}
logger.debug(f"Saving job metadata: {metadata}.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# limitations under the License.
################################################################################

from data_processing_spark.test_support.transform.noop_transform import NOOPSparkTransformConfiguration
from data_processing_spark.test_support.transform.noop_transform import NOOPSparkTransformConfiguration
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from data_processing.test_support.transform.noop_transform import NOOPTransformConfiguration
from data_processing.test_support.transform.noop_transform import (
NOOPTransformConfiguration,
)
from data_processing.utils import get_logger
from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing_spark.transform import SparkTransformRuntimeConfiguration


logger = get_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_s3_config_validate():
"data_s3_cred": ParamsUtils.convert_to_ast(s3_cred),
"runtime_pipeline_id": "pipeline_id",
"runtime_job_id": "job_id",
"runtime_code_location": ParamsUtils.convert_to_ast(code_location),
"runtime_code_location": ParamsUtils.convert_to_ast(code_location),
}
# invalid local configurations, driver launch should fail with any of these
s3_conf_empty = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import os

from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest
from data_processing.test_support.launch.transform_test import (
AbstractTransformLauncherTest,
)
from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing_spark.test_support.transform import NOOPSparkTransformConfiguration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@

import pyarrow as pa
from data_processing.data_access import DataAccess, DataAccessFactory
from data_processing.transform import (
AbstractBinaryTransform,
TransformConfiguration,
)
from data_processing.transform import AbstractBinaryTransform, TransformConfiguration
from data_processing.utils import CLIArgumentProvider, TransformUtils, str2bool


Expand Down
5 changes: 3 additions & 2 deletions transforms/universal/doc_id/spark/src/doc_id_local_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
import os
import sys

from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing.utils import ParamsUtils
from data_processing_spark.runtime.spark import SparkTransformLauncher
from doc_id_transform_spark import (
DocIDSparkTransformConfiguration,
doc_column_name_cli_param,
hash_column_name_cli_param,
int_column_name_cli_param,
DocIDSparkTransformConfiguration)
)


# create parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import pyarrow as pa
from data_processing.transform import AbstractTableTransform, TransformConfiguration
from data_processing.utils import CLIArgumentProvider, TransformUtils
from data_processing_spark.transform import SparkTransformRuntimeConfiguration
from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing_spark.transform import SparkTransformRuntimeConfiguration


short_name = "doc_id"
cli_prefix = f"{short_name}_"
Expand Down Expand Up @@ -147,7 +148,6 @@ def __init__(self):
super().__init__(transform_config=DocIDTransformConfiguration())



if __name__ == "__main__":
launcher = SparkTransformLauncher(DocIDSparkTransformConfiguration())
launcher.launch()
92 changes: 45 additions & 47 deletions transforms/universal/doc_id/spark/test-data/expected/metadata.json
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
{
"pipeline": "pipeline_id",
"job details": {
"job category": "preprocessing",
"job name": "doc_id",
"job type": "spark",
"job id": "job_id",
"start_time": "2024-08-03 22:04:58",
"end_time": "2024-08-03 22:05:15",
"status": "success"
},
"code": {
"github": "github",
"commit_hash": "12345",
"path": "path"
},
"job_input_params": {
"doc_column": "contents",
"hash_column": "hash_column",
"int_column": "int_id_column",
"checkpointing": false,
"max_files": -1,
"random_samples": -1,
"files_to_use": [
".parquet"
]
},
"execution_stats": {
"execution time, min": 0.29759878317515054
},
"job_output_stats": {
"source_size": 36132,
"result_size": 36668,
"result_doc_count": 5,
"source_files": 1,
"result_files": 1,
"processing_time": 0.08469605445861816,
"source_doc_count": 5
},
"source": {
"name": "/Users/borisl/Projects/data-prep-kit/transforms/universal/doc_id/spark/test-data/input",
"type": "path"
},
"target": {
"name": "/Users/borisl/Projects/data-prep-kit/transforms/universal/doc_id/spark/output",
"type": "path"
}
}
"pipeline": "pipeline_id",
"job details": {
"job category": "preprocessing",
"job name": "doc_id",
"job type": "spark",
"job id": "job_id",
"start_time": "2024-08-03 22:04:58",
"end_time": "2024-08-03 22:05:15",
"status": "success"
},
"code": {
"github": "github",
"commit_hash": "12345",
"path": "path"
},
"job_input_params": {
"doc_column": "contents",
"hash_column": "hash_column",
"int_column": "int_id_column",
"checkpointing": false,
"max_files": -1,
"random_samples": -1,
"files_to_use": [".parquet"]
},
"execution_stats": {
"execution time, min": 0.29759878317515054
},
"job_output_stats": {
"source_size": 36132,
"result_size": 36668,
"result_doc_count": 5,
"source_files": 1,
"result_files": 1,
"processing_time": 0.08469605445861816,
"source_doc_count": 5
},
"source": {
"name": "/Users/borisl/Projects/data-prep-kit/transforms/universal/doc_id/spark/test-data/input",
"type": "path"
},
"target": {
"name": "/Users/borisl/Projects/data-prep-kit/transforms/universal/doc_id/spark/output",
"type": "path"
}
}
6 changes: 4 additions & 2 deletions transforms/universal/doc_id/spark/test/test_doc_id_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import os

from data_processing.test_support.launch.transform_test import AbstractTransformLauncherTest
from data_processing.test_support.launch.transform_test import (
AbstractTransformLauncherTest,
)
from data_processing_spark.runtime.spark import SparkTransformLauncher
from doc_id_transform_spark import (
DocIDSparkTransformConfiguration,
Expand Down Expand Up @@ -40,4 +42,4 @@ def get_test_transform_fixtures(self) -> list[tuple]:
}

fixtures.append((launcher, transform_config, basedir + "/input", basedir + "/expected"))
return fixtures
return fixtures
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import os
import sys

from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing.utils import ParamsUtils
from data_processing_spark.runtime.spark import SparkTransformLauncher
from filter_transform import (
filter_columns_to_drop_cli_param,
filter_criteria_cli_param,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
# limitations under the License.
################################################################################

from filter_transform import FilterTransformConfiguration
from data_processing.utils import get_logger
from data_processing_spark.runtime.spark import SparkTransformLauncher
from data_processing_spark.transform import SparkTransformRuntimeConfiguration
from filter_transform import FilterTransformConfiguration


logger = get_logger(__name__)


class FilterSparkTransformConfiguration(SparkTransformRuntimeConfiguration):
"""
Implements the SparkTransformConfiguration for NOOP as required by the PythonTransformLauncher.
Expand Down
Loading

0 comments on commit 625132c

Please sign in to comment.