From 15ec0c735cba2780d3da37f4e1e3b5d4a969b4ff Mon Sep 17 00:00:00 2001 From: uramirez8707 <49168881+uramirez8707@users.noreply.github.com> Date: Thu, 16 Jan 2025 16:24:07 -0500 Subject: [PATCH] fix: improve modern diag manager performance (#1634) --- diag_manager/fms_diag_object.F90 | 13 + diag_manager/fms_diag_reduction_methods.F90 | 21 ++ .../include/fms_diag_reduction_methods.inc | 234 ++++++++++++++---- .../include/fms_diag_reduction_methods_r4.fh | 11 +- .../include/fms_diag_reduction_methods_r8.fh | 11 +- 5 files changed, 239 insertions(+), 51 deletions(-) diff --git a/diag_manager/fms_diag_object.F90 b/diag_manager/fms_diag_object.F90 index dad97e13f3..9bb7ef3081 100644 --- a/diag_manager/fms_diag_object.F90 +++ b/diag_manager/fms_diag_object.F90 @@ -61,6 +61,8 @@ module fms_diag_object_mod type(fmsDiagField_type), allocatable :: FMS_diag_fields(:) !< Array of diag fields type(fmsDiagOutputBuffer_type), allocatable :: FMS_diag_output_buffers(:) !< array of output buffer objects !! one for each variable in the diag_table.yaml + logical, private :: data_was_send !< True if send_data has been successfully called for at least one variable + !< diag_send_complete does nothing if it is .false. integer, private :: registered_buffers = 0 !< number of registered buffers, per dimension class(fmsDiagAxisContainer_type), allocatable :: diag_axis(:) !< Array of diag_axis integer, private :: registered_variables !< Number of registered variables @@ -144,6 +146,7 @@ subroutine fms_diag_object_init (this,diag_subset_output, time_init) this%buffers_initialized =fms_diag_output_buffer_init(this%FMS_diag_output_buffers,SIZE(diag_yaml%get_diag_fields())) this%registered_variables = 0 this%registered_axis = 0 + this%data_was_send = .false. this%initialized = .true. #else call mpp_error("fms_diag_object_init",& @@ -657,6 +660,8 @@ subroutine fms_diag_accept_data (this, diag_field_id, field_data, mask, rmask, & main_if: if (buffer_the_data) then !> Only 1 thread allocates the output buffer and sets set_math_needs_to_be_done !$omp critical + !< Let diag_send_complete that there is new data to procress + if (.not. this%data_was_send) this%data_was_send = .true. !< These set_* calls need to be done inside an omp_critical to avoid any race conditions !! and allocation issues @@ -686,6 +691,9 @@ subroutine fms_diag_accept_data (this, diag_field_id, field_data, mask, rmask, & is, js, ks, ie, je, ke) else + !< Let diag_send_complete that there is new data to procress + if (.not. this%data_was_send) this%data_was_send = .true. + !< At this point if we are no longer in an openmp region or running with 1 thread !! so it is safe to have these set_* calls if(has_halos) call this%FMS_diag_fields(diag_field_id)%set_halo_present() @@ -783,8 +791,13 @@ subroutine fms_diag_send_complete(this, time_step) #ifndef use_yaml CALL MPP_ERROR(FATAL,"You can not use the modern diag manager without compiling with -Duse_yaml") #else + !< Go away if there is no new data + if (.not. this%data_was_send) return + call this%do_buffer_math() call this%fms_diag_do_io() + + this%data_was_send = .false. #endif end subroutine fms_diag_send_complete diff --git a/diag_manager/fms_diag_reduction_methods.F90 b/diag_manager/fms_diag_reduction_methods.F90 index 86fe98aedf..802d251377 100644 --- a/diag_manager/fms_diag_reduction_methods.F90 +++ b/diag_manager/fms_diag_reduction_methods.F90 @@ -71,6 +71,27 @@ module fms_diag_reduction_methods_mod module procedure sum_update_done_r4, sum_update_done_r8 end interface + !> @brief Updates the buffer for any reductions that involve summation + !! (ie. time_sum, avg, rms, pow) + !! In this case the mask is present + interface sum_mask + module procedure sum_mask_r4, sum_mask_r8 + end interface + + !> @brief Updates the buffer for any reductions that involve summation + !! (ie. time_sum, avg, rms, pow) + !! In this case the mask is present and it varies over time + interface sum_mask_variant + module procedure sum_mask_variant_r4, sum_mask_variant_r8 + end interface sum_mask_variant + + !> @brief Updates the buffer for any reductions that involve summation + !! (ie. time_sum, avg, rms, pow) + !! In this case the mask is not present + interface sum_no_mask + module procedure sum_no_mask_r4, sum_no_mask_r8 + end interface sum_no_mask + contains !> @brief Checks improper combinations of is, ie, js, and je. diff --git a/diag_manager/include/fms_diag_reduction_methods.inc b/diag_manager/include/fms_diag_reduction_methods.inc index 52bd7d9a9a..b3c28c6bfd 100644 --- a/diag_manager/include/fms_diag_reduction_methods.inc +++ b/diag_manager/include/fms_diag_reduction_methods.inc @@ -235,13 +235,7 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m integer ,optional, intent(in) :: pow !< Used for pow(er) reduction, !! calculates field_data^pow before adding to buffer - integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for - !! the input buffer - integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for - !! the output buffer - integer :: i, j, k, l !< For looping real(FMS_TRM_KIND_) :: weight_scale !< local copy of optional weight - integer :: pow_loc !> local copy of optional pow value (set if using pow reduction) integer, parameter :: kindl = FMS_TRM_KIND_ !< real kind size as set by macro integer :: diurnal !< diurnal index to indicate which daily section is updated !! will be 1 unless using a diurnal reduction @@ -252,18 +246,49 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m weight_scale = 1.0_kindl endif - if(present(pow)) then - pow_loc = pow - else - pow_loc = 1.0_kindl - endif - if(diurnal_section .lt. 0) then diurnal = 1 else diurnal = diurnal_section endif + if (is_masked) then + if (mask_variant) then + ! Mask changes over time so the weight is an array + call sum_mask_variant(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, weight_scale, pow) + else + call sum_mask(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, & + missing_value, weight_scale, pow) + endif + else + call sum_no_mask(data_out, data_in, weight_sum, bounds_in, bounds_out, diurnal, weight_scale, pow) + endif +end subroutine DO_TIME_SUM_UPDATE_ + +subroutine SUM_MASK_(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, missing_value, & + weight_scale, pow) + real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data + real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with + real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object + type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion + !! of the input buffer + type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion + !! of the output buffer + logical, intent(in) :: mask(:,:,:,:) !< mask + integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is + !! updated will be 1 unless using a diurnal reduction + real(FMS_TRM_KIND_), intent(in) :: missing_value !< Missing_value for data points that are masked + real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use + integer ,optional, intent(in) :: pow !< Used for pow(er) reduction, + !! calculates field_data^pow before adding to buffer + + integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for + !! the input buffer + integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for + !! the output buffer + integer :: pow_loc !> local copy of optional pow value (set if using pow reduction) + integer :: i, j, k, l !< For looping + is_out = bounds_out%get_imin() ie_out = bounds_out%get_imax() js_out = bounds_out%get_jmin() @@ -278,56 +303,167 @@ subroutine DO_TIME_SUM_UPDATE_(data_out, weight_sum, data_in, mask, is_masked, m ks_in = bounds_in%get_kmin() ke_in = bounds_in%get_kmax() - !> Seperated this loops for performance. If is_masked = .false. (i.e "mask" and "rmask" were never passed in) - !! then mask will always be .True. so the if (mask) is redudant. - ! TODO check if performance gain by not doing weight and pow if not needed - if (is_masked) then - if (mask_variant) then - ! Mask changes over time so the weight is an array - do k = 0, ke_out - ks_out - do j = 0, je_out - js_out - do i = 0, ie_out - is_out - where (mask(is_in + i, js_in + j, ks_in + k, :)) - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & - + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc - !Increase the weight sum for the grid point that was not masked - weight_sum(is_out + i, js_out + j, ks_out + k, :) = & - weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale - endwhere - enddo + weight_sum = weight_sum + weight_scale + if (present(pow)) then + do k = 0, ke_out - ks_out + do j = 0, je_out - js_out + do i = 0, ie_out - is_out + where (mask(is_in + i, js_in + j, ks_in + k, :)) + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow + elsewhere + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value + endwhere enddo enddo - else - weight_sum = weight_sum + weight_scale - do k = 0, ke_out - ks_out - do j = 0, je_out - js_out - do i = 0, ie_out - is_out - where (mask(is_in + i, js_in + j, ks_in + k, :)) - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & - + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc - elsewhere - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value - endwhere - enddo + enddo + else + do k = 0, ke_out - ks_out + do j = 0, je_out - js_out + do i = 0, ie_out - is_out + where (mask(is_in + i, js_in + j, ks_in + k, :)) + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) + elsewhere + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = missing_value + endwhere enddo enddo - endif + enddo + endif +end subroutine SUM_MASK_ + +subroutine SUM_MASK_VARIANT_(data_out, data_in, weight_sum, bounds_in, bounds_out, mask, diurnal, weight_scale, pow) + real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data + real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with + real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object + type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion + !! of the input buffer + type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion + !! of the output buffer + logical, intent(in) :: mask(:,:,:,:) !< mask + integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is + !! updated will be 1 unless using a diurnal reduction + real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use + integer ,optional, intent(in) :: pow !< Used for pow(er) reduction, + !! calculates field_data^pow before adding to buffer + + integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for + !! the input buffer + integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for + !! the output buffer + integer :: pow_loc !> local copy of optional pow value (set if using pow reduction) + integer :: i, j, k, l !< For looping + + is_out = bounds_out%get_imin() + ie_out = bounds_out%get_imax() + js_out = bounds_out%get_jmin() + je_out = bounds_out%get_jmax() + ks_out = bounds_out%get_kmin() + ke_out = bounds_out%get_kmax() + + is_in = bounds_in%get_imin() + ie_in = bounds_in%get_imax() + js_in = bounds_in%get_jmin() + je_in = bounds_in%get_jmax() + ks_in = bounds_in%get_kmin() + ke_in = bounds_in%get_kmax() + + if (present(pow)) then + do k = 0, ke_out - ks_out + do j = 0, je_out - js_out + do i = 0, ie_out - is_out + where (mask(is_in + i, js_in + j, ks_in + k, :)) + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow + + !Increase the weight sum for the grid point that was not masked + weight_sum(is_out + i, js_out + j, ks_out + k, :) = & + weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale + endwhere + enddo + enddo + enddo else - weight_sum = weight_sum + weight_scale - ! doesn't need to loop through l if no mask, just sums the 1d slices + do k = 0, ke_out - ks_out + do j = 0, je_out - js_out + do i = 0, ie_out - is_out + where (mask(is_in + i, js_in + j, ks_in + k, :)) + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) + + !Increase the weight sum for the grid point that was not masked + weight_sum(is_out + i, js_out + j, ks_out + k, :) = & + weight_sum(is_out + i, js_out + j, ks_out + k, :) + weight_scale + endwhere + enddo + enddo + enddo + endif +end subroutine SUM_MASK_VARIANT_ + +subroutine SUM_NO_MASK_(data_out, data_in, weight_sum, bounds_in, bounds_out, diurnal, weight_scale, pow) + real(FMS_TRM_KIND_), intent(inout) :: data_out(:,:,:,:,:) !< output data + real(FMS_TRM_KIND_), intent(in) :: data_in(:,:,:,:) !< data to update the buffer with + real(r8_kind), intent(inout) :: weight_sum(:,:,:,:) !< Sum of weights from the output buffer object + type(fmsDiagIbounds_type), intent(in) :: bounds_in !< indices indicating the correct portion + !! of the input buffer + type(fmsDiagIbounds_type), intent(in) :: bounds_out !< indices indicating the correct portion + !! of the output buffer + integer, intent(in) :: diurnal !< diurnal index to indicate which daily section is + !! updated will be 1 unless using a diurnal reduction + real(FMS_TRM_KIND_), intent(in) :: weight_scale !< weight scale to use + integer ,optional, intent(in) :: pow !< Used for pow(er) reduction, + !! calculates field_data^pow before adding to buffer + + integer :: is_in, ie_in, js_in, je_in, ks_in, ke_in !< Starting and ending indices of each dimention for + !! the input buffer + integer :: is_out, ie_out, js_out, je_out, ks_out, ke_out !< Starting and ending indices of each dimention for + !! the output buffer + integer :: i, j, k, l !< For looping + + is_out = bounds_out%get_imin() + ie_out = bounds_out%get_imax() + js_out = bounds_out%get_jmin() + je_out = bounds_out%get_jmax() + ks_out = bounds_out%get_kmin() + ke_out = bounds_out%get_kmax() + + is_in = bounds_in%get_imin() + ie_in = bounds_in%get_imax() + js_in = bounds_in%get_jmin() + je_in = bounds_in%get_jmax() + ks_in = bounds_in%get_kmin() + ke_in = bounds_in%get_kmax() + + weight_sum = weight_sum + weight_scale + + if (present(pow)) then do k = 0, ke_out - ks_out do j = 0, je_out - js_out do i = 0, ie_out - is_out data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & - data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & - + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow_loc + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) ** pow + enddo + enddo + enddo + else + do k = 0, ke_out - ks_out + do j = 0, je_out - js_out + do i = 0, ie_out - is_out + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) = & + data_out(is_out + i, js_out + j, ks_out + k, :, diurnal) & + + (data_in(is_in +i, js_in + j, ks_in + k, :) * weight_scale) enddo enddo enddo endif -end subroutine DO_TIME_SUM_UPDATE_ +end subroutine SUM_NO_MASK_ !> To be called with diag_send_complete, finishes reductions !! Just divides the buffer by the counter array(which is just the sum of the weights used in the buffer's reduction) diff --git a/diag_manager/include/fms_diag_reduction_methods_r4.fh b/diag_manager/include/fms_diag_reduction_methods_r4.fh index 04a4f4f0ba..1fe2b05539 100644 --- a/diag_manager/include/fms_diag_reduction_methods_r4.fh +++ b/diag_manager/include/fms_diag_reduction_methods_r4.fh @@ -41,7 +41,16 @@ #undef SUM_UPDATE_DONE_ #define SUM_UPDATE_DONE_ sum_update_done_r4 +#undef SUM_MASK_ +#define SUM_MASK_ sum_mask_r4 + +#undef SUM_NO_MASK_ +#define SUM_NO_MASK_ sum_no_mask_r4 + +#undef SUM_MASK_VARIANT_ +#define SUM_MASK_VARIANT_ sum_mask_variant_r4 + #include "fms_diag_reduction_methods.inc" !> @} -! close documentation grouping \ No newline at end of file +! close documentation grouping diff --git a/diag_manager/include/fms_diag_reduction_methods_r8.fh b/diag_manager/include/fms_diag_reduction_methods_r8.fh index bff7f44ac2..2e60a64e73 100644 --- a/diag_manager/include/fms_diag_reduction_methods_r8.fh +++ b/diag_manager/include/fms_diag_reduction_methods_r8.fh @@ -41,7 +41,16 @@ #undef SUM_UPDATE_DONE_ #define SUM_UPDATE_DONE_ sum_update_done_r8 +#undef SUM_MASK_ +#define SUM_MASK_ sum_mask_r8 + +#undef SUM_NO_MASK_ +#define SUM_NO_MASK_ sum_no_mask_r8 + +#undef SUM_MASK_VARIANT_ +#define SUM_MASK_VARIANT_ sum_mask_variant_r8 + #include "fms_diag_reduction_methods.inc" !> @} -! close documentation grouping \ No newline at end of file +! close documentation grouping