From 4ce87253b63fd5d6b6af35f71eaffa09d0e1f3c7 Mon Sep 17 00:00:00 2001 From: Jose Esparza <28990958+pebeto@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:34:44 -0500 Subject: [PATCH] Adding lock to control experiment creation during multiprocessing --- Project.toml | 4 +++- src/base.jl | 39 +++++++++++++++++++++++---------------- test/multiprocessing.jl | 38 ++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 +++- 4 files changed, 67 insertions(+), 18 deletions(-) create mode 100644 test/multiprocessing.jl diff --git a/Project.toml b/Project.toml index bf446bc..2ca179f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.4.2" MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f" [compat] MLFlowClient = "0.5.1" @@ -18,8 +19,9 @@ julia = "1.6" MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"] +test = ["MLFlowClient", "MLJDecisionTreeInterface", "MLJModels", "MLJTuning", "StatisticalMeasures", "Test"] diff --git a/src/base.jl b/src/base.jl index 82413a0..06d0c55 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,22 +1,29 @@ +LOG_EVALUATION_LOCK = ReentrantLock() + function log_evaluation(logger::Logger, performance_evaluation) - experiment = getorcreateexperiment(logger.service, logger.experiment_name; - artifact_location=logger.artifact_location) - run = createrun(logger.service, experiment; - tags=[ - Dict( - "key" => "resampling", - "value" => string(performance_evaluation.resampling) - ), - Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)), - Dict("key" => "model type", "value" => name(performance_evaluation.model)), - ] - ) + lock(LOG_EVALUATION_LOCK) + try + experiment = getorcreateexperiment(logger.service, logger.experiment_name; + artifact_location=logger.artifact_location) + run = createrun(logger.service, experiment; + tags=[ + Dict( + "key" => "resampling", + "value" => string(performance_evaluation.resampling) + ), + Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)), + Dict("key" => "model type", "value" => name(performance_evaluation.model)), + ] + ) - logmodelparams(logger.service, run, performance_evaluation.model) - logmachinemeasures(logger.service, run, performance_evaluation.measure, - performance_evaluation.measurement) + logmodelparams(logger.service, run, performance_evaluation.model) + logmachinemeasures(logger.service, run, performance_evaluation.measure, + performance_evaluation.measurement) - updaterun(logger.service, run, "FINISHED") + updaterun(logger.service, run, "FINISHED") + finally + unlock(LOG_EVALUATION_LOCK) + end end function save(logger::Logger, machine:: Machine) diff --git a/test/multiprocessing.jl b/test/multiprocessing.jl new file mode 100644 index 0000000..9ffbefe --- /dev/null +++ b/test/multiprocessing.jl @@ -0,0 +1,38 @@ +@testset verbose = true "multiprocessing" begin + logger = MLJFlow.Logger(ENV["MLFLOW_URI"]; + experiment_name="MLJFlow multiprocessing tests", + artifact_location="/tmp/mlj-test") + + X, y = make_moons(100) + DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree + + model = DecisionTreeClassifier() + r = range(model, :max_depth, lower=1, upper=6) + + function test_tuned_model(acceleration_method) + tuned_model = TunedModel( + model=model, + range=r, + logger=logger, + acceleration=acceleration_method, + n=100, + ) + tuned_model_mach = machine(tuned_model, X, y) + fit!(tuned_model_mach) + + experiment = getorcreateexperiment(logger.service, logger.experiment_name) + runs = searchruns(logger.service, experiment) + + @assert length(runs) == 100 + + deleteexperiment(logger.service, experiment) + end + + @testset "log_evaluation_with_cpu_threads" begin + test_tuned_model(CPUThreads()) + end + + @testset "log_evaluation_with_cpu_processes" begin + test_tuned_model(CPUProcesses()) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index fd090ce..5122038 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,11 @@ using Test +using .Threads using MLJFlow using MLJBase using MLJModels +using MLJTuning using MLFlowClient using MLJModelInterface using StatisticalMeasures @@ -21,4 +23,4 @@ end include("base.jl") include("types.jl") include("service.jl") - +include("multiprocessing.jl")