From f5df937ca5a9426bd25190a84c1713c414181c77 Mon Sep 17 00:00:00 2001 From: Evgueni Ovtchinnikov Date: Wed, 7 Aug 2024 14:38:37 +0000 Subject: [PATCH 1/2] got rid of code overlaps in computing gradients of objective functions --- src/xSTIR/cSTIR/cstir.cpp | 42 +++++----------------- src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h | 29 ++++++++++----- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/src/xSTIR/cSTIR/cstir.cpp b/src/xSTIR/cSTIR/cstir.cpp index 7757d665e..9824c8530 100644 --- a/src/xSTIR/cSTIR/cstir.cpp +++ b/src/xSTIR/cSTIR/cstir.cpp @@ -1198,26 +1198,13 @@ void* cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset) { try { - ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f); + xSTIR_ObjFun3DF& fun = objectFromHandle(ptr_f); STIRImageData& id = objectFromHandle(ptr_i); - Image3DF& image = id.data(); - STIRImageData* ptr_id = new STIRImageData(image); - shared_ptr sptr(ptr_id); - Image3DF& grad = sptr->data(); - if (subset >= 0) - fun.compute_sub_gradient(grad, image, subset); - else { - int nsub = fun.get_num_subsets(); - grad.fill(0.0); - STIRImageData* ptr_id = new STIRImageData(image); - shared_ptr sptr_sub(ptr_id); - Image3DF& subgrad = sptr_sub->data(); - for (int sub = 0; sub < nsub; sub++) { - fun.compute_sub_gradient(subgrad, image, sub); - grad += subgrad; - } - } - return newObjectHandle(sptr); + STIRImageData* ptr_gd = new STIRImageData(id); + shared_ptr sptr_gd(ptr_gd); + STIRImageData& gd = *sptr_gd; + fun.compute_gradient(id, subset, gd); + return newObjectHandle(sptr_gd); } CATCH; } @@ -1227,23 +1214,10 @@ void* cSTIR_computeObjectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset, void* ptr_g) { try { - ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f); + xSTIR_ObjFun3DF& fun = objectFromHandle(ptr_f); STIRImageData& id = objectFromHandle(ptr_i); STIRImageData& gd = objectFromHandle(ptr_g); - Image3DF& image = id.data(); - Image3DF& grad = gd.data(); - if (subset >= 0) - fun.compute_sub_gradient(grad, image, subset); - else { - int nsub = fun.get_num_subsets(); - grad.fill(0.0); - shared_ptr sptr_sub(new STIRImageData(image)); - Image3DF& subgrad = sptr_sub->data(); - for (int sub = 0; sub < nsub; sub++) { - fun.compute_sub_gradient(subgrad, image, sub); - grad += subgrad; - } - } + fun.compute_gradient(id, subset, gd); return (void*) new DataHandle; } CATCH; diff --git a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h index 649c5fd51..973ddef14 100644 --- a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h +++ b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h @@ -1105,11 +1105,28 @@ The actual algorithm is described in } }; - class xSTIR_GeneralisedObjectiveFunction3DF : - public stir::GeneralisedObjectiveFunction < Image3DF > { + class xSTIR_GeneralisedObjectiveFunction3DF : public ObjectiveFunction3DF { public: + void compute_gradient(const STIRImageData& id, int subset, STIRImageData& gd) + { + const Image3DF& image = id.data(); + Image3DF& grad = gd.data(); + if (subset >= 0) + compute_sub_gradient(grad, image, subset); + else { + int nsub = get_num_subsets(); + grad.fill(0.0); + shared_ptr sptr_sub(new STIRImageData(image)); + Image3DF& subgrad = sptr_sub->data(); + for (int sub = 0; sub < nsub; sub++) { + compute_sub_gradient(subgrad, image, sub); + grad += subgrad; + } + } + } + void multiply_with_Hessian(Image3DF& output, const Image3DF& curr_image_est, - const Image3DF& input, const int subset) const + const Image3DF& input, const int subset) const { output.fill(0.0); if (subset >= 0) @@ -1120,13 +1137,9 @@ The actual algorithm is described in } } } - -// bool post_process() { -// return post_processing(); -// } }; - //typedef xSTIR_GeneralisedObjectiveFunction3DF ObjectiveFunction3DF; + typedef xSTIR_GeneralisedObjectiveFunction3DF xSTIR_ObjFun3DF; class xSTIR_PoissonLogLikelihoodWithLinearModelForMeanAndProjData3DF : public stir::PoissonLogLikelihoodWithLinearModelForMeanAndProjData < Image3DF > { From 8068e2307d5b29d9fda59670da24c736d202bea2 Mon Sep 17 00:00:00 2001 From: Evgueni Ovtchinnikov Date: Tue, 13 Aug 2024 12:48:32 +0000 Subject: [PATCH 2/2] slimmed C interface for computing gradients of objective functions --- src/xSTIR/cSTIR/cstir.cpp | 10 ++++------ src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h | 5 +++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/xSTIR/cSTIR/cstir.cpp b/src/xSTIR/cSTIR/cstir.cpp index 9824c8530..d74be8de7 100644 --- a/src/xSTIR/cSTIR/cstir.cpp +++ b/src/xSTIR/cSTIR/cstir.cpp @@ -1198,12 +1198,10 @@ void* cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset) { try { - xSTIR_ObjFun3DF& fun = objectFromHandle(ptr_f); - STIRImageData& id = objectFromHandle(ptr_i); - STIRImageData* ptr_gd = new STIRImageData(id); - shared_ptr sptr_gd(ptr_gd); - STIRImageData& gd = *sptr_gd; - fun.compute_gradient(id, subset, gd); + auto& fun = objectFromHandle(ptr_f); + auto& id = objectFromHandle(ptr_i); + auto sptr_gd = std::make_shared(id); + fun.compute_gradient(id, subset, *sptr_gd); return newObjectHandle(sptr_gd); } CATCH; diff --git a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h index 973ddef14..742773015 100644 --- a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h +++ b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h @@ -1107,6 +1107,11 @@ The actual algorithm is described in class xSTIR_GeneralisedObjectiveFunction3DF : public ObjectiveFunction3DF { public: + //! computes the gradientof an objective function + /*! if the subset number is non-negative, computes the gradient of + this objective function for that subset, otherwise computes + the sum of gradients for all subsets + */ void compute_gradient(const STIRImageData& id, int subset, STIRImageData& gd) { const Image3DF& image = id.data();