diff --git a/src/xSTIR/cSTIR/cstir.cpp b/src/xSTIR/cSTIR/cstir.cpp index 7757d665e..d74be8de7 100644 --- a/src/xSTIR/cSTIR/cstir.cpp +++ b/src/xSTIR/cSTIR/cstir.cpp @@ -1198,26 +1198,11 @@ void* cSTIR_objectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset) { try { - ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f); - STIRImageData& id = objectFromHandle(ptr_i); - Image3DF& image = id.data(); - STIRImageData* ptr_id = new STIRImageData(image); - shared_ptr sptr(ptr_id); - Image3DF& grad = sptr->data(); - if (subset >= 0) - fun.compute_sub_gradient(grad, image, subset); - else { - int nsub = fun.get_num_subsets(); - grad.fill(0.0); - STIRImageData* ptr_id = new STIRImageData(image); - shared_ptr sptr_sub(ptr_id); - Image3DF& subgrad = sptr_sub->data(); - for (int sub = 0; sub < nsub; sub++) { - fun.compute_sub_gradient(subgrad, image, sub); - grad += subgrad; - } - } - return newObjectHandle(sptr); + auto& fun = objectFromHandle(ptr_f); + auto& id = objectFromHandle(ptr_i); + auto sptr_gd = std::make_shared(id); + fun.compute_gradient(id, subset, *sptr_gd); + return newObjectHandle(sptr_gd); } CATCH; } @@ -1227,23 +1212,10 @@ void* cSTIR_computeObjectiveFunctionGradient(void* ptr_f, void* ptr_i, int subset, void* ptr_g) { try { - ObjectiveFunction3DF& fun = objectFromHandle< ObjectiveFunction3DF>(ptr_f); + xSTIR_ObjFun3DF& fun = objectFromHandle(ptr_f); STIRImageData& id = objectFromHandle(ptr_i); STIRImageData& gd = objectFromHandle(ptr_g); - Image3DF& image = id.data(); - Image3DF& grad = gd.data(); - if (subset >= 0) - fun.compute_sub_gradient(grad, image, subset); - else { - int nsub = fun.get_num_subsets(); - grad.fill(0.0); - shared_ptr sptr_sub(new STIRImageData(image)); - Image3DF& subgrad = sptr_sub->data(); - for (int sub = 0; sub < nsub; sub++) { - fun.compute_sub_gradient(subgrad, image, sub); - grad += subgrad; - } - } + fun.compute_gradient(id, subset, gd); return (void*) new DataHandle; } CATCH; diff --git a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h index 649c5fd51..742773015 100644 --- a/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h +++ b/src/xSTIR/cSTIR/include/sirf/STIR/stir_x.h @@ -1105,11 +1105,33 @@ The actual algorithm is described in } }; - class xSTIR_GeneralisedObjectiveFunction3DF : - public stir::GeneralisedObjectiveFunction < Image3DF > { + class xSTIR_GeneralisedObjectiveFunction3DF : public ObjectiveFunction3DF { public: + //! computes the gradientof an objective function + /*! if the subset number is non-negative, computes the gradient of + this objective function for that subset, otherwise computes + the sum of gradients for all subsets + */ + void compute_gradient(const STIRImageData& id, int subset, STIRImageData& gd) + { + const Image3DF& image = id.data(); + Image3DF& grad = gd.data(); + if (subset >= 0) + compute_sub_gradient(grad, image, subset); + else { + int nsub = get_num_subsets(); + grad.fill(0.0); + shared_ptr sptr_sub(new STIRImageData(image)); + Image3DF& subgrad = sptr_sub->data(); + for (int sub = 0; sub < nsub; sub++) { + compute_sub_gradient(subgrad, image, sub); + grad += subgrad; + } + } + } + void multiply_with_Hessian(Image3DF& output, const Image3DF& curr_image_est, - const Image3DF& input, const int subset) const + const Image3DF& input, const int subset) const { output.fill(0.0); if (subset >= 0) @@ -1120,13 +1142,9 @@ The actual algorithm is described in } } } - -// bool post_process() { -// return post_processing(); -// } }; - //typedef xSTIR_GeneralisedObjectiveFunction3DF ObjectiveFunction3DF; + typedef xSTIR_GeneralisedObjectiveFunction3DF xSTIR_ObjFun3DF; class xSTIR_PoissonLogLikelihoodWithLinearModelForMeanAndProjData3DF : public stir::PoissonLogLikelihoodWithLinearModelForMeanAndProjData < Image3DF > {