From 49aa6ecb28541e8ec95f25e142012634a7c8a3ba Mon Sep 17 00:00:00 2001 From: jslhcl Date: Fri, 22 Mar 2024 14:36:55 -0700 Subject: [PATCH 1/5] support GroupQueryAttention --- CMakeLists.txt | 9 +- cmake/ext_cuda.cmake | 16 + cmake/externals/cutlass.cmake | 11 + includes/custom_op_lite.h | 6 + includes/onnxruntime_customop.hpp | 14 + includes/onnxruntime_no_customop.h | 12 +- operators/contrib/attention_common.h | 148 ++ operators/contrib/contrib.cc | 4 +- .../cuda/cutlass_fmha/fmha_launch_template.h | 276 ++++ .../contrib/cuda/cutlass_fmha/fmha_sm50.cu | 22 + .../contrib/cuda/cutlass_fmha/fmha_sm70.cu | 22 + .../contrib/cuda/cutlass_fmha/fmha_sm75.cu | 22 + .../contrib/cuda/cutlass_fmha/fmha_sm80.cu | 22 + .../memory_efficient_attention.cu | 29 + .../cutlass_fmha/memory_efficient_attention.h | 59 + .../contrib/cuda/flash_attention/block_info.h | 44 + .../contrib/cuda/flash_attention/flash.h | 114 ++ .../contrib/cuda/flash_attention/flash_api.cc | 465 ++++++ .../contrib/cuda/flash_attention/flash_api.h | 92 ++ .../flash_fwd_hdim128_bf16_sm80.cu | 16 + .../flash_fwd_hdim128_fp16_sm80.cu | 16 + .../flash_fwd_hdim160_bf16_sm80.cu | 16 + .../flash_fwd_hdim160_fp16_sm80.cu | 16 + .../flash_fwd_hdim192_bf16_sm80.cu | 16 + .../flash_fwd_hdim192_fp16_sm80.cu | 16 + .../flash_fwd_hdim224_bf16_sm80.cu | 16 + .../flash_fwd_hdim224_fp16_sm80.cu | 16 + .../flash_fwd_hdim256_bf16_sm80.cu | 16 + .../flash_fwd_hdim256_fp16_sm80.cu | 16 + .../flash_fwd_hdim32_bf16_sm80.cu | 16 + .../flash_fwd_hdim32_fp16_sm80.cu | 16 + .../flash_fwd_hdim64_bf16_sm80.cu | 16 + .../flash_fwd_hdim64_fp16_sm80.cu | 16 + .../flash_fwd_hdim96_bf16_sm80.cu | 16 + .../flash_fwd_hdim96_fp16_sm80.cu | 16 + .../cuda/flash_attention/flash_fwd_kernel.h | 1259 +++++++++++++++++ .../flash_fwd_launch_template.h | 294 ++++ .../flash_fwd_split_hdim128_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim128_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim160_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim160_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim192_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim192_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim224_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim224_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim256_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim256_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim32_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim32_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim64_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim64_fp16_sm80.cu | 13 + .../flash_fwd_split_hdim96_bf16_sm80.cu | 13 + .../flash_fwd_split_hdim96_fp16_sm80.cu | 13 + .../cuda/flash_attention/kernel_traits.h | 367 +++++ .../contrib/cuda/flash_attention/softmax.h | 215 +++ .../cuda/flash_attention/static_switch.h | 64 + .../contrib/cuda/flash_attention/utils.h | 499 +++++++ .../contrib/cuda/group_query_attention.h | 489 +++++++ .../cuda/group_query_attention_impl.cu | 661 +++++++++ .../cuda/group_query_attention_impl.cuh | 50 + operators/contrib/cuda/utils.cuh | 5 + test/cuda/test_cudaops.py | 668 ++++++++- 62 files changed, 6415 insertions(+), 7 deletions(-) create mode 100644 cmake/externals/cutlass.cmake create mode 100644 operators/contrib/attention_common.h create mode 100644 operators/contrib/cuda/cutlass_fmha/fmha_launch_template.h create mode 100644 operators/contrib/cuda/cutlass_fmha/fmha_sm50.cu create mode 100644 operators/contrib/cuda/cutlass_fmha/fmha_sm70.cu create mode 100644 operators/contrib/cuda/cutlass_fmha/fmha_sm75.cu create mode 100644 operators/contrib/cuda/cutlass_fmha/fmha_sm80.cu create mode 100644 operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.cu create mode 100644 operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.h create mode 100644 operators/contrib/cuda/flash_attention/block_info.h create mode 100644 operators/contrib/cuda/flash_attention/flash.h create mode 100644 operators/contrib/cuda/flash_attention/flash_api.cc create mode 100644 operators/contrib/cuda/flash_attention/flash_api.h create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim128_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim128_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim160_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim160_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim192_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim192_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim224_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim224_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim256_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim256_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim32_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim32_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim64_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim64_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim96_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_hdim96_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_kernel.h create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_launch_template.h create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu create mode 100644 operators/contrib/cuda/flash_attention/kernel_traits.h create mode 100644 operators/contrib/cuda/flash_attention/softmax.h create mode 100644 operators/contrib/cuda/flash_attention/static_switch.h create mode 100644 operators/contrib/cuda/flash_attention/utils.h create mode 100644 operators/contrib/cuda/group_query_attention.h create mode 100644 operators/contrib/cuda/group_query_attention_impl.cu create mode 100644 operators/contrib/cuda/group_query_attention_impl.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index f67e31e49..4170c473c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,8 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.25) +# Don't let cmake set a default value for CMAKE_CUDA_ARCHITECTURES +cmake_policy(SET CMP0104 OLD) project(onnxruntime_extensions LANGUAGES C CXX) # set(CMAKE_VERBOSE_MAKEFILE ON) @@ -281,6 +283,7 @@ endmacro() if(OCOS_USE_CUDA) include(ext_cuda) + include(cutlass) endif() ####################################################################################################################### @@ -347,7 +350,7 @@ endif() file(GLOB TARGET_SRC_CONTRIB "operators/contrib/*.cc" "operators/contrib/*.h*") if (OCOS_USE_CUDA) - file(GLOB TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*") + file(GLOB_RECURSE TARGET_SRC_CONTRIB_CUDA "operators/contrib/cuda/*.*") list(APPEND TARGET_SRC_CONTRIB ${TARGET_SRC_CONTRIB_CUDA}) endif() list(APPEND TARGET_SRC ${TARGET_SRC_CONTRIB}) @@ -561,6 +564,10 @@ target_include_directories(ocos_operators PUBLIC ${PROJECT_SOURCE_DIR}/base ${PROJECT_SOURCE_DIR}/operators) +if (OCOS_USE_CUDA) + target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) +endif() + set(ocos_libraries) set(OCOS_COMPILE_DEFINITIONS) diff --git a/cmake/ext_cuda.cmake b/cmake/ext_cuda.cmake index ac48dcb84..b115f54ed 100644 --- a/cmake/ext_cuda.cmake +++ b/cmake/ext_cuda.cmake @@ -5,6 +5,14 @@ enable_language(CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) +include(CMakeDependentOption) +cmake_dependent_option(USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32" OFF) +option(USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) + message( STATUS "Turn off flash attention and memory efficient attention since CUDA compiler version < 11.6") + set(USE_FLASH_ATTENTION OFF) + set(USE_MEMORY_EFFICIENT_ATTENTION OFF) +endif() if(NOT CMAKE_CUDA_ARCHITECTURES) @@ -69,3 +77,11 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=unsigned_co set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no_effect\"") add_compile_definitions(USE_CUDA) +if (USE_FLASH_ATTENTION) + message( STATUS "Enable flash attention") + add_compile_definitions(USE_FLASH_ATTENTION) +endif() +if (USE_MEMORY_EFFICIENT_ATTENTION) + message( STATUS "Enable memory efficient attention") + add_compile_definitions(USE_MEMORY_EFFICIENT_ATTENTION) +endif() \ No newline at end of file diff --git a/cmake/externals/cutlass.cmake b/cmake/externals/cutlass.cmake new file mode 100644 index 000000000..24b9bf72e --- /dev/null +++ b/cmake/externals/cutlass.cmake @@ -0,0 +1,11 @@ +include(FetchContent) +FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG v3.1.0 +) + +FetchContent_GetProperties(cutlass) +if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) +endif() diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 784e2b2bd..0e2caa3cb 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -1033,6 +1033,12 @@ struct OrtLiteCustomOp : public OrtCustomOp { return INPUT_OUTPUT_OPTIONAL; }; #endif + +#if ORT_API_VERSION >= 18 + OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t { + return 0; + }; +#endif } const std::string op_name_; diff --git a/includes/onnxruntime_customop.hpp b/includes/onnxruntime_customop.hpp index 6144338a2..a0d965e47 100644 --- a/includes/onnxruntime_customop.hpp +++ b/includes/onnxruntime_customop.hpp @@ -68,6 +68,12 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {}; template struct CustomOp_defined_getInputMemoryType> : std::true_type {}; +template +struct CustomOp_defined_getMayInplace : std::false_type {}; + +template +struct CustomOp_defined_getMayInplace> : std::true_type {}; + template struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { using ComputeFunction = decltype(&CustomOpKernel::Compute); @@ -146,6 +152,14 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { }; } +#if ORT_API_VERSION >= 18 + if constexpr (CustomOp_defined_getMayInplace::value) { + OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t { + return CustomOpKernel::GetMayInplace(input_index, output_index); + }; + } +#endif + OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { if (api == nullptr) { diff --git a/includes/onnxruntime_no_customop.h b/includes/onnxruntime_no_customop.h index 008980be4..6bd9f87e9 100644 --- a/includes/onnxruntime_no_customop.h +++ b/includes/onnxruntime_no_customop.h @@ -36,10 +36,14 @@ class API { static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept { return instance()->CreateMemoryInfo(name, type, id, mem_type, out); } -#if ORT_API_VERSION >= 15 - // Caller is responsible for releasing OrtAllocator object: delete static_cast (allocator) - static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) { - return instance()->KernelContext_GetAllocator(context, mem_info, out); + + static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) { + return instance()->ReleaseMemoryInfo(mem_info); + } + +#if ORT_API_VERSION >= 18 + static OrtStatusPtr KernelContextGetScratchBuffer(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, size_t count_or_bytes, void** out) { + return instance()->KernelContext_GetScratchBuffer(context, mem_info, count_or_bytes, out); } #endif private: diff --git a/operators/contrib/attention_common.h b/operators/contrib/attention_common.h new file mode 100644 index 000000000..6f32bb94b --- /dev/null +++ b/operators/contrib/attention_common.h @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace contrib { + +enum AttentionMaskType { + MASK_NONE, // No mask + MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length + MASK_1D_END_START, // [2 * batch_size] with end positions and start positions + MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length] + MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length] + MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + MASK_UNKNOWN +}; + +enum AttentionQkvFormat { + UNKNOWN, // enum value not set, or depends on qkv projection implementation details + Q_K_V_BNSH, // for non-packed qkv, permuted + Q_K_V_BSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + QKV_BSN3H, // for TRT fused attention, qkv are packed + Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed + Q_K_V_TNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + QKV_TN3H, // for TRT fused attention, qkv are packed and paddings are removed +}; + +enum AttentionKernelType { + AttentionKernel_Unfused, + AttentionKernel_TrtFusedAttention, + AttentionKernel_TrtFlashAttention, + AttentionKernel_TrtFusedCrossAttention, + AttentionKernel_CutlassMemoryEfficientAttention, + AttentionKernel_FlashAttention, + AttentionKernel_Default +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct AttentionParameters { + int batch_size; + int sequence_length; + int kv_sequence_length; // input sequence length of K or V + int past_sequence_length; // sequence length in past state of K or V + int total_sequence_length; // total sequence length of K or V + int max_sequence_length; // max sequence length from 4D mask + int input_hidden_size; // first dimension of weights for input projection + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + int num_splits; + int rotary_embedding; + bool is_unidirectional; + bool past_present_share_buffer; + bool do_rotary; + bool broadcast_res_pos_bias; + bool pass_past_in_kv; + float mask_filter_value; + float scale; + bool use_tf32; + AttentionMaskType mask_type; + AttentionQkvFormat qkv_format; +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct PackedAttentionParameters { + int batch_size; + int sequence_length; + int input_hidden_size; // hidden size of input + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V + int num_heads; + float scale; + int token_count; + bool has_relative_position_bias; + bool broadcast_res_pos_bias; + bool use_tf32; +}; + +// Parameters deduced from node attributes and inputs/outputs. +struct GroupQueryAttentionParameters { + int batch_size; + int sequence_length; // sequence length of input query, key, value + int seqlen_past_kv_cache; // sequence length of past kv tensor + int seqlen_present_kv_cache; // sequence length of present kv tensor + int hidden_size; + int num_heads; + int head_size; + int kv_hidden_size; + int kv_num_heads; + int num_splits; // number of splits for splitkv + bool is_unidirectional; // causal + int local_window_size; + bool kv_share_buffer; + bool is_packed_qkv; + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; + float scale; + AttentionQkvFormat qkv_format; + AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; +}; + +namespace attention { +// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; + +// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; + +// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled). +// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. +constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; + +// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). +constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; + +// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). +constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; + +// Environment variable to enable or disable flash attention. Default is 0 (enabled). +constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; + +// Minimum sequence length to enable memory efficient attention in FP32. +constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; + +// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention +constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; +// Default value for the above setting. +constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; + +// Environment variable to enable loading more KV data in flight in +// DecoderMaskedMultiHeadAttention/DecoderMaskedSelfAttention kernels +constexpr const char* kDecoderMaskedAttentionLoadKVDataInFlight = "ORT_DECODER_MASKED_ATTENTION_LOAD_KV_DATA_IN_FLIGHT"; + +} // namespace attention + +} // namespace contrib diff --git a/operators/contrib/contrib.cc b/operators/contrib/contrib.cc index 39cc02f85..3efb99a51 100644 --- a/operators/contrib/contrib.cc +++ b/operators/contrib/contrib.cc @@ -5,6 +5,7 @@ #ifdef USE_CUDA #include "cuda/fast_gelu.h" +#include "cuda/group_query_attention.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -14,7 +15,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { , CustomCudaStructV2("FastGelu", contrib::FastGelu), #if ORT_API_VERSION >= 16 - + CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), + CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu) #endif diff --git a/operators/contrib/cuda/cutlass_fmha/fmha_launch_template.h b/operators/contrib/cuda/cutlass_fmha/fmha_launch_template.h new file mode 100644 index 000000000..bc4bd278a --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/fmha_launch_template.h @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "memory_efficient_attention.h" +#include "41_fused_multi_head_attention/kernel_forward.h" + +namespace contrib { +namespace cuda { + +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + +template +void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { + using Attention = AttentionKernel; + typename Attention::Params p; + { // set parameters + p.query_ptr = const_cast(reinterpret_cast(params.query)); + p.key_ptr = const_cast(reinterpret_cast(params.key)); + p.value_ptr = const_cast(reinterpret_cast(params.value)); + p.attn_bias_ptr = const_cast(reinterpret_cast(params.attn_bias)); + p.seqstart_q_ptr = params.seqstart_q_ptr; + p.seqstart_k_ptr = params.seqstart_k_ptr; + p.seqlen_k_ptr = params.seqlen_k_ptr; + + p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward + p.output_ptr = reinterpret_cast(params.output); + if (Attention::kNeedsOutputAccumulatorBuffer) { + using Acc = typename Attention::accum_t; + // workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float) + // TODO: ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided"); + p.output_accum_ptr = reinterpret_cast(params.workspace); + } else { + p.output_accum_ptr = nullptr; + } + p.num_heads = params.num_heads; + p.num_batches = params.batch_size; + p.head_dim = params.qk_head_size; + p.head_dim_value = params.v_head_size; + + p.scale = params.scale; + + // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel + p.num_queries = params.sequence_length; + p.num_keys = params.kv_sequence_length; + + if (params.causal) { + p.custom_mask_type = Attention::CausalFromBottomRight; + } + + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } + } + + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; + } + + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + // TODO: ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); + static bool once = [&]() { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + // TODO: ORT_ENFORCE(Attention::check_supported(p)); + kernel_fn<<>>(p); +} + +template +void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { + using AlignedAK = AttentionKernel; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +#pragma warning(disable : 6287) +#endif + // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. + bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && + params.qk_head_size % AlignedAK::kAlignmentK == 0 && + params.v_head_size % AlignedAK::kAlignmentV == 0; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { + LaunchCutlassFmha(params); + })); +} + +template +void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { + if (params.v_head_size <= 64) { + DispatchIsAligned(params); + } else if (params.v_head_size <= 128) { + DispatchIsAligned(params); + } else { + DispatchIsAligned(params); + } +} + +} // namespace cuda +} // namespace contrib + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/fmha_sm50.cu b/operators/contrib/cuda/cutlass_fmha/fmha_sm50.cu new file mode 100644 index 000000000..1900ee46a --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/fmha_sm50.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/fmha_sm70.cu b/operators/contrib/cuda/cutlass_fmha/fmha_sm70.cu new file mode 100644 index 000000000..8cd8d6f89 --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/fmha_sm70.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/fmha_sm75.cu b/operators/contrib/cuda/cutlass_fmha/fmha_sm75.cu new file mode 100644 index 000000000..9454953d9 --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/fmha_sm75.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/fmha_sm80.cu b/operators/contrib/cuda/cutlass_fmha/fmha_sm80.cu new file mode 100644 index 000000000..f5d956fb7 --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/fmha_sm80.cu @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "fmha_launch_template.h" + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.cu b/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.cu new file mode 100644 index 000000000..608b79798 --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.cu @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if USE_MEMORY_EFFICIENT_ATTENTION + +#include "memory_efficient_attention.h" +#include + +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) { + const int32_t& sm = params.sm; + if (sm >= 80) { + run_memory_efficient_attention_sm80(params); + } else if (sm >= 75) { + run_memory_efficient_attention_sm75(params); + } else if (sm >= 70) { + run_memory_efficient_attention_sm70(params); + } else if (sm >= 50) { + run_memory_efficient_attention_sm50(params); + } else { + assert(false); // shall not reach here. + } +} + +} // namespace cuda +} // namespace contrib + +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.h b/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.h new file mode 100644 index 000000000..99188ba01 --- /dev/null +++ b/operators/contrib/cuda/cutlass_fmha/memory_efficient_attention.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#if USE_MEMORY_EFFICIENT_ATTENTION +#include + +namespace contrib { +namespace cuda { + +struct MemoryEfficientAttentionParams { + int32_t sm; + bool is_half; + bool is_kv_bsnh = true; + int32_t batch_size; + int32_t num_heads; + int32_t sequence_length; + int32_t kv_sequence_length; + int32_t max_sequence_length; + int32_t qk_head_size; + int32_t v_head_size; + bool causal; + // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. + bool is_attn_bias_batched; + + float scale; + + int32_t* seqstart_q_ptr; + int32_t* seqstart_k_ptr; + int32_t* seqlen_k_ptr; + + const void* query; // [B, S, N, H] + const void* key; // [B, L, N, H], where L is kv_sequence_length + const void* value; // [B, L, N, H_v] + const void* attn_bias; // [N, S, S*] or null + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + cudaStream_t stream; + + static bool need_workspace(size_t v_head_size, bool is_float) { + return (v_head_size > 128 && !is_float); + } + + bool has_custom_right_padding = false; +}; + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); + +inline bool has_memory_efficient_attention(int32_t sm, bool is_half) { + return sm >= (is_half ? 53 : 50); +} + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params); + +} +} +#endif diff --git a/operators/contrib/cuda/flash_attention/block_info.h b/operators/contrib/cuda/flash_attention/block_info.h new file mode 100644 index 000000000..1ec632658 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/block_info.h @@ -0,0 +1,44 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +namespace flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , + seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])), + actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash diff --git a/operators/contrib/cuda/flash_attention/flash.h b/operators/contrib/cuda/flash_attention/flash.h new file mode 100644 index 000000000..603a6e068 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash.h @@ -0,0 +1,114 @@ +#pragma once +#include + +namespace flash { +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; + + // The number of heads. + int h = 0; + int h_k = 0; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio = 0; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; + + // The stride between rows of O. + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; + + // The pointer to the P matrix. + void* __restrict__ p_ptr = nullptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; + + // The dimensions. + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; + int rotary_dim = 0; + + // The scaling factors for the kernel. + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; + + int* __restrict__ blockmask = nullptr; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr = nullptr; + void* __restrict__ rotary_sin_ptr = nullptr; + + // The indices to index into the KV cache. + int* __restrict__ cache_batch_idx = nullptr; + + // Local window size + int window_size_left = -1; + int window_size_right = -1; + + bool is_bf16 = false; + bool is_causal = false; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative = true; + + bool is_rotary_interleaved = false; + + int num_splits = 0; // For split-KV version + + const cudaDeviceProp* dprops = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); +} \ No newline at end of file diff --git a/operators/contrib/cuda/flash_attention/flash_api.cc b/operators/contrib/cuda/flash_attention/flash_api.cc new file mode 100644 index 000000000..73dd51fec --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_api.cc @@ -0,0 +1,465 @@ +#if USE_FLASH_ATTENTION + +#include "flash_api.h" +#include "flash.h" +#include "static_switch.h" +#include + +namespace flash { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal, + bool is_bf16, + bool kv_bsnh = true, + int window_size_left = -1, + int window_size_right = -1) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + params.is_bf16 = is_bf16; + + // All stride are in elements, not bytes. + if (kv_bsnh) { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads_k * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } else { + params.q_row_stride = num_heads * head_size; + params.k_row_stride = head_size; + params.v_row_stride = head_size; + params.q_head_stride = head_size; + params.k_head_stride = seqlen_k * head_size; + params.v_head_stride = seqlen_k * head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + } + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + // In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates + // local and causal, meaning when we have local window size + params.is_causal = is_causal; + if (is_causal && (window_size_left >= 0 || window_size_right != 0)) { + params.is_causal = false; + } + if (window_size_left < 0 && window_size_right >= 0) { + window_size_left = seqlen_k; + } + if (window_size_left >= 0 && window_size_right < 0) { + window_size_right = seqlen_k; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; +} + +size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads; + return bytes; +} + +size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded) { + size_t bytes = sizeof(float) * num_splits * batch_size * seqlen_q * num_heads * head_size_rounded; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split_kernel = false) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, + int max_splits) { + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (seqlen_q + 64 - 1) / 64; + int batch_nheads_mblocks = batch_size * num_heads * num_m_blocks; + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + continue; + } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh, + int local_window_size) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q*/ nullptr, + /*cu_seqlens_k*/ nullptr, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + kv_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + run_mha_fwd(params, stream); + return nullptr; +} + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, head_size) + void* out, // half (total_q, num_heads, head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_k, + nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + true, + -1, + is_causal ? 0 : -1); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + run_mha_fwd(params, stream); + return nullptr; +} + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); +} + +// This API is used when past key and value are present... since cached, these are assumed to have sequence length +// of max_sequence_length, so seqlen_k == max_sequence_length. The actual past sequence length is held in seqlens_k_. +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits, + void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // In kv-cache case, seqlen_k_max as kv sequence length + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*p_ptr=*/nullptr, + softmax_lse, + softmax_scale, + is_causal, + is_bf16, + past_bsnh, + local_window_size, + is_causal ? 0 : -1); + params.dprops = &dprops; + + if (k_new != nullptr && v_new != nullptr) { + params.seqlen_knew = seqlen_k_new; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; + // All stride are in elements, not bytes. + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } + params.knew_head_stride = head_size; + params.vnew_head_stride = head_size; + } else { + params.seqlen_knew = 0; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; + params.knew_batch_stride = 0; + params.vnew_batch_stride = 0; + params.knew_row_stride = 0; + params.vnew_row_stride = 0; + params.knew_head_stride = 0; + params.vnew_head_stride = 0; + } + + params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; + if (seqlens_k_ != nullptr) { + params.cu_seqlens_k = static_cast(seqlens_k_); + } + + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + + params.num_splits = num_splits; + if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { + params.softmax_lseaccum_ptr = softmax_lse_accum; + params.oaccum_ptr = out_accum; + } else { + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + } + + // Only split kernel supports appending to KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); + + return nullptr; +} + +} // namespace flash + +#endif // USE_FLASH_ATTENTION diff --git a/operators/contrib/cuda/flash_attention/flash_api.h b/operators/contrib/cuda/flash_attention/flash_api.h new file mode 100644 index 000000000..512b7a6d9 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_api.h @@ -0,0 +1,92 @@ +#pragma once + +#if USE_FLASH_ATTENTION + +#include +#include +#include +#include "onnxruntime_c_api.h" + +namespace flash { + +OrtStatusPtr mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16, + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + bool kv_bsnh = true, + int local_window_size = -1); + +OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, v_head_size) + void* out, // half (total_q, num_heads, v_head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal, + bool is_bf16); + +OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size + void* k, // batch_size x seqlen_k_new x num_heads_k x head_size + void* v, // batch_size x seqlen_k_new x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + int seqlen_k_new, + const float softmax_scale, + bool is_causal, + bool is_bf16, + bool past_bsnh, // otherwise bnsh + int num_splits = 0, + void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads + void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); + +size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); + +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); + +} // namespace flash + +#endif // USE_FLASH_ATTENTION diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..778941e8e --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..5cfb3019f --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..dda68cafc --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..3eb91029e --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..1d6cec57b --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..166b9a124 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..fd6e6693c --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..520c5482f --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..6b93f9627 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..12def28cb --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..6400a4829 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..81d19b481 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..d84464cc8 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..98fbc9a2e --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..c788cc92f --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template<> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..377d6118a --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_kernel.h b/operators/contrib/cuda/flash_attention/flash_fwd_kernel.h new file mode 100644 index 000000000..c44a470f6 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_kernel.h @@ -0,0 +1,1259 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +#include +#include +#include + +#include +#include +#include +#include + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, + Tensor2& acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + cute::Tensor scores_max_prev = make_fragment_like(scores_max); + cute::copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_sum); ++mi) { + scores_sum(mi) += scores_sum_cur(mi); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_tiled_copy_P) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + cute::Layout l = tOrP.layout(); + cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); +#pragma unroll + for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { + cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= n_block_min) { + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSE(row) = INFINITY; + } + } + return; + } + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + cute::Shape, cute::Int>{}, + make_stride(params.q_row_stride, _1{})); + cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + cute::Shape, cute::Int>{}, + make_stride(params.k_row_stride, _1{})); + cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + cute::Shape, cute::Int>{}, + make_stride(params.v_row_stride, _1{})); + cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + cute::Shape, cute::Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); + cute::Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); + cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < cute::size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + cute::Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + // I can't get the stride from idx_row + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + cute::Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + + // Convert acc_o from fp32 to fp16/bf16 + cute::Tensor rO = flash::convert_type(acc_o); + cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + cute::Shape, cute::Int>{}, + make_stride(params.o_row_stride, _1{})); + cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + cute::Shape>{}, cute::Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < cute::size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyOaccum, + typename Kernel_traits::GmemTiledCopyO>; + using ElementO = std::conditional_t; + + const BlockInfo binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { + gLSEaccum(row) = Split ? -INFINITY : INFINITY; + } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + Shape, Int>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + Tensor scores_max = make_tensor(Shape(acc_o)>>{}); + Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape, Int>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape, Int>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + } + } + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + if (n_block_max > n_block_copy_min) { + tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; + tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; + } + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + Tensor tQrQ = make_fragment_like(tQgQ); + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), + Shape, Int>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } else { + flash::copy_rotary_contiguous( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + // if (cute::thread0()) { print(scores); } + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal && !Is_local) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { + flash::apply_mask_local( + scores, n_block * kBlockN, binfo.actual_seqlen_k, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16, + params.window_size_left, params.window_size_right); + } + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + // if (cute::thread0()) { print(lse); } + // if (cute::thread0()) { print(acc_o_rowcol); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum>; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSEaccum(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); + // __syncthreads(); + // if (cute::thread0()) { print(tOgOaccum); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_splitkv(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void combine_attn_seqk_parallel(const Params& params) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + static_assert(kNThreads == 128, "We assume that each block has 128 threads"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape, Int>{}, + make_stride(params.b * params.h * params.seqlen_q, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then tranpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { + sLSE[row][col] = lse; + } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor(Shape>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // 16 rows, so each time we load we can load 8 rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_max = max(lse_max, lse_accum(l)); + } + MaxOp max_op; + lse_max = Allreduce::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); +#pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { + lse_sum += expf(lse_accum(l) - lse_max); + } + SumOp sum_op; + lse_sum = Allreduce::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } +// Store the scales exp(lse - lse_logsum) in shared memory. +#pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { + sLSE[row][col] = expf(lse_accum(l) - lse_logsum); + } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape, Int>{}, + Stride, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape, Int>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { + tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; + } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM); +#pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; +#pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { +#pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print(tOrO); } + + Tensor rO = flash::convert_type(tOrO); +// Write to gO +#pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; +#pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_launch_template.h b/operators/contrib/cuda/flash_attention/flash_fwd_launch_template.h new file mode 100644 index 000000000..e2f2505a7 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_launch_template.h @@ -0,0 +1,294 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +namespace flash { + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + flash::compute_attn_splitkv(params); +#else + (void)params; +#endif +} + +template +__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, false > ; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // printf("About to launch, Split = %d, Append_KV = %d, knew_ptr = %p\n", Split, Append_KV, params.knew_ptr); + auto kernel = &flash_fwd_splitkv_kernel < Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV > ; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>(params); + }); + }); + }); + }); + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int kBlockM = 64; // Fixed for all head dimensions + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 96; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 128; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 160; + bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 224; + int max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { + constexpr static int Headdim = 256; + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +} // namespace flash diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 000000000..8553913b2 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 000000000..8ed5afc6d --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 000000000..9d74a13ce --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 000000000..235eeaf69 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 000000000..a95bda783 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 000000000..23546d313 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 000000000..18b0d8010 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 000000000..5df080b63 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 000000000..f6eb273df --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 000000000..d2a3dfdc8 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 000000000..bbfe75396 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 000000000..75123f1d2 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 000000000..366ecefef --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 000000000..90845fb31 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 000000000..b71a69fce --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 000000000..81d87e161 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#if USE_FLASH_ATTENTION + +#include "flash_fwd_launch_template.h" + +namespace flash { + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +#endif diff --git a/operators/contrib/cuda/flash_attention/kernel_traits.h b/operators/contrib/cuda/flash_attention/kernel_traits.h new file mode 100644 index 000000000..48e899c2a --- /dev/null +++ b/operators/contrib/cuda/flash_attention/kernel_traits.h @@ -0,0 +1,367 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +using namespace cute; + +namespace flash { + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = cute::Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = cute::Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom; + + static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + cute::Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype(make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + cute::Layout>, + cute::Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); + // static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_32, _1>>{}, + cute::Layout>{})); // Val layout, 1 val per store +}; + +} // namespace flash diff --git a/operators/contrib/cuda/flash_attention/softmax.h b/operators/contrib/cuda/flash_attention/softmax.h new file mode 100644 index 000000000..9c31336c9 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/softmax.h @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_local(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash diff --git a/operators/contrib/cuda/flash_attention/static_switch.h b/operators/contrib/cuda/flash_attention/static_switch.h new file mode 100644 index 000000000..5b7098894 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/static_switch.h @@ -0,0 +1,64 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/operators/contrib/cuda/flash_attention/utils.h b/operators/contrib/cuda/flash_attention/utils.h new file mode 100644 index 000000000..cd10bd534 --- /dev/null +++ b/operators/contrib/cuda/flash_attention/utils.h @@ -0,0 +1,499 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting + // "int_tuple.hpp(74): error: conversion to inaccessible base class" + // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + // TD [2023-08-13]: Same error as above on Cutlass 3.2 + // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + // get<0, 1>(l), + // get<1, 1, 1>(l)); + return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), + get<1>(get<0>(l)), + get<1>(get<1>(get<1>(l)))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy tiled_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_w_min_idx(Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, + const int max_MN = 0, const int min_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_interleaved(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy_rotary_contiguous(Tensor const& S, + Tensor& D, + Tensor const& Cos, + Tensor const& Sin, + Tensor const& identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); +#pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/operators/contrib/cuda/group_query_attention.h b/operators/contrib/cuda/group_query_attention.h new file mode 100644 index 000000000..69206e105 --- /dev/null +++ b/operators/contrib/cuda/group_query_attention.h @@ -0,0 +1,489 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "cuda_type.h" +#include "ortx_common.h" +#include "../attention_common.h" +#include "group_query_attention_impl.cuh" +#include "device_prop.cuh" +#if USE_FLASH_ATTENTION +#include "flash_attention/flash_api.h" +#endif +#if USE_MEMORY_EFFICIENT_ATTENTION +#include "cutlass_fmha/memory_efficient_attention.h" +#endif + +namespace contrib { + +template +using UniquePtrWithDeletor = std::unique_ptr>; + +template +inline UniquePtrWithDeletor GetScratchBuffer(void* p, OrtAllocator* allocator) { + return UniquePtrWithDeletor{static_cast(p), [allocator = std::move(allocator)](T* p) { + allocator->Free(allocator, p); + }}; +} + +template +OrtStatusPtr CheckInputs(const Ort::Custom::Tensor& query, + std::optional*> key, + std::optional*> value, + std::optional*> past_key, + std::optional*> past_value, + std::optional*> cos_cache, + std::optional*> sin_cache, + void* parameters, + int num_heads, + int kv_num_heads, + const Ort::Custom::Tensor& seqlens_k, + const Ort::Custom::Tensor& total_seqlen, + bool is_past_bsnh, + float scale, + int max_threads_per_block) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return OrtW::CreateStatus(MakeString("num_heads should be no larger than ", max_threads_per_block), ORT_INVALID_ARGUMENT); + } + + // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // no packing for q/k/v: + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr + + AttentionQkvFormat qkv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; + const bool is_packed_qkv = !key.has_value(); + const auto& query_dims = query.Shape(); + + if (query_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'query' is expected to have 3 dimensions, got ", query_dims.size()), ORT_INVALID_ARGUMENT); + } + + int batch_size = static_cast(query_dims[0]); + int sequence_length = static_cast(query_dims[1]); + int q_hidden_size = static_cast(query_dims[2]); + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return OrtW::CreateStatus(MakeString("num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", num_heads % kv_num_heads), ORT_INVALID_ARGUMENT); + } + + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return OrtW::CreateStatus(MakeString("head_size must be a multiple of 8. Got head_size % 8 == ", head_size % 8), ORT_INVALID_ARGUMENT); + } + if (!value.has_value()) { + return OrtW::CreateStatus("Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.", ORT_INVALID_ARGUMENT); + } + const auto& key_dims = (*key)->Shape(); + if (key_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'key' is expected to have 3 dimensions, got ", key_dims.size()), ORT_INVALID_ARGUMENT); + } else if (query_dims[0] != key_dims[0]) { + return OrtW::CreateStatus("Input 'query' and 'key' shall have same dim 0 (batch size)", ORT_INVALID_ARGUMENT); + } else if (query_dims[1] != key_dims[1]) { + return OrtW::CreateStatus("Input 'query' and 'key' shall have same dim 1 (sequence length)", ORT_INVALID_ARGUMENT); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = (*value)->Shape(); + if (value_dims.size() != 3) { + return OrtW::CreateStatus(MakeString("Input 'value' is expected to have 3 dimensions, got ", value_dims.size()), ORT_INVALID_ARGUMENT); + } else if (query_dims[0] != value_dims[0]) { + return OrtW::CreateStatus("Input 'query' and 'value' shall have same dim 0 (batch size)", ORT_INVALID_ARGUMENT); + } else if (query_dims[1] != value_dims[1]) { + return OrtW::CreateStatus("Input 'query' and 'value' shall have same dim 1 (sequence length)", ORT_INVALID_ARGUMENT); + } else if (value_dims[2] != kv_hidden_size) { + return OrtW::CreateStatus("Input 'value' is expected to have same hidden size as key.", ORT_INVALID_ARGUMENT); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return OrtW::CreateStatus(MakeString("head_size must be a multiple of 8. Got head_size % 8 == ", head_size % 8), ORT_INVALID_ARGUMENT); + } + if (value.has_value()) { + return OrtW::CreateStatus("Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.", ORT_INVALID_ARGUMENT); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + + // Check past-present KV + int32_t past_sequence_length = 0; + if (past_key.has_value() && past_value.has_value()) { + const auto& past_key_dims = (*past_key)->Shape(); + const auto& past_value_dims = (*past_value)->Shape(); + + if (past_key_dims.size() != 4) { + return OrtW::CreateStatus(MakeString("Input 'past_key' is expected to have 4 dimensions, got ", past_key_dims.size()), ORT_INVALID_ARGUMENT); + } + if (past_value_dims.size() != 4) { + return OrtW::CreateStatus(MakeString("Input 'past_value' is expected to have 4 dimensions, got ", past_value_dims.size()), ORT_INVALID_ARGUMENT); + } + + if (past_key_dims[0] != batch_size) { + return OrtW::CreateStatus(MakeString("Input 'past_key' dimension 0 should be batch_size, got ", past_key_dims[0]), ORT_INVALID_ARGUMENT); + } + if (past_value_dims[0] != batch_size) { + return OrtW::CreateStatus(MakeString("Input 'past_value' dimension 0 should be batch_size, got ", past_value_dims[0]), ORT_INVALID_ARGUMENT); + } + + // BNSH + if (!is_past_bsnh) { + if (past_key_dims[2] != past_value_dims[2]) { + return OrtW::CreateStatus(MakeString("BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence length or past sequence length), got ", past_key_dims[1]), ORT_INVALID_ARGUMENT); + } + if (past_key_dims[1] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_key' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + if (past_value_dims[1] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_value' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); + // BSNH + } else { + if (past_key_dims[1] != past_value_dims[1]) { + return OrtW::CreateStatus(MakeString("BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence length or past sequence length), got ", past_key_dims[1]), ORT_INVALID_ARGUMENT); + } + if (past_key_dims[2] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_key' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + if (past_value_dims[2] != kv_num_heads) { + return OrtW::CreateStatus("Input 'past_value' shall have kv_num_heads", ORT_INVALID_ARGUMENT); + } + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[1]); + } + + if (past_key_dims[3] != head_size) { + return OrtW::CreateStatus(MakeString("Input 'past_key' dimension 3 should be same as head_size, got ", past_key_dims[3]), ORT_INVALID_ARGUMENT); + } + if (past_value_dims[3] != head_size) { + return OrtW::CreateStatus(MakeString("Input 'past_value' dimension 3 should be same as head_size, got ", past_value_dims[3]), ORT_INVALID_ARGUMENT); + } + } else if (past_key.has_value() || past_value.has_value()) { + return OrtW::CreateStatus("Input 'past_key' and 'past_value' shall be both present or both absent.", ORT_INVALID_ARGUMENT); + } + + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k.Shape(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return OrtW::CreateStatus("seqlens_k must be shape (batch_size).", ORT_INVALID_ARGUMENT); + } + + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + size_t num_dimensions = total_seqlen.Shape().size(); + int64_t shape_size = total_seqlen.NumberOfElement(); + if (!IsScalarOr1ElementVector(num_dimensions, shape_size)) { + return OrtW::CreateStatus("total_sequence_length tensor must be of one element.", ORT_INVALID_ARGUMENT); + } + int total_sequence_length = *(total_seqlen.Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + if (cos_cache.has_value() && sin_cache.has_value()) { + const auto& cos_dims = (*cos_cache)->Shape(); + const auto& sin_dims = (*sin_cache)->Shape(); + + if (head_size % 16 != 0) { + return OrtW::CreateStatus(MakeString("head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16), ORT_INVALID_ARGUMENT); + } + if (cos_dims[0] != present_sequence_length) { + return OrtW::CreateStatus("cos_cache dimension 0 must be of present_sequence_length.", ORT_INVALID_ARGUMENT); + } + if (sin_dims[0] != present_sequence_length) { + return OrtW::CreateStatus("sin_cache dimension 0 must be of present_sequence_length.", ORT_INVALID_ARGUMENT); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return OrtW::CreateStatus("cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.", ORT_INVALID_ARGUMENT); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return OrtW::CreateStatus("sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.", ORT_INVALID_ARGUMENT); + } + } else if (cos_cache.has_value() || sin_cache.has_value()) { + return OrtW::CreateStatus("Input 'cos_cache' and 'sin_cache' shall be both present or both absent.", ORT_INVALID_ARGUMENT); + } + + bool is_prompt = sequence_length != 1; + + if (parameters != nullptr) { + GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->hidden_size = q_hidden_size; + output_parameters->num_heads = num_heads; + output_parameters->head_size = head_size; + output_parameters->kv_hidden_size = kv_hidden_size; + output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; + output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; + output_parameters->scale = scale; + output_parameters->qkv_format = qkv_format; + output_parameters->past_kv_format = past_kv_format; + } + + return nullptr; +} + +template +struct GroupQueryAttention { + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 6) return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + static size_t GetMayInplace(int** input_index, int** output_index) { + size_t ret = 2; + *input_index = static_cast(malloc(ret * sizeof(int))); + (*input_index)[0] = 3; + (*input_index)[1] = 4; + *output_index = static_cast(malloc(ret * sizeof(int))); + (*output_index)[0] = 1; + (*output_index)[1] = 2; + return 2; + } + + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + int64_t num_heads = 0, kv_num_heads = 0; + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "num_heads", num_heads)); + ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "kv_num_heads", kv_num_heads)); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + local_window_size_ = static_cast(OrtW::GetOpAttributeOrDefault(info, "local_window_size", -1)); + do_rotary_ = OrtW::GetOpAttributeOrDefault(info, "do_rotary", 0) == 1; + rotary_interleaved_ = OrtW::GetOpAttributeOrDefault(info, "rotary_interleaved", 0) == 1; + scale_ = OrtW::GetOpAttributeOrDefault(info, "scale", 0.0f); + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); +#else + disable_flash_attention_ = true; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif + +#if ORT_API_VERSION >= 18 + if (!disable_flash_attention_) { + OrtAllocator* allocator = nullptr; + ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); + allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; + zeros_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), kZerosCount), allocator_.get()); + } +#endif + return nullptr; + } + + OrtStatusPtr Compute(OrtKernelContext* kernel_context, const Ort::Custom::CudaContext& ctx, const ortc::Tensor& query, std::optional*> key, + std::optional*> value, std::optional*> past_key, std::optional*> past_value, + const ortc::Tensor& seqlens_k, const ortc::Tensor& total_seqlen, std::optional*> cos_cache, + std::optional*> sin_cache, ortc::Tensor& attn_out, std::optional*> present_key, std::optional*> present_value) const { + GroupQueryAttentionParameters parameters; + ORTX_RETURN_IF_ERROR(CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, ¶meters, num_heads_, kv_num_heads_, + seqlens_k, total_seqlen, is_past_bsnh_, scale_, DeviceProp::GetCudaDeviceProp().maxThreadsPerBlock)); + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + std::vector output_shape(3, 0); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + + OrtMemoryInfo* mem_info = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx.device_id, OrtMemTypeDefault, &mem_info)); + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && flash::is_supported(DeviceProp::GetCudaDeviceProp(), parameters.head_size, parameters.num_heads, parameters.kv_num_heads); + // Allocate buffers + size_t softmax_lse_bytes = 0; + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + // softmax buffer + softmax_lse_bytes = flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.head_size, DeviceProp::GetCudaDeviceProp().multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } +#if ORT_API_VERSION >= 18 + void* softmax_lse_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_bytes, &soft_lse_p)); + auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_p, allocator_.get()); + + void* softmax_lse_accum_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_accum_bytes, &soft_lse_accum_p)); + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_p, allocator_.get()); + + void* out_accum_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, out_accum_bytes, &out_accum_p)); + auto out_accum_buffer = GetScratchBuffer(out_accum_p, allocator_.get()); +#endif +#else + constexpr bool use_flash_attention = false; + UniquePtrWithDeletor softmax_lse_buffer = nullptr; + UniquePtrWithDeletor softmax_lse_accum_buffer = nullptr; + UniquePtrWithDeletor out_accum_buffer = nullptr; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (DeviceProp::GetCudaDeviceProp().major * 10) + DeviceProp::GetCudaDeviceProp().minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + cuda::has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && cuda::MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } +#if ORT_API_VERSION >= 18 + void* k_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, kv_buffer_bytes, &k_p)); + auto k_buffer = GetScratchBuffer(k_p, allocator_.get()); + + void* v_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, kv_buffer_bytes, &v_p)); + auto v_buffer = GetScratchBuffer(v_p, allocator_.get()); + + void* fmha_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, fmha_buffer_bytes, &fmha_p)); + auto fmha_buffer = GetScratchBuffer(fmha_p, allocator_.get()); +#endif +#else + constexpr bool use_memory_efficient_attention = false; + UniquePtrWithDeletor k_buffer = nullptr; + UniquePtrWithDeletor v_buffer = nullptr; + UniquePtrWithDeletor fmha_buffer = nullptr; +#endif + + // seqlens_k buffer + size_t seqlens_k_bytes = sizeof(int) * parameters.batch_size; + void* seqlens_p = nullptr; + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, seqlens_k_bytes, &seqlens_p)); + auto seqlens_k_buffer = GetScratchBuffer(seqlens_p, allocator_.get()); + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + } + + using TT = typename CudaT::MappedType; + cuda::GroupQueryAttentionData data; + data.query = reinterpret_cast(query.Data()); + data.key = key.has_value() ? reinterpret_cast((*key)->Data()) : nullptr; + data.value = value.has_value() ? reinterpret_cast((*value)->Data()) : nullptr; + data.past_key = past_key.has_value() ? reinterpret_cast((*past_key)->Data()) : nullptr; + data.past_value = past_value.has_value() ? reinterpret_cast((*past_value)->Data()) : nullptr; + data.output = reinterpret_cast(attn_out.Allocate(output_shape)); + data.present_key = present_key.has_value() ? reinterpret_cast((*present_key)->Allocate(present_dims)) : nullptr; + data.present_value = present_value.has_value() ? reinterpret_cast((*present_value)->Allocate(present_dims)) : nullptr; + data.seqlens_k = const_cast(seqlens_k.Data()); + data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; + if (data.past_key == data.present_key) { + parameters.kv_share_buffer = true; + } else { + parameters.kv_share_buffer = false; + } + // Flash Buffers + if (softmax_lse_buffer != nullptr) { + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); + } + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } + if (seqlens_k_buffer != nullptr) { + data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + } + // Memory Efficient Buffers + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast((*cos_cache)->Data()); + data.sin_cache = reinterpret_cast((*sin_cache)->Data()); + } + + OrtW::API::ReleaseMemoryInfo(mem_info); + return cuda::QkvToContext( + /*device_prop, ctx.cublas,*/ reinterpret_cast(ctx.cuda_stream), parameters, data); + } + + private: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + bool disable_flash_attention_; + bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + UniquePtrWithDeletor zeros_; + UniquePtrWithDeletor allocator_; // TODO(leca): Does the release order of allocator_ and zeros_ matter? +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/contrib/cuda/group_query_attention_impl.cu b/operators/contrib/cuda/group_query_attention_impl.cu new file mode 100644 index 000000000..550e3a251 --- /dev/null +++ b/operators/contrib/cuda/group_query_attention_impl.cu @@ -0,0 +1,661 @@ +#include +#include +#include "group_query_attention_impl.cuh" +#include "utils.cuh" +#include "device_prop.cuh" +#include "onnxruntime_no_customop.h" +#include "ortx_common.h" +#ifdef USE_FLASH_ATTENTION +#include "flash_attention/flash_api.h" +#endif +#ifdef USE_MEMORY_EFFICIENT_ATTENTION +#include "cutlass_fmha/memory_efficient_attention.h" +#endif + +namespace contrib { +namespace cuda { + +////////// Auxiliary Kernels for KV prep + +// Kernel for seqlens_k +__global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { + int id = blockDim.x * blockIdx.x + threadIdx.x; + if (id < batch_size) seqlens_k[id] = seqlen; +} + +// Kernel to append new and past kv in either BSNH or BNSH format +// Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file +template +__global__ void ConcatNewToPastKV(const int new_seqlen, + const int past_buffer_seqlen, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to past; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int present_buffer_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } +} + +// Use when (H*)*num_heads > 1024 +template +__global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int past_buffer_seqlen, + const int H, + const int num_heads, + const T* past_kv, + const T* new_kv, + T* present_kv, + const int* seqlens_k, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_buffer_seqlen = gridDim.y; + + const int present_batch_stride = present_buffer_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; + if (s < past_seqlen) { + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; + const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; + present_kv[out_offset] = past_kv[in_offset]; + } else if (s < past_seqlen + new_seqlen) { + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; + present_kv[out_offset] = new_kv[in_offset]; + } + } +} + +// Concat new to past in present. Supports past BSNH or past BNSH +template +OrtStatusPtr LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block, + const bool past_only = false) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int past_sequence_length = parameters.seqlen_past_kv_cache; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CudaCall(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int max_seqlen, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int max_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const int* seqlens_k, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = max_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +OrtStatusPtr LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + // Indicates past sequence_length of each sequence + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + seqlens_k, + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CudaCall(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +OrtStatusPtr LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CudaCall(cudaGetLastError()); +} + +__global__ void PastToTotalSeqlen(int32_t* seqlens_k, + int32_t* seqlens_k_buff, + const int add_seqlen) { + seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; +} + +// Convert Past to Total sequence length tensor +OrtStatusPtr LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, + const int threads_per_block) { + if (parameters.is_prompt) { + return nullptr; + } + const int batch_size = parameters.batch_size; + const int add_seqlen = is_total ? parameters.sequence_length : 0; + + const dim3 grid(1, 1, 1); + // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads + const dim3 block(batch_size, 1, 1); + + // TODO(aciddelgado): small version + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); + + return CudaCall(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +OrtStatusPtr FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + bool is_causal = true; + bool is_bf16 = std::is_same::value; + + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; + + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); + } else { + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } + + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; + } + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORTX_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); + } + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORTX_RETURN_IF_ERROR(flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORTX_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + + // TODO: DUMP_TENSOR_INIT(); + // TODO: DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return nullptr; +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +OrtStatusPtr EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } else { + ORTX_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + // Concatenate new kv in place + ORTX_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + // Not share buffer case + if (data.past_key != nullptr && data.past_key == data.present_key) { + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + // Copy past and concat new KV to present buffer + ORTX_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + } + + // Ungroup if grouped, otherwise use present kv directly + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORTX_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + // TODO: DUMP_TENSOR_INIT(); + // TODO: DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.max_sequence_length = present_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = true; + p.scale = scale; + p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + p.has_custom_right_padding = true; + run_memory_efficient_attention(p); + + // TODO: DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return nullptr; +} +#endif + +////////// API Functions + +template +OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + const cudaDeviceProp& device_prop = DeviceProp::GetCudaDeviceProp(); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, cuda_stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, cuda_stream, parameters, data, scale); + } +#endif + + return OrtW::API::CreateStatus(ORT_INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); +} + +template struct GroupQueryAttentionData; + +template OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData; + +template OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); +} +} // namespace contrib \ No newline at end of file diff --git a/operators/contrib/cuda/group_query_attention_impl.cuh b/operators/contrib/cuda/group_query_attention_impl.cuh new file mode 100644 index 000000000..df8021e75 --- /dev/null +++ b/operators/contrib/cuda/group_query_attention_impl.cuh @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "onnxruntime_c_api.h" +#include "../attention_common.h" + +namespace contrib { +namespace cuda { + +template +struct GroupQueryAttentionData { + // Input Tensors + const T* query = nullptr; + const T* key = nullptr; + const T* value = nullptr; + const T* past_key = nullptr; + const T* past_value = nullptr; + int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; + int* seqlens_k_total = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors + T* output = nullptr; + T* present_key = nullptr; + T* present_value = nullptr; + // Kernel Flags + bool use_flash_attention = false; + bool use_memory_efficient_attention = false; +}; + +template +OrtStatusPtr QkvToContext( +// const cudaDeviceProp& device_prop, +// cublasHandle_t& cublas, // TODO: cublas is not used at all + cudaStream_t cuda_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +} // namespace cuda +} // namespace contrib diff --git a/operators/contrib/cuda/utils.cuh b/operators/contrib/cuda/utils.cuh index a40bf8f39..6376952d8 100644 --- a/operators/contrib/cuda/utils.cuh +++ b/operators/contrib/cuda/utils.cuh @@ -192,3 +192,8 @@ __device__ __inline__ half2 _Tanh(half2 a) { template <> __device__ __inline__ BFloat16 _Tanh(BFloat16 a) { return tanhf(static_cast(a)); } + +inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { + if (cuda_error == cudaSuccess) return nullptr; + return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); +} \ No newline at end of file diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index d868fe675..eafbda328 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -1,12 +1,602 @@ import unittest import numpy as np from numpy.testing import assert_almost_equal -from onnx import helper, onnx_pb as onnx_proto +from onnx import helper, onnx_pb as onnx_proto, TensorProto from onnxruntime_extensions import make_onnx_model from onnxruntime_extensions import get_library_path as _get_library_path import onnxruntime as _ort +import math +import os +import platform +import random +import torch +from einops import rearrange, repeat +from onnxruntime import InferenceSession, OrtValue, SessionOptions + +torch.manual_seed(0) +class Formats: + BSNH = 0 + BNSH = 1 + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + past_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, sp, n, n2, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.past_sequence_length = sp + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + +def create_group_query_attention_graph_past( + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, +): + past_kv_seqlen = config.kv_sequence_length + present_kv_seqlen = ( + config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length + ) + #pdb.set_trace() + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not packed else "", + "value" if not packed else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", +# "cos_cache" if rotary else "", +# "sin_cache" if rotary else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, + domain="ai.onnx.contrib", + ), + ] + + graph_input = [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), + ], + ), + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, + config.head_size, + ], + ), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = make_onnx_model(graph) + return model + +def gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, +): + onnx_model = create_group_query_attention_graph_past( + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, + ) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + past_k = k.clone() + past_v = v.clone() + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(np.int32), + "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + sess_options.register_custom_ops_library(_get_library_path()) + ort_session = InferenceSession(onnx_model.SerializeToString(), sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_input( + "past_key", "cuda", 0, np.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + np.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = np.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "past_key": past_k.detach().cpu().numpy(), + "past_value": past_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(np.int32), + "total_sequence_length": torch.tensor( + [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 + ) + .detach() + .cpu() + .numpy(), + } + sess_options = SessionOptions() + sess_options.register_custom_ops_library(_get_library_path()) + ort_session = InferenceSession(onnx_model.SerializeToString(), sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = np.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + +def parity_check_gqa_past_no_buff( + config, + causal=False, + local=False, + past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, + rtol=1e-3, + atol=1e-3, +): + torch.manual_seed(69) + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + window_size = (-1, -1) + left_window_size = -1 + if local: + left_window_size = random.randint(0, config.kv_sequence_length) + window_size = (left_window_size, 0) + elif causal: + left_window_size = -1 + window_size = (-1, 0) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + cos, sin = None, None + q_ro, k_ro = q, new_k + + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref( + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + ) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + # Compare results + print( + "NO buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + np.mean(np.abs(out - out_ref)), + np.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) class TestCudaOps(unittest.TestCase): @staticmethod @@ -116,6 +706,82 @@ def test_cuda_fastgelu_f16(self): else: print ('CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.') + @staticmethod + def _create_GroupQueryAttention_test_model(domain='ai.onnx.contrib'): + nodes = [ + helper.make_node( + 'GroupQueryAttention', + #['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seqlen', 'cos_cache', 'sin_cache'], + ['query', 'key', 'value', 'past_key', 'past_value', 'seqlens_k', 'total_seqlen'], + ['attn_out', 'present_key', 'present_value'], + #domain=domain, num_heads=32, kv_num_heads=32, scale=0.0, local_window_size=-1, do_rotary=0, rotary_interleaved=0) + domain=domain, num_heads=32, kv_num_heads=32) + ] + + query = helper.make_tensor_value_info( + 'query', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + key = helper.make_tensor_value_info( + 'key', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + value = helper.make_tensor_value_info( + 'value', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + past_key = helper.make_tensor_value_info( + 'past_key', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) + past_value = helper.make_tensor_value_info( + 'past_value', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) + seqlens_k = helper.make_tensor_value_info( + 'seqlens_k', onnx_proto.TensorProto.INT32, [5]) + total_seqlen = helper.make_tensor_value_info( + 'total_seqlen', onnx_proto.TensorProto.INT32, [1]) +# cos_cache = helper.make_tensor_value_info( +# 'cos_cache', onnx_proto.TensorProto.FLOAT, []) +# sin_cache = helper.make_tensor_value_info( +# 'sin_cache', onnx_proto.TensorProto.FLOAT, []) + attn_out = helper.make_tensor_value_info( + 'attn_out', onnx_proto.TensorProto.FLOAT16, [5,1,512]) + present_key = helper.make_tensor_value_info( + 'present_key', onnx_proto.TensorProto.FLOAT16, [5,32,129,16]) + present_value = helper.make_tensor_value_info( + 'present_value', onnx_proto.TensorProto.FLOAT16, [5,32,129,16]) + + graph = helper.make_graph(nodes, 'testgqa', + #[query, key, value, past_key, past_value, seqlens_k, total_seqlen, cos_cache, sin_cache], + [query, key, value, past_key, past_value, seqlens_k, total_seqlen], + [attn_out, present_key, present_value]) + model = make_onnx_model(graph) + return model + + def test_cuda_GroupQueryAttention(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_GroupQueryAttention_test_model() + #self.assertIn('op_type: "NegPos"', str(onnx_model)) + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) + query = np.random.randn(5,1,512).astype(np.float16) + key = np.random.randn(5,1,512).astype(np.float16) + value = np.random.randn(5,1,512).astype(np.float16) + past_key = np.random.randn(5,32,128,16).astype(np.float16) + past_value = np.random.randn(5,32,128,16).astype(np.float16) + seqlens_k = np.array([128, 87, 0, 22, 125]).astype(np.int32) + total_seqlen = np.array([129]).astype(np.int32) + y = sess.run(None, {'query':query, 'key':key, 'value':value, 'past_key':past_key, 'past_value':past_value, 'seqlens_k':seqlens_k, 'total_seqlen':total_seqlen}) + + def test_cuda_GroupQueryAttention_iobinding(self): + random.seed(69) + for b in [5]: + for s, s2 in [(1,128)]: + for n, n2 in [(32, 32)]: + for h in [16]: + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() From a7eeca40e36e8c4bbc9891025ae51195561f9b0b Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 1 May 2024 23:30:29 +0000 Subject: [PATCH 2/5] fix build break --- operators/contrib/contrib.cc | 6 +++++- operators/contrib/cuda/group_query_attention.h | 10 ++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/operators/contrib/contrib.cc b/operators/contrib/contrib.cc index 3efb99a51..f4b58493e 100644 --- a/operators/contrib/contrib.cc +++ b/operators/contrib/contrib.cc @@ -5,8 +5,10 @@ #ifdef USE_CUDA #include "cuda/fast_gelu.h" +#if ORT_API_VERSION >= 18 #include "cuda/group_query_attention.h" #endif +#endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { static OrtOpLoader op_loader( @@ -14,9 +16,11 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #ifdef USE_CUDA , CustomCudaStructV2("FastGelu", contrib::FastGelu), -#if ORT_API_VERSION >= 16 +#if ORT_API_VERSION >= 18 CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), CustomCudaStructV2("GroupQueryAttention", contrib::GroupQueryAttention), +#endif +#if ORT_API_VERSION >= 16 CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu) #endif diff --git a/operators/contrib/cuda/group_query_attention.h b/operators/contrib/cuda/group_query_attention.h index 69206e105..edf26b82a 100644 --- a/operators/contrib/cuda/group_query_attention.h +++ b/operators/contrib/cuda/group_query_attention.h @@ -278,14 +278,12 @@ struct GroupQueryAttention { disable_memory_efficient_attention_ = true; #endif -#if ORT_API_VERSION >= 18 if (!disable_flash_attention_) { OrtAllocator* allocator = nullptr; ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); allocator_ = UniquePtrWithDeletor{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; zeros_ = GetScratchBuffer(allocator_->Alloc(allocator_.get(), kZerosCount), allocator_.get()); } -#endif return nullptr; } @@ -331,19 +329,17 @@ struct GroupQueryAttention { softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; } -#if ORT_API_VERSION >= 18 void* softmax_lse_p = nullptr; - ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_bytes, &soft_lse_p)); + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_bytes, &softmax_lse_p)); auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_p, allocator_.get()); void* softmax_lse_accum_p = nullptr; - ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_accum_bytes, &soft_lse_accum_p)); + ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, softmax_lse_accum_bytes, &softmax_lse_accum_p)); auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_p, allocator_.get()); void* out_accum_p = nullptr; ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, out_accum_bytes, &out_accum_p)); auto out_accum_buffer = GetScratchBuffer(out_accum_p, allocator_.get()); -#endif #else constexpr bool use_flash_attention = false; UniquePtrWithDeletor softmax_lse_buffer = nullptr; @@ -374,7 +370,6 @@ struct GroupQueryAttention { if (use_memory_efficient_attention && cuda::MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); } -#if ORT_API_VERSION >= 18 void* k_p = nullptr; ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, kv_buffer_bytes, &k_p)); auto k_buffer = GetScratchBuffer(k_p, allocator_.get()); @@ -386,7 +381,6 @@ struct GroupQueryAttention { void* fmha_p = nullptr; ORTX_RETURN_IF_ERROR(OrtW::API::KernelContextGetScratchBuffer(kernel_context, mem_info, fmha_buffer_bytes, &fmha_p)); auto fmha_buffer = GetScratchBuffer(fmha_p, allocator_.get()); -#endif #else constexpr bool use_memory_efficient_attention = false; UniquePtrWithDeletor k_buffer = nullptr; From ce44a244ea8d317853372c1dc854200d70b8e8e5 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 2 May 2024 20:09:53 +0000 Subject: [PATCH 3/5] add ReleaseMayInplace support and GQA is working under UT --- includes/custom_op_lite.h | 1 + includes/onnxruntime_customop.hpp | 11 +++++++++++ operators/contrib/cuda/group_query_attention.h | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/includes/custom_op_lite.h b/includes/custom_op_lite.h index 0e2caa3cb..2282b7dd6 100644 --- a/includes/custom_op_lite.h +++ b/includes/custom_op_lite.h @@ -1038,6 +1038,7 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t { return 0; }; + OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {}; #endif } diff --git a/includes/onnxruntime_customop.hpp b/includes/onnxruntime_customop.hpp index a0d965e47..d8f5d7c70 100644 --- a/includes/onnxruntime_customop.hpp +++ b/includes/onnxruntime_customop.hpp @@ -74,6 +74,12 @@ struct CustomOp_defined_getMayInplace : std::false_type {}; template struct CustomOp_defined_getMayInplace> : std::true_type {}; +template +struct CustomOp_defined_releaseMayInplace : std::false_type {}; + +template +struct CustomOp_defined_releaseMayInplace> : std::true_type {}; + template struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { using ComputeFunction = decltype(&CustomOpKernel::Compute); @@ -158,6 +164,11 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { return CustomOpKernel::GetMayInplace(input_index, output_index); }; } + if constexpr (CustomOp_defined_releaseMayInplace::value) { + OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void { + CustomOpKernel::ReleaseMayInplace(input_index, output_index); + }; + } #endif OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, diff --git a/operators/contrib/cuda/group_query_attention.h b/operators/contrib/cuda/group_query_attention.h index edf26b82a..3bf5abeaf 100644 --- a/operators/contrib/cuda/group_query_attention.h +++ b/operators/contrib/cuda/group_query_attention.h @@ -253,6 +253,11 @@ struct GroupQueryAttention { return 2; } + static void ReleaseMayInplace(int* input_index, int* output_index) { + free(input_index); + free(output_index); + } + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { int64_t num_heads = 0, kv_num_heads = 0; ORTX_RETURN_IF_ERROR(OrtW::GetOpAttribute(info, "num_heads", num_heads)); From 9cc74a34f6defbed4bc576bf48990a6e919ac294 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 21 May 2024 00:53:52 +0000 Subject: [PATCH 4/5] fix UniquePtr order to avoid crash, start UT to validate paged attention --- .pyproject/cmdclass.py | 7 +++- .../contrib/cuda/group_query_attention.h | 8 +++- test/cuda/test_cudaops.py | 40 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/.pyproject/cmdclass.py b/.pyproject/cmdclass.py index 04a69120e..be4e253ff 100644 --- a/.pyproject/cmdclass.py +++ b/.pyproject/cmdclass.py @@ -147,6 +147,7 @@ def initialize_options(self): self.no_opencv = None self.cc_debug = None self.cuda_archs = None + self.ort_pkg_dir = None def _parse_options(self, options): for segment in options.split(','): @@ -188,7 +189,8 @@ def build_cmake(self, extension): ext_fullpath = pathlib.Path( self.get_ext_fullpath(extension.name)).absolute() - config = 'RelWithDebInfo' if self.debug else 'Release' +# config = 'RelWithDebInfo' if self.debug else 'Release' + config = 'Debug' if self.debug else 'Release' cmake_args = [ '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(ext_fullpath.parent.absolute()), @@ -198,6 +200,9 @@ def build_cmake(self, extension): '-DCMAKE_BUILD_TYPE=' + config ] + if self.ort_pkg_dir: + cmake_args += ['-DONNXRUNTIME_PKG_DIR=' + self.ort_pkg_dir] + if self.no_opencv: # Disabling openCV can drastically reduce the build time. cmake_args += [ diff --git a/operators/contrib/cuda/group_query_attention.h b/operators/contrib/cuda/group_query_attention.h index 3bf5abeaf..24d03dbd6 100644 --- a/operators/contrib/cuda/group_query_attention.h +++ b/operators/contrib/cuda/group_query_attention.h @@ -15,6 +15,12 @@ #include "cutlass_fmha/memory_efficient_attention.h" #endif +/* + * Usage: + * pip3 install . --config-settings "ortx-user-option=use-cuda,cc_debug,ort_pkg_dir=/home/leca/ort_pkg_19" + * python3 test_cudaops.py TestCudaOps.test_cuda_GroupQueryAttention + */ + namespace contrib { template @@ -481,8 +487,8 @@ struct GroupQueryAttention { bool disable_flash_attention_; bool disable_memory_efficient_attention_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) - UniquePtrWithDeletor zeros_; UniquePtrWithDeletor allocator_; // TODO(leca): Does the release order of allocator_ and zeros_ matter? + UniquePtrWithDeletor zeros_; }; } // namespace contrib \ No newline at end of file diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index eafbda328..910a57bd0 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -783,5 +783,45 @@ def test_cuda_GroupQueryAttention_iobinding(self): atol=1e-3, ) + @staticmethod + def _create_GroupQueryAttention_test_model_validate_PA(domain='ai.onnx.contrib'): + nodes = [ + helper.make_node( + 'GroupQueryAttention', + ['query', 'key', 'value', 'seqlens_k', 'total_seqlen'], + ['attn_out'], + domain=domain, num_heads=32, kv_num_heads=32) + ] + + query = helper.make_tensor_value_info( + 'query', onnx_proto.TensorProto.FLOAT16, [5,34,512]) + key = helper.make_tensor_value_info( + 'key', onnx_proto.TensorProto.FLOAT16, [5,34,512]) + value = helper.make_tensor_value_info( + 'value', onnx_proto.TensorProto.FLOAT16, [5,34,512]) + seqlens_k = helper.make_tensor_value_info( + 'seqlens_k', onnx_proto.TensorProto.INT32, [5]) + total_seqlen = helper.make_tensor_value_info( + 'total_seqlen', onnx_proto.TensorProto.INT32, [1]) + attn_out = helper.make_tensor_value_info( + 'attn_out', onnx_proto.TensorProto.FLOAT16, [5,34,512]) + + graph = helper.make_graph(nodes, 'testgqa', + [query, key, value, seqlens_k, total_seqlen], + [attn_out]) + model = make_onnx_model(graph) + return model + + def test_cuda_GroupQueryAttention_validate_PagedAttention(self): + query = np.load('query.npy') + key = np.load('key.npy') + value = np.load('value.npy') + query_batch = np.random.randn(5, 34, 512).astype(np.float16) + key_batch = np.random.randn(5, 34, 512).astype(np.float16) + value_batch = np.random.randn(5, 34, 512).astype(np.float16) + #query_batch[0, 0:] + seqlens_k = np.array([5, 12, 16, 20, 34]).astype(np.int32) + total_seqlen = np.array([87]).astype(np.int32) + if __name__ == "__main__": unittest.main() From 81c33a481f73d2c09f43bf805086d607fffda16a Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 22 May 2024 00:03:34 +0000 Subject: [PATCH 5/5] hack GQA to change the order of parameters to make GQA run without past_key nor past_value. Runtime CUDA error when copying result from GPU to CPU --- .../contrib/cuda/group_query_attention.h | 17 ++++++---- test/cuda/key.npy | Bin 0 -> 89216 bytes test/cuda/query.npy | Bin 0 -> 89216 bytes test/cuda/test_cudaops.py | 30 +++++++++++++++++- test/cuda/value.npy | Bin 0 -> 89216 bytes 5 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 test/cuda/key.npy create mode 100644 test/cuda/query.npy create mode 100644 test/cuda/value.npy diff --git a/operators/contrib/cuda/group_query_attention.h b/operators/contrib/cuda/group_query_attention.h index 24d03dbd6..8979e87ab 100644 --- a/operators/contrib/cuda/group_query_attention.h +++ b/operators/contrib/cuda/group_query_attention.h @@ -244,15 +244,18 @@ OrtStatusPtr CheckInputs(const Ort::Custom::Tensor& query, template struct GroupQueryAttention { static OrtMemType GetInputMemoryType(size_t input_index) { - if (input_index == 6) return OrtMemType::OrtMemTypeCPUInput; +// if (input_index == 6) return OrtMemType::OrtMemTypeCPUInput; // total_seqlen + if (input_index == 4) return OrtMemType::OrtMemTypeCPUInput; // total_seqlen return OrtMemType::OrtMemTypeDefault; } - static size_t GetMayInplace(int** input_index, int** output_index) { + static size_t GetMayInplace(int** input_index, int** output_index) { // past_key <=> key, past_value <=> value size_t ret = 2; *input_index = static_cast(malloc(ret * sizeof(int))); - (*input_index)[0] = 3; - (*input_index)[1] = 4; +// (*input_index)[0] = 3; +// (*input_index)[1] = 4; + (*input_index)[0] = 5; + (*input_index)[1] = 6; *output_index = static_cast(malloc(ret * sizeof(int))); (*output_index)[0] = 1; (*output_index)[1] = 2; @@ -299,8 +302,10 @@ struct GroupQueryAttention { } OrtStatusPtr Compute(OrtKernelContext* kernel_context, const Ort::Custom::CudaContext& ctx, const ortc::Tensor& query, std::optional*> key, - std::optional*> value, std::optional*> past_key, std::optional*> past_value, - const ortc::Tensor& seqlens_k, const ortc::Tensor& total_seqlen, std::optional*> cos_cache, +// std::optional*> value, std::optional*> past_key, std::optional*> past_value, +// const ortc::Tensor& seqlens_k, const ortc::Tensor& total_seqlen, std::optional*> cos_cache, + std::optional*> value, const ortc::Tensor& seqlens_k, const ortc::Tensor& total_seqlen, + std::optional*> past_key, std::optional*> past_value, std::optional*> cos_cache, std::optional*> sin_cache, ortc::Tensor& attn_out, std::optional*> present_key, std::optional*> present_value) const { GroupQueryAttentionParameters parameters; ORTX_RETURN_IF_ERROR(CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, ¶meters, num_heads_, kv_num_heads_, diff --git a/test/cuda/key.npy b/test/cuda/key.npy new file mode 100644 index 0000000000000000000000000000000000000000..a9ec93b04abd90a19413687499949a2a8180177f GIT binary patch literal 89216 zcmbT7WnV`NFTtO@K z$9ktYr&|yfaNXUn!Um&+CjIL(Aq=wY+>ilUtqwuiP6 zq~!doeFyd6K6E^m!P!J*_+QY;*Vc*v4IKA2M$OUV;B$Bn{|YaZd=PZyB5kDA(RYkF zm0g+cuJu!w(#?q0L52;ejf`=0I-JdSS7ES?MwtUKiTA+q{B~bWlZP7qLG(8N#$!;u zgyqISkK`D@?qRCHfbo!hW302hkq5;r*bloM$N33j4LsfG;DlUh3k3VHL&-mIZ>1ra z$VcH6SV>-^1?m>h`;5974A+J%VBE^{>;qsZYGIW(Ml(|zGgeC+iOS(iatWxB-}tiX zZQ-=B&+Npv1%HC(YIV?Qe^Txgbc{O1k4BGU1~t0J{^m+pU*QsEY+eVo1FG+jO8Tf> z5=U4QQiF!*uNZqMwioE7tTEs4SxI{WpNT@!PjVYb*TcBoQX_U6c@`byGcB6Q2!7=& z(5)lpkV#}kV>Ek+xZ!>dR|(C)TK%UdHP{@sxBmoRxoV6AM;`7WjAVvVQt30=!&3khNbZ19wGAf%0e*EQ$7OndBJkCla_0>`zAz7~yJ`P@wkE{myialWmajj7@|GqFNp!jENubrpVFf#**@VE30hBKS>Wx*xK9so70BtqX?z1Vn*T*k z;GfWo;X1ks-QH?$^&>6jCHK-iYR^|bX`kc~WCL^>C&}N*_tIcM9 zK1w|l4xWPEz+Nl1%ztDHJ zKnS?XdjFsXQU}#IbUV-h9VEjn!n6_^(dD&K=#jSA*az>qV!1WU40A1U4SSifiM_>& zuoQe0@7K>+oM!TCnU&@c!i5KCy?_mqCgv0B0Bt`SsLyiK-~oTf>lR3W`%U@^`KXw}DD?|7M)QnM}N=ESbdAamQx=!=>8#*lR-qH}Nf!ef&<8%zb7PtN@A; zTe?T6>GD0Muh&}g_7 z6+vFK9?!P2k9^}#l9kk0u0|j{Z%$5EW+q=fci!HQs55NL^vhX~N)e~k%i3sMM;wAK z!7O!t@K0)A!~0Z;sYk|r}Rr4ME)PtO1_1~5&)XaCpW6i1T`9^As~ znE-BTd{u8qZS@ISfygQCktW3HI=6v2BzIw+gEr=*Mfu9e5W)AcGaPIY2%Bsz#6NDLa zu(`kyJeh371p8Xw`Q`E(4Z$PM_BmZ(O{z4%4Q_>t8rG$M1%A5_E|8Z{6HL~sObn(9 z>y_%U${Er=Xh*Zl&1tU(`?P(WFMj*+CV=;MymJuh2|%5mc5~{+QX1d z4!+H?S>udthEq-BW|$?3?3^THm5`>7Q5_&pA82)z%J@>$?x14UcJlCgk)&W0+NP za}Apna>B`9xE#7IR;SMgA9`-^T)N2X1;XQtVvcNsweFmN(Xnu_@f$Nl!{+$`Ct`GGatUbdIIIR3suM&Sr z-_UI9i+TX=1dY|hOm}gWFbmOG%wA3v79Ru`@l)IzatQ6anyp?i>iZjzvGAKT!|Lsh zWzi8TI}Hv*@tOA%Pby91jln4963XF6L?$v6vmzLXUz^j59_P)^8JQR+&4ZT%arRWb z8|bKy=F8xxkr6;=k^S`qG$H$`kuHq5* z*(iCc&^GFCHAP(?70wPP9L9b975j=SshnUusEcMs*m5Q*5`QPTv$D^52G!FJv+)+` zSs;7CR;Hm)QhlcGB~~4pjWoPlkM~Sjl|4ZwudV^K;=J$4>uQ z*vLi3weWn`e$&l@!{lzDMPRg52VU?D@dLd*xQdJF3)wBo-ki&HDa6K<()K#eGTT6< z;4XF^$RY-)E2Hn~e}fA2zj`}e7I!i~0cQ5AB)TQPVg@mc@c?EWX=4{*tKbTL0&&); z!yWM#5a)&N;18~{0QIBx@8(7I9Qba!4rDmad%F;{=(2{`3**C}62XWKwC=uK?|UT{ z#2S~}b^O<*FxwBNBY5R$#neZ~`G!Qg_K|9UC*^OHOPSq_@meJ322t9ajECZ8pW85q z|K)5Z&r0)s6ys*jz-h1)*?@RV)um2SUzrP_uydBwM4qkHa*yC@WDhd><}7j^2pogI z9UR%rx`*Y^YU^t64WII-Ts|o!jS2mW38)epstr=k=0q5!s6xRW#w@a=z9Qqad0Z|f zK9ibiFVg-JZlNeRLtn&y#{Y<2jZF2dHQYL;*FijR%Z1@DEjF;7FR4E^7`RzzqCS9F z8_y3Cbs-Pb0k5cck)PR_@>MF@oXP(nONo6PH`50xeL)gTPa zS}tcczfoDsH;|7TT@x1yw}n&5AW^YTA>QK~iTktHfI9@sA0xf^v3C>Y$h%;=lP-9B znEUy;Np1bS)YP%wzg$V=NAYuoRD4oeqHPS+6OW;0C^o1Vo$)?Tcd0Qu4vt2()e@c# z0##_sqSVyEs{BfD_IrUmF!7#tgtNcyeos8 zhF^jXTtIB5K2WYH8`Za-Z=%Vk_+Nll5nlN?COnHl`{>m~%_Sfk05xxc*X( zM#a|;)!uqwPpBeIWwW_tc#!`sE#(zyBQum9sx1w(5e<#hoIG=2#wg*5r^f!~QnY$L zYa3dN9>I>oXbuT=Qbj6vic?O zy1R&Ul^JeK_qBzW9Qkm7V+;z{^HFnLg5S*_)IZ5n)y~{$*f!`edYDVeFQg$=Fs?`$ z+B4WS_a|Ewj*~0M?bPDnWo#}kuAe8TX!$;u;iK(LOzxTdyu=5Y^^Fnw1yGSWZ}h>> zw7qh5oFLXFRQ7w~0NqV5g)Q{egv(aMo9c20TXFB{G2AIbgb&iY7_1%(*5&keHRcZ6 zSE~!KNVvq7d}Y)WEaeyRky`&yCM_fNL(_6kxvOi7xCJJX+j2AoaB29=Z{+l_C;7r~ zwp0`yBcG}-aS29{S2FooGnhS67++W|Pgms_&P4?1PHM6!hI>fHcrL$HO|HIiHu(%U zP75QtI@72U*$4FiJK1<{o2i~7gXSNO-Si*S&)^`nvl+v5fa8?r;{V8jXgSdmF?B-pLbnTn#Ev{gjhr3ffJ>5|QwJEHn zNuiFKppJ6Q5~u6Y>TrFIW3|x9eD7+;bu%iWT87PeRlK3kM?*aO)SKo|j_}MUW6jCd zD3bgU(0+ff~^7SV&*;@aX*1JKE8qBU9zC@tmk++*Kh zWrk>Xo^f?hP9yH%g~&E&x^|xvm7>VhCV<&?PZpo>OfLZ~;xsYQz6qVp-vaKV1!(Yr zf6Yt6H8W0Z!3+kmCRXdIf5DsdQF|rpqj5pq&#{Is&l0OzRrnd8nOqSq)XKOmeVpsQ z-r1~gB;$EPvgJo%W_czz`!!LSzeF`cnXr!<4M_a}o8mo%+WANMW`P2ApYsTxWl%Cf zUv#x}Omqyf7O|A_(tg3L!Zjl5!#Vzj^Z@5K)q|6bo!m}RM6=9(%s{mz*Hu{)XEU4Y z{{vUdRyZNnVha)%!gbtVa6g=?U-I|hir3!_`l92RzmW-*NcQrjzYooR=ggH6T&Zk1L8*7W~NV~-5)3bD}x z_@2f&Y|vJdZmGwUzo%;XeW!>ZY%9 zSOh_vU=`Kbo#>B{-s{cS9om0te`!DRsiWxG&S^1}mZzCV`NFP3DR3?BoHvfzrk930wSms}${b=Z8kI3KzK6Pr zTfkp}``||1qe=8X`VqA}UsYWMikJrD^p}B8@jmKr-)>Yq>o`{m@ASpuamqt&gj(9F z$zM@t_-gnU8|!k4!-cHr{4Aw{k+vlJ@}w)F)xkIlyLzkB%#Q3}!Vda!t(|QW-urbt zL4RyN1Aoe6`JHK&k(>2eDvCdl54F01&!H#V0TlCZB+Ig=m1($-?)2}_314KOverbp zYIlHx#!w-bQ(_M@k6=B%p?j?s#?+Nwnh|-CMQ1ccLgIM61=-P?ii#Mdn(5zeTOprEExeadk}up`iZYam<}b-^wlofs zkJVf>PfH3MR~+v8R-WStm%uIL-Jv|w9#3-|biAZH>s9b`K*RDv3;lz>+O$Uv5c~3z zot>VW1MI0kiD$)GORm@`|=lTZ>+~|8|$<2GeKU z(~R20x-8^9CM(v5F(@{0g za$FA{FzzZlsX^k1I1CJ_nLZrQyhx-)%yQon(hseZI(SDzIjUaT|Ck(RCz}9I!;a=# zuvG6!9Vc__FJw1v%6?P)WC4wQMdh0|g`4JG>UWETJ)Knj>nx9c+A%k$1=BqDaCSX* z8fs=7g_jf~H51P9Tu5Dv-lM{KJD85;-ZQu=T9|FP4$>37XS2r})%5x@1Iv(anab`Y4vAj~BBnQU z+C1UvZ^r34)Dev_E~@1drn4i>ddxBN6AlR&)>4ps=nMCeE$S*{Rmp4(M&O>omrM~r%kRNSBHFATJZ0n)&AE;K z9K9wu%I)=Dmp(gQv4>H>^q`>HTUpQc7USjZ;t;E&euJEe(&%Av0|Zvh)@FIHpnF;+ z^@UZbWz>%DfLz{I+9$15C=HYkrqEYFFm{lAvNukh9_j;2;cH%2KEd@|yfRW)~#|oR@xBfarmW z>R07Vay*)3tpT&)bCjn3*6*Px?k?KHH^BQ{+u%R=fZWMZ54BBN3Q2Bx>>N2vZKibO zwx(C3c%h$J|41Fzeai#FKo>)XkI^94S>m;vPe-%msjj#pc4!Br;yG@d1Pi<}*GY&5 z2LlCE`KW)0-XSU#cC3&4s?t(qPgx^f)DSPyN{CajDDBz>B2dlFg zg?vV=Hg-xy6hp8FX2@TaHmJbp1pmYfxhMgkk&>6z)NtH|7$YUidjDAb$H!% z1RJ=BqZDy8{V?AqxQ?kuRrF5}eCA7`83w1$*18GXVSVav<2L1_ed2kw_`cynSFMV> ztWqmDf@`e(2XCUr(iFa~x=KCA?q)^>jxb@2(_29Igf+;~+$uhi&UIcz$JM*!ACOW< zve)DTIUnFLvTQJeGm;AhxAKy82$cvwoRdkg3N)tOucg1IDM3qTUPkQqXsMstM@n5DMG?&AA7W+p{2WxYLw6WRfk!JR<= z5(Vz6`c|@Y+yOqpQAuxati|okBHVc2(p_`uM}pIz=ZIc0IYe3GT^GzTmOba3R}g*X z84%&wA}lkzaP75sK?z*~@zQBw51x{i2Zv_;EzP6v5qt3}{+GW1$J8tB{R3qtb@57O zk+h{o4R|8_JUt?Gci}=$&oL>dnSa@rS$S#V!UqqX_WUts~&J%pF${Xy`-}`3~*9^b-VFKD}*du8?*!S`7&t##kSuVSRxX@gy2XR%=U~cC*gr4SDe}}wJoFDEm>jbarcIq%{rc%Ni;)jta z<_EgdjS}0UYobefK>udm7zOq$Zi<#b&T3H8cZ7+QPvM2m-pMBtZH{!csq7C;RwA`o zN{zrj$`U;ljYS({4_fE43W@jh6XH;-n%LV(mHu_TYj1&eBS&Mv?Sr<7xxXPOyYzH@>c~ohl0arUK zi`#m1u)3Xl#E*Bk)Bg?hS2Fl&TnoyXQ=O^r85-i^yB+0-IE~dW)7Pm^S~cQYVln=J zInePIc*wueZnuZoNStD-;t#H=!cFo4w^As}Ek>WTw$1_UX!@-o>ethzI2-tuC691V zmMZG&V7jdlc|c7P!?}FvqZA_uxP*C0>!qRKIq9KqxcpfC>G zsJnaUu6XppT>}EP9i9fpq9ynV_WJL5Yltn>Io`W`1G6x$Esmn%sL!5O;(HL6`AfI zYQhex8tA6=jLG-ZKuzU2DAV=QsBDyWpGBkSMqoK3I#Vo-eeHj&%&JtEhdTiaI1dUCH$MVjCNu8r^o!Vqd|(58q}c z@@2SJX`_i;bPf%n=1~2>2S#CH@GqvZc-;6meJ>jWOK~ibKFV8j za_O5l27ObCu_MKNp&lFWZpIeC(+1ExgGIq?bqHMwuL#aZg{6x|Xhow&a2Ws4;N^b< zuVqB+!_DRQw3Bbh9YuHew{jxoBCA`X;b%?-N1AjjS4CSy%1u z7xPZnI|j};UT8PiaL@4U@@T^D?qmt=3Fk1bqZ#xmwxT?rN(qtbDZJfS3!^>VotMSp zxj0nO`gyL~m#eGUnfiiIC923LpsQ+Sc0_QWGD8}OIrTX69M)Ces?{@&NR5e~ey=q( zezw|${x?_||D;#5``n1hH;15z?6-ktAWp94%<*rAwV;i=LTyi~LLK9ODz8aW_yHah zEj_Ca7uN@~7rBPs2xcALj0E%q(ZVdK;Hvy}ZNI*ne5{uv&W5OWAGNjo1(9qi=eh`= z=C`8#dGKs*d3@VbA@4q&3!W2h>P2V@|3oX<@E<%GBk+EUTopiHTH3iGPDAI4Joz0=S(kRm##g0Xr6j&YpgFBb!aXc6_j^ z$6r;;5%2Z3UQG*dLlX`LnzE;jIPn0;X5#dR@F|+6PBSd@QDtI#>AswqOetF>_Yk3U zaDf({EO}ll$8)bs4=9TtA0Prl_TQtLWzU5#osD5*wzuQ9|3Cdyu%xH6FP47{%Af$d zPz|&4sKsOqIUSEtP7?3wLi&h=(cBAI!S+#VDLzQt?%d37R7cw9L4VR?y={JTcJ7h2 zoDdffFT)LtOJF=&X!cnDs zCwmDsX3y~Fx$@5TXek#bZg)F8?X;JCaiPQh;zYc?sh(>5HCR`PpQkTp1KspBqFfGXhC#CfR<8loMin#o6m+GbV$D45{UmBy;T zltYd6M|w~0G{XhM)yKXX)E08SKFfG&o{os6(}M0BRwX`Xz+qvr@h`L+zz!d+y}O2KF}5)I1G!ZrW8T+k-m#3%>Cw3vXuE* zED+bj^+aj7+1X5~h>AMS(|=Q;Bjat|HDy{2Q)9V#(kgK3-n~ zt{Oa@t8CX7qq;CUZzR>5`b6|IN*PP}ww_n^OZ;07G23ngZ<$%$2@JWb9PR_ zk-2y(%u*`zKj3Yp8B^DpW-gO!8}+m%Oc{4OykOTC{|8W#+hiU)QZAV4EI}*_)vsS6 zb>^+PQtyGD2J3?wXpN^on34F@Yq%?^GuYO8U85@B-_J7b_@A(zqdu3fXCHhZM3N8b z5_lRz1iLevJ(BObD!a=tXJQneDACfn#iY62skZwUKBe zSfFg7-_bF#HQg$$pyB*m_MXxf3}t$H>VesgcG3@Gij;#Od>TZge(BI|t~CD_eT}(pw^L&_w6Xm#XQ_L%7bq3p;TD2E>dL4ed@FM^**MR~ zCgnxc>`HxN#|Lxu8@M{!uO`_5Z~7&pS*RaP9oePOQd+A-TMW@y4}dRxL<`zyRaZy_#Ihnv7SuQtn^fR=z2d=_(s zE=OK5IvJ7VFRB#!jPLlisWXJiaUN?{{(Nx;Y|A7t&58d>8?7>^ztMrN!?#Ij=C7-^ zCQF;`*+zlF+NHcAd=GaCs)ZSqu$SMf4bWmd6!RZEqK{Q8NLAc6KGn)|UBGAU&kT+2 z!f)1ESut>k^B@=NZ6zv1e(^AX?4n>7V%* zFcz$_Bw>s@UTw&9&l}Fx^j?A6qR!bo;urR{Zv^-mG9^YEqdX4-O|)~)2O&c#OI#;R zH>Z$8i7R|5&Ln$+8EQ4^COqMMA-;AtM3an6I!)z5&SCwu?b;h-D5&dAX*^5)?S8E0 zXkRcCL}M;KhWl&baId_YbOck)jWB_pMp)iIZLQ_R1*-E%NvwKEok z<0pJY;+k|;;L#7)Eni{zPpSp8!@e@OL~i1lab$tqnW?DeDpUC9*$>eS^gp$=IwlbA zX(qPe?C^|zC$rDj*qZ`f^m;kNzD#+<-$YsV$>eEopI{dT@tsA=*3-L^LXiWrH}V-| zVfQMcBRdDxux5l1_A2tIHy5Yq`(%gfI%=!;Qh+C0k@q!9OxD7UYvftCqQ;O*;B{xM z5PfZ{zH~}#$c+o0pze@^&9TfLqFi<<^F46c{k~;JxG_gdSF5QdBX{c#dUnDSavV`f zKPe7lOKDwk6|FtoPK0P)&OR=cU2G?*k92|ggL%teB%8|Pa}4_;Z4K3+(KJ1axeub% z4#G6HEf?i@Wj+Iuq5PTWb(*c2(eB5b8e9lZurKZXNyg_z)#~-i-o!=&%i1W0c%xV~ zSin7Whr80;DdJ+-iu(<#CeK2H%x8(~LK$EfxRpGbeacr&uo<5NofF?15$=2g)#C_>Y(josw45FA*I)0<&6ME;aQ>k!^wsUD#?QoL3<+g&ib2a_**^RyDtr@vwrQmvR@+G(z6-Zhk9b(gBavV1>oIw#_HMkl?RuoG4Ao))Gg zj*44o+L>kcE?K?Qk_lyHgIS`L#9taV547fo`%1HI@l3eEcxH<7H2w%|DR0+f@&+Hc zC;mv^4=cM*@?L`{>MOt02hj($uSRjsCk<8Ck?WD2CxnjPHA-#kZ@RBummZIti65CZ z+$%FzUQe}vy-+>+{(+2`Cuos}LuIpON52g&cbDSpIXeXZ#@l2a zMJ;FIbFmsGc1KlUX)-$f5Sy#YhR1PEz8|Q{9^%r;{$2r;4bk@Yj`Bh~ZWwWl-)L{l z6n5_e&(z&&VMnAjK>w~>qfd}F^n+34zqv(5C8JPm88iaUW5d*K?ut!DYn`p_s2d&t ze-Yba#@ikSi!fExK8~j>t#x-S)?b4<;saCQPw@dx%mIOEL=I>MUgC#vd&rBK&%{fQ z)e_me$yum~@{xFiD)E#28}SFnDt$~iwiT>FsmEa#?DU`U+Zm@8e}z7AD%d{z?iC8cAC66%3uxrKSs^#$yKOTG}(12y5oV3J%553+_hU+;Uw-clFS+l)8F zf#51>ownCLL(Gr_aw@JX#x`(iWv!CZB8pVM=KOEhP$Dg%qqCS)NNXVOP^+N`ZgO%p zdJ7r|On!|qgJ?%y(DxZ#g!SYuxw*5pj3k^hohQMSL%Wy<$}>D#Z;O+~4)o!Wue?y` z?oSUCbywnF2p!Fdybn48TZI>DBX)93Y~Ul`o|g*qnYct=qY;SS!!P(PD zVxqx3vXu^~ew?KpYtlkXfwS4~o_q3vyvB)L&7GbiP+_C<&g+ZJ+S;o|PWrw0(k#dF zT1~`yYO3{3mon7WGPq2*7HCDR6iVo4g$!aD*B6#&%(&Zpx_SasW|yO_%s14YNXBKj zVe%Z}8$jr~JWbbZQG6hwbcpq5fPdAUcuZ1Nst47d`ycx`rzpVEUU55nJaa7B2L4b_ z(~pvVkbRs3wXsoYB#RF#{n1YKge#Fxb`~>kz%D{CcoLQJNT|M=B$S3};&4^qQ}WVr zF=>ao613#X1j1r^;SA$OYR5eb^2SYLtjLD!$lBa^X1wD7szsPkq%=ncbm{-v zUo+FC`_fb;R?mYNFXh^!e(V<|L^s$2d`iFti?~yb%WMZBs0HJj@rNDW+;`cnj8RHa zHU&R}G($6+{S9z;eYiS`FEG|oE9?h}m3VUQ6I+s!O1IHl1xMtzrA~7r@Fahg;9383 zb$#qPqqX9KbDSMfIKRYI22{~XunVc~$O0pcU&LpygS(eiIqNbQOgscV^)aqYqc`jb zGc%&ag}y#}xAWJWk6^abNPEdffG$coPg3m{4v}60vVy0ynr3@pDjGx{RTiV-?tMp^ zQP9nxrAKN8Zfb?No6HcYjkj0sZS!5AhBO~uz}2kIV*AE5GA1WX+5Ddz0o4r7y$I*S zFQ_N?OKYfJG=IZM=0a(d@pjL}v?6FNIAr8SrO?RSi$>7X1C>$VWDk2@vk5h=UDP$E z5n8DZa*pC*$f6jQd6ak|KVfUhYKX;cBp#NNmYAmxcNw>Lxps3~LB;uKyGn?WR04~o+|8*@(-vjTWRovHzI6TWkfS1H= za{`kpKKGu0M?;k-4ZTK|`doXd>tu#rJ9Cc}@ElEsp1FKYb{~93)<(HjGcca*>E0c@ z3m%wbU3={H&8tj5RD+fHVN|=2chbkRKG?vFL)Fzr!30mXj@5bADspgyDy;LN^s7nL z(d0T`!8z>`*+5m4;lPF)q2`Q0^Y(MWEx0Z>l{7phwIv2JQgY%!M)(n?bD+3+h2+fL z{C`4KPlPXw;Xs6XCfG)l85nwVi@HwZ-#n`|k1gT(g#Tla;|?`~ zE-6(}-dfeQ#q24MteupzL0wxU{Vl1#@y^VMPci<78%QO^p7=lfOx!EAgJa|txI4LD z|Lxt$bd<6}-uotQJF`NKz}xfz&O-WTJPOaSlEk5|8ZZi#GMX~Oi22svcnTaZ9j@CD zyR?4UeZ~LqY_n;wmR6J~%dV1!d-n%H`V$2*Jo|$iPbI*7MkWmMmh5(BfyyD4L>iT$N` zKT}XKTEM>2dZ<(KE}Q$rBB(t>r(Zi*QC%9j87)^o>Xo(d+%(V;CNU?2w>(?12KsT2 zUAc*c)F`*=tmAv*mjkckkkAU>HCyXhWX<3b_<`$=et<>hIr%m|4K64A78l^7c#vlu zkkEHfR`2EZ;YLwuR$rq#)Vv9J5mShGqqy)A+C{H(ncklCCc(x3a+L8jN7MNb(W3LH zufeZyHWiWhCsW?R3M5;Rb5&@fL+Nu~q$gGJ4{ zL<-+D7%*$lHSDMS5olnLt*X<#W;xaR| zX2Fr53NZ8cq|=Z%kRDv3ESVU2@`XY9BBL9p$R33+yN4X?~g= zD?L>g%hhaa*<>PGUnw?b;*y=_A$pZ}lxtOVytTnINSo+=WrQnZ*`Io7H_SfF8y%{k z!_+9dO?a#3k-for_HwY6?hHJ0CDG;VYqe_Vf7s~(p8SOK#B5(X?Y@Ziw`7)%}&-a%$7x)|kKUr{|%2j_+9&ZSgK^d{$`J`Cln#X~%@y!!=nKd>aSp#gAQ zR44r-`PB6q%_Kv|Pw1AgM@Xb->! zMjOstSC$hCvfOeZN0*p)yhA%}y2;@h${y{vj4$plnDn#^z9Sa6VsPPL6~kpd)Dfgr z3ApiK;V9Cyc6=d!U#*T7Z4AS^5_9=sQITMi`!e4UC838-DiG=E1j0lYtbj@oGv!`p zHFbiy6BRZxz%VqCxv5s>yOF2&r@$$&xUekmY-sYG#=KFc!B)a|!o{b9;s%i4X&anp zxR%mPeP!-iq>~R-uXmtj6LxFa^dYT9Mm{@8s6mfYTbu3mn5<`Tx3_ds^T=9kr6xPF zdZu0V{KWCF4*2UPRsE} z+(J?EE?TelaFyb@%$bh%Xuth$2X>0> z7Jo}MLg&VYF~4=Y)n1AuZtGd@nOuFXj*^xYDU5_9I0!z=#pA1){nTx^L5PXA@YFHW zGFo}k&_Ks2Ej)T!UIkDBCMUPlJFxwwHG$P38@pymNT1>835Ude)=AyUUZF#7qxLqd zCH=#nY_#<3#5S+O420pz1|yz{6b2AAd=$E&-jg@TtF%(F8b}dd^7WO-oD_90bJNu( zbY4LgjEoojPz%B3j2D&-Idu*A8ekavHYyUGR_)X~wHlPIFm`p|kBsB&0coxjA^!+a zRHjhGxRBa8xvvL?_WVN5En9l9Zk_~AxyOf2elEy8Ep9YwI;;>stDrX2*0F=c{mdJh zg5T6qAmCY`<_PgHmj8xQq`u~UTrRn!{6fGu$y*`ByFwgA`v=ikcjpo91wb0=XE z{6qh$1GA)Kp?3z1h3#y06CP7r(P;B8@uGc${$7d+7B%+ZU)ls}5L$%V^ZoOZefO1* zYUv=4UaC9vGMSf6HLm^J;-K<9-68GjK=og?6 zpPjon>$H2EYhlRho|ajIi{#gPXY;r0>s%GQQ>i3kp%Je!S}&y`a}C}X&IUff(ezQJ zRrUo1siT;7au0G^csL@=?Kn23m_J&`LKCg;@_OTM^%8Li{NQGiweZ5-f4~aNdwwL_ zO32uIRv!va*}J0?%t!z|+oc8>wR{Vxl7K*~j6F%+z!A9zD3W(sO>@mO^Qgc0f$lHh zomox4D3;Wh!PBwV^~Th!*hh{Gsdf5R^g+KR`ou|+K@U)S@e`Sap;-|}s-O{@3Kppo zBtN>A57d8*3+P1X)V3#Yo0#D)i|*P_uS2eT-?}Os}Ij#oN9#&-gut{7BXSx5A?sC^?CN}rb1osH}KebC@ z3iB{`m@|!wO3dd!knQOK(3NzaPLa2(vtgX(${(qhVv2|(^uDl=l*c~+6hB{XCqB=4 z#E^PFhRy5aYzHgrqIeS~==+X*MZ2j&;wvdjj`FNyii3=t4Jh0D!FM8ga%lEdGdK~w z@~lk#YCJ@f13hv(Gt0H3xmWo_G(kQ=ZR4L31)+qO2yH|yV!U%dNHHHt#XU*d7WC2> zCOi=bYKhc=K#u=GL{&J$PdOe5Cy2xLSA2aG2YQMtVI*;fsSRx0kYG(=SXMh^hkKfg z0e7N?QM2=YC;bbidYWZMhYOj_;9SbR1)2ckoW#>&Rx?Tb^=>AVsdv6|?-00P1|5 zWi@8B{GCqaAIixKF@KFSzy{7w0o1?xv`kQost9pLIW81jwvr;IR zNwG~Ed+JiD`T#aL=aBwfW9c$rEw+a8>TTfi#8zM!+$8f5leZ z9t>8By|L2#u*{wM@+N<%NJ~TQyK!IH9vo&*&_l!`KE^v1<(hYx z%eJ!o-W`{iVcEc$Ltil!%w%mNnV@|IPWV_^=1N9#@R790Jzji{1nyGAVE(4%-{3F$ zndiCu6qMyYL5(ZQKE&o3kBuer2Es~|1dQhdoF{H%mLz1t@oGP7ujEtYHfL4qPzz}c z*9zIsV4j*T-+&**8rXhMcowc)Q4xyJE*U-WiOz7P4YwIx5K^UN)X1aRkL2e#`y_nL zZXwnswQqfF#x#@A%-7Sv~dF`Kn3^cZx7U`z%9NBv^u z;AV3#-ekXMj3JJqPrjDO=l@rIrZl6^(8-P(OcC&ce!^2&Qt&=#OUKZeuCqB^h)L`} zS`UMuw(>mRfbWCX220g*>=UmtVRv+vrZFz5nZxBt6XVoU%onkL%vKc5p^2WaECgE$88i=#A{p|Y_7e6Wl9;*lDc=BtHY|=@uqhC))>1}m4Keel zz*N)Uyg_3V)|z?n>+uzp&sq^!YsU&`Mc?pyNzY_E-^gChQdbVK`}Da`Hs90#A~l|{ zj#L^uS0C)d^YSO360E4?z|z5%aShal{t4P&zB67=Y_9*-mdj)Ork<^wnJ)v{y z$K>6>zkVNF#{uP#)}1~_Z6LKkq;XcN#T=n$8WX`R5DoY84*r0$Mv4Zd!5yhG^%Jk- zu1>E_SSzlM>V{YffEx5D{x~Y<pLL+ul=%ETPj zi6)aZts|scrXzBdnCN#AT+pdE%gdEEMYo?#h&V1J;CISv`yNqC`3CACyVf%snMv2CI+tZ52z_oAE|u7@b&bZ$i4-# zn8)@u;vf{4KNQTxyv$*79ovNa22NAA&Gd_dUV$5E3T)+g67fLYE)hEKDWgr%ba>Jf z0;_qyQWM;j{PpDrMg!UaW};pF>ffzFy|%6?6NB@Wjl3r*>hMqSOfQBt(CfMfdvK+U z7$VNq-U}(jL$$t{uboDnP;2HFr)=Q3vBmH*>yZ`zsomg1jhXiUiQy`vcIF%L>HgA? zXIhzh5D$t(O|q;NBI%M`1!)J|neVbwmT_Vk_DjJ@FxQw7JDR&U^3)40VH~0{}U*Lo~o84(0>-8}$6AUINrlR^kaXEX|y+&WfM{uv=TMG|RJ7ukw z$Gf$F+B$iHu#l)p{N?VTVbp#t8~F4)=BayG2nlv}w%zwto@gClzCQ?9KHb5#PGqz% z1tk^Ih$h4#3~d^b?pxz$MXZ8raR2l6@^ z$yUaQbry%Ixv9n(qbfB;Dwkg$6{i*3ApUH^W^Zw73fdvxOWI}J&ONP{Rt{jVxEd(I zu_=m@Yock<9&&Z|V+=DhKtG5W+~T?z7_CT+LyS?8`=a8w>y7}}ffXgE7o-Ob@Lc|% zF*NW(P+=W<5futrB$pz~Qlr?mdKx{0nv#Tzzis<1uMy2h(W&uUou$|}dWh1%t-=j- zHdoVO?k2{Vy!}Zbd<5zRp1Zyq9!b{!h3io^9W90`gX8wn-I0YU3M|R-j*WN}vyDGj zP$gK5cn^N^8MI4}lm6rONNp_5s2J}{<~Gc*W?7=t+Qx0UH#-a~6+EC!V8?6Um?~^4 zR>G113;B2IEOsY$Rry=B+0p%}(#&h-U*%sdQ&u^| z&(tR4bD>LKE-#Q}V>>0VQ?idX8^#J!Q6gEbzgRt(-GCjmcX_|@r_2X77Hg-E2Rb%P{SP+tI33CO3HqCECqj7-u~qq|WE#sr zG`_6PP@{~tI53x;CUTi|#0JjY%1it9%o}nSeP_%AP{A32x`T@K+d5~^L(mF#e16@q zYcNySnLVc0>n@y2o{D~I1l3<4K`)9`ma2m;1(l@*;1(U=Hbq?^Iq+Y`IdKFtR>?L> z^MBM^;!Gh7g!mPDCGEyz@m_Ga(o#%CISfji8eg0Jk^iqJ2EPtkkoUQV!SPgq`Zsw5 zs;^`!Q^}EHMnY|4FlX{2=$*v9f~Q6)-5Y2u&dRB)r%^e9UNA{mL5~4erKw3oS_xVz z%_Sd|rb*gtW;^f*^{HX}Q~HcoR~orjiHnpN)Xv_CIG&dvraF45VfH)r&Gb?G1&C9o za>VszPhdFJU*EtCat}qko(~_hXN3Vw8ke3k-*`jM@cku#9&g*7bSI~=vVu!-x5tC1 z6h0YD;XA03zYN*HG7HXDW8td6WFY`9DP^SsZ$nQ8HAd`xV6?v>Gcj?o&>59d-RuO| zmpx4lcl`yefS`@kpQ?3(rrEeXb3WsGg-S+4iI7#9^A1?vSOZRL>ogo$-oZ8BHo101>kv^;c?%Nz!i7i#wlBKu<+G@EczJs5wt-}7x z=}PBRaXkeD1_8Rd@Ph^HZU`7gsUI-&d~iPUyNRb2eDl33QPn?!6YTb*%oY; z5)(#Q2D6h<1!1k6tIpK^5oWNvl)lP5eHmSah(u|oelc7gN{rUJWYkEn#{T0gsogON zQ{~YeZH7L>*~QGXZM0+L5@opcl=PV$Pw8rN&7d#T5m8o)C-c)3>b_%{$O4-|KMf87@9_2npjZix(%iUe=$aSI> z{(Im8xIi7E?hD76>%vD+5&-~ z?y;rZH4~mXTdHyTIM~By6$XPQUR<6d8pI=(HMs{4laCcW3F;semmEC+9fVbqOa4-uB|7PP-ZeKPj#J`(3YyqOhsF6C4d3MllJi z+|RYy!4cX*t~L88#;Is}Y0zJ_vm)x6ou-l218f013GYu%0sP=? zu6x1djy>{3V>5gUk231P!5YUFJM>FiM319hf;i_lZx?VsX^Z8P(2U&TA3#^~eFh(~ zu7P$^5#JG%51N5A|HlSHU@`SKx+0un0O|-w`Jd85y$6i|a&_@3h+iNR;Q- zD6y(cPY4_ozN$I)X#WID>%18--8xk-a0Ab2uoqzHpy&JAQmjC*uj0%*;}Zj1eU3j#aadP(NWNt zx&VKfd^nc-3lfr!iRR6Wy@1YImeI?NW)7fsWEXlHIV5%wJz3q05BAl9@5RxsGHfZ| zV0(Yd3;B>|lu1JyuCJ5g=&k(KjK_Ru^`*2bygD_=R1r^vO?9qLWd{Spm&MO zh;^ihE6K}2C@Al0sC4p{=NqGFydz_CPJz$)2kPL0F3NQn#I_d9r}n_gXdsmhk`sQY zO}yRpQ;u?c16193PUXP|k$r-%m6gLL%ciW zIf`4|rS#Sgg64n;WGJbe({zeOstpAwH& zO|Uihe$sOTGl>kV11tk9HjS608C1>e-;9r62)yu{_dd;%9k?#;8Zzl0q-UzdRotH= zcJiE}hVuq?RQw4V2&=htZ349b?313+l?k1Z*fN z+DehFQ4`RS43o|)bkGX-VXc4&Snqp$P2xTCT)p_$VQb+d zvAVKC9jR_}7-SdzfYh6MYLeD=#N6_4Q9Iek>!Z;{?g40PS!5r|gavnn`n1pDI*O$& zzI_OyM)vO7PB6wh(b%Kk@mw|&&(+`<>O#}q*cdBiHH5BORimt#*LDZGN>_uIVZF3e z(}$Ykc*~8{W+jz{b3s$=J;sZ@_-gthtfH_g@LMQrJ)049oJQZL-a6vi zP~Duy8}ti4#+MZExQgMsq|)%t!C^#MDgynh_ZKQ+E4b%A4W<$wv=Yn_Asn@5?@(pI zpI`zQp>%}pJ^!(*K^LwjSXG$B&(7}VUP#xoBmAS%3xueD6IwWHP`yzW|J>h9nG^Yu zYhj;Q+D8`GolEPf`3}Fg;wE8hy=s z2mk9$qLw5P`hMEw`6F~z=KFS8rs#X|WojNZ(B$BDM0;U^wFy((lF8JtTAj^ci9i_2 z!hV{HI4c}>Pq9w+-csw*A#B&=Z>ny52brJ=RmXV=&ci!!gK0vk>^+c|td!Q=+yxq# ze78k>N$Mh6X4|2)BX+}PMmYUhf5Y!GQh}RFrl)#Z#hcj*y12B1JFYZ~wZqE98L5I> zRI`;9;1c%>ejzJ)2C*Mdl;7=wYKVEq9;FloE%_4KmV?zyx9|kDIQ3I$$!B>-X<6_* zsF8e#in4?#wDJ?#P=VAzSSI}pug_MfQ-Ya7+_M~47rcY1{-IT;?5~0tqwe3-eC75xJ3Ob0kodk5AFO@JXiRDMY#J>}^cW^!tnF37(zT>4+oFfQiZFz;wiygj-kgepJu zXhjm5TUvuhq|1LisT*-u>Oh9l{l#KnySiSwk7zO(4Wlav7N+BY=;Re#12c8;&QxTl z>b%x7$ZH{7sP%~dv1OD!#oB^ifj(H@;^)9jcgwrP|J#@AvJ@7o$ZsrW{l@m=o-Ul4w-) zue4R;eo~)^_hu@~Lp(7m2KM*|S|@U}_7`jdjs35|dF~Wg!(G{@(-r7Qy*1O^W(RLc zM0ZkJp~`?(AK3;vcS@|cg73W0-CL+k)SFg@OkehnqqSIzDWUzq(tQ;jMHt=8K>Sh| zZdt=CmKLrGcoBML?krDhqCPixUo8+CSvKh>;gkk?F0{-S$?n&ZZ$AmnVK8UVb=|CNz_?BXOu><@)y0i+6RrxRI zQ0qhUdw+{N#%z!7?C(NPC8k+0I7Y1shVqmx2}UTGemf^#7)@?~xq&gBsQe?jKdFDP zP8!rIIQoM$CC%9A-^sV&vV5DUF5m?y5(qa8xIt+k7pI$(6Kr*ui9(K@HpaWEv#qcz zibtEKFThWFPJjmqQS>0>|0z>MULHNRro?MH15PU}yEM-T>rkOXP~` zG^H8&+wxpoVtCO!-vAhHdRJQuCG>cqEBDpbC-BFzpL&fZ=|9;D+Q7WQNvXnRDl;;c z^m=b|j`+Re3jug5NZ+ZO@-+@AH~km2wx|MEuV9jUycDbW@O!}R9>jKY{v}>mHnNG_ zfABN-&-9Enr7yCJfT~_6t{HPgm7Yf(G(STzHU;z`-Fz+WGgdpJyl+439mwLwt7TCv zS3YrG;JbaBR#S|m%Hd_jW7cBqjmCHLor2M7UE{3Uo-M)~ z!9@gT`kKS&B;CTrfPRi{C`h)^gTAfIFuFx_4tY(!D7=HSj2YNK&@lKTsyTB6d&*Gg zZC*a;lT!829y@ePOXl5 zU!6y}ZGBNQX0de7=wmuFiY6D4-n+({&dh0sM=r{?HNR&?c|@5GRz92jF8yLUa)QaF zN~J#QAEh3FUqVYmSKAcyr>4mc{HhkduWDcjyq} z1wNxiN(90!x2zwNZVQR}3!zVNtzK5l$4j|3@^y8>a)W6rG^ZLmkNSR#d+@VZiR8q< zPNgc{v%skw&xr^=it&OX!gP0`5am836eBx|!{G%p(boh{WLv@%!X-0ULm(SRToFw1 zU=|Iu)%pFnHR^Xc$<(C|i6evmaZgo8AptVAk;XQvld#?X9AD((cpk|UPGl5-LILx@uH>Mag0ltaeUN8bi$abT+x{_(M_nz`ldek&ZA1cg^Ijnai z*2hkvJMy(aj$Ma~9KQo6Oje(y%vW!t5dDz&-8&zR#EyD;(@-rJxM~q&8ydd?6t+$r z2=7tls4)F3b$|&&RoFV#;q9 z>J4gF(8ew>(~L!h^>lJ(ef(zJ7xsd#(F;K-Avs9fj<7cpp9pIppPZ)EaPA8F@Nrc8 zzC?XgzG8mb7tiY)?5oBoR07!IC7K5yFo^Hj3(#+q)d)kJC@ zm+UB&|6RPDU{?mBD0qVkCA&foQ5rU~CgTT^qUbBYE!*XBdR44t(5}Z*s~L@+%Jx$C zGDVEALq{y_?Q!fw{Rh9nGh0vB_u)a6O67@eii(Er4@Fa0<671IE+8+!@r(hu& zBG(EQs@4#Uhs9}0#BOI*;a|E0OoQRm>-q` z`UyH?&fu-VhS6WqakeuW$yN1MV{S9!VV1BO+d$pdBOF%iVw23@m}w>bcFfT1+5}{E zM3Ws7FVicrxwzt&Z4S||r?TTQS;@uQQAa{4SGsGNaF6W>Ur?)A77XSaQ9iaFAHmke z`k)KyRWjY{+?{ z(vUOEL{Pr-aE&5jim85E5K=~=3CdG}4usL9KGstfnmh<ZK5?ADl$Z6|pB+9L{P9E023zT*P{tQvdRypA{3dU`x zuQ6XV55LdZo_lU{Xp4og03j%7Zp=<la zJ+T$?ZmtOTm)>-!I9o#IXNtfhX zsI4W*2jMRMhxHN}#~9$7T8AE=XpVx=bH`DtgwPJHG3*o#OBI?lY4f{1%HEYPqt#4I zR_>W+&sg7CM|T^O-@;VZ%aMJB4Ro&9pJ|TPiaQ-6=wBvj<~RI`)iE5R&+-Q3fHzuI zoC!9z)ONM@h0+mNCF-15UaZ82CU!%Z_LNMAKXRYRou#fGT>qB$K)#-9x;lJaul+q7?PxEy05-p2ty9E+h$}RH;N~xq6wCV@Lz0> z<#}}U&c2Q;J!Q{C^viBHeI5&JWz*xNIr=kaK*Tje{;A)gUEEpqFq4Oe(C?{R#1;p~ zouB|u(^=L`&_k>dOv`O;O!Hj97yEDU$JiwAMQyb?Crjl#}jLPD5jiPXVe z=&D2XkUA*3l)~potI%yv9er=X9Y<58t}tEsE!K4KHsE7%etF&QB&v% zQM}_Hc)_dqgVLJxXslA^bP+L^8yf5j24Z8oyuRqcp9(~g+2 zFh;o(OhH>vjx$HcsAEc3+?RMi#;@x1zlnG9Pin0ahKOI0hQEO``OZu-xj5eD;_&&zmh>6t0O>9$$JG`FnaW5aRs|KaI*~#*323%gmKK8B-pZW@mV#2? zecVM?745Lp#TgqoC?BOCP_6lEOewVu@jR9=iM>08ZMH(a7LJ3WoL}tdSi$^adt)mB z%h=F1aFv_uJ|G-cJCdx;uHO+WCQKua>HdNn=s(R|x}ZJYoh(kZl!L@@ahlcxZBe?w zCmAJ#5A0|50<(@wLf4(I-P=%>6jHaf^2f|v9mX1HcbMX~e5MCVfP=vgsfN@oVJo%M zQOI{=<`N(ILG&fqiK^@QX02*p3S}QF72|NW64wHSp|$2jz)r8m`r&S+l|0!zSqhzH z_;vVBB~6J|25~pE;YzgCpZHB(n;RtO=#1r|+yk2mN^@>og93_t01Uhw+KP&zj#xc( z)Ke_}2DUA@4G_UCfp9X-Hd80eo!2RH_%B-}P+QE>tH(E2mu1>{JKCr>D%=IesVWM3 z4;kwMtx&Fij$UX;$<489N$rd!1qdEXDnqrmp4S^=3*ZuSoBeG%l1>X=y1Ly(a*X7R z67$dmwsO=|uAbvMp3L~+5;O~D^PeDRXe60Pm|a3c!FBRAb!p!T?JPc(X+tbVkJL)A z16Bk3$^SvJR7xoUvY`#D$=5^`=m4IN1(ch77wm#EHaG$*@?B6`9mY)xm%(GbN8DYx zS$;HXVNXG0&>RzNzo%4~Oesz9d%pTaWfY}0Co9@l(ifzX)L!p5 zXjnfooopimjBgJ0N$(}zWHLR;80()1-zig6QuxPlpIYaj*ktUtuaGO}IPHv479}pE z`kI8@rf4?s(bobY9s%w`v2sH&1}tM+nG>^!z<8#&aGI_dC}D}FCcq8MF8QH0&r;G` zomvQqaZ3OX30BN9l~(ym4_+1{bx*%zJv7ZM8R+Rqzgm-*o3L zL;30)^|p+$6Px~(L^MT=@HMw@aCFDFg%`(T#go?0%pep55onA1qsUNGoHjVh*(0eq z?5&S8%&skQBhed`H|m0~@F}Y0t!JbO8YtixbBv0C^xM&gYEcS_4a`-CQrVWsm`P*2WVVm9RH9y@Y`Nu-RR^Tk5-po!Tr=$ zz@o4K>>(aFcP6|fAG1#8r?M?FvM?flrE?_kIR@!b)GriWc$exc>`%(@wJ}#sE3>(& z9u`L()8}AyZSAmmSY*@L)Mfc-;G#Jf9mf}urs)IBx#b`(T>X=M&A)(s=fA2YqoQCA z`^|yC?aFye{xLFe6Pq>~>!hNVMwT3%n-XaDh zoCV*viQpT5L7c~QqxP`5NnU!jcEV^2BKgH=lbO4?K(%EzJLej>@zKc{ZLoFzW(67C z5Um~7NmukgdW@QC-kUyiuet}QM$IZ5KNFHWuIDyKjDeeM>v<0=e#s^!c_8pgJSwLF;;!Y zlqc(g=X4Hut0kc?%3on`^4UNe#XRMN^7Q^-3GAa|wszmI(Yw(gI-5FXrlgL`K|Qwq zp8S`3(cmqvC+Hqr22|A6KZj32IkpB)-gp$)Ds}^H`Es5-g)%2U0=I+d$%ImGm{ZDH z`nvH1FJd`L?>92kwZt`3CprZqG$%60WAYjN7X8(3aypaV=C3q!-WITiUSd)(*K&8j z1nseMR=%oFWd|}_+2mkj;a|%trGPH)aDyz=+N9^y<@dmjbQajkKg_jK2h~}Q@6>Qg zCFbg58?8Y7qy_A6lO~icS4+4}y|ARS59s6cK9f5*%HJ8kLb=4Zcn$EK>;hld5~w-+ z4irQ;^v={aA{hixHFr8FVAr5u;Gm;=e7Ia0t?dUnegRq=}89VLIOiw@@N|J71#TlMo zO*UlqIj1w5KoVI&*{82Y*C;#u!%Rq=gp%>&s3cW_Y7f@iyXAfW)ujpS(6D56SKLJC zQJ27poMthEaM4L$^)yGH?Pp?l`k&$_xgPpR@|NUv#06FUiS0SQz%kLWlkVdw#5Xh7 z^~LHn;;tB_GQJ+h65Do+rXJfPY<8sYlztiTVM*^GLC1HrWx_&f$DOD${y2wJAy03Hr1L*pJN&cm%6MPmp<%u(IfKhS1 znFZ<=p;E>&q{UD7b>Pm@<+<|m7SzwoOJ>I-$1|Z6ILeJACyB-20KNqy_{v(3M%sV{ z8U@v1q(mQRY5EHypbLMT{Y&g*eAI?xtA*;YiC)6_25t2pM(^!S_$zt}=2EJfw1y>3 zOY$ecRG~~LMwKyJ214Xz^0`=Zw#n8-&A|_J+g$w-)K2n_T7w%#u2HHu7r{LUcYKh4 zYK!Tm)--TCS~4k%`!$cXdf=AYN%yPiNhRp%mJnBF{0KeFDCNBg4`6Qde$tp($GnZx zl+NH~gd&XqpP5r~F4IxD?rFil!AF46j@{Ir;BjrT^`IlI@Retxu$Uf1KY;rTGl@?> zrN$pSK{;S;MINEn zOI_4ibOWx6=`wjh-_>Iq7Anwe16gUlEb0iVKFhm_8HQlIKXRq7lYb#KJ+TDULAwHe zW?#qpI3I{9v`wtV&zFnB<<|0+^=j+fa@q>BQ?5APk`awHS~2{cIzcam{91P=A#Rc3 zOyjh2&H>)qW{1>5%Vz2iPlGSgP3|gBP+zH2)=%UPut8qnO|cZwfO0{ZCl!-}pcgjM zs2BLl-sSs>o3RyQmf(U6)k?m_+*jU)WB6lcv}KBACN@dysjv175f>^0*eeI4s9q+= zXDRKHc%f8af;taN@zkWAdLA-2T=)BR@udb8m*NOHdz~&aRP);0oyC` zGQk7JBG6|Yd4bUQ65Kfu3;yWOENPTUSvJi4MQ|xvsXBSPr0HFP5&Al_g6*Nyk!p(f zQ^-I!Fpef1B~5kyh}orb92J#sYBMC9IIaz*e#g9##DE{%4)5$dM2qzO@R`I@fl+It z9>y(vl@uObRh8KkSisiD=Ieh==3-5znKl~z2*lvI!BY52@5+E{FQz?CJcYZd9o9Df zbbJiTl#@&@S|mygYFZC2+g_8H=YOG`Ht!ru!&(tu=DS`-IRL}eAK>}HQ1g}ZuM)}a zgKhbS`f1x5u3f=A?sdYIn9u4d+3(*eq?uhkJxs>7UE}cX`U~DBKG50{=s(`D142??VB&4M)RqO1{4?^_UzCR)Y-U2n1vYpWy<_`ZMs3Tbujb z??N5Sz4--Y(e4G3oU@gd(8=GX7Dx>ewSuko4-BW400*VU#y$#oZpgdDM7l_RnP7x+ zQ-nr@+#cL>y;rBk?h(>F`}j1^bW{&iXBO&R$#c+;Uh_7lim{#Sf}1P^cCA*ApUEjm z(e?*Z0#!(>zDb-4=U^h9$=Q8%%>KlDmBYil8_0H&pv9AU+-hT#vffO&N9o6iSTF{k zOiVIe$)Q@G`0M!*%t;zg_*OeJZ;RH-`0CqScm&Lo&Rdq6v`k8NQ{|Mg8CUa?upB&# zTh7PmeS_6ebU`xv6g6b`7EHuPdD`olmWPzan2;q?SB>R)=8^tZn(X{W-j_GZm(3o< zHv9eDXML7q3b_jWW4;)#nl9Bny30~hIV^QnJ0(p=d2pUE*85qANGO&0*`Oq31jSA4 zVvffbdaYd1z;JFW(bQf>Xm70#cBQtZeqsNEH`r8R7x&K9UTenfY1%@p2RHjVN-N|4 z-QP;ygF;xzbnH`(?(}Z3m+GJwDr>39@Qm6wsyls4m~Qqm zpCRvO_vr!OV)?e2`>dD#NZANYHxe(=2{CgD8{=c_9689D%Ern)EjP^`ydow&F^(}E zaqcAF7gq)DATWCs(LtD|Rt?r+D&T$KN6iIZag=i2+kkMwRCtDcXKjx*p+r||LOK?v z9pp2GInIgD7I>y^AkH#KAiyek(j+L&Mucq=Uy*kDD#%`#5V&By&-M~n?{IdX^4$Mh zt{zygl)#$^HBc1FCc(uPq!IMVbh zZB(v64xD!%OSCF?$*tm&0Pd`=%oNvQgNf_5j-U}DovDT6lxL1k)H1jfZ3)IP3&4fI z$^2}!c)?|>`;F#;UV+PIfK}fH>B{hKV3WNh%w%w6F8F_l1x#Ey)KJ?PNFQB;QN6Y zHssjPfY{T%>EsBwi48b{N-O*%-$NZ~x`VoN^;F7xA33bm?TzTSAi*HfaAmJjnJUIj zF(YNwU>4EYJ>>joZNY z(T65%W*hovm?0ij2Z;Y2A!^Qoh$G4+JG?#Li{nP_D^a zh$VegSn_NsU893fxgxe6CNDe`w8duWbDSd_{VfBzJ!oWLxiHQ<5J2Xbp;OWPaqJ_X zV&=iJ5{tk@lipm%ch+POCfkRjv(AT}V_Jc%5VweOV4YOKaM?O0nwBicxQ7~ zWIW<=ZX|WnZhDxIOpY<>O|m^YH%vdS-=`l4K0pWCF|MTdusB%Cs?M*;<#GmD|MRbv zmeD0Nugzk+;&|&=>gg!|;_h%~$c2`ybSzk$CaJV!J2dEzA{u_4Z{gd3Xs<3b$$-{OCCa}5>B1*t`qO++! zXX;q3JaV|Zqv)g*ZXz4z7{!Nj2YovgyFDy<5Gin9eYtE6y*3(ywS%#ypC=S7Cu=8f zQ4XQiYBtrw)SOR7KL9@WlH1PYAgV)I@nopn`Lda8uc2vT(i#)eQ5!6@8oi zDICUU#x0`O0#18mDX*U(ajBYp2|Walr0(UHO{mMf)RFd6c!e6^1h^LSjQt>H+b&@{ zgptm{=raG1S1Ad+RWeNSm0cg7lj7{>yyV^OpJJ@`-9}g8Ro`}Hb$lNpT8iT`N z=MUk>@Qt)-Xp;>(7bfK3>jPu>T&hFUENEA^DIK_3Oh0~pb^@B8Gev1i%tiP3qWpb$ zo}Ixq%)A%a4P)SKJ>b2R)D;Ymx!t6sumHUkPKZ(ZYx22x6ILZlo211v!hGrzTFF*K zW%R#s!-8uatM!s#sDDRN6
    JrXX`t?@(lVb(+{*EZVW`~M?_>b|cuBdc}66IfDm zYP&5**lzWod&cwkiQoRZ4LIXFAEy7d&jw%dDnv2kI{V6pqCa!M?4R{nvrtllq45=9 zS!D?Gi^`G`thYVuwf>f8(ml1hw!83x7ZPKYE&(U!fTN8vOs2mU{nv;kuG$@H9=jX0 zBO?RTsZ(kVsj<93;|Pn*sh(#lDdQ!YD8qD8LYbJ#ZIXTpM~IXA?&M;W&*(K9R;2{`V$)p zW}s|r2R8+uZzhaKCLCldNib%8G65HIzsN_v1S6lU7q$c6ByWlj1*=dJpy(ddJ%0!) z;y6mKMWMP?TVs4>ngnN@>RuuAKo?b3uLYI>x2CX_3ggffDpe^_@SL%m zPO&L)h(Ex0^u(*_J`43pdy#*f>j?aO4clKc!{$aq)Mu``S{q)s)zAt#LtLXICst-{ zN*5;7hbhGqcl%4=<~)1hJJKm3cus zua**h=vjPE?<#MYwTpO+U1hx2XU9(o48^PEf204Jdw79-$68&c(Bb%Ys5gA+7-8w+ zXv(*UEUIrT{7Wq*t}C?zja{F)DfC1&#gYylYiZOvd7%eO>f-H*);elh>pJ6-?Mch< zsqr;@GqjENDoQbBu&uc^R$fNF)H&@V6|cp|9)%Uan?P+g9t;v+T8bK@G@JMHvGwr$d$2$znOf$zL<+GS#`Xg>rLF$PbZESK*LGf@m!cC4>F}1WlhYO9_ z;GEzs%7q+S$PRY==NT0`MQCR%M;C3!l!p+=7yZrTFTnGE6rBZQ)L7ewabH}EL#fM1 zJSUkcu-M|ZxVyXi!lK393lv(~X-6U_$w?-~-QC?;V9^)0@SX1;v@>mT&hy;&b@|?A z@K`bXgxd=WkUeOXJ`nuCQAUeoja-_uNBe;rBoDnMtcl!n<}=ugJE4Ce%9smIm_@W< z>Swtn@dFLh?(>%pj*!>eq4eJDW}&hC#{=tc&ZpKnYf!Dk2v=eCh*A*V z1Ec&2>?~@wL9+YM+rUXrBl<2p;5e;!5!QP9S%bnom@e{ZPksBodkOWzdcp_T{9r$r zM=z&faJ%W#`T$@I@i*0bc|$vY*CrCP(>o-MmluLxm@0E0zL$F@*syh|j(%P4=1hiF zxWU<#*@D=~S6 z+3lIFtmK|s-N-@P*ZHTCZ+!%Q#y6BJX$OVl+77suD9cKGC;3ufEPnUQ&{xX&)HCNr z=DE2`Tthg;V|HE4buC4?GZc=eC(;AY^(NMZBYg7-kLL zXHWBPB~sxNoV!iK(&H%I;4u~K7Msnc(CedSbBAMwW!;rpgC^P)YOvl(ZKhSmnZ~NzF8^KR$M)@` zw#j{Wt{}e&|53}uBXS8xd-)VoO5f|ahWAb#lml)%+(A;?>g9Qvpi*Wp>wc)o~4fqc7OSZF09{L&14!5AvjO%ng`h{fYE;K*J`n>tsZRA(X z8SoY@C5M7eXbQK2pfJ(uH)koXL{_o}lW*}(_jI{_$eX20k)PjB3KlBE64YrN_ z*r+YgaGZWsgwCQ|oy^A|7S@Y(zXD)ivsi z$?ObvzEmxxM4&g^XWuqIiVsbV?QNW8=Ccl%N}kge2)fZeXH9rVTF(B{ zoS#h4SBKb)@!E;TVXCZmru3)M3dRzH5DR|yW(h~=eRcqzRez`d)J6!^9Byl$`5;`` z9;Q8GUbqA%%DRCLqV&YcX8+8HL*D}5ttFJi&hy@cwD&HY7Di~GuvDxTC;>Eh+!tvU zhMhfSgtmEqO9I@_rkG_stolxB=AH+)N@qPIU=?kYr;FGE{GC|YGXre4c<&U}k}9yH zLVu$Z*xk3?P$8KxEfAmkj9*7@1q(u7m`kp8o{fR(p8uo`a#MMo(I>K@W4@Niz7f}B zx>>rI7kyO;!8gnusd3~nt&^vZI#<}P7NptmJ+O^^rrH)K&PZ7eGx!v_1#Fb_2iU9} z0reAto`p8Aoeq%R?VuTG4ZDl?xq?hzOf%Ch9V|eFl`MX1&N%Y4t3-yVvf^d49M{+P zKeGcW z#%|R&WS`)N*h%cu#A)<6>Os6WRP9hbA@~BavqX_e6TQ02VjKfoap4YL<9|2Yq1OI|w7YE${6VLxc?DJU)?=VAKCIPx=m zYgvi^Wlh&3J(KLN!rw73G1b4bHr0Jxh;_CiSJ0Wt0DXfV;a$KNX7^A$?gO|oej}?J z5JH2!`0-XNG=i)R7iof6f!!_Sf+cY??WekXBQ3k;8R4P1x=|rm&NWv(ETs9N^^qzQQyx|cL{Vn!GJT1C)SRZ*j2(^gp)WiE zOi{PEPlYe~O!S$N=v&StZC0|wY=Eil&DblL^GOyD=Lz(EF$ZjToybXpj5^1HW;j1v%8mJ`}e2v)X)=Y~7YM zv5?)`ul}7jvhBmPeJTJ0W zhq}AE5Xr*#;3as%l^b}+4PsjOtNQN{3YMtD(jS>1ewW{0&ZN$BV*`x?W zWsG%RO@?jhTWo3PFi(>31IIKvfi~tJY^L-+J)I~Oo9Ic09rw-xvAjpq-M)kxghMMW z4I&0|mcOjAlj)1g2gjri*w4_2YQJ}~T#WAyj!?2H$WxefP@AYC>GZGgPhpGqcXgj* z3g(A2VaB0t&T{4z?SxQ=t_2G*$B4VcSq^Wc`HjXibcX9mju6*5Qq|SYk-&ymlajzs z-vfCzaab9j-HGd(R29>>h7kvIHbA8Iu@_sUatIxPujSoLYjmDHkwl?Z;zqAe=)?P% zbFP2LnXrT{MXc~_wKgSpR!^wr+J#g#G@5@Lysvthra*@Ka?TnSe<$h<8k-ZTY?l($ zV<=9334R(l^O4v_8_V@%O8O}fr4RE@RL=y@rI%7O@}A;;P8DsW7!?YHKABZe�T= zp8lIX1LEc+Q$(&UDss)>o5cFuRjLO6$)qIM^)MpWci7)#L@(-CB z!e>qr5DHjFgI@Oj1X1smbM*4h$-Yg_JM!7)<{TJKC)v0k&_KA?W0mV_PMe} zIRpl&dzt$Q)7TRNWnNagFl`cEi{)*MuyD8L?lpW59OE@{nM2P;^(pr_zPyCl!AGUamV>)!b$A}g_-5l zH@&z1Kb)ii=#iGG{=zhxpT_@43$_U7!KKVhY!KOEfpGtDL}W4d8{&Xa9?i=h4kYGJ zKvUC#Jl?pMa1|7q2p{zWcu^#nPQn4Ni3S~k1k}vi9xk-wH9jgY-yC087M9 zwF@&E859nAet{WT=NN{_PI&0!%!bB}JqZ814R70)}1vx2ZAb-TQt zuf9~$C;Ox6lhW@@SAIuq5)`3~U8M8ihm6hPXl9puH!=oIVk%_xWIL8g<7{@FcGNlF zyT!RfA4DAzPs<0WW@2OEsQ8tNcL1}M@3{ThtQIKgnVi1MKR^uz>q;I>g=xc@)^_0+ zY%4UyIh6x?AI5gHN%*csreG?&HbwsIX%h94-x9eV-&-$)Yk|$e5?Gph%#IPSNvXsp z-wfrNvs&bBe22cn4si~$X2sm*{!F>(?yp5_mg5?d%}l(NziYk*eK3simsvTN5-d8~ zh!N7n0`)#IUN!1mIT(rH~75V z!x~31+*l^T4oRNif`rjJ&$PkK<)=F;2P=p_lx^U?6_iKpE2cS z2EKVGi8fltIq4XF2nECSv^wmczC~ytwr*tw&!P$-Cs4$0W7&dV??eVAhdDfNw&#ER zoZSA(bAd~#z#Nvpvju=pE2Tw(2Hr52Efh?d!FEU8#hKh|t}u3h=tOo>1NJJWEk%h< z^`(}j7eFht#d`~2qQPuPa+eN{W11^0T^%$?Dc~06Us4^VWB8<2%NKIL#dNG5IWN!_ z{fvtAWvH;dA2bS;h5eIL(1GNG1SY%2uhr&yss_I@{{z(yolTj^Z_0S2Jz*5`Wq`K3 z=AJML$s@JD*c%5X82!ChLobL5q|Q^eL+dNm$?9Vj=aW)=Y+JV@XDoWjoj>IC)s(;D z=Y>Ot;!?>Ywv}3OI{zk6#NE%jgi}WIlmu{Cm>c)V^VD668!IMoX?jjf3fmTq4zy!O z(BGZI#m`(b;S@cGS>ZM||8Op))+Qmg1asRO2}`8k&=>L9R!l}ihf{GX!1!ri<7OFa z+05(#>{p@5ff=+8ANop4-L((7r?462usqh;Oy5AZF#3g#;_?B_0OgZx(pm;2rcQehMof*G&%<~&lJ*NfkwMK$*u)O1gI+Yn{E?0js3sWp}F3x+G zuySDq$3nOOHI`RFD`BsGla(6R0B`)gsr6{AumT3Xg*-nI#SXyjV^V)@xmlmJozGC$ z%D2h1{WZ-Z{JhNZ0DzgujRhAK%Qap2uUv1?jY{XRYa86-VuyI3kcKC3( zCebY%3z}sw=C5ITaD?_*`xkO3kE&vqV`_#ng-cS1Sw@|Vc@MfX3zUM^n_zurxfMYq z0gKq>Oo5$(eS=;w-P>G@0axS++(1%=A3e~09L~#qD0~baf^Uq1+0l7!t733>+Lz#P zt5@`PW1h1Y*;eR`pH)XO&vJo$5^(e=@0yf4$sO3QB;kR=Oe;%WXm(&mI4=vU`BJe? zp=r`PJk@CXJbxAyZW=^!_pLNKwd$A6FO?8`9H+em^J3w$V}}mGZk|! zu98XoD(yJ=93@%>V-dhKJF(n8m+K|_YS;QtW>p&!%=$(T8;os@)Tz;_B|EKGs zx|FZuKJ1+&T_=m9^V|=yA=;4W^L@l`<5~15>m|DxH%3k=vxQB{Xmx4G&!4k;@k5;e zO$fZyw(*$S>aHXvW9IB7cUQ6%zu5|e#!9safASCDb=3@erBruWVq1KwGEcos=S+2>Y2({`Ip(eYQhG1^W(}cVk(KwSa}R`Rfil{G zdKSM)H~A(G*WN1Xf1q{Vaq|k;iSLwin7QyT#|k3QxHh+qP^6<^sMSTP=j6;qU?uS| z`UdMW2ZR`Yr~4mbP-t!_d3R#cB*`i4^1KD3Jd@~$`~<6C*rxN@AF!2Kmsul?VpmgR zV~*<$eO$ORSeMX3n+%uGz2%Y&=WQf(4E-c%*HSn<;fIM+Zc25}xj=EIH`~;FX|5Mt zR?Xyo%3Viv;18|?ykn*Ev8Wi=OKv1w=I+IQfD!6T;-0Wf*dYwWQ^5z!Eb}95CDPna zDInvP^#0YF4SS(Wfz8Sl`h{>3`~zH6inBKHPO)qcrd3znpDJ!v_gP;BXbopsqY`|E zryFzNM01F%G0GLEGv(MYzB#wSwz7O@J@+@G9$(pCRSu{ttmlRclTh~yL3G$V7&Hne zI<6c4aAde4>W>K_VFM@El5klBJ7H{Wt$lW|yEWC`LH!5o!4*OU)k~>K_p{0yO*~il zfua()hi6`Iu+1%5Sx!*O6n=g1cWHKLyg5MAOom!34}s_6SEzL~hMcQ);=RTbe4F2s zkdZUglZg9A6|r$Kz%&(8jlZgx?(eBg>XDwTeyj6>ze)Ay9un7_!;J6dF6U_eXiNro0EFb@Df=CZzk<+I(ZjNwTuBJopgyk%wuG zBAcML!2nLR?q?CaD$FyJnBVvkD3`AQE(xOGRywK6!x3yh&s4_aogt3qH^5q85nRd= zn0IZc!?b8Po+uvb304asb-P$!e-?VgQgyE-Ro4iv5<4q1$LmHTlu=w+?gRImUCp>| zREZnJKMVZEmbX_sPGiT{V&e>TP3q5r!5#g4(NB}v@Fn@C~ znl3#(;M7M8-x%EXSBtvav5VQRFgNL?xhTyBm*b0x%cwR~Ds!A(1FCxeHsi=y&hC7I z{e444=7b~9tZBr6+k9`KnBLo!lDpe?lqwr~D_IFoNp_<+-sW04}vfdocSAja( zr60`R1MA0LMA)*eM@u~ddFGv9AAIUbV2^RbJXywBWsg63%WTlGzkC5soV=X475E=`9oo+i+91u<(@p#I5pv#@}y8Z^5`n z^dOv2>d);ExvW~;>JTP!3u|e_rMN@nQq))bM{g)R@~c`!)~EeN-L)LzWV!3UH{Kpr zn?Mp%aQhzg#ymq}i5iiPfY-yup&x8nHYKBqR6J{H9&E|kv<^VT$P zv)gkTyCTA4a&xs|OkK8<_6imO-+aGt!*vBSS}I5U>y0tm8fTc_$f;;}a42>teG%?j z)e>9aU29z>KbG=VqkMK*bcQ*Yc?<2p(}SAu|D{kRgmJl#giCgB6!dfv(cdQU`xkSVxNz5?Ky($)iHmQ`$-GKs-A+G+u6uxmP{eN7l_C z2$VxxjeK`;JAlnV*VXhyzfy*%&O9(4dGBa0X<}kMe-X6hpNHqUE~H?MGOnK$fiNXn#@ zVeBx`jx;vpY-_9$N-~t*M}O<=?N!KALrdi+vf?l zNWt5VDNwd9BvrxZnR>iL{VsNwYP0KtSww`qGte6Ue^xQp>cfL?*%{2J%)a&KWp8As z&^zHjNMx7g#*o8}{#IdSL&7t8lYbLa+pN6uF+Yu&C*MoSj$kvYph4EFtQ{%6*+Z}) z8sI%7wa*#NJqe@0IPlyk0w=Q>RB!E#G(#UK9jEr=^p_h)pbx^&;-;9Y_E?o<#h}Bz zTFLdUL(rxqme6{4y(%kIM z2M^nWtzO|Fpa}mld|KG0SE0u;JB8wA4d#|G0({edg5iuT%^~X2ec*8aOQs&^mi31= zVBZwEI-BWP>FQ(4Sc#avb)C`UT5{u|gx( zutj_hYR5%*$Y4193@%DItW5NEraf|3dxchvo=Mg<^Y#TA9peU|v#wnYrxsB8#IXN}gNUbOW8P-JXnoL6z8SpD zz6p7i*-|;~hE|ej9Vi|yrHn8J3VA|pf<*yp5S1^_78_!MT5GU0`=xYO^~tN<1BKqw zSH2oHycP04)=Go^aACM6Ihnqz8RjywpYlb@)z_eO_sXPO!Y9;_dMj>-?_WQkyZUdt17yhfAU+EO^I3>D4aPJ6N}Fqw{wS~ z#po8_4i(bg5e=yxyt`~wd#YXsj$*Dlol-ieFC7xU@B^`#tDQMaSOtc8cj;Q7GJjWU zPZ-#_x+GC_E!JDwb91(5pP^I3hqKenos5m?jiu=u^=29+GpLEWUNqk&HWn5UNNx!hdY7YvOq5uN|6MX*`RG{vvBheuVpwJt_#tJ1 zni&O{>A`E^I?UhZY%-R=M|@KKW+XjB+n`KhI%L&!ozZ;iLhos))GulJyvRr0%`TuS>@?Q>@l=E zd5w94?12+Ym+B4(4gnedvy`QmJP7pBN?YOvXo9C*c?3zzE!T4 z62xiX3!XXHC0U*oT-Cf=(HSB#aTW6f4%P=yy^KoMA93ZF?!+4Nl+{;R6>Q7(GA6nP z2px$9W;wZ^Ry^}e^4V}HUFM#-Q+=h4Qf5cxm~e?~z&Eu|Iyb z4v+qbIpCG|Q|Jsz1AjP6h$H!`!gTs4Pg|U8BBpxy^=oXH7^xiqDymA9#Z-thXuRtd zk)}0;3HBCigAVy|;Fs^4t4AO)q?*(YlXaX0@ZL?CnJs4`%=WL_sRR zF*2`KhUwbx{?V`;HiEMf1&d%i?R}z+SR%)(mN73a0l$Ng?2J zRwaj#=gAjvn*LM)>`P}?eH197-3o3Xim(l^&&x}8Px~*ShrC*tWgg#|uRrB`QibSR zIwy4YY-MgT15o{pmgXX9OQPvgm8nET^^N59hA)|8>_&eTwe$XTB{K7|{3odA?aG%0 zE2tLaZ2meLu9nj48~5lj-dtgf{bWz$?Bd!4@Y60yj0Z2^{m>t96yL;f3X>8Bf`_u5 zxq?|$eG21bM75Xqv{HY^EXpf{YVfub1f74IyHdUhJG8AR)m;EBVVZ@bi7i|e z>b-P3_(@MAcQb!*pU?$6PI#1?L>zXILI$b}+p0~}R`nx@l0sqTl6pnli!LgaQ9Dm_ z@J!F;mJmY&hj0R9hijxbJ!D%Gs0^Vd@|*dbO#I8uCwB46!9sHzL!mnCD7p$~FixtdzN}6zwoDlQM`z~~yo5`H?zBjjmw`@K0NyaPhR{DmP z9G0A3FkX5stwOI*HEy|foH?uaRJt7Ojxx>Pz;i7hT!5v(bbb@_Li*-PW-n2S+sVJaBme4%ea zH5^-rGUm*>gV4xFAC?oL6{+_bTSmE(mZP6pg(AclxqxyK?LS~^}i+hH9unqNC zG)>$pOy@tQ6lYsW8^At(ADU{76FNc{B|Fnmhq!jZ3;t`co7^IAGaSU7i5!WIF-Ofu zj@?>0;R@qEI1CKsIeAkMH(gB*bxs_kk2#c@zS{e#;WM(lT7K^}`bxZHRYpq!U73np zJ!dlb1mgm3wy2bz6QzW}I%*g@k~;)cwN2VbW2%~g6OHXq903#jR&rTK!xc4(Vg8-aS^2FpdOjaV4TU3-#&cXc=nb+^lb-7Jz<2W5*bj={{t1 z=WeN=#p3E#wh%zr^Ku1VHHJ%z=vl}^C)+dPm)d`$7^56lht~05&2c7Ak4~JxN8?*W zcad;}ar1GWLc5A9bHO;BG}`du!F=hP+JOIvdqb;IV!hXs7@-yA)^RR{>MeeCPckRc zX}SPRx+;22*B43?M-vvKE&NeqBoh?tZCk=^Ht!@{I4}+VV7KbafDVVU$Mg?eI#Yx@ zPGm4A#HD-N<9rbId79nTBsk05rj@~)tO-niM-+2P0>n2b3yM1Kh;x*I!W!lXSD&wj zyP(hPx*iv?NE~mwl8WX0N=W8|ibJ1(KkJU+E?grP!I$$nYM!k1-1K`P(>LRwVck)516g`cZD!f|{UFnKuK= zJ>~66dO^J;-z#NrsF>y9>M_Xu5KNSlbH=#V1{abw)a<+#z`-rzO6Wz5g6=Q`)T8L< z^m&vAG-jjR(mI>eMmc6xgQu8H#0kvfI1z3p&f`jilGp?ysBaZ|xj*siv#vP1YpEis zMzWPno_i{^G#y|8ZrP<-74uGqH(5tvNxc&*FeaNt{tFff+v!A4F-(?N22L|YU8KM-}inTJXMXSdDFm6WoI=O8CP)m^RpD zsGeRO-IrD?Bh=oYEAxn{!z$Wuv~?hY{Ar%GAL40HmtZnimV6m(#H`j{a)Sv@FCq4H ztY&^E55wq`iTwITNBK&aSprZAXBGP&aq%Oiea|l(8vaOxO7bS?j^+njVN)Z>69Rk(0qBf}&5vsAc_pz~oKAru`6Jd7~ z4r!N!6WGOj5ta)Vig2^Ghtn>E5~{gG>&Qaf}Il?`=P4jQHCD%xdzSt#1f zQ>3H=;s&eit{diEKWXH-9IIj7VI?c z0jqHB4;(Q2fDEI&(NjMUF9|P*NeK;%FjPF9tu5e;d`%cA7lQru8`zz&6Ik?bf+D^b zH=^C@CS@Df1&oa!#;+x|fNcG5t{c4p-$zf%#q^azHNFlr0Fw<%+S|}4qaX8j>j|{H zT3PTcqcZsprck@yO} z3*$k$RYYsxEyNg&ri%UD^`!PzMs9!DT-?hyWA>TD{rz_@_qH_(?-lZ15D9b%Posbf z8>+u@9nCY^5i{O-pI*(W#8mAI`02lyv`BL~XPGaFqKOCLP0Qg#*r57QTW+0(1mpJn z$_KB}9XziRTJJxlO@X$O;n;#LnmeL*J5mC}wC6!XeiC&btY#}2 zujDCWC9)!R^zB9h@fYggy~iQeYZq($#xY+E@EzE0;4yJj-^M-i4CG3CwoxnrodS|#Z>I56pm<1U*D9x)H|K53*t+W~$=;IsCQ6U`)Ulu?q!Jz7Zfd+28g z=dHH%9k|#TNnEjihid~ha1OAoUB$F>DpE(7qA3w-JJ40Tle{aYESe459}#dr#fBDyX?yalX0GCgu#;>-OIjTX{q)E@k4wRG!!9r07`?aiBc)4{%rX*TxuOa+q7i4EjC=nq(tTwW&Ve% z&@JH`?K9b0iAKBGTii!JiH*|RNCBgpFo%2)H&Xd3*20aYflT4861qp#k&a{2>}P6~ z5Ro`BeAle4=#ImBKcXmARqqzqkyA~c<2i@QAf4SueQIHkWOc`a7> z=udL(g^ihdc^+&Zozys2E(Ox}rGzpS-c!Jsgt{h>pvbQIeqXpx{G+ILOnv=}YcO2G z433;=6!m^l&x1~R+o=Z}LB48uNbZY1+DjD;E*ATE#n`8G5%5Xtn%i8fNk+T9)@Ar7 zU)4G-M1$&>oxFrgWZJvC!}sitKnr~yyg`;qp|X)3?Wrex2uDgaaQ~;h{VHZNGmq;o zTo?be*O_;uRe`c(b?%U-h2DkTt1KaFTc!%=f&4S$5;NEDG)5(M(yE|FlE_cnKgp5L z-{Sl7JIHMIO}v1*=y&zLVrpQ#<-}W@F8*}Ld_Y`oHH^VjEGyStg7d;gnN<5gb0~{g`8q?H2yh~g+^Gd55J;SVhIa&Hq zBZXZA0j7Ih(o1_{F^Q@wcS)@x4Im=*lVBoSHS{$+mHxnx{1&@R&T}r=eH2s>dpQ;r zycP=ld$2jx!5kKz5_C&9_|?Sk;%Mt4yGtJKE{Lj#eL*GePk9Rw3Hyk@tjVy0rwB2D z7-en-6J)mD4C{Z)U44Gzx#oTLfl%G3z@(z_nT1{7{3+ODvQ+4aKKV+Uo7v&kb8({1 z!d2{Mc2TY)yMperN9JDEzvKU=jWT8fncrbq=V-2k5p-6?{aC;b@Vt~rqZav>Sx!x6 z`CvV68*UCR*COD{jNi?4s)>GDEMb2DONC28ogBrKQlG+VM9IJ*YO89K-PHd{7aL7v z0V?WP9G-n>Cdy;h@$20|t{MSVhD*nv=^f5G-bP3EX>$?NFPdTG?d zypi)>6VV{!J71qYC#9Jit?S|T`f>BJx`OXQ9SILHpSrrh`&w;ynwid2qswRPA_MN4 z@*=;(4(x4fH)0Q>W6n$Rde~h{hI1)XP9qkYO+}o=0Ta#7^wFqQFh6+2ND8a={-y~I6Tp3Ow2h?lCD~GdoW!|g$$i1ETA5;`*a`p z&a4S+007fuUb_c{uVaGI2iXz&0XH_+y8bfo!1kD%cCrDF`%XY zT)r2468xXhJNvSt!yV>eI54o7{0ru2l{iw%p>UHSroO)lP|Q!>OUwXyuGYldYaR!? z$x5Jxvu!9My14RFNKn4JwS8?9gCNhC0Cf91(^C75x@gC7x7d+VC1FtZcwf)J!|-{X z1u1n;LCPrT92~eDx{L46hW)pBL7Q$pr%uR_pOzdg)P1h+8tjJtppW}yWyXz$3Hvz1ZqLd2$tY4$KT}=ltiI({B1QsT?emn8=SR$pQ4sC zvv6ZI)z?UVqV@%iF+uj!;0L6|G8^$E@IHGHmgVQFSHhRD`Dm=MIO+?ZiBmzn^p)T{ zzAN%MN4*ZdI`f1qsORh5$tz4t^aFoWHdBjj>lltfa(C6!+%fSt>VYJAo}mN61hf&m zyJ?(S+>YtUJlO8P%lrYFgjbTAe5`iTlaH;8&zx)NVrId=eFB4&mVEy{g&@c4NpYI3$e?t+dmHLDot+h{hObpAJ=1su!&y)JH z#4-uH&CTc&bH^%~^Ik#TD@1n&GPl`tN}{~2$=bwv!LRBGc?dr7kqL2u$zET0QIbpl zY9<-Quyd$_TkuSiKgD0r`iqV7-hyeMLe^txHQ3^dJM4s8*)Dn`M|Jib;<<{PgY@t< z?LF*4(USBSgMa8~Yb0^vA(1I(9}@?<7Q3&4DL4&1H)*mvQG@iq{t?L5R)Q6dgW49y zKBEbK652QuH^t~5^6W38+(v71hj@phEnu#AM4OMF^L^5zU=&+N8)V$D#)j+ZYqh)D z=G?NdRBT1~EAXEV1Wg>TPtoS2WhvvFRX~}%al-eI;@QfVNUDO?K>~N{t9Vzc#rZdS zF+Uw!Ja;tvf_W)alj0n4_7&<-j!XDpI@~i*h`1)5an|k)rx9nkf_$7+7(C}18ROW{!M#b1 zF>|Y!Fq;v4e>0^VDWD!vH&D-4K;4f$uCvu@W+~;l<7@IjzI$LJbJ!(N4Y(h2h@0Y^ zEppx?DZ`mY)Ne#V_bT$P(JW}?499P)VamV0Ucw9QPjfzq!CCwV!LDor`l4Kot`F*? z9gY=7Fk^b4HC0DEOP4fCr#mTKg7bw0&t*rN_>nJURbrYG zUg@~Fz`nwy1v_}hveiLL??`E?CBTnlSMU|*^gPB|#cLkb2BUlI-q;StG4pEnAHr~K zVp?xRF}0MC@YT^Wbl&kz-6Fb(g`t^flTbhWn{gq1A|EL%02j60;7{%)PU+T7oPjyr zCOp9oOgPT@)DeQN4acX*+3u$1DklK_L0kPJ|K&hK#%&E19gIixa7Az;xF& zrh0dK5{1v&Yan`#XvxAe$*+Y`o>3%co$^kPlPA;d(QIZ3*Vo#E{u5{A4r1>s-RYx% z;9oKSaIx^Zbx{b2gUNne3GH&|g}N8Kp^De-=yNzA7?hNzp2NxVal(CY8ziKZ7CWjt zsId9H)(~o-5~-bJDluo-ADJVIFw8};6?Dc4y0OYC*W)nj=#n=NC(^Og8tuj8sc|^Z z#3`MXL)>*^7`9X12L-^%q*dB(&liSJb!{Kvr9T_f#BA-Z$%sqACHODSJxy=a-_;ro z3{(kT!Z)an`(bDu%-8F1Ht=vSIMeQH1cjl*3i*mrH#`h(WCt-4JJ529D7?ok>i@x= zFq`e~3uc@6Ckf>UiQ)&<611GPcuSq5kGF5q9Y7{ZG20WRwyKz#T0=WW^)owqaAG8R zoV1TV&-9{OFuUdEfj0KV~{ zgr!U+uVg0c5#$_Hg*=NLjefDVT8Y^fEKF=*3(L2}Xkioc(HhJq#VyQ9G}tCYdaEnjD%j`nUPC21>cojo@Ht^ zX016IcGN%YY|Kq$MDWnM$rKG^o{iDck%gHe<(OfP-TEHpuC#+4Cp4BTyE9yG#o5*y z@{#%Z;BZvbWq|_BaIZVADqR5G)(48!jn#2Wv@^~U{2F5)?mFBDlR?L*&3u)d7x`C4ReI~_U)3M6vq$bKi>#b`bm@XW^Z~X!4ReEBIm-m70Xc+eX zBk>dW_F%O`XU?ql|I zmxTuE&Vx18y2cj1DmHs`4My33SuV6oD&dOIXl628>3;9LfKtMf0$)8kcaTYChUMM~ z9+!p?+1!6zhW(0-W0D-lLw7-4yy4r5^S+(sk>W|o8IkEnVOoZ4?9UMAf7G>Y*mv!0;n98_CZF}Yk-N0b-om`JO zE!98;GB1&B;7F?&^A${l4z-){UvOZOzyzpgXpMT*`YCONUU&wjqma2(y2P|HoQd<8 zTUseS*HzYB2&Zw!^rhbM`d{#zSXpb&pT>U<&DEh?4EL2;0shs0fPXl@OZD4mJhxJs zLLEL>3AQ(*m}0^#*=3c{F2QDe0eZL40sg~I^^e8x@z&PSaCUNB(oxTy@Dj&PwJr0_ z--xbYMVW5lZ_8($3{=<4+IztC+yZ7*It`U&55p>aFLOEn1@=1lKdG?CVQ0d6;g z*4cibmBLJ(RYnT^6?8+NaMJuM+sm9GY;xVmTga?Ve&Bcw=drQWE&MIL;!X+$(T?QS z^bKV<=}J293FvQ#lgv%>7*)ZpBLBv+Vmay_><=$dJ&n6Uz!-+R=r_b;zCOevgvr&| zbo{^gV`duvH`7g-k1`b#RPj{TdwQddjoM@P(*zIB-^?W%;jK>B|D)(E+@rYOHXK|6 z!Gn83bd!~lGc#vqH@LgIyA%oTUff**BqSu8WMzEL%x<8#Q=~{K@GDxN#i3C6&UfW6 z$X>fM=e+Op+&3niKLjt6YAOY(Ye`6V0!_8L#uxr%(pWOj5TL{+g<{bH?S@b>e}k|_ zSwsFVjyGE>l8}_hFx$a5amtcLVxY;yUxV7^&`U+=o% zuLD><%q5B!l+Lk<+@^r#^Lcyn^%5Pl$=5UoxE^B1un-Dz<>=)SpyIf{#FD~Nmf=^M z1>XLsa$$iM)<07__}SQqw~Q7;C(U-c2;=ga^3H;*sH9XDRDx55<>r1^kI%JFjB9PH zu8qvSueY=MC{0jxehX|?IFDOlz_1V8K$YNDVh5h27AAJoZWCKkZ=;=d$|!bx5pxKQ zGrypY##>=ZDB66`ye8CxrGur|VR|w@S$z=NNJSNr=qS5V3sSQv;n-uMq&~pCTPj3b zsMj&6V0ccm<^`d)u$xd3-IKcGZB8UNow)12j;4}*_qzGAF9Gl={}xW0$}veTu6gJ6`n~)BMd)1IRH$U|6eS zR?+b;)^<<((`|G0QHG;2$z9&<_LDNfjy4~tU-+)@19soOH6HjOypVWaDKC$t zPdU7xI_5an%;kIsVA&7{Vx3*%8z=mh_Duf3{c5?;8RNM9iIt()J*}lPL^ z-zC)`{zRSCRD3fXY?^2z70Eo-sqo{>pHP|573VFcg@L;sQsr&8u^}3_Fx0frpX5q> zd+D3fo*7`*2)o`u@5a2QW4Z2BP3wl0MLf|Pf=X6~D}_kOu+d%Flc-lrY2R;r9g}yj zSKsURjTJB|yJ@J0*fjJIt<;mn(@2o(5-S-AttF}{qlgG_B01M@haIU7QmSiJ%rea8 z7^7WLrsDpKFPyDCk~@Kj&|z^a`3kI|j|G2M`haL_EpJOUQ6K9Fco&YE3i>8z39*ULBthmi*Q?O zs7E-fx&Bgm(4}1U!Uc*8M%8dj-+gC@nV`leuW@~a z3#?&$$TPYiLOU!nR{u~abl!}U+^Ct{6x?+WL8oCBTo!l|+^tPb{=rTI$Ki1^2cK|8 zp+7^Pn6~jzx~fiUzMOD@Dey~b4WVIZp2eor!@cgxWH&29d*{8xcV@SGI;wrNG^H$8 zJmFBFbDMj|9hwkJ6mmQAL^`NRWe=h3?`Q( z5z{e)voD+xs^@>AUo^`awC=(Epozv6`hLM6AI`z((P1MlN^3wG!Ze~Z_$m98y>@V7 zuv);aAL1W++XOe;o8`|`f1)pVhA>tA3E?qKa?FfSmG}w45||J4b3!{VEHw6HV83bG z&_G%dR$4!p>HdnId*m*VgNnIZS`YMl#y8xHYcKEDR=XLrQ^*npxEA^Fz~sAGz3ktAkEztXzfY3Vs3|i(1QbU5||jq6bZDy#+(N!xrEq zVT8VucnX`sFNO;@ZJQgn(iH7aYMWyoZ}Xf{inQ+PPjPF)#lWwgl5#R;(~nW2K-H8k z%qY=|PiLpuebhXbq)NFLg9lPIxQ35Xx>z=TICgojqxz{`n!Ai^TBoDAS_He5nVh!4 z)gUeCh)fFy`>NrhQA}%krehn>v;TF};`)R7;%gR~C3LUY*t)}C$bX>v)rG*P;>fsI z9N!m7JhxwDyrFALM=;-Fo^z{JOCBl>=l+!^hg+Zl>W@Him~Jek@A*s1K_<@nA@*he zK!ccAcfS6WaAN{PJilG533fAEK{sK5R;a&MpTuQWu5uzD`*I#L50ZJq&WUZ#rc7pN1QBT$81lkxy|5LQx=uo?R&x|CR2e(irHJ+}>udn{_^ zy8JpP4g@fBgD*o~qj8s3T+VIauke{L501x{vrF!Yy6o_13zGZ6E#eL4u>KI5=pJsG zj%QwApYnTdA5cP5t-q+-u!?73_*UTn>40LP6~Xf6_ zTjG+I@2RCOvCrVn6Q#%t)HEd{-_Cz<4@1+f`;NhSU)06%!QNV^Ku%y&sR&q7Zso_l zH&PTIlYLw?(%n@v z%`0NM^$1<{)Jkj0yU1e1d-O7Sa^ML2H%RxcVgk(>N10#`_<$LnKg<1A++#4x76Xwj z_#?swu)1jl)Sa2dEXg^6En_b7fw;`d^j!egs6{Y9T|@1O-Q2>X->v6_oylsvR%u{t zcRvk91Yf5VY1KWIf+4U_9iiHVOZEr6$G?GiX*850;XQ}J^mgQ$TZ_KyFCZY7>4)H6 zFdEeqtI_=!pE1E4Oda`_8< zkDmTeYL(ZYOG_C=I_{`Qjm`@hC_kYgXrXtMXBdB*!>qDmk@;RTbE^4w_!}|1Jxfq$;{kfY?7Klqpb3;YGv=k~bc66}VaTo7A#8g)Aakat=jA58|HcK1A+>+~R*Akq@C$#}tME_#E zM1Sh@t!>0%ZM6B?H!*mFx(l}I7vMmB7Y{*qewg!uB)WU3TavmkOZ^#!tgcm0plZyo zwuUK}KrT^V@2&Jhh!74}61Ep@Q$g)`Sn@zxvrbb7EB^z_6m*SmlxQqfRxw$jBjpBs&mVLexO?l4*d zAIH7{9_pjkn{uFu;E%u+cUn<^dC2V2tFxbMqohq}5*+OO6aFtyKgSBZMj!1X*&Sv} zv3|G(lL3n(XX<;^&8O>E#UgZyX=3b`=Lp{5QFceZfuCz9Vx3HyRL%^;^~66^cdi(y z<4tCkak+d0eIT9&R$^}<7izU=uU1t(rbh9_^(63~j}`6| zieS4|!z>xBW?g4jr~c%)lA3@`n30|zCNPuK4>)gG!p`Yq{V^e?@T4ale1*v<*6zWc z-nH0pw!{$puc!}+Y`vvvqFRp1xLw`bT1Z!5hT5lz4GY@v9G=MU<`)&zH7DDy z3o{)_DGSYuYLs?X3z_F&G)hMYgbcR2>66=+A35xk&%<2z-uAuh{_yxD%G1MDK^57y z!P@3wGm8#L{e7*$bu=G7mwIr;{SD$Su<2B7rlXmb`jamVyChr!<(Sh_IX2@ z@CIKHYzc6GkuOoShRvabYD#nblc`VYZ@8CJnfaHuNneTETzp_<;1C%D>H@@2LCz-kQ3s{B{^{({4sIHJf zUKgrrpD;`1p52@GJwVV?gLTEj>L{f?8xj0j`^mnx@Q%(FJ<|Up`l@f_PSRJMg1>PK z%?#VF@OpJd-UVf4;olITix4sMwI@mpH%5JrACK85X_yhP7eGR+xmnovT&uFr=Bo-9s*dpvcBpWW;A zvC0M83GJ1#(`qAU`+Lg)@j!OJ+$Q1-vz9OzQ)A9aZK!v6M_dg)N4}JFD~)f82Bre$ zFK?aW8_`4`B{;0b)C=q;`D{I4?r4oeb4Z3eO!Pqi*%!b*Xq;g;p4djK-IFF_zHvM| zGVNbif6Qld@vZ$4;99(z+a1(wy5HX1U0OdbuVyFEE!i&GeefqdPxLdIa1qkKsxy?u zY@$QjOK<)gM1xqS2}_gy=Me)w%dUtuXjnb$!V%rQ77O<-ob=7^2?m+mHF zFSeosa$9tojV!3C)m7u@jnX`>z4wdU+gV*-tf#mu;`a@IOw$g<-@_TXm4*lVtLimh3tNak zE1KMsDa}*3Ti=f9FWpx6pmvRT>YRT7v`hE+m9CM6w?jRJA6zePq_GYTi;qDMF;D74 zgTxRfJTd3#hSC^$HNQq}W&KH)RL9Y4Wv|iBRUavswK$5*b9WYQdvE&vhoYIM;8!Wf zcbKHyzvZ@dROFs%ixsGB!epsS!O`LYY`n>mM(C%cRo*bS(J$4^%JgP{ zN?En|r*NfGk~pDuXXC7RE9i3Chp2nhH$-=<18UDjYn5^rN`zPwow718i|~_sgFL{! z*Jk4g=^p#(a23#A7{gE0fi}b}^fdIgPQDI4puzAA(?MRvA7}p7cIty&Ci6bDkDur{ z#CF5%>*aZ!*`_!nP|^Q6?h}@#Pi8P&67F*UP^+;A#kWTF)cSi!d|%uMNwKx>IoqDC zLs!*06*O|+MXEjkoFz}fe4%Ej5AK^3c@kY0%sZeearAg5Qk81rdI$E#3N9^{_OMr# z)y$5(MMwov!Vcz;F`?OM_oqNCbDCad>xJ{l2Zg4fgR?kTp`>Z+!4L5`H!yw}@s(d< zEMV>WU&b_Utuz~F4BY%AIzt_SyOW=_6ky`cG-bP?KJeX?B9e2osq_h^P;4fhNAvYM zcG7;4+--GK?~-Zxl|6jQF0hsPQJBu(J`9>x`P^}%`H@qJ= zR*s7IUAd5#k!Y?6-3i3;{hPHirxATTkv_}XpeCxXe7)!f){o$jg8wl?nI6IN+8f+n9hWjI zSW9hf_9rh>^-!1}Cv0<{VWT)r-yyb_CIG zyipVB(>QDB#0L88L|UN9Ao?FWL;PqeDP@G8`1O3W7U}pd&EhuXG&Fy5LoNz5aQ`Pv zp{Sx^F{`cGXtMdM zEs`olGKAo`zdJ&bHkTG1enX+VY@nO>RxFY=9@RvhFIgxPT}{A z%WUrJ&J6eGWXIPK;_iXjhKPCls<-5agjW9Cs-hOr{Y!)*=Yri|Id0I&!;uG0Q zIE~xEF`-Tcy?D9cv`PBPd#iABY%Z>1upav_*?=1;9icyh;ShOqxz!#Ob8!}ie`Sg> zf;kX-Z7BUFn~CRZb*)mNccGP@=UO-Bl6pctjdSd$EitR3oag>zO@J5d1U1A;as|bXoc*K72fzZPhO-X=5@Zty?MBdMt_r*ux;Oz%{j_$;=M*$;k3%{9Y%sDHFN znBLTM^n22HI9~36*i?}#6<;E`klm>S%!P3=QY-&5^_%aUqdq20-)1h-L)1Aa$#Im+ zJeFWZu+{YP%&X7~w_I?*_>XFfG_j39`+LGg___KC^ba>@+J;lbXP9;Q7e(Qx;<`{{ ze5x--on*6>jpiFG%Ph;Qjtk&7u+lRr;cosRFpOTs&chk*y|4n_k)-N2P=_@EQ(JFHa`@M{;S;3&-DCg3$xT3pQ9wyzYs!B857h3Sw_C zO}U}}YZpvT+a|XJzcLN!4Z@Iu)xOn9u0SP6tT)kexTorFSc^YzY}H@MuT2O$H{X7= zgS;kfkhz+!0yBiWYNl4zI?=4ZQZKZfn;GaUk2h1ZSD2*;&XM97MSclQ3qLd6M!Qfy z}Ynkvz5K^utQIg^AWOmdB5D#Cf_ zhy4V5iM!|hN&SaAor)6)psYv4r`v9+n!HB+OxuKp%#f72X1q6_i_zIM_*iQSrYX0h{n~j~qP)avm9tNoQrI&7I<=K;nO)XsYTR_agI$~l zsO5NnRE+*;G{odMBivp+6&kN};^!s)!tOGHy-R7^Wj7mPITuptB_6HxS z0qB(}tOA}*Z`RwY+wGZYy+KWWkeUt$r#vHo+!EDvofZ4Dja!t*q?2~;5#ENlF}zbr zU~)~DiA0CLD&L#U%h?@&8r`%656@S&uh5Dx<+QWfgS!U4rTpZpjM^8FdPS6+&`GcX zOX+0A81pedWt#Yvj|qLjoRrJPD5@fT%bZrQz*>VEYtPINV6#Wiv$X`~zGX3e;|{>i z85x+_e2|$+&*7J`{{tyXL10xVxo96~9Qxt;#8%eE8qq|J>^sZ{qCVdMMHe*D3-k&F z^+G!=F}@{yrH%wOb4HqXsY}lMg2}L=@s9pSJIbaCFN#85Igo73ms%<3H9B>juO<@< zUl$Be6Wsrz3ub%g3hRxt4N;0qH)L*x3;OOezX`OXsT2UbeihFHp8DG{>nIOgugx%; zPywQ~dJei9l_d7VcFJnuH@POPqjs=1N+XgQz{Xl6^I=aF=VGe0EiAL63skfcMd-+z>aWg60Ev=_o zTG&MVhmIDlCL@ZPiW6|A;fYQfFzRXR;>Bx@MATt%z zBB#PA{x@_C-iFiso8%Ob3hV1VoV&OUD3)4wEXjXa$OK-b@r`gI^((baC|UGTBdB$0 zDPl?B2I=%v-&nMRjPxZP{41x5e!#ZYa|$ejQ}G$VDR#Ec<4(X0b`sS|YUR7hZc)!_ z|61c+w~tN-c1&L%AX_=w2W1~Zyoi<#w5r0eK6i8H>VvdJ+_ z0K1hFsWB?SjAR~~^Qq#v={A!cfnI5{{(Iaj2hv; zL)>xpy8E1pk_RHgY#;6@>`LK1XY75!f93|Ul?IurAhNKP)I)8|?ZXMxGD0kR<9MT! zTwCu2wg&hWMSzCVF56hx)b)Vi`Tsczm}LGBo#&jugttPXdQV@#jnqyT7NxGW9vhRe ziIGj|qc_F0fHOoVay6^|rUhb*R0VCL6{3@-ZNKs|pS6`dIck^DVz0yD;S9W#wzqy39XlpG!bR;>OT;Lqhwtdc=D3tiCce6k^2E z-0C8QX^z=pcd#GgnCuRH#b)=PVoR_*h{EL#sQb8Hr_t7wS^U z!@c4L6H|4shoysN4{nCojT)&Qw^lKP&X`r@2>PXS8uK4r1Fmwd3Qy-3`{s*YzOD9v zFCSkZ_t8!h4!D7Pu0C|k#iadCK5ue?5UCGxR5F~x+E7S0^ATcDF;5fnu1g7)Qlmoe z)Fi4W-jbIguNLTZPj{;Dcm5XUwEPHI=!n^n9IjW?r&!hX|9lKtqNtQ2xN4)`2(UI` zJTV>2!wB0Wro1)$F^dyh_NoBV3 zm)S>aqTCKGaJ_}2rIqp~GdjPt`q>=;_tWz+WpX1L$d53Gs+4meZzrz60oY#L!F;BD zHRp1*@R>6f9I+*nE&Ms!6q5#p@#$(^Vv-dp*MhZ8C^a&5zC{Q#$82lq>DE`|kLfF& zH$32or^piQWx^-vZsGy|VWT(fY+r+I9n;j>)W7;X_L0|gb@MD`=Xjh}X*cFs2QE@HZ?p1lVB;LL_o#6p?|m+d^YAo z4m54*O5AN;hR@2mjW>C`))CZIe{AkjE9l*YcuUm=FtOedLfyhG`~x@>=gLR=7H@Y_6Ybeb$=nqfd9d6@rkGz9HBmh z_k>N(n|$Ok!IjSL*6So67ry8!-vMuKABd%$|2b(hkzZ!ihZ%GQeibUF>aJhJT=#cS z)l-6Pi%pQlk(c`o{X{oYJGp)OOQ&iMC-#cP3uR3XuJUBThUO->4$lb<@aF4`at-xX znp>U4&17BP<6Nv>m2~T=3D{d&efAvt(G}3I$=zy4Zr%Zr&n`kx#H+rjyNvom|!7D6Uj!pyN7xEYAxM zQ_}!6{9*$8b6iQfh9jHLwr+}1VJ0|UJIgp-JDg|ZC$lQrXRUG%qiou4r6!hVCkD%AiSNZz)H(5bP65W-=TvQyF@LxU?Wq@2O598YX=fJl1vk2pFWNI6mnP>sPsl z>N(Qc{26S~E2b`x26+ambJg!~78|y=4GhxTFitc?T9o|4(T>@v$HL}%AE9*ZbA>IO zZ&3H(-)ste0H)-vg(>BuoRyOG@Hc_NH$XTZuew{{@sM zM9pXPglp6lOea{;8q9XF8iZP_EgeGqy6{3THn94(?5m}7R#mcwr?`-;{*V134W(j} zzXhto4ZfS^e<}Bf4S4>Ho7Kuv>6AN#rjebwhJSHomFXy2xS3p=JVC$BcNa8q4#YXGtDpq82Huj>Ju9p@?p-`?Oxk*D>9C)1jL86-FmrW)Z6nv+9LbdxpP5hC z`l=~pgVo{D@BvkkEl)Q0m4`nOeBhJu!5O4>phEJ8-ErtBNEGtaC@LGY)aG!>Msu@0 z{KDnqJ7%GE(hjjgFrQuSUc#A7r0GfS%n$G^h?n4QbPI1z%d0VL zI?>hh!>DI(<_<7N#HWQF)U9b-{i5qIX>i(zu>%nlDU~FXB}csBPX#p@HNoNY#h7+_G`t^uiOxZLc2WY9aVXX_{=UB z`y5<^p7Ay`gKclrf<yXX4`4&L?CSSq@|YO!=5|TDLBYK3JmsC&uI_ZS(CV%snr~_)E6OpDE;Vh z>`t$Nq7w${XCRV>!G7u<*8=6*8Gcub8PBM>1`o!f`N`{@ zN6Z`axKMS+X0;Ou!lbpDnQI#2NWPe(5BMTx>ZxH{pqpA#%&_dF3mkCSJaxl`%vHR5 z+7b#&*=b^$Og*xy_=ixn=*@n@jKJxG9r&sE2uu$De|xnj`+;|8C14y8R`dR!x`wlI z2?Z2Y{`hZjJm1kcpp+Ar@&6;&gW*9*{*O-dMCV^n%H^(;KB+5nru){>gP4nfi_AlN zOCdUK8H(ml8UNFMkk#RAa57Y#`l@eD=%;&NGuoIR&K$4j>VUG+x6 z-R5EbWm3R1%N>x@{r%OS^FQkm{8RoB(J+*h8sfN2R?J4@)>GcvJt&46E_61 zVVn@Dxj6?n3csmGD;x7Sd19UIG>%VZ*5}h)@#NR6>Nl|Udk^u^7rSqq!08PWd})KM&33?x~+srloz+ z9^)y@7MO;~Jm2(o9vfSp-cJP0j{0zPoIXxRn)m1z!c_heJf>zh%2lh9OE5Ef0b0$w z495B8;4`H%j866wchM8B4Lg{qrv2%*dQ-w6@Su zS*;6>2~DUvWRz_Mh!9G%*B$R!51v|o1GqP;?IsQff3?06t@%Oh7J8U5++YF|jcMo~ z{VyO9dtK$5|H?*k2e<`k{on?iK%Qpy6^X`Y&7U3n7@NB-Hal%HUy+l*5xsVUudXcp z0lcNA z>6{|`V*H?w>NdWJ&GuhT{9Bu(`?E9Uq}b;}ux{Jt?_5xW6xQEw4j16_HJiUIklKk)~GXU+U(J@kBES)-Z&9IFhS~~~sdZY@W7pNPaBN<2S3CDr=IaSB zLux)MM*`*y@E)yX_lfJmmAK=+Im9M<229hAYK2;&+Yk1rKkRoK*Fbo@ zn9DzjiO-vd++wlhfm+wVcrlG#4}scNFv9PlIHds^Y78(Vz*OfzwE~(GjyKjw zOYxTO5YB8wq~3Nf1k2^`1xuN?VzSs2{+(jC9VZ(*&u72Q-Q^if^F_(R8Z+MUh-o5R z;h))Ns9|jtt0<@V?aU2kAv>g~uDXyL$~v8ok*V%-%rb5UCP=E$JE5wQtD>TFq=%0} zKlpEO9aus#EOIrsm(5(rwiTMXma|6-%jt8KSN3AeYOY&q9p)2dlZQI{<-TLrh8Ng} z9IcCrs|4Ryz0Y?|S}lnBDTPs|0+;%aedf{YMjJC7oMewXo!p(&{krKs3C^gs;k=x& zFv!(m)%`crH^B$uJYuPLxm=RlsfJ0n8vp}SKFTCla|ci<;R+Yh81xf-Y;{U`=h;P^ z2@F#x$2B=EVToBeu#xU9T*aM!57)%#V*cZb7wW3f!an$1=?-@y>;+f;lqTV=Vmvum@!oE>3C`ELXXfoJ4F?v0pb)&m3h7v>nE znshpOvRIXvW|U=saXh($JBS*9LeF?aa}z=?>`~unTtCN9$F{*8YIDag zU=m9Al}kRPKXC#n&i&e3ykPaQHD>4B*}`h3hPcwPN**YDRDIx%qOd)aqcGK&@)m~P zXs6h`r1|JGm`V-Np25F;n~7rBX4XlINmvMN?yiE@ao+t4SyFq5riCNa{#q+O4Mfs* zu`gO~`q>{jokKB!2SzDw7*i6~fNQL1VmQiY6O`&XTZD634<|$=>`~@5t8XEdN@|74 zy!#q8G$g`^nxI`*shRBs8hJ!A03Y-5PQ z-L=tbhVlh7Kodj1!Ey!3!gmfR_Gdd4ojPjk*@YbEsyULyi7iH zD_yd-s@Iv;fs$-vqE7ZQF5s(gt|@53o`l=QJ`!!OOU#0P`E}L;>I9mM3EQ_bx}%{e z%A8}qpvM8mylLl(h8bM5XljtU0K~F~q0If~bRiaX6g%k);yreM`$eGWKNK7xVG7s^NaY~u1H{N zhd^hdC1zpR30QCg4I*0^Ax8`vVV0&yp@f)Am*>lQN?Yx;SUB8!OGWM?vvl%a?N2&n z%xBA6b*Ry%Xy3r@bx_halV*em6SQU9M=Hm{u zTFIXuo8R7 zY2GW;3tu9KCr?s0Dd&wD1rabsno>B7xuIp*mm6i-<<<^7?HOm30T1Fo;BTle^gx}f zRASmWig|EvCh37c%{<`E@mx|2YDZpa_#HOD`IeWgCS{BDl->CL@rBD|H`A+)bw&qn zwm;vVrqz(j#r$mG`-!-O9t}?u$y_()Z(}(ac+AE7?ZuN11eWN*)DHYT{des)f5}gw zMQka|iT;~wE$jn}j9dI_E>mAh{Sj!Y_F$eG*Yy=r4fURO!;CeyD);5tco*C|&@LG3 zYw9@%(h_RvIp$kqnEBEBjw%7B6Oz~=WH;7$e}*H$x3qkvi;TRbi|0HwFqeCe zwt_w<@=V!q4l};+q4sO2H;7kvlM;L0sK9)Hd^iWyHHY9fH_aTP1hOq#&OKM%02Z11 zTvLgod^KvX%EOaKd3TuVn)a1HqHXf0%FBWRe8w!$<|-|;&j2XXBrBYuZO}uo5&B@& zQOmOPS>Rl)bwTf4He!zVFjL!g4s;QQ=uy@^Ga6=q9HT4yzrac~iYSD64jVNB)uLB% zrRB)rA?zHTFOBpBaE7_8+C}Zaq=bJlQu$&6o@Nyl)0&az`4`xpP({tj$t0VJAxzGn zDohneab-*&J3zPq{_$oKyMmjYbx^Uq31TVbcuJO7im#D;B=ne>DEx;?DhhCN&Cou7 z6>_nlC*CJ4w9fQ%zk5E1+UxzWP2qn=9o`Ph8Sz4i(17Hr(h>3u*+QA4cEHKLM%q>| z&wn8lLlzqAz-?tHT}jwTraC?5LU7RjV$ZGouHpWAB{=BdB5JuY6gT-Bsbj?$*ptai zuC5FsxA4oT7EosD=zINT=~rMusQa;MbQ1e2Fw9)9&j~-4$2x0U@#5S2z<`skW+>Ky)vo*g|Nfpx%4>HE6o6TscE;o-#4lj~&xvTb* zVk$G#_f&ZnlP%UJ&eI#sG5B`2Ipc@HO+jc%^L;Izn9tv3nQ8I&_d}K_i$lUV4~>;YrG`; zMSrYYL_nyc|1R`1_VHVEOgYQ%m}|>tiAoIT-b81H(u|SLEaDMcjXi2@PSaAB3ygeF zX(4@<2jf=b0ey%4p>G77?&sM$iBHrrMj5iT$>eB+2I8qGw!fAR#zU_et)3!ELtQN6 z=*=DR{K+>^E#We~#yQPesm?dv zYaRGFJ|&(_c`f#WhOkkN)^EeP!WEp~>gfELXzQF38cjZ6=F`>V_vg1_EzpENZ6pLY zX8cMyt?~RosteO2a}?Z1UN>Kx)42!~$G(&+CQWe<0yW)zP&m*;t1NE`SJ4DDMypM9 zcU>=BC=9S7xm#jCqP6p~*&cOM?O>6A2QkfkP}poGz>lGAdDTIy)K2KX(38|jS|nFC zR5|YtW&*4mn1;5=N7;=1{kYBQE@rSiQP^xZ)m3D@ULD!A9Qi@89QEK3HNr>;Jkt8XKhX|%t-YZAk6~Ns@`T7xZPoRKFmU$RQ@=hbpXG>Fw^kv-@u zOh-0R42$2bG}knuy>J=qJRW3u_fm+>!)WQTt87X2exNhlojjdc9(4Ezhkk@6n2T|~ z|1!=lkDo)ZSbv#oaGp8_F*wnXid(ehFyZ;Uv!}aqLbSO!gbg6tFcgtP(-*`Op%B`Dd(}(8 zbUv)VMEwfF_@4PYYHRg%_F`$wV>)m4@-(quM*~26G>Sh+osdm9u`sOV`_oc72L6!F zDO;h9Uh6x<^x?M!CZN@FJF_pyAQIf0;b3JkJc{>~Q{u<52b@bSAFPhH)1TGGs0qwS zOSkGUHElA=Wgp{4zm4eQo);!lE}|MmZ*faFPkN#7_#E0?AL;rHmZG5a9bU-&*D>5$ zifXY3TqOd#ifAj<*nxR|mqNb_acV$ahz&)J_3q9H^_qjW8ae8R{tV4UUfT`)?29rs z8t-vJ`?=Q9ZpXB=F;+Txk!{Y^MLon7U?$nXuw!2lAv7{}hAN^u#v5c&vjU^-@p5G0 z0#g=_qev~}Ddzcv?~>D%#_CeFeC9a)pjK6U+L%u)BYr0@W=u(&YVFXow7;qEu2iwD zr!{7Xb%Ag5aqdKN2m8ZGMPG3e=7BBU|CXH|bCUlT?ALw7WwolHtL?G(44QXfzPVc- z0v0>x1!enL_dC;&>LFm-V<)+*5D{`2GEunV61o1)1g#8v-)g3<@Qks(!Y1r`aMu>E zzH-I{Wj>D3)ZN-b_=Y|$MrXfH>4^@>IY}D0DSKnW z2^ixTW#5A2+>_)Ty?`I4L`k`>8Y!@^>Bc3aoNu1^Ou~oSqIzstorR1TIF3j=t_1C1Y!%RKh z^+6r2-vFGs&tF;_iRV-I#e8JJWObge7x_EP!_y$f%n>#AUFvRknp#=gEQI9mj>aKH z?dAHxT-FZhj%-F*>`ybg*xNhznWxzM)M9HBI1K(X%BD;JFNk|=nsp%cF3EUq`Ieaf zQ?D7r#NE3ua7oY|UC^|yHk~Pr z9YVWOB{&QLrLS_s`bK`X|6-I2Mx!QjW3jSWDAz-^j6FP3d!SOTIB`2B9(w}QnPS#C z+L*Y?b)8~mMUb=QzO+YrMFah^|-#d7O`(~bpl`$P+wOQZj$|Qytx}5%a z^j+@Nq@~R~aHO`@nbUor5*oUoa+i>?JNOfnY!kauVPT|M$ zbAu~A!`PY3zad<=EclIF*&wVx0S}r&JKZ?$Ry=|mx!)0b>P8x7x5^E$oQjZX{wf-dfx@kp&LM$j&4E|Uswv9GDC6vR$xy|f}} z9A#QLpso19Gm3gFwr4cZ1l&_Qs&^a{?0n#=d0chivwTBgW#w-)Z>(!PeMr}6=M8!^%eGk12QNY zSC^yoQtUAoCd_0m>peqm*8}PtOS0+pHh~_&lF}NvzOg;77t_uC2|gEm%$86AZa-B@ zd;-$&ci^uCk5Sz2nl(#~=l#qb)LoTH*#{$Z5_YZv;N3h?lBI%0QnM;Oa6EC<1WclP` z`gqV6FrXRznVN$A%}x(rOc1LR0Ajumro3X#}4U z&!kt%Bv?n<$es2S3ctcnv-@LCV@XyQILcoFIf?VhfEo!;!F~sNF^%vnt_mF^kCCIT z8!&HrQzKt}j%G-c6q<}}R{mwW%Wre4n*X|^)Dip! z?{TFt|3tm3v{%k?zt#Nm3B))Vq{qeW#^V6auE1Z!&t&?-4*W<*cTc)IP8tU5ay!(% zU@*8Mx_p>tQ^#>*9uTKFo&uISLlPW2V!DcO*<^N@h`Sgncd2P%<|ZF z8@1;A#k`!o0zT5J!1BT(N35`&4LTU~EVS}|F=v~@(9Ln0@!Yds=&GNm4uf*sYYvAk zl)8}w_c*nlT2O65eIl1}Lk>1nTBQG?eDaKxA$X$no-D{TgAJ*^j_!DPZ|WSTFg+%< zR@jxTV5X~+JcJ&A&Ka$Uz4|;mpK*h9njefzb~>1BpD?Eg(fAmJv0O1zxnb@+-Yr~L z^^G!_3JQz0;gW()R#KIVSP4ieQ33(7bxxR(@X091Tr@8bQL{<^-#5kZ} zN?!FPJF>RPt?^gz-!S{C9l%)rwj0+2R2KEZ^|$d^AK(zEvIM~uWI@&n%}07bMQXNo zDtoQ^m6ihowEw~Z>xb9RX7cF{pCHg*k=3OZn+>3!RKFqx9Ser>3ydpt9nT3!dRe=Y zz6fgrrishbKO!bfvac5`?<4d*B&7|e$B-EKFZ5ih;q{YS8*;kErqKV;2X;q0s`^Lz zW(2!sFEaOwx$-lPX4(;LVNZL#^Jda-i2VJ`LD7_JCE4o<(0|$NRp-h z_5H9`v7M;vc17Pq|89DY{?ynDx{%cjJhOsXhNZPaju*@yX9?w(>yHA%am*SfHL9qw zSz48ppAbMxX=dhF)MjR5Yax^BPUd|)jzy#c=3}?b_|4A9V&4?L(LG=Qeu8HQ{7ug}!8mV7rK{rY3X*PrU7oYudWle5hY^AIyP1 zv^lvc?2DAPTXLn@BJ3z>8`&XqD4s$tWG7&ch!b^gX%&LY+?Uv5MsXihqTn{)7;zeX zLBFePAx2vkwUKLP70RjRi^**pxy{tJi$}7QBVozAnM+g{rbeV$MykC>pD0IbL9Qa( z#y^g^h&JSduyEEgPCML$nd>Q`Y;5?=Kl0En?ucFtBf0OMXQ6X@s)1|KKdcW}&^r)F z-XHF2`YLy8%n#P%Wk6fl2{_n&pmS>T#7@*f{;__}7)YM8M#9G$jH|$HgbKGLvo08E zPT-I6zv^_M`m%ndMWAO+N$QeTR~oPPcAfQJ!Uwxd^k%81)NWV_JaLSJdBcI6BCw{i z7Y9ZzONmRg{#0kM%hL>2)VsrNXi}r`I$|aLQe!m9cTC(4dmzrybGEqIOE|4I;TdWI zH5ya2W0{$#SG3YB7OT^BsmokcPJ+=Jb>m9wI6W#nf?W&R225I@t2*`?(e@Da6%nU& zjZp#SIBAd2dU2>bpH&DRQ9iq>1*hr9h=ce|s<+;j3c5y8Kf!v$Wio@k!+VWX=72(L zLzKSUFnCQnLbOn7MBLmv#Q1;RaAIb$U^OVzM&B;~5A&XCg*xL$n8#dqv6fN_`(mxq zcE=7^%OyQAa)=GGrO;4TXNE2k#SIUEX@|HDB97(SV`3i&R8?3h=Fau3@!%j*1jp-lzYrI<04&Lm<94WpRsFv9kuaB7rG0NF)Q7j zjR-^X2aU1nIjja8l(f&>Ym^h`=|kiaOb@HJk}35UcDe?k+W_HDP?6`?TCFhPgoJSc zsslvtL6o7nP9#S?A>y^&-a=qbqd3HcT1j7_z6&S7L*Xc5Dh~Hgf{FA3YM-kwT;{r) zwFrL-?=k;!kF+OVBI%e(;%nXKoNbgL_OvVv)+WQ~-FaKB!0%Tib6EOl?I80Qd1w#A zDeB75D)W{QiXSAXXp*#$h(Gj$c?njKzp=+c6jw^ETjyG&xlhp7kk{nn`nA|&Y7)GL zOm;PFhinV=6WY1lv3ad##A5lm)Fd&Zfy+NtYi?B1=H^~!R;U$;AamX`Fl)XqLm9v2 zsnI#e;0Kvu()&n#P=+2Zc9RRbzw`Nw>((blGfo)g$zNO*W36$D`5JC4{AV;XZ)laV z4d*xigXzLX`J7$`YfL>i55k|3O|W((1RitA%&PQe)MlYQvZc)M z^i^9aH<_t=b=SbGOVoL7bZ7vg7mkY==-%TrQ2O?wT2QjN9VXhP(G$5RUH{#`eg%yn0n!t>L$b=e1Mn{c|Aq> z62ulTpD^Vr$|fo!@hW`^(MAi>W%=$_1kYwaQX|0=ctPxA{Go43b3$7+w8!=q_8stL z^50#@d?`S~+HxE8#ksh4B0HHMC~S=^^Zv0LbLFVca1P^xcjN}R2o49!dz-p1qUlOQ zuCkhC6sC@)t;(8WGvP1t*+Y_Ek6$Gu=@vdfS9^#aYETv z$FuiPyY3$<3ci(ZY^Rm_^kuu0syUPB=f*mvG<;577#ttj^ zi%JH^l)|tqVgr_!^TI53m&G%!)C&};mk!49lIO0J*BBdW5xS4%@Jar*MzQ*r-6^4I z*ha0sKgPa*cwHG(1#p?F6zQN;V8zHq^yIw;*~BIIik)8jFD=PN-QG^_vLpu z@~Hy|PAg}y%*1q}7(X60Q=Tdrp;$NiUvX^-omVF^*A$$tPEBUk!E1UYvMTCqHc^^u zXK6Q6Kz6}tFjn<4LGlQeV#?dEsr#kT++JOIiP)Ls%Pb#}AVF zD$lUd=?8pqT!!FQo5&Agy@ahW!DvDpMbjOY`6>+(lf7M}H|$J$EIOxNbPpHM)I9cy z`5$cuF}j`N2Lpv((tM+6+9DCxlYNc2pZXuJy7DJ-Vk9R9-rI_?b~OoKJWIHla^91x-~C@qet<;$kXM%(u0c)Ew_qp*&qg_#;l@ z!n)H)hSkYt;qRnR@61==P4gRBS{*^=2G85?=)TerCSCYo_EVZ@Th_ zH&j20Hu^&(_DC2pNqZp!NFnQ;+lr}!Eg+-WrfO1rSvUmvUGI(KXdZt+aF``wKCOs! z&5n~lFijGCuz=7krwAy|Wx$i>Vzj~D?<$oe1A@vIGfa#3|Bpnu)nFw=pkJ%mv4yB@ z*0b7MwHGX5HkWUuf6-5d9~%?gYhWmAy}p$1rJP7i=BC9TGn4dZ@UeQxRnwlGRVw?r zktN#tCEq*oyKx-lOGKrSdPGV};-!a9M)<82w7*L)wQ=Yiur_|AwH!>0F#dD=VddOb zCE_ubijBld24{JxwN@*mfh&o>$JEeg+D(+ZS^MzD)VG|;S|?XsGnO987GsY3S|XQG zGAsr1`98W^s2ez+@f{{BHUyFHaj%r0bTvM&Zw9@KoZx$){)2DiiV>BCyJS9dfv1uu zmD}hZn$yUN4^N<0Dfz(%>y7|O8*D|rhPH5_P=ae5xh^jhw!02U-H8VruUL_rIfIB- zpcR@UjX}oMdRlA5#En<>NM3bZ_G`^e6hw*Wb*?`8blwP@)7tuqk(WU)yfS0B?g)8~ zJU3^0N4v%ZXKO{wqecb2cS1ve#_r&GthVswb-if#2z45kn_p~7x}uNMcWGsO zG3+rqG3x+w^e*Phr+o#t_3CmXMeaEZSscU?%3-w6%aFP_a{D*vzE%6f|RhtrsnK}PGP`pggD=pl?WBtct@|)!vT^A2x?~rwOPUdOktTeeiR$BZ~U-|4bVx80`BVVI6>R`4q z9nqHA57e4;Icz?%U^L@yfV1f}^oh#(;8*xb>_$+`GvyNVQ+q_ChP%DdY;6SbwrDSM zmF}U8fEJl1=P}m$PQVY$EYLdZst@N;HZ6l@_eibu=6qX1Bv0y%?V#tKIaFVeUc!E^ z$mFl=#;%dXFtak^3${2^X&C$)!7_Ks?BzH3o$|3xySzoV(9npW}95bi5--;KA zw=_Ug@h3*AGS=UX{v3M4z0iN@E8s+KkN>K-EpZ6@O^@WV>K04Vw8vUcWsZp}fOCab zD=nd-9OwOyNF z%z&G7^T68L2+&D<1|B-P2wjZoVn4PfyIojR`643!WJ;f?7tBoKj;A*pC719Ivz`2I znbgY|wVeAh_IumGEn2$oJM0c)U`Ky#Fpx}hma+>oXSs^%RwH1n-s1nh5Z`LRBP7FASK8N?)MLknV-<<@Or85Yt29bZ0^Ue8|olx}tt2%A#yT#3*RJ z(r(Eyxhstvs$-65v=1k19<=%DN6Oh>G%uxyUFl@}C-MZf(qj*ul5e>@+-JzU^HHa; zliHbjxAE6ZMgAV*BJbD!QFcnj5j*W;;3G(g*@|Bc)1IG14m#lyur1lk)!I=)NFkfD zYuIu4K)tCjk*Z|RqUTCMX|(=H5R|&~WH7WkgMt_W*D z1AJ3bY4qu3aLAcaa1e`oXTc(vEzilQ&rKzy z*A*NQ-qAz25%^n*1pjEgIVGWyatx*C21rl%YmvT23hFx6V;iUuZMf@+e%Pnr|Ehv- zDz}N7svP3pW2J?|fsX1Eu!MN54yQN@7@g2mcmil7b})WPsoX?$StwUHhb-We!nV68 zH{YlWUWRIl<@HAM=9N~V!&Xs=qKVos>B|1{Xdy<& zg-|P~se{$8p?L=mM>sYl81M=8kvMOC3jWKn{0%&bUXpo8c&qn-=}dlZIBEOl(s59w zmgr@)d5H8^$up1t!XHGeBnwj=gG`H`02{$(uD<`~-&D*2By zJE$S9J<%DN)s^ThA3q0Pdu|e} zhv9JkMa~2NY>1mBtT9=C<&Nazg~OS{rCRnb+h51;lt zCXbn2or$Pz`59%FO5nQp7ju*Qppqt=So z!gF~VBPK~gen73sj|R2udJW!K1vET- zR}P94-;NARPLRenq)2X?a7uO_%)`rEEc=|fp14pcNnK_fA&V}8W{~;)6FuX+xv+<; zFY%Sa-9OA|{13Yw*<645CV?O9IB`j2HcEW2GlJ@1c}GM=lZF7J*_X9(BHQFBBL-vO zJa91fH?rccNWb1-fmS}fD0M^ItbR=UjPH0~f=*xdc8s zP{{}zJ=ob)QRA#uj?Is5w7WxnGV|!i$q!N&_*@xkev%rvR;VZ0ZvGGWdgMZxPru=> z(FF^B|l09h7G7)s0wpE77qz0q6~^MR00bCNUJ zET&%4AoXtI1%F+1f*uwO3iY$wp;>4dwf(_1_;P-&H=+?S+n5fjhdV@7^Y^nRV7;hZ zcYd1=?26oS|5COGf2Fo#N-2wB_w+4jgEhPk%H>#7ec!Vu8@1q|G?<=)?A{&mXI6eV z*f_+uN6mxh#(4avA(Ab@P`X0y9W0h9Oa8B6Q9TKaXUp=J48xjbZ8#Vbzfe{6?uiDY zFkxRg;=A_>^E24nmkgqerCDK>)eiB)ebZwGJ(Z{!x zXP$D~<<)LG>ll4eS`Y5n9nuctKbbf!f!Rtd@r}kx0**S#K&}?tp4-@14$je?%@5=s z{1{zPF0U;DmpK>RmEK#+b_azrETAj6x~hYyKR)^J6|F-esaywl<7XNRKvN;w$mB9S zKjk0999Sg1N+OS?%c<%Pau!=exu4k>!_j`1B+K6;iYB=YsHQDM zZ>=xSfz^mKXR%EP21W~en4Q>QyQF%<^93K3S{#uz3xn^hAJ+9R_sx|LO2ZJdyc@Wx zT!*zv*?`x-rdBPb>c^X6Q6s8F5Eq z6Q6{?lI79as-|ySLUnn0RtEBJHXv$qKIFsDw-kVb^)j><r0V+ir0!hdMj|s)t7<*dUMm~Lr_=@}ua%GLikD=EnTOPXxSkoKSE2f0 zwalqTSWjg&>pZ>DwMT1-93UORTx%^8ga^#e+)ed{8S5_$^6Ty49~tsCO@k^_Wkl3E z88{w0M^rL=LNDQ~U{j@X3Z^znI0Nrt_kA_>e(YWS4`oONxHpU(xg_NvlE*83CjW*_ z)cg2ef<^vSTsy5L{2EGTdIt{kPUVNXRvw^K#aaF`b4_lBjpp)r)>=9W6Dr*=RS)5cs;yuC5PAHXEGHGZ*XX=ZNHY&g$4 z?J2>oS4*>%$?V3Tn0m}ocT=gjR3lO>Qk}lRj$kf}G@V!4BunrX*8v=K7NrZ8udFP~ zDVMuKye3BB|0oJmi9ZMMpoX;u6Jkyxzu;f~l3=PY1~m3>hBNii_yKM-xQcdMi>bw^ zeOsQ%FOTKtGG_27-Z%V0FNr-so3Y;Nc=M=sIBt*?qOO4}!ejPt=AxVuUm>?Mz5%?@ z8+wa+b}NH&y@)Z+(hZ=9?-oj0ow0HHjAyQSLF*v>WK%&onzHv$o45j4C9G+H2VOD;ld_}jQtr~U2+M~%3742ux<4E) z!_f*orI9e3QGy7j*3e#g$8139SgIiF)}Hypil=Sld?t9ycc0H-lFQjjf4W`W{#q zFB*yB|K~Vj&QY7Hvv4K2fZ?(BT!5{^m9SWUj~3eSYtTgNcIbM#)2I*YgJf5Kup53;*J1(hFlHrO9T?8{2TmoAiUjniF|r(6 z1FIlsZ-cXXE~d2lYP#tVc1mfe@0r)Dta{Z z+6_d)cwjeBYG5^u3w&EoK1_EP_jZk{XcqLQBA<@M->1(E;iR64T@TB#S>3ZTY*6Tz}5!- zGy3Dx0|}9~;ZwlCHt4gBjjRCs#id9!%xG!={ndCScB4=0zl3Oduf?M7IcDxscH>vT zbXN_t9n~htb0Z?zc%rb{>TA`C`vu6oPPJ0PO7l5i*PN|pNXg7$ZJ=v|zZ+kez}b-o z&AkLwh&_0&?obkZRp~F<`bN(n1_S%MlT(Fc?7MbIp2S3}`NHq0R?H=+g1fNPL53U! zDqA8i8)ZR8LVe#RFgzN9m}$zFgmuKn)_X8m zY>v9K3lQt88|z~{Fjd`XpHx~ZqUO%Nti81&S(mY1{>IvKY@k(9cAFEqAijrwohE7mHj%L3KgPljXpl=NI);(X$52|E0ht;WtW`4-ZTG|Hl zIlLb#CJ&&?;}>!UnGrFA3W2fAd{B%%=O0AYPh7#aWhTL)ovoCdv(TC2qO@kZs_0l%dV`D$Ep&20H`@W`##Ks;K&js#Ya1AiNkJqf?pf;6$V> zQ5aL*O{_xf%&qV?lD(39jT)={G#K1FM$E@#8W z!<;B^%{vs15bH*^V0YniCfDVLw?hBZ-SJ21F8+JWw#=qk--#zobMVmn9FZud5+Eay zdx5Xk%BzQw|6-voXssgG_-S>|8x!DK_gWK?3&1S)r#(-Nk>66y=m%U)>nLReXmts9 z%Tq?F#n4)1d2aUbq}C{t&rX|`NfaA$HX?5nvgE*y6`H}I9r2xk<+QTJDEtPyNF$gTc+w8t^A_q zhA(0c`I9e-n=2dKxP&aRc*Kvky<4fLumZEtGr&G5UZ7K5m_(IFd!u}wmT2#vH_l^~FgJwmIJ>0%7OQhC|Fa5UhomauTi&x5n~ut-T$n1>^_iJ#mSc0tO}hr;pN> znfbUbd^Gi1J)Av-L`hIOS>8()#d9>qj8GHJT4W_SO0TCdp8CjRS5=l=w9$?k$3$^} z6&n;3%Q639ucVspfAA=yHFPE{AzD!vKpaS5r|R>_O?X4e4=q~pwFow(OXFLKjpjwW zPPmy;NO@*m2UGD(egW}?a{4Pn%IzfzY=$zF+(eeiEe_wLH5cr3npudi zWhNTk)f8w(faLVj?r-uMj~!Z^^T0Wt+#;mM#~byqS}83wOL-X`Flrh3Lad2vHLP+8 z?Tl%}8gI39PTOPEBA%i|+#Dt&@g<$hO?U02+}s;5o$|YmazA7LVjI8%o>`tKt|@0` z3G6Vmk(z8gfyFoho8@Rt#aUwqJjBJZUEw-rxccsp z>&Q-aif@D7EoqptO1>ZkBfWFVH+ofdXj&^cN~xgrR{s`0Quma+=3A{6xUGETI$5`j z{4N)>Jt(Ws;aasIXn;25T|Lc|*G8VuJNka4sy5HmxSYszW@2DHIJ4$i@Gz_b>$|@g zFT@J$`Ki$%&vOT*E8?sK@K-$k$Z zm)c;|#c9d|EeYS{|j>erAZco&rA z_14O8-mIorAN?2eEmRIWjl$2B^eYA}-_%O6aUrY5SR)Fiir3&|^0C;51;A&gDcj&k z)N}WQ4uU+sW3UMvmy$ux&Mv7$VfC~-fu1gciz9!ixkNc7LBGzg5srueQ1o1XJ&>N> zP?(4&37f1!%r5S-(NIn&FzPR?Ug$}}JgfrBUi?Mto68bAwrwElWOP zdC*%Y(1L$~c#QgicVw5*-vt+bLH!7Nalh$0R_V~y?6={X{vn*Hj#e{0#ij0{#nv|8 zA~;5Q?!Ff}v5x_>#SbV!J;QM!GFdnd{=y!DYxX_&Ade}|qplGNjH>X)PPoVUDr=}V zR2A`S@>c$)fx+Qe3*#*;CX7~*6AYWhR8X4;68+It-xq1KPTJC9jvCPJuwizpfxYfXES5?9A=1}P8Fd)g}SgG z!ne7Rt_SKE9VQXtr`Y1`Mq)nykqxO{&+80>97llgkR{T<5tlTw}Ank--m5+@*E#6i&acuAn3AdHQdrv(-nO z&t|G8v88E~Jp^8xF(Qe0O`Gr5(mK^ozG*yK53#7uM3Yv#017y(h}z zJO_!9U9LrRQR8oM4*6X@YAmC&>G`3V$|!P?HlWU^q+I;Feo)(JImPMJC)n30M6YVt zhrrk&@(Hs)GaM_9Hn8pCH!u>h7_U{`*3+`}1`(J`? zP>}j)+`B7=duT$A5Vn z^V!BF)IHv-yhb^w$NW)wCcYc7KP4s^o!H0IlijVQBf&U$hRJp{M1I9jS{3tuC=pyl zU8Pa}boj3s6gnwKK^tYYc}M4cCC!(JLtUMEL!1rn13GyduVhT*&L*Bj4J<#=oGz=B zPh1ESnf=&p`ZXv=bX9Kg@s0(223G_ugGA&2(TMBk>ICv?<*7E#jS-RIjd|2Ml!Caz z&PC?PIOc*cAL6*zg--bar6=s9O==0&g9zXoj3roAUY2!O7NGxZZYVcaX+v-2=ZD9s z9oV(_e2xW@J4#Dr4n)ei=LIW-3JHhw#oD~kWTKaE$+&tV_s~CEtE|=C2ePLdEJ^xVo z_uyFxl0fRj6wI2M>(DQ`6R1nR;_7I^8+l!IGx>@sTW_wJOdsQQ_SeA-$6F!?pJI-H zSF^`jGWZyJj}72HM#jn${0UjzaW3hBr$yC)`eRSSMkTxhGRL46Kfp}<7UE`XQ*O$g zv(GT2cYUCXA%^U3Hs_4XCHCXO(dO#n4P~?@cn98V6#-^Mi3rXZ#a*X?@|d zTzMe{2M0<~tBgbbiS#I@J;3mh4iI9slUildHcJ_Iz(KJ%SS-ZQy{+GPFciUN$KTOw z>SxgJaCZ#}h@N(c#Tw4~1aGLD)xCJMvD^rXTbPerchgpD#_e)%ax+>r(uO$mEVvBi z>A&(Du{k()U?`Pl>@cckuChvl58jzfj@pGg0`{vfor9=K=0-QmA0W4@*=h%^Z_Fi- zkt(9x#5J~__ZC(Xb!xs+&8(&wzuB#ZK~2<33Ps}Y79^qBU5a8H~D6|7ur22jaZ_IATY1*pFEOoW?fP&QEby-P~87oA90j zNrlZBkfz$Z51=&3o;vMZbIh#}*96)04h^P7Wr~m0?TqOfqqgzI(Iu#dekV=&%9HiM zHqg~oK!C;>W)9U4pQwvUJMNBlU0kT@Q8y#a5fym{tSMDdYWvf%7m>Ye0?c0a>jsf_~}zGysCJ78RQ zo^_o`xxy659jx}pm%`(CaXH&O#pVyKWEW>P#g~Zdg8QwDK|c71DueEJKR`2mocA** zi0>z+1O`b}{bdt78&wof+(3G-tkJ{aI5E~9iM$De@mJ~y)g=!14A-uu-gQaRK*14~ zm9s==*W@rkp7(ZF=S9DU^QfW5cK&ht8{#uF#&H6c_T}Y^h=sK+zS=N>OhI&%>bVz` z#>{8=3qFy0?wguEDD)dyIRoG>H%A*K#&GLhPe5j35529jj`|g5xc%ySbC9=%-OE}; zTS@(OjmqFRhM&=hPs+|RHkk*lMcA1mTlgyWBDe*fWHwpjxxUi&1c7?!yiMHWzj4+0 z=XeKf4eb(Z5rgsHN;b$bM(H)UD+!z>v&rGc#%IJw=z_TW!>Ki5UcG{}1v>m+>9%yY z`u9};ns};b6c8HQFZupre-k%{VDs25S+i{#HQ6d-+rqdh zlDE|lAmx_5adfZ|>nFPnSPdgMN;7xj(F?Iebk3I>iBF&lAx^}9?5&15$rR_=Z{ zI`Wv$aqZ$G(xRM-{$Vs8EJswgYKJCJcfcg|ECQ;xPjmV%A~s%E&vW*@XGNq&T19<0 z6ve^rve-3{q^zdOu`Uy${QeFqOy*eIv`XO1j))ZU&1G6pe?m{>-gXQ#+g>g%b=T$o zHS5!R@zT!!nL zT1j^jb0ZEqN`5RC5L$|BjfL7sj?WDH+Tf+A?p9slqoXZ0QTytP^>~T9Fbg<5)7fL5 zfxaBv#~fr-tb_E9E$JF5ZX;L0Ueq+ThUm}+3Ri`tYGZm`aEkMhQ6*KU!rU#keNI4X z5aEpkVxF2z{Lgq#v7QHJUjI`->W>@}>Q<~V%A$=>5k^yEHH%=3iC8599VdZ+HI@7@~<7R?*f}{ zh5892^$kiFJSde>S_>Gz8gNQ^$cb6*Cv*@0O?TnAIPV#)5}dDHWk$MYYM;6PQOl`B z_ZBvm>n7h~cB#SasqzG>IkE;g+}kLcc}KSs>nE;7F3|7vY!O#2x+=XGc|>)kL0TZ# zpVfL1e!AC{5j!uDqwO z*%Mx(XX`7p1miZFLcY--VA`$#D zQ;n|$N6VY&K9NiK8Bd)4GIX5kABr*$2m1IstEHF=QdN^f(}@b=6>hNfSa}9ynqa!=B=|t_7Mba5(F5bK5 z6s3*OpV&=p0%hvo85Lm+~*wR2g#6Z8UUmmsm7o8Kyqd+o~hjDJZ9MGR=cB6p3+d zC>UAiEUgSC@9}4lS)?m_U#so>p`7G*@C|Z19QsQvqqanF#|d%)cmlNdO(q7oq2Atj z=h&@%VkM~>TuSXxs|1HzJ?!m?HHb0#RJyqLomDU~pL>;Wm-=1LQrD9o(*|mI44}Fw z_rinWe*Kt`7JQ+8@*Jk_g8R6`u)Gt3wY)ChuJVb#*E*y9av7`SFUk6KiPGPCqZSQ%hE9Mtv@uyS|&{8q@SaOE)QLUK+Ofhk*KPAgUyb9k|tFtW( zhuQ;FRIcW(f(8iVn%-OQuY8FdtD7&be=4W#)4mCV>9WQY?Ettfqx~>97`c%~h!G{# z{FAUZ>q96u)V$6N`noD}tY&j;+5j+3)m%TZ4xqZ7pE~Z1;+_KoURCQF(TRz16?{cp zZlZ4Rr`S8VLmOc)P@L*?-&FYCcrI}w#;@Q5;!e>C!iS3Zqts7|PuWN&0bfSx&|AHI z_8zG>sK`8(+hNh%V}A}x=*>z}5hd{|-2l%|r;1ZZbGFOnl zw@WG;=_4QZ)`FL`Y1nV%vpOfQGC1))crAFb&BP>nh0uo_uf@`W`y%+j{|^pkFNoKa zcE%aLgm);{lqsRkqk>v7BDzlHxEruoPGjYBa0)XYr8g)GQcb`xHeMB^Crq-_CPoH% zN@5DjbLp|$@pBMx2DA<4;!#za* z>~8{F64IppIsY?0Dr#C58KB>Ay**cjR$wa3Z(KV3!97J7qt0M+$^DAcNYwV~ox)et zJNN9(|vVg_=7JrOzBo^^)qJMQp1rQXI>dvWd-(9xKI z&Se4O5n|v!j+gQA2@m%?77#zi`k=37R#M&3{r&;lkREL;Bi=BZLAqzT_bmvL3U$rh zP~SnU5C*HqWBwOBqW$NH<-4eVJejdxp|g-J`m6!)Iyus~g$STC?0>N72OUI2YG`GI zi;IiG%fVx617DOlUb8ov1?7{L>HBCC?9&qvwYCcU89d|f$_>;eH60ndjREu}W_V;N zvxj+z%%^)`$xt_6jw>bm6kj2`+y6Cm9u8GpX&bkEpdglDiL!`D5fHd_%9)unb1%9Y zHI`_IEm31b6ExT>mS|SRg0fN3C?X0f(r%ePXXak4iQUCr5qmEY6T2psZ~lilbKdhl z&+~iqp%u}l##^ggQ80On*cQef>W+U6ZzVS;CaXtG0^%KW_YjLJ)EVRxsSxYTJ`=NT zbE*bwl5XLm-HR*dmp_pK%Cz)$$1qdPzV9k^28>3W=_E2LU!v~hU49fai5<)~SYXOY zbG+U7?cVlEvOP0xo!5(xOx|gp>`FR{$#D!71B^{d_A@f#GVvaAl<7>Ej~}rt(kfiT zD3@>mC74ZKlO-K3Zxhe2qi|W32vnaRWC#j56+Le;K`vMst$bs(;Le*8~nPPjGj7Eu#cuQ)qL-R z>NwCvIYWKnF5@db1zLkd`fA<%J!aoz`YvOsmW7Ae!XfUsFBfBHWSNSFE4blyi~K~7 zrAQhJq{}_DPq+qnRw+_x`&M=}JIq^yrlmLQ;DUD3ixFi%W`5tKsbf5{UjbH>B@MJp z)ynCu{1f&9k>x+`F5zQL&qp(QkJ&kTmzsifJ|z%spMkP{Kfn=8BtNXj&EO-33H+)Q zgD#k+4V2nrt0bj(6S`;A-~K)>;oa*9`FS@QM;S=r=08KTz!aNt4h2lFohHN zr*Sy|h|TC^PpSIMInH&O{)AsGm<~yKGB#0RjJ5D6<_YZQI;3W4`$d#874?$}^mbq% z@h27~Und8VQOr~LH66*mPaCT&5m%Tbp$&TI{^>?bw=9*`CqkJ&=uJmqI^K_7qI zGM*hlg!{hS^u6e;-pg2~ zuIzTmbJG5t9vc|W)}s{Hbmq2f(Q1t?#5wOt?u9=T@ZcP0X7d6f7jF6sZyNR3FJhhW zv)&pjGUhtJga!Ic_8W16(!=P?Hiw&pCh9s@1Ann>xjc-y;_#RtkWlfkvs^u>cd4>- zr-+#c9A;Zoq1xNkH71jIg!RA@s15FxmCfyq)m4yEV?Y!9xc{))+C2d*1srt&=2KU} zF>Wb4$2KMSK>pJ+9G^jT!9VhQU~kVn&bAJc+mDYt=Mt*O+CByh)VT}G!AI^ocC0tj~ zj^4$5u!iH|j@EFe+$pY5ZYj5OJOh8Jz!`^aXMXe?)j5!I$coRYP6+PNOyHvFwp$9T zuv^j;lZj@iZ^=~h7O2wqV*~9|)oW;+(TiEg-*GhPXGC4Dg!TA6_|V@;wP*(K46c<( z|0C|HJJP(%2YVh1VM-m-%)IjllC71B^nu1;(w;QVR>n%?VKRr_7TJMlCazWDoExak zRZW>3DxX?H?#8!sO-oC-p?(9ehezc#N+%6&2x z&h;-yydiL@VKBaDZ>r3DTUiEX@EcLNf2zRQy7f5A*wA*-p_N%X>Y_H3`4Q%FZlNDL zKB*10Tc0F-V$7za^h5qW+}NrdO9*{ZxXyXNa(5}SNIK%J^BLS}$;B6;^O_BFJ0tBp z(T|X*EO#2J*P3vN2{Rp^QWKeT)lTrhPBF&oas!~Pop>FU+|fIk!uXG!XBqw;Er@Z&TO<_O!1V_!bZ3PTSXpb(mki_ z_mnzGq&fyG)%iS$>t<5oM1T&yEBVFW#8h#0M<*Ft0cvfx!J)24zG9^}H@TNztU-<3 z6k`^&CP;SLoYAs~c-s%|d8khJ!5H#%X^H=viY9?W3AY1_{Ii`MwH@g@e2@x)QYLJ7 zn0*QSk-cdo`nRhk6d;1~H$IV5sTRx(?O!lF_^Fsqe?jC^U2P$bP-;y~nmLkAQZJUL zDs#0#_F}rkeGZ>cTUy?ize7Fq_A~xKeb9X_$PZ50XkV+{O|iKTB)>4{uRLNA2vM3B zeHYB82B9xeUiFUPMkPDrD0h}!Wp=<1pr+$#wY`i2$CpM!nr5~!LDAp^&)~{qq?7v6 zk=dG>$x&}1re%zVo$*aZp07f#t22cIt`ELbSiJqb=z*=!3+5gCN9Aa0Cc8^|s$Tac z2BeH=dt2J#y+W>XwO61?e`e)%e6G=G9P-~hSnSv-vA;!$1Z8F!_vErh+I9(M+=JmY2e9*y!Nr>~j#B;^ikN^fliG zCG36pNP6eH!RBe>(*MuB2K~kjbzibhJrwWy37?X11N@Bkc!p4&*lEmPYOb zp!6)WpV$T$m`vM#dI5?j^W9c9Q1d71#c?-#SQ9n9RQ3h$WMRx7MF*9nE*qdu6&m(mmVKXiN}p zqHeyQ*vr@*c@(3VB=jUCNGeTnWT>dJ~*VVm&~iV%YXSeeRpy`gWTIw{+3 zr{jo_Y8$LYPf=M_2HgfYd#(&}T(k5KfJ1LRP}-V?xn2v~#qWHDman*Sxs$#WVOkB` zUsfO=qZ*`hsSi+SaDj?bZ@u;S9PBr`nkBrJ4xtJt%#`v$aG~p~ZP^s~~ z^M^9s-RsVg-q;UZ1+z&orw+jb z|L!@I8pJ$PCi1v@3%L=Si|1wQ|i<4=~3Gz`b$w!`}P)8aA{MJl2z{ zUZqBwzB`9%!Q0C<>{$3(d1r2}1T;u}E-r@*>>}=}sX;#3>$kX0!ZmW+;H5yvgQ7pg zm+Re%a{TF37C7(RZFYrO!Hdc)?=Yohpu65XfreM87IbR~hzXT{S5M!0lv`Um#$+Aj zfH)AQ9>hg;Jl>nSb#OhiO9)#`KIf7_K>c_`v@I9NB&z literal 0 HcmV?d00001 diff --git a/test/cuda/query.npy b/test/cuda/query.npy new file mode 100644 index 0000000000000000000000000000000000000000..274a1851c857fb0bfbf52ccdb6974da53554ff9e GIT binary patch literal 89216 zcmbT7b(7mj`?bTFOkl#y8AgMkCADNtn3LC}&~k|)V-m^RGGKG`tS zhV%3L6<+`0Dyv*}snu}qb6;m;+b(U|_b6CkYJpiXgNIBSG%<#)8v}>BViN1d3>`Od z^27mS`;D78c*y^|-gLm|Nkc-{Ck-Dken{wZET9wXl1^7#V%<4)|Nma$CFL59Gmf&L zx_1pD%PXxUPptKxy8)U}wo$_|RNs<#Tp6YeF|P^x(Ba@c843xu!K5=abCcO_N(Q>f z=!~JC7XFrOeLI_v7t5RU!~7F1{PCQI>Y4F&bF-9{?LTf-fK!vZ({s=jvp`@! zYbkZvU*1~M8o5|vBi2o15yQ1PxKn&ddyY8>76q5UG+(B?9$tZLc0uR^(eNz{H*I$u zuFqFBuW;|cRhWj})8(Wgcnm+qe-Pbt{4{}H#VCymnSGeDcAWTydBZKNpQl#^X>fyh z$813SB)8`d5F};KLOcPg7)e?{i80^#$BM6Uep$MHqKJ@*N&i{bT z`I2f8g2iX-{h1alZ`f#l45JJJJ%hE`czOia)jF!S)2@=+W18l*Fe6P4G{&8&8o|b3 zJK97~-hUzb3Xn1`sXs+4-+?P=v1AYGk+y(sz|Ubv8tq|oN6s4FKL~PtM+@m5C5j`sJq;xU6b!?j|yhAJ!0e>=u<=YDkQz~hMw3b}tUo7^g8fr`JuV%oi z2fpd!xSIS)FiT&gz4gsy?&DEbQ{ORpjZoa#l`cqrWJ;iZP`5a1HTT-Q#f*aE*~0+_ zbmM00+pLW+&A*c?=CHC6J}v_ z0U+ZgoTIhVM}a+-&9{@*@nbnrDixRj8NQR+J$i#SLa*(q2gZ>j1D_ZdGbKpi6r(_V zVP{wV3{db?ex~B|{GqLFSWw?>y~5$X7E(Bzhr8O1nM`JfGmLG-W$A%?3Xjz`vI-TC4uF1MTK1R)=$6LZ=@;OK)|91X_~&m zv)zuj9+@x9xz-xGV(v}Q)may=!lZi99CtKM>#VoZ%LfkdPtlYrF|b%{WhRNK!RLA- zT@BI&lXRCrTVXvp0!3f@`pLqAtvv3?w1FUg<(4Ci3@`2=(6E* zXxvou161=X1|s|`%zj2Mp^sT!U#@S?U!)&r=IBTL_waVTfY}u{YA^{@27Tpn3I7E9 zWb{C-jJH}tM;YHYI0U@Zf0)&i{vlMYG%Lt7-n?@pTxKOQqXRwjUMKz(r|VhKJvq~V zK|dUKllh--2A-Sv+885@_ub|QyOba+OQ^$?tiBRnBor`s01}deet!yb##f13CmaWT zGsC=w*+X5SF9WBHmP#Ak1x?%apR<%Tf~#egbd0Do(WoPnad@9w~=y z3y$(`r7U$7Pr+OBtMVQ3-)b|q1U=MRj{aeGN{#$e?M?6}JS>v@KFn~(nHi4Vp1q)U z5^F?I{gWngzv33_!@Vpy){~Ed+$5`GZU&H@V{_k|Rp4yyII&#`XR^qq^jWP|pk{nC z+!gaoW1}-n!Hso}Xu)3N(r`2UmANiY(zAtiiH%dsv5GXx=pa{yHH~_%8y-bZmbzE3 z%bqga#u+;g&4=67`}v8u9l<1TBf^67@FcBE{tM%8UstmI&S;SK4WQI|=>_9Hp$M=TPw*qF zw*E$bq>1X})F6Mk{#E8v-fBHxYzizqiitPx2uQ1BPL>vf2CzAE(2VnLVEPBI0af~2 zu9~+!b1AB&w^^TL zlMK#t)HGZBtL8ib`=kYUALT`h0==b1VghW%w*ouuGfZRkyM9b-=%a*Hakr9I)vs;$ zkp+VYji#)eMq4yD{*}IrzmCSsU-@TfswUzZ)MI~)|`beOHw?4T#=CxJK#PV$Z zp?(ryy^;Pfqa@y_*XADhU($8hmf8xom2fS1R;cVb%L1touEQ_1#v`Y(+FgjK zmp+eZFHAR@V27&-pHGirTJj|`64<&K|IzVWIQ)lsM-FiA@n3*T(JAr^Jb~XQc!XDC zie?DKMeHk;&E{Q&FViCPo%xAiJQBlMqjO?c?rU?qD>Gj|zpFdQ#8 zDjMsoc|;`tG-h^?Pi$e&6G}2`@pdyDHwE=wA6={b4ebuJo6M9_J*TzaYTJyzbO6pW z3xZztP-1MdrMH3S*dcl`Cee=9R5IMtpKYqIp!)c+Q59Xm9=i|I5!JydERBX8sU%y@ z%J7fahA>9o5Bn$ym`C5JD*P_6JE^&+0n>`OrN!W?TnBJfyC)PjnhRCTD6~a?A05dG z?kd)9G>z$wn!D}^Q<#Kc3$+Sw!~gui|3I5h&BjoPQC_M?yd{{4!5)0Nyl~eGNJkxI z!=)Is(K?!+0TF z^j3nZBWQd=Yje8u8TgR7RC9;|F`W938Yo2LD@qWbGkUQHB-xoFWt#om2fe-hTi|+q z1S-J);faLTttj@Gbpvn13C3ok9@j`Y$vg%k(MjHC9e}Oe*WovCKtR$>j?tgte!v+# zY9Zk!KNZ&DH7UX@ZL|z7*4~i)lrd@zYKA8k$1*!W2SW3-F+g@zUwu@Py{#UHhG#PV zoKugUi%VK(+21G~FRni?b~YPJjMb`SM0hjUVp_N^9C;AT$NeL-t4`|t+w_^3*_1KRQfq?t1PF!Fx7)KtVh9{A947{>QjQKJq@K)}lfOK5}h3{MEJ)hKO`IVAHD{2u)R>VX8C$Z08@=OTQgv`P%H%jIv^zPqoe7#pt>6(TdjPo? zuk}VGJ<>PPJ840?L$%3`!=0>4zKP6-fSp;9FhFE-jf^?c(#%1mGrKamN$DO^dw+NY z{IvdjSRJ+Fj*)TrcFq#+1Z{h^W`04JP~G4&+{`hAD~f!+KQadXQCWS8)}w3q2DdCv zG3ra3&{ircHAkL^rWu!l8O#Y>E2RbD61;qr(a)U)kH`&RwB(XDX#_B%Jl-Bmft2TJ z9j>F@E4MsfQmn@`BueMphO5CSt)ev)6m%zur4mQ;iH>F*VKs=75ztfLQe z40c~|y-w~BynhU!o71IJmTjgn?L^Reqy!a#a7^^}UFdv5I<=&n@}6d2wD=XdZ$$!D!Bd1=DPxS=Xm4Hx@*#*5 zo4U@T^-41(4(|Y`(NI*BX9G<9Su5Q45=NQfexGx#**dYaIbWM=UFRE%KM{wUnUkTW z#n}stHu75d$nJtHcLVM@J5qh6Y;>NFzHa=>{;ibNsw%tHD4OQxJO41|N!?sm&=c;b zREk~5u4i8v7pxJPcKImEbtP$koOydm1$Zd|qVWUJAw2x{x!2 z+=^mUo&F#r~!CzWae?XcJFq4frq4 z7$s9)Aa+)6$Wczii9H+5p{QKKII@56qyG|f-Sdj* z*D%X4KZyJ8BWyKpo;uQ4hm+)dU=kyxB1aTlD|4ZR~yht@o*U(=|P2F&g4%N7ZIVDpAH~1~5IP??g>z1T`N=p>3Y8 zCQFpb?m|WInM4`3f8YQdV8nXZ;4(eajOU&@&Lk>$W6THVQmsF1DAn@qgo#>9b2v`s zG_AMQRI5m?1xvNRBVN165Nz>R~eI7pF9F&R7W_j)Fa!;{c|~e8ft317E5Md@K|hH-leRA6oR;#$6L0o% zUgKU-iukyAP zsv}`JYPR`>zbfpMTVNJ+Gy8idI2N-ry-kCe>{4%6Emofo_MlJtE>4Zr#ew#6d#~Bi z98djV3gE%SDaT`Q9G>BxQ`68E;J^$~#59Dp>M-9a>bmoiVu-higVrKAh~JyDi0R_k z?pd+@i?cm05v-18nKgMr`w0HSB@P#(z6RbK5&p8^qjJ~U&$q!d!57#l=_J(v&Vlv# zvGgJOvDjB29SNw95f*&L4<`B$IYuR%%_S#RF)LaxhzUSW>fErjyh5J@oXj_qg)dQM z^f31?`vba6b~4(#+q=^IKbR$WgEUR$jV%8PSS0Bf8WpNHHU#&xATZx}Y~44~s3Cgw z)Op$fyBvJX-2s(d@u~B2bjAT&T0QV2xGliB!a!UPW7Tg7ZyB>1@zXE-^FU(C=#fv=`(o~crbu+meRpz%u70lqZ) z1!gk);WFO{b*~Ufl?AW-SIJpfnN_1;TPBQHZw;n~XWJ%V`>pxa4CB0Z4EMA4+eMkpa4(mX@n>EbyrhuA89Cc}=@hizu9G47B<9H%u_8+q56rQIVO)pS;XY#iUoUrr#jbNhQp&!vw1enJe@ z%FjIq%oX{(_JcxHWu_MYIONbpYZc*TJ_s%ez5Jhp6k6svD0XJ1>%H{$z6T6u-Bc#& z2Swy5!ZQ8?Gg!ax%hkOgkthYijL+;xSGH$O=6CIoNYd)z$s#c6%MF z;LqTakd1QaXLM=iB-NR2#m2~4o~x)StcUM$m5t%~B7aqpUEhhPj@eveW+1Z*)@K{o z$M~bbFblY*n>CQdKce)|8a{#-^QLPkHsoE5Xv`NsxT=|R9LF|cv0X?{I7k?B<|d^M zbsiuG?Yu|84+e|So)`Nbb;ebEbYjm>nq1v?;TnrE*Z2HqcCsBfhF)Cd&xRPkSf zH#HY`o*F0Fm^Q|SoVtL~NYYV-lok+*)yzBilP66sqpi+vky9Q{#@Fm5cY3gI{#;KV zX(e0Rx~MHtmKbCC-dxw(C-{w#eDZPWk4;cO?L03nP4GyGfOQ9nl)gkfMDktKe#;JByeD7c3o56hZ< zaTik#_edT`1d|P_sdo@+2)61+!QZHa-jRJSPc_TgPW7Hxk~<}wtveQe@V}1QXqU<= zZ+GJBin~L9gZ0~7nD8gFf&4@2YGvbp%>}3~dV2T*pJ}Hmee=tjQG9vyTg%065+B)M z_rsOAp~3yIy->yTEOP*!1MY)~0tPEGlj&8g;H@nNT_55 z0QI;tk(2mBaG∈(fK;Ur-*$g8S%TE~ zFarBlqT9-&ko#N?78ja?v{Sv1o-3Ak$6c1MthdxFI(isUp1bsMWlZdKYj%2oiNM8# z{z=uVHeyMsb+9kHDd&lo*UoWem`rdqHo|gymas#Y!lXrxf6z>AAq-doIg8oMG-Ybr z<*cTGeo7KNNS_LJP;y`;J&TwvZfEzJ?Gx{)Ilgdl6oUFZtu<36cUAU!kcm02L0~*O z2~L9d;4x#VjqQ%fU%jo=(oCdZQ~SF%uv@iS{8;A(eM|l{&vRF&yxmM|v^{8nFI;_a zhB~Mp0E5_SEJ>uKeka42J9!TCsz=Xm&z_@x;N{`~rGkA@KZ_be$XD}T6k<{UJI;>B zo2Xs6FxX03#1;+dqdmk8IYxO57Qiw|_Y+5h5y4MMnRtA~2FiZ51zVOZZ1mf`%G%{S z>iEWT`dF0~|4Ohc-N#C#6H+M^W#m_4PQ^-ClxXA}Flld2w6T85*=6GhI z>(89yCdC}~)`E+?*&)Autp2Y#C~yl@2)Ow-#=7LHD9y7A+6m?oyd0u zr}h=_q9XRR#;QZ z5hmf=)F-k-`e@iZbsYDt)*+9W^UH`re|hudiukdA8th^2VLwTc+%9lCqaHJsbx<9Q zX4>$$Ah%xpgd$0(?&UcjC1P{mj#W#V&*jRC3 zY9o09qhyT4?~>AdJB4wfdD2fWj9)PWca3TgiLXk>4M{&lotN{}f@~k6iRrM%W^njV zZJLlmFOXho3q5P81KQ}s)_g#A@}nyc0V*DfF9avSx!e`?L-bfM#rnaP70M}BoJ+-Y zTv01PoFmRNTUlq2Fk-j{usYj{K`;qH_}y1sNb*hg7GY0oEeI?Z6G|Bm)phx2tpDwu z=Xfa3Fw;Okd5zQM|ScQ!+YpA&ZBBq;S#%_Il~@eSBev4 z*6Qb5i=GBg+aEDz-gpN|8ALMfsZDa-ls{%=pg)WRVG$MWd8)QF*LffKy&ywe5xAxw zgqZyfNN~`IMq{oIRCqs7i#9)FvE4d>$@;3*Xo+!xX*1sDw5CSPHmW<_gG^MK>uk+RR+B0}eS+#SBt&PxNl_2IB!3;Uv^;l9w}+)W*n8|M6J1Z&L) zZye#86Hjy3g;gld@5n1jw$ch{jBrcpB^C^N;B_#P|AUHzkGzkSK`4{@A@AfWGWpU@ z?=GfVpebKbxyL7))6ib<3FJwXV}!U?*+*7@_kznE&$JO{6?&g`I3p*p$DVAS5$Ay@ zQ`JR$-kQYJVGkQMaO#ml{74X%8Eua+hERPRJ)|~-KcuLxPz%*B+HU_+<&f1|)3p8I zHN7aaIht#3^R;pR!@Bvx-qWx)HNZZkCL3#vGkVX4sqPZ+7wY0P@iolSGw4a|DfOh4 z*6>_z6-QtBaPQb;f!oiB{4DBQe zClof!_|yEL1<2^^cj|C%H13nTbpP4ldbyl(7$$I2Lo5ApLSNJw?j=plWennr^XtG8 zW<$aNrZJx+w^jVOoiI?|KxaxaJ`I=CbNC+d9sC>_>FKB?pn3SLbbz`AA1H~YL_bJh znM4G=2~5AN!11B5kc>8MU6Gxs~K)*AeXuba=U zr9%Gh2% zgPCXoI~#US?k39Uk<>#S=bV<}vA+QULln#1rk6&=^P`M4zNKs*^^rD&?gNIpd?=6c zGJ|tJr_a)O>4?8=&~5z1uXn1_WLTT2s=c#^quKfrVoTck$Z*pqbkp`bCacr!fTKON z6(~C7)BIj33LOMcbUR?k2Xn*KAa)#2ep;jFjxA^e+&$>=czNZ?^~T@ zP8yp34n>#+%`;}fl=A!%*(FrfAA(O-XO;jT;^H;Z{g9h3y%5u_&w4Q5OFM#10^PxT zp0Mh{8GM3rr{;IQh@=^HF`?beYA%=X#lb4U2kPFCzbexcgig#bR8hVI=9>FGbF1n>Kb;43wM>JJ@dQU&IsE7JFqxolSO{FOaC`%mkkRzuvO+|rMwzTtau zr=)e}F0-?snm?uSQYqzLT?ZE@jNto;ZL;U%Wb+F*hl#PafJ>~TH#Bx?JCcry%TaNo zH{6GZnylay2BKrZCvZ8uh);S7fO=$q|9uc?p0rO=2lQ+sE_D%;PMKs5iOdCj(?Anv zVXZh_RPO27!nc*TBwj(|gY$8>#7Y`8?!me6JR^(En2F}oOwC`}*dXk-BE%1g75G`` zSX2v76dtY3HLLLuO-j`40_YPQO?d5JTKVh;WE9;pFeX2iTo1mHKl2-_T>2<9jDJpT zG~xokqflN!>5MWM2d(MmSfw=CI%_Yp7=A!4nQ5tir_Y6#v)0acoP?~(B>ZVi>AW}{c&o5``4<4@=%-IKM8_NT+Aa9fc`_#`%n-LAZ+ zHpxj!Lo?qO@J}?Oxn)Knb0exQElwUqYS@iGqw{R@oK&v7Iz#H<@iV{Wu`th>j%RZ) z%I{pu%)`5^sTrq(tajP7`9Ww>KqRbyfJVwq%o917Yw3O_+d^S}5;1@(-EbUUwBM8O z0xzgy@N@MSn6c!f68M(*R$BudsDDagWJz~-wl*)rt@bg;H=|Uwy!@pe3iVIQXJA}Yn?K}SCM({C@6k0Pdh%l3$ahkf#zNACHkDya zPOLOf`{ci#dIB=UlUOf%L)oGfWwwHY+(od97lOUqK60qm*4iB#=1Jo^k^51%nO~r|7Ut`0KX#Om$Jn(cA5jZ_1+BGgepmh~ zYLK2^r=JDfbF?fc;t79Oz17C@IhE0;iE}91iwxD~>__ zpJr8U2E&N+_0=ZL^fEsw|JlX#3Ranfe_X4Hf!S@$fzDCBqN!a_kH8suHW#L)MCWP4 zm5E$meWP$CajsS|pl}hX>ta7^Q?+A8Wk*-9r92ReDjl`bd>$wuoRwvL4t{SO)*nh& z?BQ5+w6q-gTh%pS6__Vd+uOMTCtLYbVcjdeq z@CVr*at}Dy>IbUm|6@Aq)1_6Of^bA&9KBGP%WTbE?#SaCC6=&L`0?yCI7z!e4L}`n zS#e0t9Jec2$l4*iNq-E}V3sxyT+uP#(Y|+dfYF9|WoJoS-L2X4A^v7CUt5gFBv-_G zwYMF6oj)Qs1LaXouB!dbEXvl=KL^&}!e#TURxFdZiJxi~wI8FnU7E+O~ksWcq;0^4AXv~JFd zuB*yZZZoRLRiUf;f6DjG36`Dw#;K?M#V<(;={2oppnc*cO%Ns-ZSY=VeMkdckar0A z{(tO)xJv3q+}*VZ&2&u@#zeYv;vGY!R5^^9&TY?Z2rn89gA=_wbIMtZLprw%=s};$ zyl$5@Uyw!e%UDYzEk~?3m#^kp!o=YDaWB0`^=ITEwwuR|{=g{!@M*L^)qscTuTnA? z;Jt>!giOy9xvDYO{1iLP-xghCZ_=Br>c)03)LWPsZKkWm!8^5!bV~2;|Ec##8A)w6 zGI8CbZ{7LE6A+A9LbYds-O)2xzE~UZ#rQoWL9dj$7)|nh=4-Z>vOz!Lc#_;Gu_GK8 zcyq|+3&t)r7U6f|z+e)sy3TXe^q~!BnEiqy*b8g~)g1mC>MY*IoxHqtUypFt5z?ro zu#Qs4@n2?7YY{a!udSMnBMYWXL0{9JXqCZACJ)vNZZ;ZI&8&xXz+J!v?aJyzBZqCS?9u$@ z4rj7<`S993*&*_2lK-qPlWdFM_}pd#t`;){$XW_nM)-(s;&aR>?DiiPX36iQA?7P# zJ@~@7Li(@_b26%p5wA563z^^XR`e3N^l1HrEGs(QFJ_Xlm?^}UXSAs5L|5NG-s)B= zdC@!$`)UWd^Tv>1C+fbSGb@O$$bxOrIsXCqzYtCPSy`C#Q(9#`X8sIyEL=oat*&dh zzm)Zq%Mk;hkLN_lx6R6}mDM6?vJA{SQhS7KO!mfGalvAXx3?~UYVJD^)8Q7N7Oj5=k98$sr(y~J9NVwju!IsUF3 z(tqfM)?4bLPEG1+?lFEg{3Z|P|4mkf>)wCTvdnF9$D|Q@mk_EL?7QwNg~j^S6kEFg{Fqqn&e6SPFGvhw9I`Z;A8xbJ9l z&b9hUWARKC>n?v{6+tw}~}SCxVx z4sjtIliSTb+O>o|LM0fkyB;HmS7)DkWy%3+sPORN#1`6gt3GM;Ov z*k_nfCb_f8{znRuZ;VH&PVPGT8kb^}Bp(SsT!+CpxXIpWRuhu*?t|U_U370ELSM`I>t+GU$1HqEv_Z3yjIzlyneQU@u}1 zdz+jjZcCj^Gy-G%-`K~l(PA4qPAiKB!f{>-MYG$o>zI=8VgIZU`Ta+FsyGqUg++4X z_|o<>G|^Qbm6jT4`Je?`M7$3>d)8-fCpIu2OgC?vWi%aqKoWb`yMEtyE|L$YU*dng zwe>y*p(Pr_5;s}7=#6;;^$IqVtH7;RQ@Mokhql?cPOT8i0@AC!fL>V^Y}1q(BW3-`Zrj z*Dd*>!`Z-dNWCSLvoGs^X2r4& z;szLlyEx+AL#Y|`J$NwZvhdLFR&!!+#2#^JWFfY&+BD(5e2m)2UKei!S}~2x&G<=R z4?1Tx7aP!|b>9vT(SiBSOYw-Yw3uK=X9=`}dyrQN_cA-lP=Dy(&c2ba=j^a%vAO0D zBEUx&AJkXK3r~gSA$mZ`J;a%A4)j;anH4 z1N+NtB0wdJsO9j(W7Z^gptdG>8!wP&fUWKYj&|if%N6*Rp=8kfz#?`MQ_U`E_9IGD zlZma$h{WEO%R4^ytm^?DpcWL0@$-{9+ii(8upgD9G@%wUOQe=)8l0IuFA%A{B-7~q z-dwU2(?e;D1Idfz9>Q0AU;Ih@#*l$vo-of0b=5N%Q+040&UwW^|^J;wl0{&m4 zHE~F%P-)YEQ}hi9eFG=i`TlRDijU#lTpag5{I|G@YO9s8%Nj|{R9rgaHnSro4V-|Z zvIoGmWKk~0d=*;;{)KuPzmpn)T)n+E5WAwG+BUVb=eAJ;KI8Vm0OYNP2E)_*6zdm0 z%`N3io0s|aFang~#x)%6SSC#~r)jT!VdQ3MJAIMy5Lu4BMwogPmZO`38gQi6%if_; zX7{AZK@gh#kzNx2r{4yTh1vL1)_bxZ`yCH6)~~-Pw==@w1$I7}40gI#n)h5S_(*Wd z%%URz<#7h5Q>~=c(q1Ie1W3=Rp-xq5lQs25Y8KcHdNXH*=GHCPmMaX0hHA$X=rk7| zN_nhcCrC|&iDY9g*Y3v-gExgsFiYP=4&^V1H;Ar!IJhiFWwNOeI<9pF7C&@cYXes3 z6--KAt9}E;jk4TkPd$CHSraMvgQpSMfO}#;bp=71Hdd?QOK|F>LhcCE#>>HN*H7cI zxjMA-tf5=L8|*OGJmYrYy!94-@y5w5>B8(rww>f-{ul6Z52PN<*PPy3b+pxAKh*vF zODzw_5xP`T{amF2wFXy@Y{^g63u2qfH-DNau@BV6Ie3luz_^h-Ct)drf4O~=7O*xEWbKCIWiJQbbA(Z4uy30Q|OEU}fD9?F)vozH{q?_Ps z&ak}4cpQ^#lxEhjcc|B)WW;_?LnAF?7*mQZnVGDvadtK;dnQu7FyJToo|7Z-GRL7^ zV}$z53$BQmfGc~GgJVMpsaLQiF%8yEf5cg458-&yDsU^YtvnGP;=YqbrE%&bD0^RN z6El+FS6>bv2wWt}i0iod=Ha6UrQ69p`FkLnouO8jBx9^lR(xcBVp`Zvajx$2WeF3# zPxazp!lC*OnNDYa*cu8TRKwTaGg*M8$ge=dd4jFl;T(0 z!qlr)H82_f)c`J~cgn8hp_oa6#F%jN?CHZ-DIW-Jy z&?}jXLub_Q+!~qBL+9BD{hnPcWr9eDp2>{BY3-(?gG9w8I#lV3UfF1W^?)qjiSMBIa9l^td|S

    B0M%TCiD&z-}G5h9p7RzVTxIT-o%7*W7J+VnqVB=;%JZ8{SB>FlXBLG zGdU;NPQ3+3V4me9TYEa_dF=2|;&Y__%CjN3gq$Y7@r+dam@80$sMFf1;08NO%!e(- z;>NtJLS#dE5mQTgN=#J6P=@Fyb5C&UHHiQ4v!sj6GVQE>+AN|~H8Rx^@zc?B@)uHl z+t3$&PhKg0k&+l+2iyR>&;)Ac>7-xixI>mRr1456PxstY>`(AXGw9vP2l(^J`;0~4 zKh>MtP5dC8kq@GZ+F<3lQUN|gZIezU&u9M(c|lEbi?it=2jO|rH?&WAV_(W{r0qsK ziMaI9 zWF`3oI7W;X<_H_Je8w&(A6%c=(bZ7c;ur3iuc?=C`Zw-{r9_St;u z*`{>PTcnIr3mXrRBp)&shG=pfkj@X;^@ZDDU~rTCME^;j*UIAp!us@Q*>FwouWz^`oQnbcWhs-{Br-Fafh@V zqOo}ypU?T4S)XSxCv4xrHGr!sUgsuo z7jkM?wX}2oRi3^$$?T{{^G5>HU4mMlJ?hKDe-c8Vw_Hg_xI0c39-9{!h5yH}d~=2G zs066#?qS?wZ*j%sPg<7ozu-Vb+eQ76V3st@V7aDjJX(@;)HxLG;M;&oAVAIXG}qrF z$#oLe;Opvr>{DF4|DHZE#7{n9zktm+i(T#c8y_<}qiwG65U&**Xu=$Hb&CCE?9@AO zQ}G@vl}%3?RktQppUIMBdAR_n3(_XXKiW$39h{9%=iAJ3Gn{m z5Nx5i&=eE7$2!K#r_C0~5Rbzy+?4~byh%7YL`hi83Y#r$d2oeENTWoPnc`Z;q+~2s^DB>I!h?IT8pv^DtvHl!Jj02*sy<4+u#;n+=(FHC zX|{ffseqj^2(BjB!(+u8j6r>(E2zEIr#47G6|X8yjL@7Ck;hO?aF2Y-K4foNp_#L? zPH@fyJS+1bxE*{OriZmxuO_~RbJ)1Oz0`eHE3 z#aS}PVWvKIUQ@I##~MEmG;-=@yUY&sWXB-&QdS?-oJ~m>A$9holIxVMIgX4HNjKyT zX-;7pm!{S8?l$si`_XH@ggzUnDAv6+dnqE!omK}<@Uh|+tF7D;c9(TypHw<$Kk8?^ z*1Dl)!JFni>23;swxNrUgA0h#o<5EqXvo2G>;uqQEaJXvwo~gRz2{;<2HP%qx)4a1 zt)|;1Xur8Ngl^O~pn3p*(~`k&W`=E*z|h;+7~d`}U>23y9GI#*gHuQnETa#`@3X4m z&n=yM=<>1;`Dmxgmy36xbhDDPty)Ww**EMhWMkT!YDO%Owe;8#Fi&x~khx*`$!Xw>XS$;eQ-)sx^Ry>)BXVFMO}MBvGrl_k zIwIGiw`)1*582sP(^nMq=jvs%%5c|`u*K*YswjlJn?|=`vXw3(b~@2n+#0ouAw&2T9MK z;Kj(#a45bLcV_;ftoq}WD%LIPj_(TSLT*rJirX}YQajkn>}B?gUkwt7eBQ@qtHc** z2<*xYf!knDxhZ>E|4TIi8?7)(Wv8H*Y$-OCz2u(Zy9HaOMG(o{G2u{J%d(Y{_GrHN z5GJ6jTqfRt$MIqCfc%|nEsd2{_%G2fu$T8JF$DYHn#=ifwmEWrg}5912Uta&N3l4; z+|M~XJ<|A;y%Tl@8E6w3aMzYjr7ae=ajVqwftGQF*$rvsK|ou=zr#dP zy%~L18m49iGS&CgF}Enn`dz&YC}}pe78&!6C{3`=p$$UK_;|B}x{UsoxPWPjlO#{e4uVH~kej8SbuZ*cnfHY(a<@7S+{YP> z!CGgWyog~(BSiiYYXm2lw_QslH#HNMm)pyi+{@f2r1i>SEkE8}RUsRy)Ad%&9N1LI zp)QI))I4an)uu{dzh(qqhzwDF(4FN6+D~nOni}xO&LBI3hg_lHAnrH5pUu>x8oWhM z^%>cXt^V4#Lof8HzE9j^g7qAOi~JwpXJ{v83icL zsJLxEtVc(o5>6ZU!9-|d;v1%C?$zK{eQfZ%Yfk26ID)L>zhxDTQTeTuKYOb(*~%iG znw)(eh(}dGS8`jh3u>>ww6Bp!DEJkBMmP)$$qkHT zu7nYEDZYV$Rl%KlG1Nah+K9~R053VKaBuj2T9^&#=7Ny0l{odH$!3QuNt%AoujQ*o46$vhApl0AWCT^6;%Jog}`lp zoZ+z_B~6Qw%Lya7U+PHB1qNt_85j3gZKf3i5ptruF?hh+g$V(@?4)BkNEL)335r~TfFHEs@vJ*;1l0R-^iRE{x+ydw$m9;_d7Vky~FBczoiGV zFZ<)+Aaw*-$}u^s`QgF}GlE>}UaL5%5o{5bf?l&2|A?qS750W}3CvA)2Qw1%_XN3! zBFnT8uNrWKH&S$E0n^#g?au(N|zHS;95o@<-A(skT<9A1V0xH|&m5}cE@a#jpG zRa~OJ!6}zV>{jNB6vy^pZ2k!AqD8?G!VFeoYDxpGTlnmH&a?_l&9R}M=C+J5zf4BLC#!GK0>?-jPVJ4j#A4VL5()v=p};T!cLj8UlJS13)wxS zz!!vlq|qp}&H-%`+0Lh^411DH$o}ITD*a_I1iL05z|Yx5qbmv%FUCjbuEtx{N9ulF z)6$WFb&ZqWMnYYFzxpNWE?B3XIdIfC#?<$JM1`qSIb`NM5{SdpQo)l<8c_!};(suW zd~d)^eort&fnrnG8K*kj2N%}koCWZXegnBHdkpv`WrAnwM%5&1vG@3BzMhjdKPTMM ziukt(DUyLaLWw|GP{^#F`_wFi`gu+`ILv?f#q3g>8kLE1MuN3XU2K4qP&mqaNIPSU zF?WNBpgpYVaf2qJpZ!FPg^AjTOi?<=JOL5fB4;IY0w1IH6DO&XyiQ}dPWm!cjcPi$6~{C8oaa0pc4j;3CoG3)L@f9RH`p689sI+HInKgFp`4v)vNnm0 zB>ikC^)x45nq_uJ^W#>d4w=oF8m?T>GVlZx!yc1(wy>(&^NW-H&)E9xU4FA%-Frn1 zgT?t&&mZMDTb|#b4a@A#%m$B{czv4k)c24L*bmX+5q&)jJKrA+CVA|^1ozj3w(@v= zrnxlnod2G&*_CVRVx73%dT;Xpz7_9Nm!M)vCCowiyg!Bt>Ur5mFad27-O9U``0Od8 z=696flI7>t0wb3$VzjIGin*25$rF;;o^1<18)w+>>{8=1w?a&}56zhv+=8=ImzYN} z0m~HEp=8Hs*K~2PaV7YjnVj&PB3v_hk2zWz<0(cJvL?9hskv-hF(bOKeK_4VGlrU+ zW>N#027FVs5!mW$!|a!@qC@fTysZrUe*y;hWbBtdYs1)T%5U`*A4QI58~Veg^(fg^ zL3U_oy(^s}=JK_oRBJRfAn7}wYjl$|&;~Y93&fV^ z4eOO!4}CVUYV16fI!3m6n!?|xb5fob!CW;T@zcfpTI<-Jd%55}|8z6nDwz@utD(i} z@a&$37XlpW|ncq<>ml#Mi z-rFa0SFA;dqOY+d+3EUjHH)ikUD1u8!QD3B!8CeAN~Gt4d?xrz-O03weWVep%heu5 z7&%O&kev4aOu5a@E=hkqot#a;b6>H8pY)Gp^ zd7nB)t1MR+GLR$iQ(p*siACtLR&P|u^)_yj)P-$e`yStmZ_V8Zd!wyI1BkbCAewvN+QC4#rtx8?7087SUzy8GCsD}A-Uhl;C~F#u=P z=fDG20Zy^jByRQ{;Nz2eW-erWz8uGFxWgHZ^Q072$$7+TB+pEEfD|c@ZRZ-LW}$Xi z*&O0-sXljik^d4i)7>Vmu5?xoyGGV@KZHjdBTTF_WVhD)^0(}n%5$-zF^OM-mB*g^ zNV+P{_b{Lgk!EaWF7uD9Ny<-d7iy~S!ES@AmLm?biEbZ1j!!YFF&N-rk;d%Zw`xQ0>Z>3zLIUGWl zag-x&2G*f5W)jwWLK}#O3j>EaEc`qZ!@F5PeD}1}n*{@UWve_piC%=+KkmR5rJ*&? z_@yj17i7GU-Y5^u2`TyP`{2n9&Dde`2yZD@9G8zdoSy6Raz|7rS|Qg2GIO{i~c`6gAbWJr(|=S@OQh`jI-?e#lVuTY=r$NB%VtK^*5!x+@3P(AD)soOAdsU4RFmq7@8wQSa%OoC!hE zd)oOXP*d;9I-uwqWJW6yuA;%Aa$Wm!ScfCDF4YDo7x)WtGjcYYKcb_|ZA7xVi25xL zvDF}!feWC!_}09xOSX=1xNE%nfq_CZ;e*8JSE(}YX-XNDz-MktUeP%AFx$s*MoYDa z>r>Ess*8R>wk&IJ9y3iW4XbdQdC@gBF3k8N4JVxZ#-v8*RaA^PDe;nFK*5;=YigB^ zO2V+j0K~a(oTwiquMnHtRx73!lvXEv1_#_F;FojuAs(J|8Mb^|$XAR^pn~RBf5HCFf)yVIl z2?|JUP`;#2{ER>YrKr}(d~EN*J_H8^9`vLZ5aqS=#B@=Y9|ux=AJ zr8cl&PVbCYAPSVlKH={MATFs5V61_Cpki08^kFJ@6MLHLg<2@PdCPj}YDLtPE-EXt zo2wys)0?e~r1q2B(Gg!s_XWeuz3Hth0)B%@nNs33>L+!_|140Q{+{CX>=G_`JL42i z1a_phW!g&3t@GTis1cZk1?hXRcihmVn<)#_2`HaAEN}~~qc#&m@TO#PSQFwHb1&7ST$^#4CiQ9JU?*)!I3JR$?$=sN zTEatmHvgC@$2I3u2%ohVIf&hvs=Sq0rsd-{!H{6A+%o!_)>-X~>Q8%`g$(RSRf4?g^R~l!{R=kuCE_1Co5E6{nx}r^jq&$FwSrMIa z+!ed!{7;&uyfSxX?PEmkU%(BWNf}f#mq&R4whN07)zMNNi$QI6RZf)fIq@)C7~c+G z3lGdx+g2r2?SzU6F&kdlUwC#fGdycDmTU7O#^I^sW>$i;2K>P8KsI5m<2`#!?Uh*? z-{#|#M#>2^g>68#XP?kR%#HSw%u(V|dSmf9;G`W!N5=zxJDF^>BZKNLwX1J2Y`}Gm z5a4LP+tEw&83l=+%n))Yx{>`QQC5%F&IL%(G~qP#fu7^oZ#F)hXRL`?$(>M#p`xfTxs@&H7y%0#LFWpv z8#v%Vvr%}n`13~dm}B<7m|CY$D+dM0u^n@4L zsnS8}m3fo@6WdFx@2ZZ^_1W47eRp7`J}VIB?x9sD{()8X$;3SFnVl!kv9azx%)fvBxDbhpI*djsd2{}@iVgRNpF!jD=Jbq9ga!_*d36>r9#@=_|@3UgLQ zB$q&{8k;mFCO^*4Y@yfV333pd0s5sik6CQX7O%o+Uqv>Ht)VwkJ^V(l7hlGFr?ysZ zusih9!c*vTECSOLV<}cXZ2jP)P#bx(tp|QLRY9l49cCpnBW6LucIC03CcgQq#LwoZ zu=%NS%wDRdzMmV1YPvKw#xqR%1Pigvh^?p}Gs{{Gie>&%omOlRyI(Cg>ZlCY$7WWi z@-v0C3GQ<^A6?KGEv;am0~U@{#~YK$4WMAKt;J%csk&a14bv|AN2+@mSISY8Myynu zy7)S)wWY>Pc6LG=eKL8#b-_`W%axVx@B z&3>18Bd1#~W^ry^N9>`#Vk4`gd>_@-EUq7>-yIt4B;5z1#UA#9OgXv|e@oNi~UNgwQ~W=)|xtmnO^pEg2CH~T#Cl$Gv(@45?FHZNYJ zXZTCnX4`9kLFk25lxq@?{O#1~Qt8-cL^_F1#udGZ>%D!7%yB2MW>p8R)`ZB{{3W)R zzcWoR55Q{=jR`MtsFhwPKH56LUu6HMuM0e-Z=+w-JK-mt15R^0;%Tum^IDjL2J^nG zgH}?!lU0Qy{8HunS){^Y?<>j~Xe3*sb6Yiyczft^2^SL@mbX>UVvj7)udiE%y+WWRuz6(jc5?*iLzp zw-YZ?7m#^AhOjHY*dNgc$+qYS>?fwfj^b};OYT1~%3qXUi9fqI{}vyk)puB+GFJ&a z%szi;yBrq#pQBsCBRG@1l@yxrM&n&CnQO`t^CH{ZR)lK7j^KOvb-W83K}@1Ha|ELa z{pqg89by2}NH^))TpZZUU(LNKR9C7AWy#UTx`ek(D{#lwn4gH&$sOI3g{$IgYPnXN zOLojhE46l@o5SZkFFrP7kQ=BwWu>Buz$?cK6bBlzv%}KDtBVI%9_u`n%|$w820$r? zO-=`Gf-5lXa)-AMoR8CJt21jb-%+GHo_`FU=v8t{(y_t{qc?hDtT1ELY5P<3lX!+| z$I4oN|7EF)b18Wxj`#-0kR)Jqqk&(=bizn{>s5ZA;l3OH#W>=)0FphM0U6A2 zg~RCJY&`~a(JiqIQ9*pjjR5!DlbCJP7-_ZkM2*FX?fJrMzN*5?idT}%xBAD z5Aaj_a%;5fsO<^WU+SkiiKg0ja8MlME(?AV_t9fxPVRp>;YriPNtk%Z=^wCGudKXXYbj7k+>VKlU7(hjp${o zNT0R!Q->-2l%>W|+lHw85w9Hc?GV**T!s$0T}mEzPVB8egr9x+%yVL0&q3=2D#Q4> z{4f`})yitBvoKT9>H#xd3&E@-y^We=O|=zQ77e0*8V%KVSRW$QP+l_g;A_`RFdw~@ zrw4rp)+aE+|5Vbx7n$b!1z)E0>^N zQ?u2-?z$SUrot#D;7#O3yLn(7O%fYHmR>pGB6uZS(L1uM#5CqN43nE12B-)Rt9u+% zQ)@9IJQQA;xT=R!6;TCt5G;>714fk^&I}12@kctwfKU80wzvC;`Wcm_hMOgr!bYNu zy=CN0W*vPD{fS#@m9{D*yfQ`@L%>+i;$UxnW`b_qRFO<$j(?0+(;be{Oq<+`^b+t?;_sLq{K>6NgoX@5C&OWB zUwz^5VelQ@9xUW<#IEvg;PP`(m?UwDSe;r?yrie;?d?0QXk({m5I2|Fja?ZJY(Jd) zv5vWxZ7DZPs_&^GNpP1f)4XW6#s3y6XQXl<(Q6Ktchmdz{@fE`Eh>$QM-}0IDxq2j z^HE%DdMj2y+KKrX%N$K5arfeE)_%UdQb5SGHjtyGD#0jB6L|n%!fAY%naWP2bv;k9 zU`I0GC>|&r?5tE`%5$BNld4L*G;M}uhB>-2u}VcIBSs|Vsuexe=$h_fRz9Y-(Fg0{ zJ;)a3ks37;(hzGeBvoKj;QQb&;Q}aZ*5FrKTkT1n*AX5*6!x3{#Msn2;vL=RV*r#B}!uv1hEGFvM(4Zg#In zAA=L&C}S(7stc^WvJZAeHOYrmeXFlJI{i1ZTm3~ni9C+VlmB2U-VH)(06z$kfyv0v z*TuQoOIYu^=sF-)M-Tn2!^gNw>Z^m>wav<6;ge^O*v~AToyh$(eHcbs8qGACxX0`g zq*4Zy&dalOktyde#fKCFd?-g~tp1JS`N??tzGp;ZE#$HK8>i}*F|BzKCtokXQpv?o z8DTtXjzZ9X`XBVn9Gm`Cn?~1mo>zZI-8LY3jPWpip>BpU2l-88`(S-bQTv%gv>SLf zUX{S3CwA7=jJUuJR8OZ>K#qjH_6gj~%=XscxQp(&@k_xUs&dR3{ynvT8)eT4?&kZD zXQ)k%qHJ$2PB_6wMz6QI^&jSWeJJrJ^&4mlUl{+C>bxSdXomY|lt?ULe2IeD-*rh& zi4U>0QMM=a5Y7^FJd@BYGZ`nkyAY$iVeyed4B{O5!D!na@s%1DJmSk`S__lZe8Cv4 zXyRRaXKk$Y&iE=%K!UZ;ztk8cg)?|Zo^V$`1tahiREWVmd~GoQirb@h6IQ4KEE8YL z`o}!uZmDfi8e2j#8>y3TVjH15Nwpxp0Cd@ zPwPp2HiUu5z+#7xBzoZ`Xt(=yLi&6q404CmBpb7l=t#Y5l9uZ9byEWv7sZ z%|>iTXJgkg{<%IGHRY9x-^@5+sJa)1p#4q?rn5VVn(Vq9?E4^d;3~eD&yPGW6a6ae=v*xp6lS>2GPRjX>~pp=6DMsk7x4SD=9)FV zU!^lBQurFTOZ&?0_e~YwC(Z%Ah$+%wsWle_?X^pZ4z`Z^#*lG`hTS_;_{B}-{n~w# z5GIR{)nT!voecvLD1{k{A5js0ppgN77xM&ITs`m@`vg?f#mXZ^?q}7( z>ON`#(*Xn=LTU!Nn)@G@%*+TI=KHKg*sjBtegTHqz47_Uc4kSUhJ3902IYGOC3Zwr zsF<8P{#4gr@QQ7veKxmCEzDfo5iS6}QZ$;1Gkke?W~}Bcc;E&1lJ9^&$&JiODIm^E z{e&4#N_2Dd$oT8*z)Vs*&|A2+zI3S^9g@kSxzvs9uF@#?8mxT{<-coAZmii_?F?u5 zasJq46?~-o4t1#Z18bpjz@Esxg_*%8On=`P z?V3_rEDB24+RO2-fuNvR#s3Qx^{ge!1X`g9zSFh;2P?DFdAr~vCF~?PN*fLzD$8N_ zeNHFSX{?cTd6`@@c#*v8u1Sw1G(DExkXRo5rOKK;*=F1h{FD!gd7OdoZf2-%GZ5Su zoC3BaZ>GD6t#K>NKkk?{09*@12npOg-#q&T@WrT#Udvs~WvG!_D5j!oywTCq(o;ig z$9|ig_Hr}ixon!+6SR?^6IIj-U<`dzs?W6sHnEwi$Ol13J)hAEK2P6A zjfJ7X3Di4BAV^WFH3SOJ`aJrv#%Gu66cLgpCj zyk1gc)q_+zKN*Ao8Qi2s)%)ssB@{64sHgC)`kw8C^M&w-c}x_hD~l88juOj;1y^y? zwCAeoAnj9f8e0{krjzxwkjy(qDibTNGF0|fc07dsJUN82#Pvj|x0A0f*hx>0N^wHmUj_bwwlfW2mjprTc_{J~XPGuth` z3afDAb2?G^#XU&IJv?WKJ$Skut)FDKdK%=+Pkg}5(a(dk+9QeJu`TcO*2Y}6ZLjzt6q@05a>V-h23)V%v$Pd zyz#sMe`#HCTBHv9G^HrtN{AsyzIAq}IwiQ9TmUEUUtpHWrQmrZ8gfi5n`&-lqJ$T= zIRQ~03%mMLopmGX>lygB`N3IR2{qLh{4}c%)el}sV$Cj|uIgwpfAB8Zd*5(dTlfT=Cr)sxJPRk( zG;<1>|4=O@{NQV!#B3LPGcWiZ{OQa>M3iqJCNVVNk!K;lS5H!}ck9nPRSM_P}`a;)O zAr-acc1aDjppAGIM@p zZVaX}9l=ueE&Nusy?2p)0-MXe<2teH`3}NY!;s>+b4nk1fj8UVN$H6;xi>+WNT)0z z|A9o`X6B_Z)L%VgBD||!a1^CJVQ%(xwL7~Y{sDMHjyf_6f1h&>5@>6}WVFNj!YAl! zm6F_7RT1xUN5P!jiri{nA9_2}*ge-+Y}5b}*+Lh^idkEUaO}Q&&pqY(fcgF{WMjBN zO=T*|pX@E!>1dHqm@7}A=;rqD+M@&E2fOk|nWfyB#9I7DeTv$a3{l@{`Q6RUtNblS z5cgtloxnddW^t|Q+ia{*lYghww|#^s$fntkh-Kmnoz#cJG3G`$Dak4OxczFR_0L-E zOig+aDDM6b4h!8CoKt4LyclXpyFoNx%_jFVVy^85pip)}FNTwhv%{63yMEoh+sm1(2{h zD1!{S4t;};h5Hj)s!x*kkax*8t~kt(pNTp1l`&^+v^_^$NvvT%P#z_aIEn43uh$ls z59uQ;8(H7qkXphg%FoO=cnlQQ2l-N6J^4_~T7O49uZ>xq()~o4pQgX_#u(>=BV8eK zU2iC8Hg1vz$Yp8IaXAifq%GIzi(*t=0j&(T}LG!U+8;%wl=ib!vL zW7hv{NBO}dq$}i(();7?CoQna57BR90Z}4t4j}0Q>J`s6G?ZC}zs-B8U!aqkWb6v$ zr*0YJh>bN<;aIf==%Ec&tEywo1rPYm8P~j@HM6xwdm)i6@vAg+AkR z@c>}7bWq9KBRBMQHRix))H(K_TtD_|^dNqR{RAxUp_Fn~W9AZdsm=yrmYxkKq0zp+ zAWJ|(x}+>9p|Le2{flScLCAy_oOx z>nNL58P#`q>EHYwxP*8wQD_!?3l{Omz&TY=J2U-`ex%bT_e@4&5B>o#TjSiRh|OmWd5vws<47>OD?5yAY`WS4S7y3l~{_{ zm)?~4D!cWv=(JG^_3_wSW@Qro823D`3I9e} z#ac)dKQgD_bYr^igDv%?-XHu=I6U_y`+=?I{cZh6GY-OD+SSW_QIq&kJio2c!n8YK z<;Y8p+jJ4NzxGglCLgs|lvMQ+TtarK&`^nS^ins)@%oQI2>B`{iQKN$5+;yatmgUDMx^1qF*`giG>b_fw)HbIjY!`G?3CaWw6ELZ_gIpUxbL>@J>(tF+%Gq@m+< z4M)B3oUq>HB6R~MPFB#CV>-oC?ls%aR@C__@vb>MkcoCnwU}FuD`p!jMb1hqLU%X% zITOGux^Cuub(B8dB+RDXr)n$k(zQjY0R?qeW*uIz9&p#?i>_=UR3NETqA#eW40T?~>8m=&avbRYT5H?}UQfzGK<~hsk&Iu~9u0;1+ zxh~b*`apKlU&gofZ7>Q&#tM>oU%YNs;@c7p1HF_5{?+tcVk_B^i1#+Y-IrUUx_~&i zgIg_(OQ7Q7l@T@Cfky@O&1hP1~*X2a!Y3=X8 z`>I>%7fCY5=^SYe^3&da=($kFTB&Ye=Np5NU<9=E~azNggpu5?fnPz zHGC-jR__6>NNDu2 zx@U~pOe<=ZM$?2=*l8t^C!r5)jay|6()#6ewDMfbwITd<`U0>r=4D;w2hczH1Ge$> zZY9Mxhbe-aI*thigv;=Yx?ehvNmHawv&}ODQLoYqfIf6td5y3>kSZ1ALI^M-!H z8j!T!SyDLxr^dz^OXM#`H9FBWGL_va_+lKMe z19tr+*Cg=R=Kuz?HLDbN+PzEL7N~{_Nj>31%;CGL6$QUgaGX&JJQq+?8G#Ku17?jOs)i5dt>o9+PO>8Y<=71zFEAUsGsudK@Q9hNS zS}|qC+V-wYC_fi(IX4h{m6tGR3$V}Qrn5KHyZTmi7}km4aJHmWd)umi&_cN} zHIc7EE*4kHm&{#(yqHvHV}7Bg0?J!CI9%OH)?&uvBu2Vo_kPa#CMOy;^v@)?q-t4w zO{$7#vUj1e19x3iWCXUQqNpRyP}DR~Pk5Ogr)c(>%yg|hF@(tsT%f+tW#kXyIisAG znpTIs3?Vm9t;L)(9-7zS^>CJLt`u<%Wv|1L{$9izAUQi`A>q)kbyIMpt!zbJe9! z(wXSHG{ybSD9rotY2H@8f}Vy=(|!X_AcTJa4(pnlD32iO#SbJJgNis$Fw^zRBL=23 z5x5=cqWTJKw2p8Kn08n*XEI&h>5@v7(((ruCR{+ilwZnKZk=&i%QOC>B5W?<^sdG4 zprcBQ_!@52flp=gKg>5scKrZ}d}L}0>JF+qDrGeBO;iHm+r?OojG9VpPzMJAvCz5! zQsqnPUS^4Zwswg0@~e13@1t)u>hYVHmab7*yW$Rzb6uB&x|u^^f_-XsHJoycH|Eo% z($#!tYg{j2TO*9CvXOQex2+4pWM)FWHmIff8tPdttpT@;eQ3@HK}Qkm1oc!e6qN-k zph#jw@D5tQX9v&Y_S`(oz`mt4fCYn-$m8@>qJo~nu8LmnEbr|_J*6~di24MrQjdbO za!6eRBpN2(sPLGbRX8}6JEJ8W9_)Q-t~BT95uR4sTeJe?gXN;mbB*Xiu1oYf^&3{5 z<`Y@!3u`Bltp@!4nJ+>fJC$jS34@Z-H}Zc>m#{P=%s;k%tk6(=u1=PvfP4h>B>cQjUD7KL!8hr>nf{WQm?aJP7R6Rk;I0y6aUvW79NmcY{bDI&O=Le0s zAFgU_Gsj!kU_J)E3)rY#egij47L-1veb36sYC~Ni`-&aG1Q;Qmb9`rN@@{Sc|Bl*d zeaG3pd3to*;UK;rxU1j<)q731Pd6JdS5W`W8Ny3g*mc_6hJ1pi@`8XXgvW%{B16vPKPuO@sukCVztY z6JNB(Z&D6S)K1Blf+dMdnT#t&Eouym?ZE8;XEe-Yv94iSaHzj4+>_ZlFdZCHRlTRi zdMY_X1NDt&-m`35qp&s={@@>z9|QMXv-dXPYQf$(b5xj5AYP$qL6#kr7y=$diNbuD zl_zu2%t&p#r-1zg(>}P6$_@BYtGJ-zimeD2gIsYmU61+8ZIVgxKlM)PeeI`mR=kM) zjHUG)?-}}-GbG^wvopGjZK#$1P`_Pm7>Z8Tu4na2$*&hA8-qL&+V}V`Vp{hbzBNC~ zoT+Zo!mPruEgIzsNxDdMlL^?)I4lhaZe-3lizYr`+LGnX`Yw-kp4|hs`Aa0_(f9p5 zyd7{FXmoCzdW_+vR4x{?URL-$2J5dQwPQX5U4FoCv9)Cnz_M2LTI1|(_&->u=1fu0Cg$B~WX{oC}am}V5g^(G5yv(X-S#*Bi?d>5P-ZQr?-EWh<1 zn_zCmg!sLfO80~s1S#W_$f}v#333ETrLOB+=-KvI-tOprI8D2*Z_m2K*NZLRRwvHw><@Bq=jcSU9r>0$LaX3Ak&-dU+7(#KMrlbd$-AiZDM)<=l4gkeFz> z&1vRi=OO7d{z)dd+uN2dpSueu-5csr)+}X-{tOfVE>=>WaR0D>W*%^OH1lHoB6v?5 z?ah@<>bO#r9i!|N3##$BSMQEc0{?v%flKil#qIiA#7f)YGMFVB`ZLs!V2ugFX=^AP zqV83vxF%aKVP?W>D<41Cw*zYroFEBD)nr#m=M3WtQ-lpORw7q;-@r!PBG_MYxO*6v z;R60Y^M$oPCJVorPbX~`Kgr>N=ISl#i+UfD#Q(@mpigj5W;>wcYTgt8156UJ-;f9%`dus&q$?xDItej4Ql@C*{&spn~7U*2!aojcb z8?B;xqdDkhV4cqi7HT-3;;YB*GS2y%=MFK8IKT1%`KMehI!C#r-GPOa)?_YEs^9S( zy%Z*)KJ41SdYpW#?;-f(K9_4=dLDLl#se+*j|_7PXlB+$b(a2=IBr_F|3J|P@_UFK z++eg#iPfe^qOZEND`E{%6dsNY)T-sJhP9mrN+?o6OK$^yr83P?$Th}Z%6HFLnl%&D zF0p6Fb!V0yD&Xnv z?F1h4r_`_ZVcz{}No|lEm(W)YCU?kd;c7t@9 zJ+pne%|Mhn)z*VAgifdrP|MV#mA#G9N+IUHOi zwowaeUBFf`jK6{AV;-U8zCy<;)s!LV1I0@RwG$-m{HM)i!s4!~H>5GJwN#L=KolUB zF_(fx$Rwc=JnN(6wi$a6&q(9YQMEu@d)I%F-X>NupJ~Iad14m!&_?_IblqiQZ43Eo2dP%pBkx+|6(WAaY&NH*__1rkk`9+#}tC;@5)_+J3zN@8>62 z1HCn@Xf?&&1Y0?JsGKkgUgTP`&t;o7hZ}0_hOLcg`c@=Hzos^XyTCSPr@39O8LFke z)|NA^lwYm7dP&y~+`QVajN&&}`kP&-&-65*I-9+~RXxT$=xxEy0{PAaI-ML2!#Q8LB@4Nx!Od19G3>e_^$$N|hKX-=+`B=Ap5MVGw!67b znu#;sPr#c%kw_&}klPqUqpsc%-<@2C}|pnp#tBKvif)V`3JN{pjDTqFL=PgA>EFJNme zCenyl~B=&kB9bHq!I-cG9o&db|(?qVj3&#?t8JrAFK722rHJ>|yI zDWQ34C9yhHf*_ST*f2CmmT-1`oOheL!5D{Uo5l5jv#9YuJU|HW8ZyIqD zsF<5~4?YDW~Yw4qdOy4F=n_iDYfunQP#y zr#p;t{Boui6T5qu)C>yfGuuf!gJ+N)wzJo<_wTIkHr=cu_0fj7r-Y-^J(Y| zEEJ~jz1jAB2Xhv2KJRE!Mf8%o!4;P;=zoyYlzeoZ?qwwH_ky`Xwvs1#T8Ljf>t=Ug z!;a0`N3Ud*;HxwJ5WuI%D zTo1LeE=bwTMo+4k#dzIAZLw4$rXlo8346DBz6q1DiZ)AlrhC?Hy;}T-*^_ z$KHqQQ8Bg+)lT~+egjWMCA5flDmhFoY$z>DHazSK*RU$%{l-;psd9yBA3jGE0aszO z+%dsofQ0R|=3xKdC%Ul?)QWfpdsyj%IrJ~(=^7W=!`9KFU2o#JaIz;R zzlwL0JOGUmtNZpqlgElQBIZ(VxG%wd0aighDn8`F^-%B=Ti*js}Q0_AO-7pu3hjyW&F!$iD%9%Whk5aBEY zHhDtRPq3vwNhkot`HkVofi%&Pewn+-XM}d``q%#uu2fFp^jryj3{yrcPW99}8Uu|f zD2~n3I{2!?QI3-GFj!4n2#_I`8=mqZ4&X&USLwPs7}U~+ zx}HW>5dGc7z1975gpTT4tuAVc35>g$cJy#xcSjAOYK9lq=ZeWE$Pus=c68q1yXBvB z)!cOhQ}sUR?e#zXg6ACWV|`~t5-r(!yVX*{J22Hlww2&vXaP0Huz^UyF<0vM!|OGO8*1>q^{Ag zz)E*rYC`rcnqr@Vq;QUOEafz4iE})sVdvm-)}y_%wyFEsS;jgwnD&cz0V=I*5!bOL zf{R`C*x|mF#%<@7l;TE$fr)G00iJ?<+&fFiL%q=$&Y?t6uT!3|J}u%-6X)xb&AQIH z-WSFi{SYQqj!j>oek3b-tMQ4>y`EYy$r5BuYZR$r6b%hz5H(mIYz+5b)5;UQQwJMy zd{lUyRYXg}leeyN549JuBfE|9jy%nLz?0Hya!Y+G*Vfu;wnW408El~s`r4SIz;1VM zZ#i^RezuetAV1F>yyVPHXKC8Kf53Qzz1|H4sdl<=NDNH$R- z+B0_%-+(6Th$y|Af|G#WmgF`XPlDvNbbseU{h~Tdm0btrAMP9GRNh4IY}?s#>@s$6 z%zs8ZxRKiq$K(Bmpk%PcIS8t%tEs`Md$SUfS2FjEAwn#@ ziK~+JhHbC!&zcyy4h`vrcWm+}IGx&QwFV<3TCeXPD|9ud35tI(sA+ynWxX}kMPLt= zYhKg-OnHZF`3+!qrh$H!jq@+ox~l`dU#&sf8GXGz&X}GWD^6z>1EUxCG$>2SposGl z?ZK&AY;NbidM+E^K#|nc{NJ$WBLZGoe@IusThNmDto$Y98mB{s`wQ8^8$}ta;+0*8 z^s$&nP!ByJ4rDBXB~_DY8Xo2)-2Y&Xs7`GoD=VEXr>9n=s(XRUW|v?ow1O>f^#|8Q zTWU?CJqln3`aGouJJdX_s&L@JvhEg2Lur4ggbR4)U(o8_8DJ^uYgmx(_j&$rcvgo-FyaVut={NKBP8ApB(vk z=k&;vV9ecIJv5P(y*+A7Mq#mIcg51Tks3WJ4tcgHSR9j68H!SY=-_o z@_-}oJo6js$5-%Z`dK~KEX3?*9k#Z1>J6->`jOnrdKDqcyGt$27v(AzRJRtpezHo! z@5I~i1$zWAf7y*~Z`i=CsCwnF+aQr8Yh(m3@!apqBh}zK60~v+aAJPbTK;VZE_$ zoU#*s!L8CWwsv{@w7SInP+9eCK{@TJ^wiZG2#$Z)YitlI;a$GN>LKAf{z=h0Xw8_H z2bOB9{dF?h<8yd1XRI`{>=n9@zwfLA9M*a(4HaejNXuZ!K&!}RprgV15?bqa5@Wy=M~*O647e^fDnn$060%yv}&3i_d4ZBPZu| z@D{57q2t(Ra8&CRYGb}OfFzJ}6ZIJXINfoPV!*$F9ak0au&ph+>t&hD1&_ zPP$wheDnW_vmRf(*BVy(+6{><~hEFV}xAC_QM$mO&==FVX z&?nSPYh>Vyi&@V%6+BAGRGrc$>lCPJ&eey2+1W$*9U|muFfp`Cn$EhdyVhm#n%XvO z^6kv)>M%=5o~FNN+J9aVvO; z$&~Z>`+8J*kSQJd+leAOt@EzR?48hw@O}Q8^2uzC{YyV{og6pJ8kk|`)OJ*_ES_&vEj^39&auAp8?%W;SA0M*;EP^=GHMz#i7*AlY|H6d$qFu@w^{T}HH ztH_e$vqFMcwur{%oJDW+%SN3*7ovJj5T8Hq^KZ|D)3z;3{6wun2 z;;U^oOR#0v(Fl|lSfRNSC292WsfOYc^z|UN9bb~A*wesGI z30rB@04=Rxllyo2Xv!ySnw#zVBwO*CvRD7ZMZ+w$N;pP+p=i{8YE{rHwY7bV$u_Hj z_9mY+04|SQHJ5nHXbe-`SjX%$%BeTWb6(a;V(v!n%Tl z+D0v9b_Q0vv$Jy1c4>TQynTuNHrtKPw>F5_4Z*z$z?kL+&rK$7${+QAh>_&{NM+|7 zCEqH>J@ii{zcD4KJ8IS7E@d#z0aoGXuxr(O;`zWe{U=Boo22??P4k%X2Yb0`YvGlv zF|{YZM@$E?0$0aXo?qsoILz*&a_K(0-+6%U@9PQf=bn~2;5O@YuAMqc9}m8x4eT^6 zi9h45X(kbE@b4Xza`h_u9=MN4@=c3`tX~|@m@U}Q4z zyTLm|f4GJ=0=*^6P^BZbNI#rlEMymnA7CWc z0bEm>f=;TgU-tjw{tc&+iibv1)pI+dHQ^iLRw14(2%aO$_%8moFVc~3DHB{(M@pxtiL-+-YsapFHU7$zNyQ;Ent3oW9KRd0dkt(JLrXCR-sig(sY)VA<4Ye6Gr(k^fKd2X;f% z6KWIjY*}qC&I|VC@>!SG+8hUe^UQ}ibQG|eV+@zKq%zD(W^L4X)DPs4^{K|(B2u~stxZLqvUZcLy_Q4kZ3VMI;B6?}IH8&X;LJR*Qe#$<@w~$*9 zrhjd6M`g6HoK9PVNGUuB{iehuCCk${`K>cT1s!-lYq9)#9+R#!el7DV#Uxx=Yw zhSf_iYW}L;l)ncz$?2dEwToz~_UBf)j=Exqsd*z=*|S|rL%OJ@LfUqhAB0({*sa*)@LxzAfDb*H{z&op{1q zl)6W5GETr}d4CX{ehev7005xuvYMhuQ4~=AFzuVS@k zMk=crHe4SYe4D$L%Xc%{n5YXK-RX;gOY=H#hSjmYe~*gm@<(;ek=ZH1}%6%_|F!q+qZ0Qgs-<7XOv zeRb#n-V-lluNlRa?et=J%X=i~BsY5h2dZ&RO`q99y64%Y`HZTe4>8}VGCc1+i`{P* zVI6%M=KGEgH$=-ob$x!GzCQ0?w6?t13Ie9VQ!n|*q{tPpbny}47jdc&YZM1aXX#f8w7q%9A8c)tfk~vH?PJ|3H3o)boh4ERN<10_PjAwkTIxBKfUY-}r zJ=EWTM6soC7xOxonS=5Ebs{%8E=`%pm-Mws?ZhoKuOnXG!;eEHgFycbH=<-=BQaR- z;w_V1B0}Ik*)%#r=h# zY}@lo3hDmi`bVOfuW@d9Y+N?+8G=Z=@KM-7-E1YJ%Q_?93jW9+ zr;jq53SRh8FRIoh>j!h0^MR5|96!bVHQbEo>kTTyxvrUSv;*>HwMjUn&EyJ+%l4wu zh?D~OGIJdACG4U9^mpT%)8o+#oU;3*9+INnq3qK9GS!Z@6Q{I(zI~uLEkMbbf!Rx5 zeKIV%GlRGrsfRfZJ>AzMv#c1kF+RhEv_RO6o2P^v#or=0qeNt^9L~9iMskC-vmpmE zkykwu>?hu^l}zmqM!7K`OM3~Aac>!tpQ`x55@oEbc}`DxuV;ak<9`d5p;G!IxQ@1Y znhNLm`N6-`wPEOwg~cfsR|fxGuTzS9dXbmJG1?Tlrx(*KXkSimP@U{g-saAFUr3vU zcHt`7&-GexuK#fGBhDY61HF)2xrTWq8?;n~djb|dn44b@+~ai7ab`tuHlWtpYU!g24=Bg4HUR6=ZXZ`=*8C%mStr zF;dv;x`y)k8N?fAFZF|sGe_z6rLXqef~1%6m5@rX!21)t6*@*#5vyn`;kjU} zs|~(OZpiNnYN;plUWh|^WJlPKdP=H6l>v8L)yWFP&-yX_E>3Ma*kZU-k!k%b5#}4S zGFuJKW+di%=&*2xrs2{^TJ{K~H)j9uBoA>9_yhhJ_HS~KStU2siPSmd3TIv4Hcr*- z;t+w))s+wQBC(CUB=dvEm9!77)Kf)>HUu3xtEsPXXUQ~aecm+o1lL6FOf__`(`n*= zu(Y?P?}>Ip|J_u`ZLW%`=`@|vgK1~XwszWHXzKyZEXgb2#&dPm%lgmE-`pl`jegnr z0NV_kOGCjwTq$`p^CGvR_Coh2{N(?kx2Ai!YJoe(M8_}sRU+4Wo*#rBM5u6~v6|VP zQjQns@9JEVWFvBZN-B74d(0G)mz91_Rei6nu+MWOhWDa%LTNwG@8SN9{0>r?4(OP? z)9jWqU+7dU=_j~TqDMSmjKq54`%_;TPP zW>L;G|8y7S0@h-FHXwb8M(&}d>^`jz=nDtqTUDM>kDr0JM$L)O#zIF8D;Wk=JM(YG z46~W`CE4ywp@(6|V}URi)sW+@0^c*EuQ`sHk$;XK&+N7w*wdAfa+ZkFsyffBG5NE& z{qC4ZDP%X6@%`XtwILG~tdjoLPw4g3nXadZHOpFm=rNT$o2_gm^&8~F!@q+t(umN3 zn>hut=Scftdw4%^gB%dz?0bb2Zn8d-$VjPY_6$78-Jjb!v?QS}*%)qvOVuv&Sc6k% z2I@gewyAri|BOCXq27hN9JXumB7P_c=A#SQFJxu!|ngxLJgK1AlCN~)*-KfyUb>CKG~04KsExUz#Q)YNBNM?x>vBmb`cC_Oe51< z6L!_3;0J9g9h(&Z^{le)s?sr~Gt#J)jEw`{QJ4` z>nOurh5U8vnDL9YIV5M?h$S$KH8U6A=hT!4!C*VIHA%Raa|~~}x{5Dc6rTO93{SvW zGFRBCK4#0&N6{eoznm?h&YVkWqs;XAvVs( z(X61>XVSn%S3CZu-aZe|r%|h{vO1D`TlFH9{RR`w&rNxQ4Pu??Q&Ml7UV4sQlz(V{ z`M3D^hMm{} zxgoY4(ZPDe_5AVV@PgO5ZNYMWh_r$o>UpZng`=S9JLdb$6-45|WVJ0k0c#Vb(2P*u z)PclY+Y4AO_zsj+Q_aU#6HPbVR)M%7{e;pJ^fJ1lS?M=(qY~E`bpYkgN1Qjth(ePr zI@l_7z}uRsV0?hx{e!L3W=&2vw;1>yt9IAgh}xk=pp}0~vmb!xFfBV zS9pi)BGpd%j~>J935MkoYA<#gbzAFC49vEZ8~I-hr!AE`4n^00#720EeTWlDTjcZD zJ{?9cqlfFa*uO%C^&k3R)`I1wRiIcWA-n%ul)Ni6`oV zKy~V}I7c6_Z=zTgBn57|i%4C}@k*5b1eT~zFrFI{e^45TH`_P! zt`aef&Xyzk%X8uE;GLAgnnivlHj3ViEPW)oQCOdo9{LYIq%2e-W;#X}4O0*OskvYjekEA&hBlW;xBQKy?3n)tm z%1i!l|3A!VFb`z$Yv^Zz-n2_ui<>8Bjq_4$9v*5m$BR& z$4#}$P#Xd+0E|`ee9oVd4%~j+T%E?x4Ae)KEM?_Fk~@$zkv;4CV%YNwnw^8af;~ZX zHI3>d&h}1$MWg5Ql|oyLb=oQxQ})m`=Xi6d^K-6%GYl&OA!#I>z_+6d=}F-N_^Z2%Mrv?x37hdS+!od%=!9PY#arelJn5KbbW0Z0a=6X0 z5LT2UaFSVQKS`LLeR&t9nZkG@jPr!i1T8IhU*xa_z9W@2jT2_dZ=ii z1n%TDFd=LskCD$(?Fj(Zu>tTGu>lN)3;7*-HRd|bqzUp-VV7%%+R~qj9xFA?gcwbC zP^PVkYo}E}>`1wnazb72?;x%KeFZn}gj|$W?XG`XQWqnKA8n4v>da2p+6G^6ZHxpr zR2GJpIe!Ws_6(sm<5Xh+OtZLPwtmd4$&L-pG&=c1!4mmBP(QkZU(VVo4UFslnrZ2&C>6N6B zA!oQ@N)zL}yRG5~Z3z!zDq0TEl_Zn0I_$sC*gdc5R@^CdH5*cYxEpxiSS7PcqN}0J zqCvjim#Va+$Kxgyu%56p)LvkXSq*OB#z?axH~Fj5Pq-65Mh-I)CDIK+Ll+2C1v z@4R~YUOLU)7R0*;(Wj+52A*`|9Ov6W^?!IivA>KOL}rmQ;N zznc5t8PD~%ekV85?>r7PF{iTeRa_}<5B?-=W}0hF%oD;id5C!1tZHz!(`1hQhgKZb zQTqi7_#MhI(A8=~m**wwDc6szqjV$wb;TC6;m;G2B4z%-B;7PfeHQMsE@3fwYL-W;wh8KCjb z-}(XK+>Zh!y(7(8)W4WeR0mAccVXAJrFoDh7@(a7`17up=r!`u=gBTwd3lg>G$i8x zW<_G2nij5X%u;JoPsGgd5Un~lo_`b!lCQxq=8bgXj~hZ6RKn5BRm|2X6hI>am-J+_ zZ@9Z}C?BsLC$>AAN8f-s2}7|3t3*tpaC6xnsjYTlCg$dcE2_!f z;YNL`k2`|x1wHjxy6B#3j>U8v1KR?$nkE5%X11czVGy~-w$*kx`$%Wl8mcsRU`D51##655FEj19vB*z%sK@ND7&XCua!;>epRpJf_Ttct`rW+kDo>#FgD z-NnbzBiRFV1J7I6aIOV+9ZrQk0^ZOj;t80UeK~hIea_uZi!;jPiTo%3TX_OmOR2== zdEMGH=67NaZqqK8LV@k(VJgj>scdCDT&%f+t7R{q$YiY2TML_ryL_f^nU$^%eaORRwA>EUn<1;X{MRkP>P0rx}QFS?0|1j0aw0rfwzm;i#n;;@b`KZ zb0Ym!aEA7rOl4BEr=gcdYfua|S2Ovcc&bL?t=whlQ)n8wHLz7nBmW|JcX_&eaE7%V zwMSp+(MAtYLOp?|fZ1S>ekU?fslm)j_QHy;Ev~f21@PhiU#ue54ENc+mDLH6Cy`yy zN9?mmb&d1||A+RYc|xpOSsd@0jW&aBY6jfuS*31a51EyGb(QDhPPG(x$_*sSsBdus zj|4IB0=qi3oAn9qB%}RR&`Y17zsD^040$;^g~q6Zo`H-d1dI6{Wz;J1Ty zuq)S=X#z)s0p$E(1$h>!z_KQv&(On(cvUoKOU3*ROlRmcy6GI6^_x7+bio>&ez+2h8vLmf?`21un? zO{;-s%Ehn)aTPctZUSjmbF+wC&b%44LCxCl=uao=pT!K`lsZ!fg=gv{{|T#;w*en* z{-JgBT^Az3q1r_=gzYz_oad|A^mW z)Ns43baoE8PI&0TP6?deIHb?h(t`iRMTL&J27#tPOs*FgZP(c}{cFTjQeci83V~M6B1%7Xa&8w}FXkU=BYD|etoPN=8O4AlJhcy%#)9_pN~1+; z#5j#k)4SsLMp%%ZSzN0RyJDxo!u$tpN7od@C2pne+L!Q#avQt7gZPtQ42>i&dQ5an zX+QwAq1A=)t4i-X?1 zVHmj23{JaXGgSLs}B{N4u?yz7C*<@13tPu|9GI zUC2y;6Pa4!w4{BwVg9qVB2dj&7N6Z~X%(Hf^sr+Nx@6cR!z>AO)20#Em>tq9aZd0{ zb_3s~+}7%1qBwn}saq(IDd5Y5T1KoqI`mMYj8m{U*TlTab<($mKZ3Q|PZpKB&(#Lf z@UA#9H#XI#PE<%&6=QtTTsSv%OY#qayShlcfjbXd;LY4hxWn~7c5Wt0`Uuw$`Re3QQFKD@ zNFI`|hwoxF*%##uYlv|FZPZ8fF`u+Af!)m2=znd8J_~UR$EYs2fw( zgSFuN{Pzxh=2E;1CTpE7L93vC zW_~eS2IxJI3%t<$>Ws__u!F*r#=Y z*OP8=D1Q|2>YI!*$fdNzjP3W4s`iU;23MI;bXTaM_A{O(7U!n=e{n7+UJDfoJAVz& zp!;&&K{Y-i3o0`CKss7zRWq-1g`lkXoL-S4FdLblT`$>xl?!TjzD~rg|KUDqv;ohB z&w+*boI8kZ%C(_>4UvJMK0OGSH@I)|l?dSPHd4d!$SwlP`n$?RC6& zn87I9pK9Farnv3;>6pIyOY0)Dgm_Mz5#OqTW~rEZyg&4xe}S2)mLN-lhHzHqYFM3e zyS_jY9Z%Wk?Qawd%%HP^iIwIf=K^kLfr1*w8Y*s`m?fly6B#OA&Nu&wkHJ zwPtE(u7U7Fqrfn#r(9Qg9J5%PhGIAsZVhKy+h7kiQRt!{6YsvPl?$q3!~`H zR$rzL)t8yT??!hWVY6;>f>zz^B_Ba`g>`77@k%=c4ta0;1#ffh9?qx?^b@{)>To{B z52!!k8MagU26SJ)kMhj0biufbjfp*ogDscHs zk?;m^if%#lmAKS*?1&>vr6TZJ>i>7m`k^T~$M~&!@7VK1J-g|1gNKkJM^L@gS4>^= z1=`^JMNqUc>U`%Xp(g(y*;y!J9t`e6DeOy}5bo=pz}U%=G`8cIIrLC@e3o8Nln;R# z$`HUhKI$u|)rb@?F;C5CuFx#4alDtk3fs6>u_64Hk+wv=KpOlbRI&hC!>F>KQi|UKnGNX(=qn#A&B=?nyN4hElxy5Li_MAZGDeAml!T3TPw%$tr6GnQXoO@A!aKP!{y2=gs z=FJN|5xE4@B+@Um1TD6YqVB+CG)fN<+g#oKia0K111U%eIWN8C!%-}eenOt^t|QOX zKB`UVcm7%ETi`sXi8$B{9`frA@9BSQVUXztYNA#i&WU{ESNi&Z*$e@;pc?iFuZr*4 zDds_a2R6%$I8=l_U@yaNaQ@0a0kQhd;9(<`RPbEHP49BtXZL73xxX=$E1!PjtF4r< zTFV@D=uidaFL;rQF_R-tdmm6Txk(oOmLdBhN19Nr8TgqI9pgJ^n_QDk>S7`{q0Qg zluY)txj0j%I+~0>n!O@!Hd%9l@Ff(lRRpcL(%xj1#q{j^s++GyrMqggor2$}%El|S zBeMZl!V9@G&E`T+`yh;7V74_ocGNCn4!uv*gAN~ zUZAFV9~ec|a|z8eqhbbmSm7EwP`_ubZhC{Bj@`s`M9pXhRsxAxX-)IC(-(Pqfb#UG z*qQ1q!uknjZ~7Q~V_R>Y#y6nF>_I&PZybv2EfojE2 zT$y?^xHtPNVMBx2`nXY5#GKA;_ct-IOD_^m#iRmmv|PsZM2;5bD=orXVJod0eJ14$ zYD^Xcn-cY5LP#Z!l6@%M+lF1}3Aqbg&8> zF4OA<&g7hDHU>*2uYoI>+QEt0Yh6B9MNQNSAp!4rUNHgYt=Px8Cs4| z_-r-GJc4}LO|%u-S>>KBNlhXL(@9*u(*x{DQ_)%KOYTj*BC|yc&=G7~{6KnGlvPWw zk1nX+)dYQo_n5X-nIX-jD@$K6k#%omh7Nh0cr~^fLHJo{q`e9(piCxH|A}kP)zBx0 z?PP3Lca-f|5U`{vaseR7a?ZEJ0$*M*i_S?oDbCiq1sh5InBOBmNvAVmp)sYt zNuNqo$EKyaW(T1PO2XvUW05}k$865uGqbez2b|8dAv3w#Oq94Dv)Cqk{Kh1DnK1?I zF$R)fEM7-6EBbu~-lcv)b=6K{O5Es;qqp`+#7h7G1Qy^#60{@J%y>nDGtL1Zy>O7#g5?$$~es}tSK zL*51N!l*z%SOe1VoS+}^IE11H8~3p7xIi5fJjX9Fha1I-3w)MQ+I&dmrk9i&SqpY` z3(Pd!WFh)poQdx5*aY7Nz7RFQ0BeBRz-Yj|maEcdwF%nO;27LdiAMnSz_ft2jure& zqBibqI_bQ`pS&LVH!e!6FK}q3o6KCO_dpAwGzJa@$Cej#IAoKh--&U5O8M>mn3yn@KGz;iQW?$^dnJY~aqf#xZ zKe*wYp`MO3m2tAw7@)G)?vf%-L9_8q@Dymn{O2tp%nizZI&lJjvD81PMsgADhkZ1= z+|6Pl!Y*qA1#!ORh0uoGB6q;tk_O;wfKiv(uh`prMM7j>x+*fi8CqmFc2iV}>%>Ur z!zM&9Gv@*-O@DVx=1E3`XT*nOjy}LlF$bWaww9~pX=<x!($Ms%&V{T4 zb4%V(efVR_RO3%iE54hzjAZw{fYEICa5vEN>^A>#jRjrVA#j5a(~YE&&NAvp;Tt*M(VG}$JmWI0Vqq4&XFjX5 z6A1i7Zkj)j4#6C?{EmYB0yZ1;w|*6u;EqRkvz~T7bF`kAm_$KHK_;?F*$vwTAso%F zjzsekuc8q-OJnDvRjxktJR~V^^nB11%`$46PBE{^y;+6*|wVAwA>3N$r|ERc(mQ+d79OOVP^u`_#yEt9=oH0JrxXtY_}f*% zdDaT1)IxR42bh3;MXkP9lDnB1;HeSA%rQoXQ}uMD7|Z#RbXRRVRVT7q-NsTVCS`J7 zq55-*L+=g$u$G3e@R!p4z7Fmhy29e~8C#thjqL_@y>wnpIyN#>sHhI$5`{`)Mf!>S zCE!sfa?>n_r)gxZ)G5~MtL=Hr_QFrHO2j-{AM3J_AROm!^PMBFX=D_2HrE%}>nJx8 z+q$u-*G%HC=~r;_wFdiL{$adU%W$#oWcrN%7gd)f{ks1@*ntl~JKX6$numnTn#DFZ zhr;WKVupvO!Fs+kOc!&$Ny^vR33%4{TrbBqioC-PvnAn?asd(VJ{S2Jt&k^^5#eRh zTJk4(nc4_9O-=u^gkK`jL{kC@&z(b+ef(9$$7V^p90LSVyGDI>H&DG|J-D9hD76ht zRVITLibl&gYvbp}!fYwZMYFH%yQzw(3DZ`uX1-9a_(zy#-W)JEcs4vpO$qG`oDx%s z+xkmqKemnBfcibMM8ANl8i;-z-_%>Ue|_pr(1?6!wG+RhdCdP&be2(4BW)YT-Q8h; zVYH_$o=Vcg;O@TI;xf3qySqD#x5ZOQXK>eLabDcreZTsCa}I}d(j-;SbKloBPOk{6 z!@i)G|BiOm%Jtl&Hv|i~b`TBeYyNgArsFSI0Q-w*&726-Gsme-;4Za_5kQQr; z2gszcnM`KyM_PyzJcmL<`~zWQbFB6_c#q0q-jTiNwt@EA1HA#$Cbz%Ui_J}I7)daz zv%Qsb#1C~D-Ws`IXHw6(OUT?)6|Q$ouX#Y4O0BXFD`Vjok9PANy#!{h0N$^q{PM7HaY+22|f#{}ntTso+t zEuH=rCJ679UZ@3HK@*G=rZLEl%d4ysjz&~tt5Kc1PcK0|awk6uUiIxX>Vph*Co&yU z+F9osLJZ}CqjFpIKzKUPncsM6sshu~*Gx_2x(cM{GO=1Ng-1s)<_fcrrkHm`hzMXk zwBPE!>`#iIcksna+vR_-(}xCtF8H*>6#6p0nm9uL0;zh0?Bw<-ht(?fbUV|()|!b< z=?|1$svsNQmiW#q+1l61tDZ&F&OxXj(fn7RF36H;M)Iu6#*w zA?&AUu$P^LI2jz##T;*M6{;~0)Pl-_bQPDygyag@#h~SDL^Z*S`O1sqtoAs?Y-D!j zOwYE6zr`f0nX8o+(rDChsoY?UHiZ7E&l2O=W9nhM0@ImKNGcW4xwWg?av1fBc}v_Q zJa~TZ7puFHm3@jHWVMX#rvt-{6(++>Z*Ye)!%fK2W;OnhcZhH*vJz1qm%vJ%1?ETa zFT2`xmG0wCMzqVSiCFFlo=};CM~GAONb(!@S5_^@Ttwx{bPq5fBf^U2ShI{?onICj z=(z)zfxEDX`6YZ5yi;3-vY9w#1JjA>FYF35^{o$yoX0=;&`V|%pvcDw?TtD_U*Bn1 zG*r|&XD*Q&<#dF9)R8;?p{^T0(f;we1l6Y6WwASb7W+rN2Ryl>T4;Ta}VGG&$jfL_NZ)@u(8w~wMcaWy^LG!VeY<^ZQTcU zhhcNg)bPo`KD4PG zY^7g}zeiM!bmmU0MSwd}M0e*%(n>6gro8XSEyNV!n>C4b$a&#NWaAOl@#X?+A-m2< zCUU%u%r4+2{vRA@xs2h;K(;Bp4*m~z4E3N!ao55{zz@9wIEwlib0gE!QeBe|)Pzx> zC9Hsc^M2UDe~NDhVuG`Xk-|U7%asan{kLX=^6nwAeYk+KS$t>oG`6^_*8a@jp{JPz zBiU?A*RkL_dN4?r%R7(BocDa7zWqj=?Ro^~?Qrl>f$zF*o@bh&e0?7Ts#@*IK;!*+ z*d(~n8cIwIm6eZk_2uaS&k~ZM?g)zN?Sx>zJ zJ7K>GN1|JLOSiFkx8(atBodABE!N$H$#AR@kELlQz95_s=|EB{Zh48;V52eGLOB46 z2){S0=dju-tOGehOM-jUYLP9To%9g-B2hqXC_9KaZi3b++r_oulPMEU$D^~`;{L=< zuol_Ex(BC8w~7B$)JkB^h5q3Wn1!PK>U%1P_)}eAEuXr70#jb-7O8+0(XZ1rK?%O3 zuQ(qCcGJD_Og+F4baM(&??xY_Dp4J&O19JN3HF8hN4}$UpJq_nS8$k|qiszujJ7il zXTe0%s&w!tYCn|*C6g5QFTlwD`j7CmP$6@Ooy8^{t_SO&eZniQzctM_EO{{`YFlM1KH4rF9K1M>TFwNY>_V$@#< z{12sy3&e+=-(oHj9ohcbg}@VanyV7{A+hWNsfxJLkU%VX7Y;{M34#4B)i!5kRP-06 zKQnoxQ|@`7qL`u1=BD6cD20}tQ$Pjpb!t4_j=!q(p?1KkcximTz7)1{wMR^bo%BlX zSW?wsl1>qP_!*<;b`jdbWs%H;DQb|89G<{Wb=NjmnJdi#o`vwER!ZInTiQYTFFXxP zlAZ21l&uby3!97a_Tq@_)gi{U7;Obta~W$}#I7;!(Rr-R=8bwCrF)UVXnrMFd(D56 zULZ?$9&Vh0!*tJ0I8o1myRpAaLT}0c%O$WKO-4WJtmRsIue+Z7jml|E1UhAyB0SG#UT2Y_XyU1|WO0FJrY<~yy zAs@$!Ojnb&hew*4@5QThFV`xPVQ=c^lUl;v_zUOhSI8I0 z6DHZI`(G)Ikr}r|kdHk`bvD~s%juN;i-W1KzH}R9m7=%>VOze09LqI9X{I;+iD&iq zaE5q>swlnM7U7#TgNsB2(=h5XruYmMr1UwSV!f@~^}u`6nO*gS?4guxR- z4t@ejTB^21?x^o#S0uMqM#r`R6XZX!O+9^;k^CoEINPBNK&i1&tXsga2Z2m$oK*t$ zQa9*l#4s^a22?&XKm9W-k_(t3_B!;=NhRL|if9M5LxNy0g6-wPWHpmh!`^zzKDm9C zpJ{04A*Y~3XFlt9wqTYA zJh*^n`|YWgN*((KUWYmecziTg12%G30hhoBxtZDpHqo1ShI1_Z6mB5?#SEqs_(_)A zd5FWL%fL=F@CRO8oGpII8V$<(mc#cU5qU)tkrDX0n&}!__o{j%JXC1kU;$QE{}l5P z`C*eu9v*V@bX+o;22lAxA5cc~iL0Ooyfr6MMdgNUA@5#oJ~dW*2^)BNC)#8QtuQr) z?=LF4(#X}1dmyRl3t)c`3PDu)d?rK$ky^a4%AKR&ripnGgov%BFIXP?^3bE z8K3~v;%8yS`~=b4%;pwRul=1t0n5$RV~U8w)qnYo!W?meFP6HNev9}VX+=fvsV;Y* zRWWa*y!(^Z1Q}OS6SmV8s1tN8ZWDio7#`J@dF%LUwnF*+k=k3Kk#Q(ET-d6#fjwa? zw*rjUhcXrUMBi^qjMShO`YMZ4%}-ikeWgzdmW&LQWn~zHsb94N{`oM68@cH_a&Z*m zHxjd5Q^_l!d0bE%9y?6^Y?LK-D3!1ag+Aa7!F}pBPf4pmE&Q<>qkk~l@UPW8!a(&J5tXz| z9;Vx(UuJ{nz)E+$@OSv1wgRz?tSmx`6gXzG1X8Y9sIGK6uUO#&yr#%dh3mVz_--veJIQfe)1)lZn@P3G%LzkszAY${saxzzo>4vomoaZ+3dTqZlf}g8B zqHAj{FfTpDaYX*@9xqQ0bP6}uHriX`D<>o?QS=sFQzxiB%?8vU_bZgSxa&TrRhQO? z%W|et4%|j3oWg_zgeA8K~*m z3TRp%qkQ&Avzc2X#}SB_o4%BrjyPkn@C4b88ZSGo!{{8hP&yL)ZS-`vl`ETtrD>oQ zcPku}=Nhq*CT29fUS6c+*?SOd)`He4^$73>p2?|{PCYTUi0zUyq_tG1#Mk0!EkCKS zDRdNb7gp*`n#4691a1%IB zpVe2XYmL5S1H_wcMn!Q&tmSMKOi{g&F|a{s2ekyuGM^`DIZZJ)VyAU8AGkIJ7qYD| zCgU9RDT=SHP$74;;b5mbZ!mOZzx+_Y=1(+#(Uo9jzMNXfL|qUpIs+GQaNTvl44`gt zC0hi~aHI$*WKy66{347rPI`X{Td+9k5?#YgQ`+%U;vrX)ssY#NxnN<|QK>sG!+7`_ zJ5Q8LeET2oIF5a$){&?mGFMs}8!^TU1NiBadKaax)LKlPXiMvpR_T zU~O#^x^1;V@67v5R>BU=s(U9=%@#ok@1{_Cu(qKSQ}xb(5+^3);CXxn!Bfx)PBh0d zYt4rAN!^rDwnrBD#!Maht)9=!&^q|0>rai-oxw7kKnSod2)VQtA7n~n@SOf zHCm_=V(g==0ZNe(VJ%JxN9D)NT6=O1DZH_sIUL?1#Q*5(+#)a!rJ81n1?=DMGI)Oe z7+7SLcfHeJ(x<~O(lhxG@RKiLqxd1_gHR3Nte=PZl-@)e3|AdsK}F9E^;c%ZtV|W) z>gvg$l)D1=h?r*M_Fn9N<`j4l<%wG6?gFcPKg2I!GO}5eHRh??%>TiXR>X6P?JKnw zi;(?I$r=vh^>(3ha^S#Q<(b%u`DSJFL*XW^B6ZXBGOgIW!UWi?e!7yPxkPBrjc+3y zz`fjFb5M3)@Sdto6crm91?l$Lvk&dFie;WnEbpGndeZkQ?Ugf>X;o1BGDG>k$c9M} zOZ<;Ri?lL1uf+y1g>!&B*m6WC37PTo2C!9cAWp{etK0ODvkl#!7A!KW7KmeJXT%-M z&ol(RND7W;M}YEhpqS#H7gr$b9!HxiT&3+TGVpy;A5tHk)94tb7)nZTM0~n-8Q;!Ti=CmKLoAZ< zxsSce;=56&_%B*#a1w5?6E%MS7|_x|s&CvLy9f6Hv4>wqs?xu-rNRExAm%naYd6QM zgQc*pl7Zc#DhRoVA5kySrdmh(8Bd6R^*)*9xas6U=SkQAduXl3z7U8;LVHcW_K;~C zh>onLY6#KFMj#}q$^-qJFCc2d2|X1Kpri3JETm3?U&1YE;r^LYU#5&N-9(ZTp^AQj z2vB*g;{00Wy)m8+)4##d!)fdf^DpNyvo_}oH$)kz+s;eDN9e8-$a1pt;|c1m-cP<{ z4#QMtL0TeHgKO@IxBkXDo0rYQ(3aX`WqiGLx68$^Le%reV0O5erxD1{JQPYKu5uRV z`tT0BwlYY)l-HTo$jLaQa6v1qF5NIF;H-r;3DmHCV&0fbVcUfN#`+Qe6BE^Dfnna7U&ZJ%t_h zv%w@~Ip7%EJ($avwr0ZtyA!k-fITe0Q{ z`-69OxI1L|(um0OQ?Z!a^cwwCW=bYWh_Gm6D$y640a-6@U51wx18tiwI|=wr-xs&X ztd;D*#xl2*>f#8Pi4*P@EvftGF#Mx>$+zz3!=1wC>4IV>Bt zqIG#ye0dfaVq}rap;czS@Pq6z;E%o&akQJO3%Twzll9D1H1a%@mig6A<70DP^93tw zdSmlexRq5AC&7}W_S~-q7hTm6SJw%STPD?3`xnN!M&o^mS>#LCQ<#RxDPPoD+CC#o zKV)6weT0g(Q2pV=NEL2Gq!HFjD(3tzXEbQwZV0X>G?$lW&)8EF-%BOnrLZ_Vqd|@q zo$M07(V{ScuY~AuSCs2Yf;EsI;#-1GB8AWa>y7nSU<7&6JgUty-V3QHtMtlKm`Qau zfyby6?LMbe>A{UB3}A|R9}h523n8IE#5wnt$V5&>(w>-op4>#1^|2%FVG+H zR-qx{;RbWjBYPjdO!q_EkEM~1RCL={kRN0qJmATTA4E^Ed+3&`xThKUeFf5! z<*{il^G;-TxLYo#Cp#L+|1|M0V@3Cy?HL{)U37W_;NQVl-n}e1#5gL}^FjNSo+lU0%V2)FWv~$R3jM4dOgx;#J-6D^!yMzl9ZbeA z;xi5pr$l&DsE#aVzmT)%u8Am(#7*rc zV5VnuuujrbgbUK`-=MDd z%qE`@rd|Mhsy5_Ks&SDtSl<|<|81OR3T}U`jbj?YS^{GpH*XUqgHCt_81fwZ63bFs z<+XNifUPk@$y= z5()&)>Ho$4tM%hfXc1yCeIfEZo5a5mukp9WZS|dcF2I;^_^PZU_6Po=QJG$CboBjV zrVDOkHvKMV4_*P;S5xS@;sR}^D@~M@RrEaY(cDdKlkb`5(|0FZiayL-+!W_-?V zUqgy?4_7;)fA_S+GI-n2bEPZU!WimqV&C$jXO&w5-s6*`XzoYM$;esiDz#7?PjnY< zna{YUdN3F;Y%?4aGXIP2%FUpsF_oPIOadfxA#y%2xn25La9VH3?bNDhTIeGATpyfO z1*<_VL6gL1Oq`yVyXmV%7c~P?IpvG6limgj9L(i9`d?Olh#22>#r>I1L`HiQ@c4P~ zK%2x>1j!o3tUpwOdu9)zk|RZoEarW{Av`e7nUmRvW@DWJ@mw3GB0C_d2jWQ76h3=p5#KgK4hSZFp#_Fce!hzS%HxGbGW>H)7{8-wkc zj&vh_m^|7w$D7Sxfiw6zR#-hNv9_Dp#2ljU&M6bfBu_ZJNPkXDgT0AH-X*>fY0Hh) zp-yZyc*TB7*M<3|g2FWvCOJmEjBUzng(O{$njf(viJm;pPJ9afLhH*f^Zsg32G;bQ z*BjE~x$ExI)+=TWzmB%_dGyiHJ!~%5f@)(P_8;{LHOes91F+~-x`#~rlJCrC2YeSQcjnu*j9M=o0W$>j$7D{T@h0Cehj!r}= z#6p-5X0-}r5OmwU&Raw0P~PDrN2nOkbXVd7?iAIVYlR#c88tx65vpKpBjyIokUL5x z!Ca$OmfyT?o>V{bZ}GLd&oF}Z*^S8G?$E1%wup6;ind0fQ2!gaz_|)fYw&)ZIE^GxfSm!jGSD(hK;0*7oq16$#9p$&v}fzLPMc>!cWGl`Se0&Z%_eD zC45u>H3RYyueoi)Y_$uSY9ODlmMI(otyDr2 zpt$)*JAZfF zHM&~8#4(=Z%nj~!@(^JQdELIFpHu2n4`E5)L@^J$m2TUhAYUNqI(JH%!GC5~IPZl5 zXeQOzDr(Hk?Ga8?^U*EfAGV9xmOsap7XD@qQ&rH{*CTq9J|KG@_^NcHX0V6Z zg^0MYn4T>v1T+r@Yodh0jThQ@yEf6OL8ghKj? zoEHkOfX9tZKZAh-4FPfg|xxl5X#Xug_k?gj|MtT`TBOjGs;4fr!4#>Uu zFW6xwotTl^-s8ZA@FhaOeXH0DLPzqJBgl;ePv9E0ygZWnLYDv;2lMh_Qx^?F!D z_!HTuj??<&3f6^|SV4S{@!)HBA)X-<^omhf?syzeiBbOpkObm>Qot|{XOPGb}P1G8c z9eThX;oi~9h&+MI#&{;3@=2{-70fP?2iR!6xH6g4wM9NtO6RJYS@c)7I+m$#qIW{z zUG3kD=zfdTTwfGCq*GWO;eb{}9n2+feJhnVu1c?{{lfA37qq_AZE2<$nQC z6gAc^ZthXbSfvp&B^n;2kBgV_rTRIlde%76!LP5jz(0>FwRa^ohD^3PaX+n}8M7mY z11GUpdOG%!-{u_Z9!y_hi>q@?UL5aD$ehFFF?s)9tCzf#I|i$GCi|}IAl^moCJ!Le z?<$q0{&K%j&S5KA0S>h$@Rwnj$b(Fv6GaU#q#iPcdL!x8GPitGfMd4;6A72(ei+msh$>n{d?XvWo zN~#)Mr*Bl2f>m6qyi4haG7+n@Z7@*UfmfuqXhO6|&|j7BXVM@qo4>&~tpxiS_Tf&k zlTg3(9n%@@jC!h7;0uc*JI43d*O9G=EaMq*!TQ^W+?51wUXfPP6M0jrp}C?KgOz+1 z?CD$Wo&$rJX||H~;g{%Zu90N2;mLAQDc~629rOY>4L?@{t7X>0+oP6q5w@iIC+nAb z6t!M{Qz`mDZh`kr);8ZeP)r<2Y-cl-$=Wb2;A@Z0B)h%+0Uz-FKUG2Yf_8Q)@)a$@ z7pW=cEcOrIMd_%zm<8fE8)qxBQDi^mBVFEF!4JUZY8{xrv<=v0_H6n>qMmgrZm*ii zHdRl-o@}A?dt^WxkDkhGu8Vxik)5+i>*ifSHVls1a3nik_^mVy(+7K4vEpu!ge+IT zFfJ<=1Sp4Gj~lHG6i>Kx_+QRseluHu8(>{eKWvoL#+ct=HS?Rf+k3ZxVRQ6u>j&#c zgq<)q%~3)-2G8IR$P@m1vA=;N-3GT@|3vkoPy4%>KD{E`&JU_zsNvi|YV0VYx7Nbl zMcV|L>lMW3bn`$rdpjK!9iPy|wLn9*YW|3M4v~V>tikdyEuT3fm>TLQmDWOVN;FGv zfVS8bUx@b1Rr$}s>B%!`%e_1Mt#B@>xPIjDJaZ=bluK2PhpFsYtfZY|QT3>Jcy zZLhV;J&GPG>4Edw0&_C#VkUTd8C76Y=@;q|vO-wLO*CPaDf0v~Jf#BmpGqi0DYvwh z*~C3$wt~*iRgpVv6DrMLF?dVd1zrW%NC7quMqrYEBABS$;#|}b;vBiwJkGqpdT~RP z_ew6gE`FN6SRE%Vg;8(`zm{0Qt%tM*xZ>Mqh()>6CQY3TzvI^sgGtj$ro!NM{L$c7 zB_TngMxy=%W+$pk;I#UQ#@O8DR6$b-{Li_}ES21u9?H*i&6cNPg_6!KKBw;FdKnGS z+aeFy4BP7|gnIrWon>v{FEK;ePKXqKBdH!1kJ|d*u@PveQzbCo+D2@YQ@CGLO%i9C zGgY{sN_Tq_%2sfmMy%I$ip)%l1|}ZmnapLWO#)X;3yg%z+zz@-@RmMOeW-by6^T^& zQD6msAr1ZN%x+piYG`^iY$g;1;J{X58@rF21pE=Fup}eJ^8~#E{pK3H6xhnQ2L(l4 zx+Vg+rSC+h5gp*36VIF%1396_LB(hoplj zovE;&gBOgvS<`aY^R?k2h)m~F9FxqxQPb7Ql+UqUz8tIssxy1y#)@u?2?%Bbqp`U{ zeCplcSmOJ{wDsl*Z}2(n>-37)zEV%+s&A8L-@ZSp7cUvw%)FzkpzrS(M`Nq7+oS%W zx^Tm?#w(L`Jy;=hH2LGHcy)Hxil6yo?UaIV?|avJjfN#Vp8@um6`ohSOh zl}d504*eRG${4B5H=I^ExYnqKav*J(7ui4d&b6jH^YaW*+4!VYSDeXjy!m1`Q%~LJ zv1tC5&``W-z6Ia;RQOh4Il?8L?cHD}d2`$y#0~J4Sb_VbZO5*|QfQ0iqt-`ux!$mVIz)X4 zbnvaH^*XR1=Ye%eA0+*gS(`KyC;J*2-F-*mZtJnsKW6ef~-|;syN{Ul;gE*<$#04aeE)#j2<4LN=v`)BAO$T?q_t{n0Xrpp4%_^u4 zCW6pIorULvh3zK(+K9EKD7ngLtW&tIyjE+ckM{HUMby8!%~k`Al;45*n9q30EaqU2 zNA2!AEe*q&$UT%?xT);|bHyTbP^+rHlfP1%KwbBJwMJAx{Nb91rBH`mF13e@ra|T^ z_A?$qw~SZnJFB<56J^1NI_q9bPBNb(bCn} zoW1B|egU7PR|{=Y@*`KnKcH#oPRtkSVfCK?VzY=Cc-Ir{F#=tzG44LTH+os1F}N8m zR+9TWB;h5Zc~?>x)o%|-N@rFN|J=nuPOke!R9` zbN_j+Ae)aLLQkXCXgh`9;4Lg|%>dQ7>26$GgO?Q37^j}8tu%gPUkp`kj*SaDPY-3X`WSlzHu?J~^Ym7BU-X>@3RLEz7D7knq|r|8h74Q#;YRDM z$?kru?9(awfVs(CBzPZX0Y9-cN5)l2JfRlCM(^h%{rP+mAu`os$iDbNE;s8HP8vt8 zW^kUpz6Rj!d#;tTQLbkt@ah!m3<$iA>1plAgMMuCPk~AlUC_LqbIc^zL^=4 zcDl@*UsR6zC2-3aYFgp^R@-W~#A{3=1*PxgvJ9uS7Eh+N%SS2!5*b{yU=4K7|f8(BVmDp2C2NZAk>V3wf*k_ck z+!Xexj)-_%l=gxd47(ZyDAHQLs9g43y{c8t`v;$a*yXjoqqOV6=j>*fs?TRXs&9#k z$$f<_#wxffy9?N*Un179HM|31IbV6ZAQm$F>vO0_W+UpReF@H_{-<|hdJ~V>wW3R0 z$M>T~vi;G$pdtIkXvkF12dKxvNPcw(GYKcE`urcYs(K|Xph z@r?PKPF8k`@6vYFE#qCv%~OgYGH<+@1-Hn}#cK}QRoX~0_v`IZwwFow>xtPVXfjq# z`)RZY?~5tu`)KqLy?U95tTW7XYCRJ`e`8Gv-LAyJiE(Usu$Zc1uQInXEj$lI-K@_@ z&TMW1wo`JNJ%ocs7Pu-m6o>KO+3{3st6^X#tQQg@6A7BF$2HV#9D!Srf$g8rZfZFG zHe3i>VqK}$-q+V%pPQ%G0KJtCve&L^Q`9=|7p4q<#(1nOmVT--=*KUU`lw~q(pC{< znyuzq17C2pj9JLz_$OytpdF|bX$xbM#(7MtKUW0I=96JFwurBip5VI2)C>K9dzi)g z#H>56J~2Vl4c}3Z7$-X|u+}O^RPauqKX7w2H*+kq)V;t}4Se#nu%LFx7?@=0SCDhM z95smOVE^NY5|X5$<|N&#8>s(OUtOuffwoX^*v&H_A$N|)(s5x zmeDp?e>)4QGri4o3hD>wD)x3`o=1pzkI$#_gImOLC>!%K+Ok!p+uX(APAji@)i;UV zzzN@%-ZMZq|b!6T)R-RP<14#&+rz851!x->mw*lyy@G+J@wGO z?SaAWVF$m)KC@%_9o|;bEBHsY?TgZCEC+vba3`IgSwls63Pxsud0-S(R@-Sa-iwsu zV5Yb+t6p70xtEy7SHxE;K+x~GJ-PekQ|c)2j6GYqpfu3_B0rQn5HB2gbGz!VsJ`64 zh{PhAao$hLDQv7-m3H__25MnP;bho7P*9yM|Kw+Zs5)Z+#kIkEFqSsZGnBMw)HC4U zW?W}3hs$Rkmn$)Y1Sda@7zR6H@8~#R|G2+E`A{>Ibn`^2P`{XBL>y>k#~3>lm;Oo2 z2P%g*!qe*0-9x}*V_VL@;DX8cp5SvL{i5=MMr8RQ%k5{H@~VOw=vacwi{+Eg3i;hJ z=9=V6NyA`urY_jQmSmb+*U-k~JUv?Zj|XId*x%|a=lSqz@R2NqO}3B9{fJ$#RK4fQ z_rMlDF8iZ;H8LDjbX8*B!#Udi@I!h|s0kj%hr@lwE+WpF3D+?>;P}?p>{oaru$1Y| zO-|ckE;VGfdJ+ewYln5z&LLI^@n8&}LhQ%y!Onsg{fcDQ2<9D{zuy&Vux-+Z`xD7d z+BLhn{z*vG8p2lI6vNH8t_42g2K_g7Is8y>VF#sG z`XX~Y+~HUumeYFjpW{MiY5gN$h0XgrB+tlb5H;6Xl~ zn}n56aH}x&m3{>0>Yw0MybC@okRSr8JkvRfBJ$dC*7>!0*e)+@m&f6+^N zhIsR$bLA46bfn|Qsa;fOIK+BOyrp`v@3Gz1Kk-}qN5j)I3W9CiX?-XUQ9ElK_K=Gq zhs7+xE8=rPmDK6@4m!;3CRB7bSnE&L$3*@$%F3cVLHjp5#Z`y78u&<0vw;9%IQyAZ z&|7eKTTjWFYs|l`Fg-t`lTZrSS_)N>-HVH$p48IHBlV#%3(Nkd)n&1 z%$>$G}jdN z7Ym`A?{YNB{LWTp^E=~|J*q7=W)=`L;ZRdiZ=0J#>DW-Nr&@q7;9JhTk^j>g!LD$x zewi3!e4#Hg4Lqk00_6|5Y4zkDCa#ZLq?SVt#_|5%V4gCIsSKT-3g&#eleU!`;F-dW zFh|jylfB^=<~gv;STFCedh;ci<3ua6kq?BvT2Z`!`YPr5i_BoLv-?r&sg#Ub&9kz$Sp^TSz-YPzogCiD7B%z1sbFeks(xS1wZ0ltX@>^M!DAmIOoa zC}y5|2UHh(a;50n$jatrO86wSC*#3=aza`Mt~t0K{AwHkQ>1@^p_Y(XrewIFH3&_} z(#2>vnVYX1=2p{D{7Lv9RvL2P2UZ^S!y238TqDDEH#<$#ivxWe$IA8OW>6jAq?~c^ z8e+EY%Bjv>CKe^vmY;bWf#S?<=0otTXQR6bxmfLimki%kKdLV>Y8ZH=rI00E!Rm>x zm@QgFIpBMN`ga|(iqrLlpN>k#&j1Dc>7vw+-@xtWBUvG{u6roeiyDbHQ#Y#}xmoZn zs28e49YA+=zw1}9Z%iF(t?~$U`IBAOayEcmpBo;7b9}GJ@@XweNj-?@QIE6N2gmSJ zWQu*vUQ+fFsTpnR|Bw4-yh@vu-H9uhRUUrldWtd9K{&%~FTJq`d;f}bH6{gr>WApv zo-N{h`Yu<)NE7FQT{$Fqz*i<|H=U?2RPTv}gcm5;wu$(Q_yT)z=gG}{k*r^GEq}3q zU;TZWfMhJ;p&7@g0xgQ+aM7FW~#)I*84lyZPq{}z4f5%H3(QIF(n z_#%`@-AMf)*ElO$a}`ScOm*W9;Tx2SIcGAX=vIM9q^Z(~M-z1OyfyjWZy z{CnB%=1os;ke9A&J;S=w$3ff3OXtPxnaL-Vrh!I8-C$}07b>5e!%W0DW))w|Vbfco z5xPLFh#XYr=9K1M`JT(O5!GuQQ`Waz>*^`5zouK`|B0)$-f$E9S^q>;(nVxvaR@CV zZli>@QK*J?0l92Z?0;4y@JC*%zoiGDUWVwt3dfrdLlsqvUrMc2pET&o+ysa8fqp*+ zlrx|imt)?iOzjMnh_^(h#q0V^bb{(cl(qi{|6!B0DcT)t1 z@9@QiQqno)tJKdX*cZ%NxHKE3p+G3AnRt$%xR0o}Hcsg88-jYGPFNqTLH^-kwaQQx z=_qcL6y_Y_K^*r@%`R_@H#X>ftk2qEaEtjCI_Rsya?F5;hx-pc@>cf|!dAPJPHS1P z9&wB*N?%h)NyB!IQCpF-(fRG3H?N$-rl_oEtT1+8Sek6SLU-)%+6!s5DYF&GSKL{x zpH?GqUwEp%lY7Hy$}d+_?v4EgY_Su_W}r@vhb$64K{g8U$aZuEFD>=3FN!Arh^`U% zWjzzp*e?1#Uw80Lz33ZYrYi^Vhtz7YNlg*j`;XAy;Z`G&UlBgeRwMo~?s+HiY4kul zKebrPZ%+bqVBH8J90#AH!p=3A%5G=>2^B$q)3d-Bm=4D#)boDS&_sEe4SYvEp0YsA-?N`=JJ?)Tp9#I3~Y zd>L7hIy(-!H~F?>GisNThdL>86x%~=*PsEofyn1=>i40OGV*6nqh_j;%}KDJZ%eR` z`GXw4|D64aD`jpYzEg*l>$FANtb@5BeZUy+etNt5)EEN?s|m1ZurXNf@`JA21^!B{ z3-Gxz?g$q>TCL}CvS}(|Ugi)z)XcOdXmf}(Y6&+9KP<+hJ!!OAjO%WOm?7cw+Cq6m zunSw49UfRlr&=wNifN1Fi^#xpL%L2T?_Gr`*z5IK#6y|_5o(9jk?hVU;(f%+jOi-| zviW=3HvP9yooOlL`bv?hnUc^$o2m^4Uui`dK%R}MU-u`XSUyD!jtq0FHCWp(F7}Gp zCgLdl)@bCdAv?VXwV8TZ$Cr8b=P9H=&gRkUWEle$y4Rps@tK82y=oYk{#y)0-IVr8YdMNAxUST)kroWS& zV2xn6P!&-Gt=Zp!r`An7H*N^u!`?_8^Jjq!_bnsW@OaN!&nSXesZ2y%ij8m%Xqw%g zK5J}t*8-Q-4SWy01V7Dmi5*mueT8iWC8?32jZ(?v=46@Nk^C}iHnUa3$$G+cZC}#U z#M6M)xVYZ(a;$VNhBz@Ac9I`tc3~@uvsjys6~Bdx!c+sdr{U@9^oF@y7-~{O>NB}Y zYrw3q%S%HdDX2F-T3g_*3Z7vluSs_!%ps8OELzWTgJO+y)um&tii2O!7Z$ZChjYfWHf z)bp9a=V$6MMTPa&6Y3oQUTcK77EKYUtdsgGr=J(1??M5qCD|RHi`1w}Y7n zUVAR6{bC!qOTxHN4eNs0NAtVR>v!m*Y!kHe>I!4PO|%Q`WX{Rh&by>D%t5<}zJRL6 zw7?eA?@SS9qrB5KunC^f@&uZzHOztde}|IQOsb*2IK7Oi)2Fn)^djXpHIQ+E6Imf| zmh7dwiD_sz*abTgo})gqK5B#2HDYRHmQe&$WgP6710|AAa=GTM-~ehE^;P@|wqY6S zI<}1Cqk37rr9CH(2;-O+LLt}(Bnt!70_JMldnMlwZ`mOcc;Po!}xR6&D)#Xaj;BpX(w1bqO9~>*nnZZDLNe8B#=O1i7@nTAD3ju^TOISYOHs=5n>CV~ChPu!0c3vUs3 zljE(aOtzbaI^kuCnrPp~AtxxaKk+5IfG8wrmM?o7A0ghk7s7eko1_}*R_+cm>F!0R zsnsYAn?b(V{Q=UjOU5gDn=9EQ_|irx9kL`Ci;U&$^*KcK5MpqOuZ{y=}NrPjvu zsmML7U4=(#UXg5SB^JxZ!h7aU_g=O#d0g*G*f*cjtYy1H@LB&~D>%|qr zE2>xUX8wh)aF_=F&_dm!ST%Jxs8xQQZwUB# zKgJMhzx3M|5g(DonPYpdxD(J6ut#W6<{AUfedni5g&AX>rUFzc<(52(+KK;xrM-3H z&*MFVZIpRQt?=}~Rfdul3YOfU`)# zH=oGIbq2SNBRuv7o6ywVN(K-OAcHLH@jGt`|8nqTpd z(iFcGo?^Nk?TkjdXHlEUp3+o^TX9-rVU3y#_ffZXJ~!UT&nzWd>t8%7F_vRJ{JyjN zE;>n{5bimWHU0n)ErwERnrT)eJO${Svx({=2p+J}Q{bVafZ0?bA-+M+J3+7XO z)DN(^R!T2|$bDx$oP6k`M4kHQc&QJR(0K>6WY2Y%zH;Hl_FGRyWP#}_=H^<-bF zwQ(O;K8eIq#ERN3t})>SQ!(V)ruF;Ze}iT$4S~Xz%OgZfKQ4j_1BcjfP7Yw(y*58T?4y z=N}-BMHedpyo@#NV`UHM?s&>iwq~+*LT$OPN#(HXN^P+c;nxSzUBEu!Y*ro@Ve&dgpTeqY zH(lr1Q~D6y&PjSHah&IefHPx3Z+-+*Jgym{$Z*;=ece%%BNQXO!N4DCBJl`Yl6k@W zq8!S);P#mFu^UFL-UueCV-cg^s$+^Pn;D1t-*?H+Nhy|JStCVr=cs7zFRMKjVsFaD zz+7iiFos*L#uN??MX4QOA0w5kMHkn7LIU>#WGN0K(f!_jMa^auD- zl<{Xt6G%PuIIEOWS=>d|*UE9#lS>&1Cg4ty{{)w#M$H*LBEc8!q24>cYbjKFqQ5XG zFM_PYUhwBo#fXdQJF`5nAv1k!;HzSjZiAXlPr{!0=V5#LGk?R>P+KL$`4ZB(w_GpO zBrA@!&7G=l&1b}I38nbz;uF{gZVH~}c;h;=fa<7CR+}kx=pLSOTvf164#(zFgS;Fy z-N-yNT3bTyuxD~tq<;8%!@Ef}N!4?%xXh zT3h!|*OAQE#~&$|K?%afz9eI8dqMs?9RyAfX8|PnCde z6~Z-xsYF`d9cm1_RhP-LG0oJb`tN8X9VG@}OVW092drQYP`i5eOW*t~c%%(<-BHfT z7s$8R8McYu3#%pdrso8Z_rP!wpX`+hXYe|C(Ok}Tb{~+Oo?3Pl@$S!p>dK*v<@!Nb z-CGx;SuDSldSz8ct&=Fv6}U!Er~G7j@Z{)zy)jq}x+*QT(n(X*!^pEW7T;uMWHwR% z2<$4ntK6fIqb2YfP<95*alAwY%ATF3ya4m_TfjfG1pvK$`6h~>Js|y53c5T6;Be}6 z?A_2nwHU_+TW1gPt2(DIVmgsOkb5{at9hsfeI+m+c92%Z$8jH(T3P$as;=@pMgeu8 zF+cQEnamGF()S3SFpbH&_oe>m%8d<^UCl@sdu=#j-7!Ha{C17?=T{FsTwqd3INzF8z>Q5>x zR403u(JD8}aW!dNxBtCk1oj@gaI|8O=T}j82top$L~gUF?{fRJ*Zjf8fTxwJe&Es+~KkQXdg9@oY_p>uYujB;f zJ@;?IT=qD7_+TaFt#Auv*=*7dX$HzX-p)TA`cGx4r}(aB<@6fK@3mCqjB(mUwmX*2 zAHkE9OEL@lC?5GNoL9J>ix|^UYOD_AI4s7GLtgHu%sF{38|`Qh%F#hK8#D{% zz-`(NV?4l-yCNcIo7s}-##TqaOi!Yp`oq~toQ>J9R7Ni68t#*z9GEyeL^k-e>M&YI4F_r$o-6M1F;kVEDr_G8W9C|FGIt!;5FG5z{j z^M!Vr$<~|V*W(s@7JK`{)4|Tx5NZrnS>J~E#D|zgdFKM{jXRBc_?B=bj63iM_Rc|P zy5Wn!XLAMC!1-s+S=2epf&=v}c~!v{bUQvI{~B_r@qWmzq=D~^{y@2yzkzv`xK=L+ z8e&NsE6w71#P{@fKT<_FqQsfgw7qm$<}(~3lp*S~pKN&0pPB&lqHKkk#l@2 znc6v>WZAqWKi3C4H;XX^XHmN436;*A0tbvgl$SY&j5u&5uN$!kMyWqwytYvG3FFP9 zW@CH?9VI-<$@CnO=9=e}LF80$$R3-o5>9KQ%_OSke-B>eyD2<3S6D`tfYZ5H{V;kf zt)^4782p?<9+_n<4mK-V!B!W#YMj2oF(W%c`KaZY9h6aw)BQqa)n=MQt;E%&=E>95 z_Dm1|By|BD!iI7Oh?~}KDyUQ<+fW6;Ed|$&a$thHguEB!!fq3#J)KzODs!P8SXV8}pr@<_0 znX(z4kr&au@z~~xp=az2^9;R9oob};lgt%tn&4K~8YS>u`Ku5gWnS<+ClFzU(e`J~ z<GdturUK`TR_Qoad%-Ys~Yc%-f5P875xShft@ zMT&eQBSGm6`|;DT6pa&@;LnV3xwSTeOK=9cdraTJXKIW*mh9;o!T-wdOE|;Y!9}Hq z`j;Bz8)5ApeEGBw?~DDchp)o6jepmR8hxuC4aRM^h z8$fE}AMR{)VPc%}c>!Et`KVX=G`&IgY^aI|)n-ITK zCG%$nJb@}o?xDZ&(ODmjO^LUu3kkc}-ppM&7S`~r<9m@DHr9BdK4%=zY5a2d6FxE6 zhv4bI_>J7DD7$$=c!z6f*He$nW{2w+g#Lb;Lc3DgfIdS>7UnQHWs zyv%M=ZJyTFXd@Q9(h~3p?IvH{vAkf1@>igV^+v4c8^g7W511{jaEqIUY=***5}%nYErcR^FTXhBrKuB5Wp6R0rQ)CsM3_Wg*`NE_&X(D`BkiJY!n!EgY$i$ zz@H_)7QWktfU>S*glR1>*NFpcbKoD&@~}?UR@54NV~?WU`fBcPdM-avY8mq*(Q-|I zx#9-%8=RpZSGG5~tB`rG<23GL=nQ?CFKrgd*L5Cm%e4wcfHe88^J3vpZjo~zH_V`n z+K#*V%N=4Af1GG!{N9P6q;(ZSGiTz7xnq0^EO+AOGm>21( zbO$|PhN900j*wT?5}+Kf$vJYF!YHekP{X$Y&InXa=t}{srEuRI7wZ#RGCj0(Y#2Cc zjivv@7nA*9ZDBQ6#(0wym)FOW<4!R0wQ5wn{TFsAmI<~}kvxv&605YJktF1LhM}Bq zYj%g$(Ge{?Gpfhc%mQ3h=>){3w-DGr@dshOn2w2etQatZk(oGc8Qq_`O|^>{f=K+0 zKz;nHxd(YTP8@p#xUU$h&{OroQ;`e-ybNP5{_j%f~3Sm4K{DZ+bC8`}$CkE4)wxav;klGH)4B0Wo=57+5E)nWMMn8Vb8&TGJf9_t*1PZ#&o|7r8|&h%q# z5c31&rUu0Pk?^04>z@!ItUA9*8KkWsJJ6e?Uf2m_ziP=;bDcnbbSF3fTNUN94FcPt zU7H3zyWer0#3=6>X)|4ha%JzOBi%UlH=zZ(B7R1Wk_SF9m2Cm4sXEqnpBTvAO>7`_ zpa|vw_z(4XQ$TXe98eYXqPicwPf=z!I49As%)pzAjk({+B<&p$2gZqgL$`=p+84o2 zjaPeG8h!DIjT+8Aq~8VKF#X{?BOLWJGt6PG*7PIcsPck7ix1$cd2HNS>JYaSG`Byb zmfF7f&cL$R3Z@2k#N1~3ov%Sr{wcP(v4SY?c^#-Jiq8JzKF>mfVJ=cP*q+WcXm-@k}3_bHF}}8DB6*>&liVrupBPaasrD!i>xrCS#u7 z+H(3r^DyO+;E!#DuX1a)IM)!9L^p+FZX0%x8IN4Ros^|~25bi=>CKsv$Wi)UTdtQu zT|^r`So_42cp+cG2nBRRwP{vZ!j(#LiA{;k@-Mbh!WgP1xi2#bZRKM5n{XhVYR$y} z#)%bW8~aud1D|0pW|?UEI#OPHk*jJEf7qsd(yF*c*AbX4exkGZxAG2s7V|BzS9sv7 z%v@D`P0Fb!To&;Q7a~48)2JM}O=z)Ezvy5n%_t9JjaaghZvlMHeqnp7q;L-W1S|F5 zeFIe-j}w+NG2HB|)A$j(4EPc(FdE4<9G9`uDB(Opp9BtYi`Cl3V@I^KHLfFlhaP93 zDK#9(0|EtVQ?7B~K63zWC(__UX{SGh`v_XXpUsvVMSvyNQ~ADA`Ms@! zTs7=N=qh&|Jag|rd%OSCNqBmr1RdGwFGGL!Wjr7~TP1F`R4@6qujJrHM7^)&TOS@+;ZrI zf4HhM5wJCM?xA1xi{Y zz-{kQ2-Vtj4ZEEhjFon$`YW65jMW+ee^Y-3+?FFb-5O=2NHxs2*^x>P7AdB%AH5s$ zzgbgaYhn};?kpjm5f`hS_|C)#BQ|b`eTTjj0*^2h~;Ip98xceNGEpNN$}zT1!Rp>d2A zkyFOT{iGJsi?V8q9sK?EHK3k!k3BEFhXZMH zCE`OYG7X}Af$@=PhS&29gW20LypXT<>KhA@u}QP<{hky)k3uU8p>sOH#(oc8aj!sGbn8g_fbF2Zj|0o zKQP*wC)KgK;$7Sgm0yjAYJIFMwt#z*ecOm;_CmLr6=5i^6#kjR)Cp`G9?i|cjh-n<*-@(;$Ol6HyXD zjT?0%$CFtgZxc{5CRmZ0z9Tl6)?W#b~H%Zyl=;ix9;$Nx`w zP7h(3v=Rl!`>J9dtC&&OGr{>&{{|K-PZf`CxG~t4PW^P3Q78T{Wf;33YoB~q=%^NF z?Q*zQLK&=;RR@L~z8LE#)ekwt##p1ZD|9ErjrTBi?>y<5MAFJ>z;Jc~K;5oS8cV77^CTY`%vIzlbnN7f2J&~E0Y>4wA zi}DLY63@o1F$NeRdug(1NTqpmGinIU#DO|5@Ruvw&cph`nouR9HCHb%%-=_9%s&cD zl7D9}n|{+T4}-6A=Tqt2HMlvWC7bG=isr5}JdZ#FWswjAefkc-unarcH3r^PDi}fK zy0~0>7n@7YrcZk6lhc&7!d^JW8VbhyW@?BkYUI*;VYt>UxybCnzQXr`v(`l47paOe zg(=1T>T6(bF}om|NOPl|T2g&OrMnMkL*N|cHUonm$3%A+Tmb*3UGzWh71aKM=k5^L zXq&-o(*F$s*FJK;kPe>fkEIf(0gsb4h_OfhI-W#K%=!+&N7y{>P$y?~0A^QCABlk%dH=zH$8 zjHUq+6oAc!oh+fI7d|_>Os$}eWs|sh+#ZxlUa0?qgS-*WtJEGn69y8tQJ(_KLz|3? zT&C6;dGl6l3zILvl*Gp1;o%f%A4oBhsTd+5S&;tWqLV9mPbC)_kQ*j`IP#a8hOA)? z>E2R{{8FgZempl$EStE+*o=Lky9y^lOSy4}XOa_`IM&S%<=zwB;&u_m_T9sR*kZiU>3sI2^HTu3^t7c*|z8VC{ohh&dXRdP7? zMR^TY8QFA>zk{)jo=)9{N3%{CbHPkcJ!UFP}xxU&W>`%ER zPn{q1aHo`x!cj2SeL|0LXTx-`gk6xcQVQfWP-}QjY7@a_;ifc{TEPc{knadnb6%?) z)6>LZdFrNGjX1@x65jzU_(dGebyTjSv-~Hszk4ZDA(+PU$}4d^7_XMr2ykE4B{#m` zal(E~{5|wQyPQ``eZ@CdKBiCB+RNqW=N2*z7@h6&;30Q{^AKgX&0s%+4Z086S(<`u z#MCT^ImVZS)A*-ie=LWLutu;aeYKz~lmreN3zQCY6?T5~dC`ausd-u(E-Bazxf7a# zJmD!G&4)r7XcepoNrKNEL)>w>gj4W_Ydx1D)V7y{4YI3y%9_*E9o$5%6>I_O7A5%J zeeM?sxeI{?+H^a zV=)$CUS^%v9QQ!@fn3bY&w7c?q%*Y13JWhYG$`x#XpaLEnE65@#Bpe?ju+aX40VUl z1Nag)6|#IqtOh!ULXI}hI)NMZj|IJqtfW-huk&gNlwYc?EKonW$58dC*%>`GS==Gs z^;d?Nc3Umw=;^7aqI(Ti4_uZO5tJGSd!hvPTg4t33k)$*7d>^rl$cY(Cq6>j;Q2_* z6sm(ybdr!q58%VK{&;=!IXxfr*HVRv#xh@#7Uug~9c>ghXDREHFW?M!ob3(%F<)}4 zt*x9%FL%d@E#zL&YT!#!3%ZT+)^$L@@MH1cJaZ_FVWC2y`xD#BTxGF(_@_z0rPpUX%Tf`BW~m zUydayt{kKJ!~pk@%B70UNew8sRf zM6b31vDRwl&(LenXJ)4ISnjKK6e~AuuX^(f=|tpz-52t6-Sx(n23yesP14oN_{E4w zt_pzB6HEbZsoHs+sm?;#yfOMBy^rvyd*yJtQoxllosH(C*%oo%T?X()t z48+sjJ-Jb26aA-r22qg4QC-AlzQr|!8}v8M^5cy2Mo#7+Wg_T=KA-98e62_sq}B#M*<b|vGjRadO51_0RqK%zxp_t!wa}S?H+2<>AH-IYMrc#8p(#XwE=G4s3 z>SlO1XREsb+RZ%XRlPhuin^3|oQgroj=`P_bV;^6{zMG34r)n!IkHAhM4)e6jrh`7 zEzfdrAYmds(U|FrmAlF^$t$9H*QjN+$X%mm&^6%ip)heS80YUT4d)}_wZubOMY@eK znQ8@^u_L(!Vt+1@JBbWa-`Pg$S}>V3Q5d9^*)@Bo(}&IjQ`w2qMLjoo-?c0`i>j%d zBKx6}$@btG*{~%u1M^#H1G##>2(~XfJ-FPs0&g*G)Vb=s{0|8a&0%Q9-%ebrF1F8g zP7;B;g*iPT7Nmt@6%r*9dn!AarMgoYq2_RVnezNCqn=P+9m39Kr@)w`q=SGZ58 zB&5c~p(OTfMke#=4q`L&i;FYrAw%s|YotAlatPz;pVTM&UH4W*f;%T=a`(t7{C?~KNXvka*kWFdy(!Kh_Z<9zo`NG;*~VJ_6g!?eBS&&y#i6jS z=cBpMN@lLf=d(Ac$7mIA16XdXa?z;cNmn#+W^yCgL+KT#seE!4zCE$O(bi*2dP?w= zOJ}h*c=0C10@=a=#$snCeK8x$2Ut3^*S^eLCa;Nkqo#A?JqNfLpX%UTU8(Yc6`EJM z;<>3^QR?~^#w|7&v@3ZMSwmb?FwlCb))zVjBF%rLJ-PpYAJka)1+y1+R$50Ez}H$? z`kc{%dPX${Z9ufz17e^*7f-!t)B(M>8TkrZid`)arCKm?g`F6It)$gvE`qkoEb_d* z#g%GqSDwN1ws`C-m=Dh3>Ee6Xh0Z&2+lb+gH2X%}_6>2a2_p9*EMc2m@Sgh-J|86* ziju0aF6%#bjc2_xnUVQOsVDiJS*CSyp^S?WDI|06VQXTK?^sfMVWTHDG*F3AXK`)W z$%Ja>UWEXvz@Dw&s<`SEUTF~z6tDe&<8s$U0eg>9N`=~g4 zYjhMdHD}Qt@;M%x-%R-_v`9>qyP%EsWpxwp13$#_^g2&@bE-1X7|883stZj#b+K`z zBk4}j6}FNjqEy6n`fCn+;e12_FNWo-VuKdEeDmk z4=I%yU5sFs@Rz0O&FhB}xf)6*aGJ}JcELB;N}@0Q=I{z(7IByJO@;F3QFI?+h`Qjaje!DXx|Z zz&^6H_%}->OjbJbm$b@8L&URwNWFslvb%xa-~~A|xrnQ8))5NqJ(=UqFGSNOhhZ~- z8Vlf?!M-Cy}Q^kYd)-k_9j)l+l8Sf>HAFgqO*_A zMm)EgMrCh{yc5PKu*+I}7yc~@%g789Frr+6oQ(MxKF9QP6`!>38z=)%AVJ((Y6%;jPX zBQZq;T3@0T){bhxFHu`be2tNXO$eSxsb?YB78n0%RZIiVl=y*sZt zjUt#yL>g&C6=N1VJ83bf%$y`AYr~aTB01*wU=gfsku->06pUu$+%2uuMo)DMY8}@N zc$A5l6JDgEd?mbTM-d&{SjvB4%a|E@Iqrs8#839c$?bsU?1g9xB2U>5i4jN7lijQ& zQtz;N(3sXHkO!=>wTo;fe{H+>xu4JL6v@kpw?(H`%NZ_?_K zd#LOF#%2$)BqJ)_U6!89)TMcy#3ADYO=Db~1I~k!@FVO3`;+{YMkB`nDUmy5r?uja zpueNhOJ5^5fF+H_Oby>MSkn2)`%U_tPnDXwmlNrxjaucLjc5*=waqa}!(DH($|FKCCZMgZS~j0_yn)l%4g4@{F1EN}klL8pN0j%S zXaA)>$ji-#TI*20Zyfo9E;9Q#dManlvY!6XK@V^{iwmH`f_LK?ZIQSA~o5}DB8+l z>T2$>_p=fw=NS@R6Z|b#*Dd^=T#38q^>N|8BE$&ya%be5@;fQXHeQdUhVeDPQ>9X9 zv02uW5yOa8_@2zCPz82oWIIJ>K5KiK9#952+fXUtZK*t@8%lY`Z$diT1NnV#;}7Vw z!dtGpN{U5lxV+pq%jlg`P3ei=9cT40dKTCK>(cq^7kZKBPnEP5*cw=fI;SpD99%jG z!={A>GF1(ZPS4wGm9ln_*Il%{0*moila?r@tbRg67)E4kO{jeEtH9`jnI4#DHqz5< zOC=iaHf#X*jBwlYkLKr#zbFl`jVN1s7*cudj*{#<@ROWQ*LU%G`TAc>ad=PmI$joh zcSMT=LcL&i@>z2=qN?xUki`K|dM}fJ`N|hu6s-8U+9$*#E2o~t`X~gG#_r$_fEkW< z$yenK>O_BkA}g?j8Kx~#ir_h+q!J}L<(AG+sHwI=t0+E_-q2pqGx)L@r1~Z=4Tj~_ zg9Di|j?ON{vyCYYH7mthjrRT<1(BX(&EYD#xB4EDZ|N^Y50oV+V=ToJ@vcm|SjeU* zKk*)zu3+q7ZcJWTvA#K&NtO37^Et~HY!qk4YD?g1R|+EE`NgS`Fsoic9@Q2&eA}6_ z@@rub9bI(Lz2Rtewjb=C|BI(3M^Lr%@2mZl-(2;b;n*~(ntznD5~$&x;IbUmooKm)6L@Ca`$9g!qI)oc_oNdjISb&7Ib|- zJePYtySl%AlS=Ns>_PI(29KqMQ;d+>oWb6fJ#3NX9qqMaVKH^3&#p-wU=VR^V zvS3>aj}UgKL&0!8jVgn^RtlI^+(M^EJDEQX-;IbS#pybbQi_a)Vr#fSuRyz-f0rsI zA#=EAjWHP`rBuVKn5-&D{)OG^#iiLatG8eMCV)#gl#VhwxYh%*Z4Ka(!G#U zgvr_~trtBy^g+9dj0hh=^f`&>n#obZhs zYTY*Q#M%mHj3sUhG3vgkUchU+1P_C=#sk}Bqfoo$c!?Orc31|V=2}awGY><#AM+=vU0}BRwDGtwmafU{ zu~fKMiF5p^{Pf+iNx_xu1+FAW_wT|knGeKn$bP(CzYG7f&qm+PJw_d0B)1+-jB{^Ud#Q{&*H_cXe0h9NtXC<{wgU3p4H- z;oMQKKRQY7!&Z<;C++a_KeNZ--$U_SC*qnkG3F^r=a)bks$)nP=K(d0wPH;$NYsq3 z^fvk(A4U!bcgZ7M8CL?moUZJ<%DafFAd`>s)ycW7z4D&RxB=>k2N*(Z5MN;Ai%swg zIV(Xc-e$I!9}@ZSDnE&v0Bf*eh?=sQ>BjX69!G{zCwCK8CDzdjdfKjvO=1_*O{DH% zl+Etmuglg#sU$WW9AQeT@3=VgH{Hk{3V!9Buo$~Q{T+7l0pdv^TCnVbVVNxEXoF|A@Ms@CA@I}V8F-@)s3vFS- ze5;{5Rt>`&7(a>eY7g$c8s>dxU*Z!XdfO^FI3Omf*Y#zHjyf^!kQ-v{_3dC4++|+_ zdP`gNQ%Z61fLC+;>l#cJgLAEyv2y?$s^~pZa8u7$i_$-;iGdBqQ|@gPC#E~n9sP`z z&M|JM@ZL5lX`?bbv$2eep@eW{x7^e?%XC1q`IU5lJfF*%=cz({l<)%7H@aHk#uU#B z$gsaNZZ?w14XslWrFLptV`uJQsfpem+d)}u4*J&s{*)IAU9>AwjrhHw9deER%PcQC zM?8h2t);x%^GK|VNX*N$`T#ORg)0y!XQFN^uMWo(A@^U-imocT9%2TF7w6+G%oq4R zeH~F6*(9@#_A;vuE3BV56s2C%104aA-No}eUz*UE1hcH@gM-z1`~rKC%n7B9zY{*k zzOW_}UUd}SJ+u)-s#Do{HEZ#Q^=;sMVGE`idePU=R= z<;YRP>|2y9`l4IqhI)EX5%R)}>RexSwO0uSWv z7hpK;)-GA+d=+wkNL}F{;2+;WvpZg!>aC0spDDeK9aIooN?%~(eK!mD1dG|ud27%! z_`TQx<^)kpdTEA)$<##{D0mxuua=b7lMlqf@Vz`0Hc|HG`IVostGs_Ijr1hg+dB!r zXxd)!mos4Lp7V`N{YtNT-Ex84VD&FthE{XGX& zNqJ>PifJ$c8^LcQiwo|EG-A6^9{bHNQaz)$Qe(^pp%n8cwGQ444QASSzMxy+%)BTj zGIY;aYb2RFm32fna3{N@R?pFj!nCC@@z8hTSyFje1)MQbnMGztr8+eLhG`|OyUKqm zE!UB1tG_d^L4h6J1i89q3L{aUwK|}j%V%DcP8;=n?NnXq%0;T}=*Wcbip$vRNmh!{ zvnY%6QPqfmv4yN<-sUjPndp=wz*3M&?bm*&J&0+P2UbaDxJ$trSzGeE1ctEbMr(In zca^~EyaA!#d{>nNU{CB)rJH!%=oXA-9tU~`C-XU`5@@6WPemr(^Uz2E&5)a`C9|7} z^>i&5q_2;E83^;ILFf)DwWSPxg<8&-g3`H|XFb?tOb9G>9ss*xJxeiG>MuE+vA7Rn zj<7YZBWO9&DuVcVPkDmsK1z* z+CgFp%C$_d8?rCs((IC2~lAIu3@H1Rps zK8 zQ{e5RNqAZ5F5WM;thqNd%A&Q#-p*8A*d#ASxpL5B+@ljcSCQQ&49qI{8iK)M>Sel> zqr2jy8U}mm6R1YwA-#e-4`sWauw4R+i3q8dv=5!5%%+PMjHV8Q#`a8M3g{xPGvA}`QDydH?FydGXtvET z7x3%&B^d28sUAWLEHddQzD+#sErmKo^QkBz73-4Q2$VEBh7yvuQR|ty)CqVf=^B^m zd9P+jO=$*9Vv?m_xS8Bp^+I3~N_uYw52!5ih`~ud=lC4l$W!lzZZ-`kGChs_>;1Ql z6R;1HDCU!k{2jCcVJCHyD&^qC&xcc#Ch}%}T8po~G-PPl2xehBg#D~L>AG}QYz$mh zKXnBD(W*cW<-U`nD3hvU@97xiJR8}<{>8YZ%_Ba`AI(F7W^6aQACpNZyGDXEZHkf3 zW%+sse<6+~rx*587GkZywx*M538LJvz8jHH45JF!yD&%GM!hj|kp*a~=8$aS3q<{R z0zPveT|v4*;U4^~uM8}vc0@$5t~8|objx&4brwHVy@A;6Ny;wgUdk?<;kW0ncT0K; zs-!iLzb)M*pBaCP8&EQRnK8!rsx7nL>D%o%-j~T_DkL%7Vj73|>$UWByQl$|}lD?{m4d zgCMl;1^NVP=q2T)?nH1=AEiDw$BOr;BIy@4h1d+ryYtyD?033J=qx;+)DSo0J8-=f z6K&sK3T4&vd>$?m|H5>ONSbcBm4$pOUSm6IEsQ+QrHKK8UjxbnD-x}+=?N=PcKmsa zU{u8?P-~SRG4~iB^B?)q?ace2#~>1yH!zt#M4e^&`&)1vO8N-KD1AP5O1+Gq)AH1K zSVR0vY$Yex?Sbb>E9ed4{=7R{M&3zt0DKq~rotnt+FUtc3u^Usjz2B*2BWA#?FG5h zpG34buf*>ow!3ecC(uTrf}@h*(OW4GTs?Mlpv$k zZ|d+kQ+%(C5=Mi%YChft@4@r~&HPh?|Co)fUUFAvh?I}`&sod}SJyE+VHw$A&*-`^3Ji7?_hv#HG2V#5{(?)5p59}`PcesI zisv&8^wNxyoJa2l-;8%8EB4dBYJ^eGe@xXde^#}`O~^&}Cp%3aU_UA!0skf2X%iY+ zFJJ$l67yQE>~Vq*p)O22?HQU_GXBPiS!jpWJ98&jMd_UPA%6)|T|J-gGk!CRW9eW| z;((&H;JIIQz4Vl!UeZUEWmFsD=g}{)6TB5VinCcD&23M6*_R?+#>u+U=#LtHPMkmp7r(zP2D~ky#YmNPR ze}hN4d%0%n$qWt^G=iz(db ziviz60zD}oXqzzId&z<5#s|OUR`l0TY#p4;JSFyu!_c(m02QN@$1=cfM=#H3`UGh} zA^()8WA-ts9js7LLLUboh*@ZQv(fXz{JfxLD5c{-PH{-@oUeB8)e)5!t&T zQ+!tz>vES=D>i>4)M47YCm_T48R?B}Qfy;t4>n%utzXj*s0mgJ<{6{rCX#<^v+1um zl-zbvE`fWQ_u5JP592o8SG}n2GQSB!MALMtE0meUJ))F{ajjFfs2_ytJWF5ZZs66C zK_^eC4)1D{e3j@F=0~<2{2f2Z?s3KVPVt22Z!JEy5qOMur@qO9{OgmCD0fVn@+oaW zb@OktK&qCwHi&+0tR-4iC4srmJt_EVzeZ7^M#?a5tb^iiv2XY_=uLJKF%+h0HImj9 z+=R>W4D{YQ$PLZ<3XY>G@KAA)*$OO0E{+w7?k%a6kH4Ze^K8!FD>*HjS)DJ3Pe<#1>tX{Fbhm`IcU!Ps_O@rNH%S2J8h%?mAPOIs%V#)1Xeo!V*RA z)Y}QQsJhvaSWPv`T!zl`B*fi~IE{+54<&Ck?yXg%x^Xo_k3bbp;R)>p_bIZJA?l4n z3veFocoPuAu$8kBIHK3UZ2wo$nLoyG|9?Cmgv3!1L6D;PL`2S=-I>{W&%9@LLlJey z&6n0uwYY1ZAIE9Nq~*Q-+`XgZ0)AWp{CG~KH!qUQjWMbtZh~=tB1WkIv4tv zM-Hy$23hY4=in>4M(ZVQGzP*r;aikoG{GtQ{p#~&1*}IO8F1k}a(|=VDyFw`CLd2_ z(mPx;TuDlb|0B<2ccRi#ZSejMGlXfLReT-$AM3KvLx0~4+cun6Y6CTNlHZM~g$jb6kRYi+?}W+gK% zbSh=KKDG2h@yyh3v`v8olU(6^UogXz2ZDthHZQf>H-vUMKg1`+%fc<|FTI6v8SXVU zkVLj_D6VQMA>3UH9#9gVEp;Zi-_@sC4gT-xzXRDw4qhQIX_SeI()qQ{ zPGE*7f$%D(Znr6WDx*`o0j`;}mj^x$K}e~8!J4fw1xjK8hK z`)4+8(ka zM*L;&t)v3&DS1JDCL7!_)u)6{aGXBNwYz3r#`}Vdu7W>EQpX2igRvm6nM-zFg(`i4 z>TNjm&&l07QJpA^);4R?gys4kgXuF`<@GwW5goU$tp19mundI{xoMsdCRXR8pUX!; zfprp&DxHL{(GC0(qGJeJE72zG=0ciZ86tdB^*0;x4l#$+WZ7D)BbhW(FQVE`*@{n4 zAA**l+pv{q%n^a_sFX-P@UH8ktEKg)wVOH4VrFg{1N4+SE>}9(>vYJ5hXiLZ-MA{Q zgxh2|+1vZUisKII;8>V+5 zlWMLMq3kF7zE?~Ew~D@^_9fqir)FYkGW*z51QP;X%E>E;385mmD80WkTCLMR2YDxN zIrEIuFpJWRb?PiUUeVb!{l5E}wSjbo&EPRTlCP)dxTTbVCu4$z=3USfjb*wCBdWul z4XM+dhxAf|P=3^eKrTEgEcdU|T%mi+94Lm4nh%Ws&4&r0t#E80t5y_h^b|7VEHOs< zG~P=O!b|uQV5a9UW-}~MTNtmBmQ|dkCz$PVOiiJ*-I1srChWF|6Z1rt|tF+prYFk8*m9tf2RJ40A_l!tGP^)IPiRvJQmVV1`NDQ{aF#>V(1H){kSAog}v9{-jNAkN-g0 zJadpys$G}9lu~>Pg&gbS)Ej1Ve1YHO-{)9re2i}3SiFIQ#x}4(x@5iWo#_0Qs7E(f zHx0d$#BetVr{)gzl3A$r!;!A-WOpOFICpln6zCERM~!?U*9ON3%f(Zf4YvG>q^hp! zK|=dJV&&7twnDTXJ#+xrPMG5eK*?}2H?=Xo-X}J?Y^09!;eBB%AE94WGx!5kEa|%? zG&g4sncLRa;64SXwY z4F07I#bdO^eCrwpKa}cxgSj|T7rCgY{4_UN`q@#gb+=kEuap@H+|YZHYC!KW_v9f4mI+{`Ti+#3w+*~B%JO&2TWkkfNglS_$m7V z*-d8SM|zGK=N=1&@xkhkGFGDXW;WBSLRFurO+9NviEtKbiLNT5e4QEyY{5p*J~hd? zy0Vite@~3qS}gQLq51d$(fe0`X!Mx=e=W~U6xOqu!DM*1@}Yc&tE2tezpe8^tY;UU zM7)33si~DQN#m4H)tlts)gX7WuAN*}Nn!uLa`Dr-%KotIoma+81Jv;42pD*O@RA?|iU2u;+o zHMeU(d9~RcBw|0PG`v(Habj)Em=3m!ZE5T9XXYD_ZQiC*)uWD^Fvf)P(>?=hCX2s; zFzGIOp={w!h>dB(<%qNs;A?MR_MC5=HeEO%4CMCHTa0O-f@=-qV6r()pR6s%7XE=g zuCJB%DZ{;`%s)=j#x`hE_G$yjFON4r4^*-}v=VTWNj84K=M+0pZcPFweIJU?G+nlo zT5!#tfJZwY8++v1a?D0M3(Sei2|lr`N$n)G5>B@|poz})#ziHEy+FI@DfWx%3!jLc)4Xuu4`|7`Pv4Lmhi?w)bwfx9=j@`>Wlluv=au=?H?m(>hjyRUV z!gr}9@&iwqTJ9WV^i`ak0-qFm_zdGAUnn*Tz4XJvwTieC%cu%9oS7u%yFd1CIr2!j z4;zCktSD}#*<1SAY+ln5IGXL7?Rm}fL|F%hyQ{2#5`-2zE8JX-S{Qq ze%R64L1(LLQrjx$&3DjGI%}J&MwwHUgGP98xOb(P>u3TXde0n%dV(x}u{xYzi-&KOUv@xWTc zhQtCc(wF6wdwY%M+9<*+=&9dfqa1T9|BzL1tm>3hpf)4QdbbJn9@=S7wA7hQ-HJI& zJFdNhuP}M&PcPEOIKK8rDi1BT`keHY`OO U8z&o5ei1(7D63XHW4sCd7hEM35&!@I literal 0 HcmV?d00001 diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 910a57bd0..c2222c5cc 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -799,6 +799,10 @@ def _create_GroupQueryAttention_test_model_validate_PA(domain='ai.onnx.contrib') 'key', onnx_proto.TensorProto.FLOAT16, [5,34,512]) value = helper.make_tensor_value_info( 'value', onnx_proto.TensorProto.FLOAT16, [5,34,512]) +# past_key = helper.make_tensor_value_info( +# 'past_key', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) +# past_value = helper.make_tensor_value_info( +# 'past_value', onnx_proto.TensorProto.FLOAT16, [5,32,128,16]) seqlens_k = helper.make_tensor_value_info( 'seqlens_k', onnx_proto.TensorProto.INT32, [5]) total_seqlen = helper.make_tensor_value_info( @@ -813,15 +817,39 @@ def _create_GroupQueryAttention_test_model_validate_PA(domain='ai.onnx.contrib') return model def test_cuda_GroupQueryAttention_validate_PagedAttention(self): + so = _ort.SessionOptions() + so.register_custom_ops_library(_get_library_path()) + onnx_model = self._create_GroupQueryAttention_test_model_validate_PA() + sess = _ort.InferenceSession(onnx_model.SerializeToString(), + so, + providers=['CUDAExecutionProvider']) query = np.load('query.npy') key = np.load('key.npy') value = np.load('value.npy') query_batch = np.random.randn(5, 34, 512).astype(np.float16) key_batch = np.random.randn(5, 34, 512).astype(np.float16) value_batch = np.random.randn(5, 34, 512).astype(np.float16) - #query_batch[0, 0:] +# query_batch[0, 0:5] = query[0:5] # TODO(leca): Need padding for the rest? +# query_batch[1, 0:12] = query[5:17] +# query_batch[2, 0:16] = query[17:33] +# query_batch[3, 0:20] = query[33:53] +# query_batch[4, 0:34] = query[53:87] +# key_batch[0, 0:5] = key[0:5] +# key_batch[1, 0:12] = key[5:17] +# key_batch[2, 0:16] = key[17:33] +# key_batch[3, 0:20] = key[33:53] +# key_batch[4, 0:34] = key[53:87] +# value_batch[0, 0:5] = value[0:5] +# value_batch[1, 0:12] = value[5:17] +# value_batch[2, 0:16] = value[17:33] +# value_batch[3, 0:20] = value[33:53] +# value_batch[4, 0:34] = value[53:87] seqlens_k = np.array([5, 12, 16, 20, 34]).astype(np.int32) total_seqlen = np.array([87]).astype(np.int32) + y = sess.run(None, {'query':query_batch, 'key':key_batch, 'value':value_batch, 'seqlens_k':seqlens_k, 'total_seqlen':total_seqlen}) + print('y=') + print(y) + if __name__ == "__main__": unittest.main() diff --git a/test/cuda/value.npy b/test/cuda/value.npy new file mode 100644 index 0000000000000000000000000000000000000000..ffedf1bc23b32c4366ec17fc9d2c47f08f33979d GIT binary patch literal 89216 zcmbT7<$v79*Tuuk%!wU4W-kdw(nwl6%*@Qp%*=5Znlv0IW-kdw(nz}wbJ9SA4MUrT z+Hgb9{{Dq$->vjwSNhDId+s^k+0vnV`%b+J7n)IMZp^UZQ-@B8VH?H35zd&zMlmBM zPMJ1k@PvUArwkkZzw0dqkDEF?e0}Pu!IOrEzsCYPu@UKT#w9kI-{}AMDZ{f(+1Hc> zX7Sze6)I>RL8Fwq^lW-IFwq@fMO+WKsbX-n+EIC!wOU&!ounS;W`M4F1F)>DOm1Az zoL#7&A}64)lIRMc-EcoMK6IF`Zn>GN@(yzvb|aVmUc15Npi}-As3%h^a0Tt;7n&)? zG}Mt8ro;+%{7SlvtNG2``FuohBy&J}1;2qe-bh@EC_z=zFK+x{Oj$X zk`nl4Vz%@G$8%ZCYrb1F-^o{P1teRN<38g{p(P z>$&Du^gaJU+%8xr{x+P)&9^^evW$<;>dGW}bnv%e%is=QdD142=KBX@;5Du$GX}v> zJF5&@0_%~d_^ysCxL;tqI)Ebro#|)Hyx?JbGqD7!n-r(?_e(k}u$+ZnF`wl7!q=2$ zzP{=?X&?NguH*NFM(2&v4#Dw89_=$83d`UU-N=3)Y7)w%rxWc%BMz12Yp`qBC$`nj z4a(Wr4%!}~0@%e6XwuO-WHUVioWfNCMaV*dW=xQ*BTR}Oo&O8o>b;)ZjQ_5dEVYmChxKm#mLYSVS3X3*q#~3iM4|1IH;btEkAscTY~P>CAA{k62U5`jVO8UTTB(LYxVJhoql^ph^t!@sd)hK9{H#WX3 zyH{GnJ!eXj5qK%W`e%84XsGL7wX%=Eqw!Mfr28WL6_s}{ml~NDK`XrpcU0?|TiY0m z=7*A*i;f?BcdAfm6)vm4Q!~Y~+&LhVf#hwO#5L#Q%~8=M!}NKzQMLo%KL)b z!jrk5nGyU;saWXWpy2z$jL+8Zy#$hJcs72i}Tj6J>pJp)4 zBI9>_*jyi9T&jatbH5P1K%VB&u8G58AL9>hzCPaQAuKljfvbaqLbK?z?p=|*L^ER@ zhxk6Qnldn}DXLG6a@5s!JHKVLgI)MuRO16%_?^Z-+KdK``S&nGt*y)ze!AXSCH*a^ z63i;CiaZUTF&7wK&r9JDn&kY#UEmrj?Zx`Pq%lL_NFxS!)vJe!*sqbZBirR4)|Kdh zr#ruh`9V3r2y!;8#1=8W#)-O9vOChe)d3H4jMtm;v?9U9crs#eyI>X60u=%G*dn}? zRY|YL&W;|Vx6+2t#{#+5g$AAMo4LQVNgz&{ha!a|xGC|;56#O#C*PL`Txq>2 zSvz=#c1tbkh`e6Vfy+QKbdi$Dea4|cDkIY`JIj~n6g>JM)! z?oPad(3Xj-GEja|A@u#AaoRC$7Jyfny0^!O3L4<@`q^7sx(xK^!nDI~Ln zsVZ@8;XWTj4`$ohhA6YVkZXlJ{z}20#J75L^bA7^4H7tZwft0X>HN(67o8?GHe>9Y zP)~1b{c50rzFP03o$#&3rLu>>F6JO6Q%DMyMN6r@$<6K8)ne))>$9_gSO>3zvvC=A zI*MR`S`;tkEmuc*FxxS=5wX+0#r(@&D)b7Jz_paQX6L{rR^lk;keLHlbJNXkzMz$e ze`K%WBGgxWYoR;2l+WP1giGvCUpek9+OLe}rs1aiSviK8E01KV<4fwe;214j`=Bc< zWNsrz;J#Kl{T&Ipk5qxuQUvBgTumviuW(FKcV|oBy#jOJ>Z5T%Xdhsd3g%jAlQ=48 za8e!@Z*CV~!D`$v;!bqY!z0<}auszD+#p9PLsSUvq1w!tZw%Y^sn zitLePm0ft#3T8!cC+OxF&+(M6H5!;gzmCb$s0`8ZzoGKMvJYF3`2%a-~jB$R{+q>6MUka+=MS%szG;SDvLzl-1 z+-z`;>6P@4Zzb-VW8C9(v#pwoOKl24& zKlI#6$y*r8hkdQ#-0S?k)(HJJtjABlqxBznx4zX(^+&=6+#qyCdSX8F_2>3+KaIsv z3-z+{wd(QioxbsOc77qp{XklH)xQyXqCIpJ9Hae;ZCa%NJh*F=H)f;C$%mXjmBjeg z>U5zk>Zz}1qx2k{1V*zjofpUj&RPog_tyr7x;N=8&Jd;t@4NrRzc4@fpHiOQIZ%h+ z=4#dWgml?)!1zUYPW&N9g`athFwj{YRpIARYuW3vGjLUE2706y!Oi4~Vs*}$mCGiW zAEZP+S7mJJh+w{i7$nLAXrumWH3cL$3%B>a2sRCsq35Yz%--Gy zpto<0cV1{+j>G*)lXMw9H7c_gjRkKN3j!OR|Y5W($Nm5im9SSF+ zN_L|{dcj5hd{SZaLjFdKlgHrK_zN7UFY{8mlV`lixB#p_fT=C&1G=p{8E#Uq%hv+u zss`U3+I+0^%6Sxig#DR#XB!;P?ju{U%cKux4p^(T47PJ^Bid^_#D+XDZ&1-< zH*s(9cYY<)6!*ygAjc|2m>)sE*@NDnJq*d{A{eCQTi;SdW1)Urcd;j2iJ*pa7X$J3 z{E(i*d=Xy2lX4ONILY7}FwH@&w+aO(%SFw_^gZt`^?L9<-NF3IH?ydezi=V> zsPU8txOynn{b^cL{tMsMWfAR<9YDOJYGw4=HM4ZM9N%L zSA#m^uE>UTXKo(V%=3mQq7hne&N7dILFl@*tdZNBZ%#93__v`$@&Vikj^wWUeA;7H z=gUYt&@^GJc}(rjPovj5=Y|&MT<3noZHR9G&ze8s-|&?%1(Y`zv%72G*GteJL6lXz z;4_($Gf~S#mE2DueYg)A<7@6LOH3jKW~&ixzR`0*G52MuqngJLvsK0w=zLe6Q9d%s zTI2NMCZS_7Z}J#NeL`p7%e%mJIL9@VIq07eY5*%LCBRg6CA*(;!bI)6Cst?>ervax zD$Gvxs&&)b9j0^5=~lQXwc20NvY5z&1VxSAma#pnhhzlXTIsV5HPiEll z>QZyBQ9lpa|MYz_H|j$|8hV#B&9#S#;0H3P!S)HacwOyk!r%aNl{1f>g<@cJX|ggt zy`@-MIfthUwNSKD#VjYt{{4G9B;S(0@G8@bpXnHD&xOFAraMC_bH+Y~nT5ymDP|jE ztEsT#<)PMldNeAZH{KeQJ?5w=veF(SOf?%@06eN=|JlhF|4fRnL0xx1l(a%ZVs@HLu> zw`k8)f35_}Qg-2{)Sv9Bm(vE*!*MVEyX$+PlDST(ratj)vHu!q<4I?_haSTk-VN|7 zSxkN#oQhvYb>~xHP0Xn!LsOE7*llnDY~b7!+_OK~aBJ`U%Ttys>u%h_z5 zna;Wy8l!dN7bF50kV}Nl>I?Z>+;6ZK1KBfF9U+eSi|l4S_oRq}3x?7aeA_b<$f@9{ zt3%+9`KsVIZ3>r<{$R3*M&dr0p#Rhl&^=He)P!lSu0%@?R|2FvsP>0@)+BCH(OUgXU1ymoFp>-LIl(fWT7DgBPFqrOLlT?Lz9e0lWQ{=4;#uxTVO2~W4 zobN9IL~0T9S=i+0!HhfDgZs)=L91|Wyg2xjA4pCR`{e!x+k^VnW1$=xBTpq$VJbNl ze#JGNZ;gHKS=v(8qjuy~x~g}w$f1s*99K#36F8-4^TE#BY%V6j2+p=D-aIN))SjRK zKU%JqFoxg3zNLr3vrQ_1+5As%$I&-N$R}H1{ZlpGILU+cwOBUNKwf69h4WBUuo;Qk^my!z~ zULCX5`Jd83yUmWnBaKIi0$)G4j3A9%KTX4IwUWJv?JMbpB7F;y0 zgUSAOne7S^b6%*^94~@@N^Qgo%2MMO-%3d#;`x>AGCP19`4c#ejnM$uHhTSz7H$)Ua&9NBdB^dyarA;Nw+4VWImnlD#b82#06RoW)drKSE6@Q zGxV+g5B7S*+ZdKvop>8=3f$7R2Q%Oa=X5ZGFF_Nac#4lV`6q!+%)a6T^mg}fG4Nv|oDNEIInK~Gk9N(JQI^z}w5Esrk;oEY<#)kAn4UT=2P8w}Jmv<$uo z>%@IJwB%KNOe`3!?|_ zW{S)f!G9}rxDae=m9jQCtLv(E0kx|C*gGSbE(N)}=_VO##M5;-8&MJcwE8n8bc5>) zXEeFM6gpf*x@3O^)7WZens(Cthk75qXGRBR`4{PLQDtXsmVsl;iKs*4Zb0NdX)UBn z{1wg3)w<}a$)VCqjcZ~X=gN59bfP=yJ_N1m>Aj!KLXM-*nS$U zWu9?8p(yqBP31?xnyMkq%E*EFL>~o)x@+Ch7!zyB-VgYRIRTx78T2?XKXhHKN^b&7 znSRt_)LY=0S^B6XjjL~TC>6TtUzK%cynhDvgpy(J z#PeVeaVD^aycUo6^Nu-GHKUkZS5@>^fF?JQ>*HJVK~>aRGhAXtpC4bf*2}}J4rUHl zgWnT&v6hFs{Q*xGe|+d~;W7?^c6bQWRL`Pn@>yakXdkSlH3kb%Oz;QZ<2s~uJai?+oID{$qcOXwDS?zD$S!{IQ)xCbbyoSe;Z0nIYRx^E}`RW zBypU`kgkC4${yo{cN4DTt&2|KkwP=EYe93ZWJ7T7ow^2R5#36sz@{ z{#a{Gc0`}xS$zu~=U5f$VZJtwf~zj&&{}qrySV<;)jzAak?QEn6k*GO>q3Ppz4fAC9C4%)%&kSBtZx|2{rUzZzvW@?Cg?1S~2 zOsu)oGdYw(x8*mv=Yb;bN#;pR675_g!9@K+z{AZ1$$S7-u~nofc18X*FwZJRF5wTs zNK{2~a_994)ROS;G(&{zGnlR}(=DTkp5$H*B2!3fjPVNn?F|zg~tChhB zg{oFTr>MV3!}wWZRjmnENAueDYv;BdJGh!BK}j%FtSROTmyn^7da@P?69Nac z?36TqRAvcx6hECvV7}s>v31S6M60Ze5ZN1|E^!gUT=6%soaAs>(Q77hYW9fipkU{x z)75cN?k62E&f)duH)5cK`L#M5+?L-#?QQl%L#4NX%bkSk$loZW9SZhCi{KqJOz9$h zRvm#{br|wkq^oeSkQlhDtTjW#B17ndhM2uj{ zuo87Mt1ngG-iZ2BSi&`8CZZwgeC`^oCG{sJ$n)`AJPV)3*IOO&! zu#;YcE=vzE-=Gz^NwAE2EUJLY96I6ZtMwM-=qtL{z1{Xks7%+Sms+wpS^1mogo+2P zP)cA4{ts`l?xR%n3j7DZ(I3>^mczX4c%di3MZ`6431=gCjL%6d%zKqYyi_eARR=Yt zY5G;|Z^#=mLFrNUQTU!viSe1lku(+fyJSb8@_a)6O8QD%lvNVdqDJG6N2i2VIzV`; z2w6R2cbnJMY5sF+4PQW<4GsiX=@PI>+jK%%6YiZG#dgdXVgUOJZ6Ujd-H7eDk9I(@ zk*lQZ1s&|oUEk$=?Z5cZNrj;k+qi!IrKl4V!#ro3g*#U5+U*iwxC^W4U?#sANAUOL zsFYu5jXvQ0BxLJ4%5s`oLkRpdXGObQZBJ+=a(5)AD*6 zTi7ke6;Qz#eDrMo&G6*h(c0_|KLtZF&8OYHN)&i=1zekh%IX%1&pIp8nHi(DE0t6_IE z<5wi!V4pYYz*Q>vo%q8lt&B|hl>5<$(r$wg&dPqlytH0|D1Ha{D{p`@+Sq1AqOxA# zcwr2&7o~zk(BGV&DE-6SLrX&s<0_F~@GL1uZX9B81#%?0F}E}Qgh+~wNqR#%g$hPF zd{Wt`I=RwJBTqTyGgzM2n%~YliTCVnCLJz!Pv>@PRkTDaHTj`&!kNaz81oZv>xBNA z=Z%imci5XI`59IZ&CSHpjim;TETUMjnfpERfba4OVH_XjKE*}yL|h$y0X*rNi^`eD z_;PGUbXoUlI87X)FJ?y_DXE{sPo$cm0-ug&$sO?Lz)fox|5D$LFCl`xaA=~rI;EU6 znOoj)op>T{2{k)eH%LG-}Bv{25&Ht|7Gm|82oM$U( zb$JKu8eN(X>N^8d>;nA{xht`Zvpej~%l=2iG}s?~51a@meED#vI+0EF0HGh(SX}1h z@NA|SiY$oKzG+niU$X4&ZI*=)zBh*$8N?YsO4;O^4Nu?+uo~PeoS@IRNYn&RB8v-E z`E&YV>APJ^zU=8^C{l8GkK7CANdHCMU^;8HSR2vX)iP9o@`59cQa4 z?H*-x(~1^Yio@ChGLPm^#cXGoyZI~AXPA$2Z}7{Vm*9SGG5${)!Y@*n`u4+p>I6^| zi0nIICfgTGQ>$|$;V;2Hw&v_lq6t?lbal_vLuHv={(3y=*n<;==0Q_9CLA@^Wm)ib z^6+{Y`VOrTH_{lO<%RZ1HL70+m`y>UpLUjO$k2op!rE#q%_TbmblH_VQwLNgnfI)TUklA zF2OTsCbL88WW3Nayan_X{}`@24iH$4JiGyo5BuNenem7(c!O>+Q?bb&Nc-#8@XrT2Q_bam|;f2*YWhhZXKgYzI zSNO)ns=!IB2|7V6)rVSM{xt5)Zz*^g{32|GGmIpq9kZOD%T3_6hg0tPrcqGRnyRO} zCxWiRP34w$$wUD{?^M=pOq!12fyN45O;B9K3U>z zrOAP~ztkm5;EKq1ohwmN@K0N7*wg-wGV$x&O39^Mi;R4O1FQVY`DOM)xI}O>lcrt) zFMKh;?%#~l$S8lu#G=M}+fAi)Tz_GQGT!r0uUK%!ENxyfS0>bSKMkv&$7nd{U=nmg zt3LB8*ob`|*Ef%aJUvJ3D8=HtmdUKL4`P1!%Q;(vkJ5kIa~x3qhC^_H=MYRuxgK1? z&et!pF}9+pS6~#L=$(ZV?2~+r^iAeq_|e$z{V2Ry-%f8gElhd&D5-9H-c1Qw_)bm7-{#+j|1XX9K{Cc83ewRs)xKn>6+_hz!J|47mV zTXQm6pRSA*AF}6@FQ_-Qmf8}&rCvbK7t5G`aJz^$sIMGLkLAvU9-EPPX_W5hg6<^5 z(Khgp`A^_kAWE4C+gVLhW_l)pKlyjuXypQT%e_z>WG+EdyoF#Xv!Ayuyyzb;+&6N} z8}bbFB-C0s?*Hr^5tzUgX!D3JTu)bM~YF1a+j!H8BGaq_8yC^_f z$JGvgxBkMrQdtMn_1z5QlTFJ^=4T5V$ujq336Eu1`Hi_$zHy2d@@|K* z8V%F-W_JNlZth=%S12>s^Wo~OYW|K!Yf<`Y%cploW97HRNo^HXCb^&jNjv&cYM!J^N9=#=jRdtFGBO zG{!zst&cN79p$Fk05#=in$0C46!4bZvkQ*V_i^_ahFWU7?f*ULtlUm&LiS>+!c~R@ zcL$r;PA3|k4N3>x+<7ppL~g?zGOUj^i5emD!7|Lk!yJynQ^f;J8Z#bT;-})G4lbrP z-v*B0S@NAWkTJksR0R;^x$MsPym`-TxRCuf~?+rn)TSR zGVdo|#{vV9oj;?S8L{qc2}XYL0vSL zei+<<9vN-*8usC~r`#XvRJ9>JC`_xB5$`MWjFnanu1sd?#c^pjfv!ri(m)g~Fz}M| zpji?2FUZq}Xl`zu_qkGDPg6@e=UE!Qq;;^$3e(we?uI{FTS3LBJ-iuppSnm4^_|j_ zxm&T>DG>>(JXe&Np6pMdseU)HmMDg2@Y_~xx|4WTU#w?YrI~WdaQ7TK(#M)#^;Tpd zTwgzrmlq6!Q^}vaob-&JKu7uKqhF+X@S*abRuoqDj?}+s9h_a{rokqt3hiOz^>jzF z)tIiBKMpQ|T^xJ0=Alho6~|aD8WQ?iYxb{gn02YaD0HvFM3rOIq@U-jDl*p zyc1lc*jso;jSM@-L(GM~+}ODAjK3arCr7X?xF^Ic+&FHo{qJ-`@EKFh(L$`a_; z^gelBE8-qPD^_MkC1Ul#_@=+1`|=@yBU<2Xcxc=LOAkQ4z4$lnBo(b9{ea?3mC1k~ zO~um8X<;sGuDY14=Dix8cE<}VT^~YUVK;INYmo=_a{P35aPls_2Kx**6<6UiWY6G= zyp!r}euq9+o{JCZ(7j!%<8%h<6x8J>T8VIpdXH+iYJ}}VPpZa9x@*j-9 zx#!9tyq7OSJzyr|6G|NalUkV5%ap7`<_BBz(3R}S;5H?b*N7(CJ1!=$6z-*$DBWGJ za9ciAaa;NLA?^LPrAeBr7QdyW_kW93w`G)^;%n-_vdg4JZe z+IYJ}+0X5MW;Vq!7=;W_HRK7Em&yiDCElm+DZNaWbWwPMw+jCiJmzKNF+YN?4Lb4J z<^wh=@Q7F{KT2+=&j>DO9vQJ{X2QaRU4kw2j1{OM{9JZ5-GXVaj}vNhCxZzo1>j$f z539&bZ4~i^4dzVcOXd7Q9EsYNk)@8pp^OwQI?PBfqAC-E$Td_g_>fOv0*!;9a;Ry- zGB1&J7bj@r9F)*enyr08gRQ?&E8>KD+}J_hl&-t5_S4gy3!r1lEV-G+1*gk{lLjjp zW(xBs{v59VUdba8*UKb3pP9sYiPq{Ib`id8I<#$irC1`^N2w?JxKTcgOQ;$44s`V&?o(bz5o`$&S83-j`kVL@}6guRz>3=S}YBO zPk{+HqXM5d;v;bYACk)pQ<%G1z2F|SUHT?X*;59jLQGt=Pb8OEulYZ;o7Pk3P!LYg z;q^dC_-}Z|E2Z{DquztrInVS|m`gCce7Ig_U@DP)0}({bdZ$2XU{#^@cMvC)$&9 zQu~HBilrd}8aoqiWR63(^O(g2HuF31Hfpa2kz=clcXtZ-^n5&wO*U@A<@A^QNB&1t zA+4LXf@zIT=pP7Ht0w#g4Sbu6hTEi<gEqMxp?I?zDl2WpHPBr+&bmVVjx&vU2|D=B29+?GXG~yx zsGnSTw*&RjN;y9f5t-dIw|0%F=`^`b>MStVoDM>;8h;uM^CR;kJyY8v-9}q-n`@(u zbJBEnq0m$iT)XNfV!B1l)%GxaQ&7Jt*4X#qbnX z85O231BcKF<+80DJ*0tw9qa+vBK(?Lo0qkgiKnD9)J3NUMTGt5I0*z*!fyWM?d<5J z_Tw|vI;_X^P%Fc^6juB4tlG*fWDUoSsWsff`gKz}QLF5&y_@0d&?jn_t$~)}8J;)8 zd`-sW?+Zk;GBiAQDLWF3aRC$Z>>1A?_KUFGb{;emhM9Fy6p=1pGrp>$+#j^fY9h$S zSEYfTZ&(bT;R_}8kSUWvUr`(Sok%jP*izZ2jp|;nQY^fRmuu&g9oi=16LBJ?v!1QK z4s;eanQ?48eNC8Efz%>!UN2SP&*|kTEtjd(I#fWm3gw8T5Fb9blV-O!S_v9?<|GaZ zbqC+%%IZQn9j#&4`1dh2kQd*DRZK*QL=ECVAarX!Vf$`pY5pa?hWp0TTI8^06$avYN;?n-xc=z zN1=htNAQVXBsPxpFj2UMx`J2?K6Aymz0rL`cU>in|KjcH3|u6C02ps|cPsEXahj;% zugR?PKFp|)xEr?>FTj`Nio{;>X*^mvEL7LCUf0CAX3HTx3lNt>3F|329tTP7e7;!P5Fc{ZA?0qSYbj31Xy{(HDdqy@r=E?4EqCA@!@- zi(SAzz)O@xYz4I5nC0GO{H=k*lkt(1BI@Uahx{hIN6VLo9P;J=4esil#P*CWsS>)$ z9C7snKeapBSsrFxJCdYdHTW>wT^b`cB-|&zLx+BzeJ2Wdwc4`A1;sfg&PmQ^6!W@3<&;-I!kHH5q`aMzH~yHt+tAL>nLWQ zqU7keL+zLe;%jFr+Y=t(F0k#D-;{jmpgOhQV@DM<9d<`A^e^-iwX%JNtEzcHtSVOv z{m%BI8o~~gO+IK&RO^|)!5Y@TNi4HKJ5A-I=GH`a2dCS3sO9P>n8(C`luG`9NvAaR z$Jy&EDokwBEXg{OncF$nZ-2-94laS2ImgXQ`YEnhY9k^&FONM(+MFBEbGSxm0>fv)E{;EN{+L=BcxaaSz_HwRs#U?)D zCb6;RD*a8Eo1yi4%0E<-?DvM8;fV9E#7p0qPG*{F$CQw^&1lMVd+|ZgG z)x~@fYA?*AM{qkOl51ym2vriD;IsF-xeHG>_n@t2T-Kf4DN=MncmG#=c~lwg2yWwF zYGd3@A~D^?81MQ6-lZEk?%;}6i0P!KTWMC3k+b`c++vPW>}RpCJ1tC=?qNHLH^_?a z*k88Gg+d~^QA>1v2Y-eF@X!3Cc%D_o*&w);&P_footJu{+j{K%+y^YAM#H)ymI- z|B|~kN0Z*pI>d}{y7rUb8QPIJwjhx_Z=3@edX)JaU+aJVf_x<^=HMttZ~cXIoWMf1 z+K=jG{nV;aD-_J;f|>eqW>EM9Uz%$&UlHsu?~xw$FD40>cl5$f#AezHgz!bV5OYJ5 z`7`V+@&(>v{2SJ(d*i!my1k^2ME&6M(0{rDZf1vopgu;uQwnLMTP4>J1BjVyNIXs~ zfjmNdSz?a3n<@%2L7I3`?+tve`VHgA0caDsg|!>Y1FM1GdO$XmBE$#iv}ULw`FCk6 zXh^;dp33x7@0=f{o*3eD@_KD1ULGvZ1oDg6M~D$%2NMTlJx%4pRD=g>t3-K= zv&+daiy7mxv7oNP3GeJJ>iy;+*|}kYd#!Ii)zhId_&sQh$2h2J|vhxfpk)TBo=Qe#V+F$M7=K9kylL#?;Cl6P^azF^^JIs~dPHI`J-a5Y#8~ z+3CXYFn8PrH1Q&;nx;S(%BJ}6KJ$ewi5efi8>~dXN3CP$^J8!$@6UjnX_xgxi*&wA zO~+HEuHsm@SYK!P@@FRvWZxSP5~@jEjIG&gwGU>QtPWZnTQ{6TZ%*{n_vmb9rbMTj zNwYDKztFXqXycGKmA}e$4xR{m5(mvu!ned>Flqa*&O>r1?I>tKofLj0`Uh&M4dXwE z2jE0{9zT-{YE|9i#cZw{8)5t_F2;HI8A$b=mY#&No$IAndn2`)R2~{8{U@ypWNTaK z5h%jehJ6~Cub*(`82|8TbOv9YJVEi+2y>TN%R7_*kp9JaLn~$$bL1L3+%3$Cs7&a* zl*>o+Lml5;MZvPP;(^7WN#0I$D)gtgh1v-=(M}=m=q<&`+4TKdlSr3W9U`NRKvqy9)A%$3&~6OTftg$K;$K=A-= zOt7bg&pM2jWBtdPobe7v1H|D+Zh*gF8`O?y7|xsX*s^pMR%t59?#nPw(*t!`A05u$ zhk#MhjhV*ylHZwGiRlXd)f1GZlmlPT`k2e4&q-A+FD}A12o=Re^$U96l-ZP1iHYiD z4%GWwgPeC|%ALu4P(KF7S+SueLDNsu8=7tcn#gzn)VASu z=#jgxbcT4~Izg61?Udr)uVF59E3-Yt!xRx5LXOkxKW#Yiyx=6Zt-CY)U|2Hg-HbPz zweSR{C$5TD@HI5IR4s}2t?5I-=7A#z3-m&)u54ER z(Mf6$ohdI*KE*DTQp6-QKk6lWmMDcEIYZzmQz9qc>dU&zRIsDD2tZl)V zsi>vVmmR&def56vF2I8Ky=nTRj7)kf1rBv(Jp3arhOBI+k+a-@Xa-u+83#=0XmrpA&B22|ghy!Ka;y47Ny|-<7#yW%7dhj6DVX zWP84(v6iUrbCGtVDmo2jtA~j#t`w)lO}3TDzb=(kFYx84a%d^~W--AXfuGu{5Ah@_ zGu3*&Z^~=O&Ir$TRaff2+ZKS3yP!NPrvrJx=@To3C(DG58hUT!gX7ga*_DS zw&rPmCTk`XHAs@MzV{7gL;a)YIai$gVHRR*tDJ2_Uddo8R}dF*{lgc=%P7;cf`i6U zWH;A|QE-ZWI?}{l( z6Pu20?(1f9WNs9eU6GzS{j_EDbZ$jP5GO`I)KBB>k&E&u@diHwT+AfA8R1PzcW$M# zr1r%a?TM4FaRU^?K-6h}L45_iN@C$q`6{=DTWxw-K~IX^$Gz3oa#;+?3QRqvV4$x* zn!0NG7Q8zKmGG1|GoVolmWi9oEIwn^YXKM?CE!krznrNxK*KZ5;wed`n z#+&W&6ZX9G8}T)7k-`HjUnNFxUFMnbQe*j)*nP%quRVV_zQdhyL`VOY)*H{nHP~4O zZ4A@)DC^``M?Mi%=oNMb4#q3hqT(Suja+MF^Ka=&gr6xbMf*p=_Hz430+PT$?W-dt zcLjrts!d;O-|=*AzU(pfc+2WgI>&BRj^OU_mA=t;j_9gAP(V|%2DegUxm$W2E*f_bNDf(N z+Z*OepqO!sTf{y|IcezHfv^~M5Tg@MYbpFUbzG(`x|BMV-kOkLR&*sXyY*ctiN8k= z$y|}~L~I$KocR}S=qUt5oM&`T=zi@D1G@N`7K4$^rm*eU0>nP_ReOM!z}nIF z{6YRDdJ2uA8ks!P79@KLz=DZm`SpM>(>zD@ zM)U)t5`_ac)A}!pYvfOvjg=5H6xC#wo6*jF@E7HsD6*^dzWEdNb>bUgbMUBEz`d4j z!JGky#1`CR@;|e)m2s~FlB67S3-gjKMJj4-<0ks$Zn!Zm<9S&ISX+aqMUA6k2g8PFBS0OhFIJ*GXPhm(<#21y0do$FMM}mZ!H)~4<;jKFDr8jiHb{VF z(f;5QW&^Dx9z*Mbe*TNOIkKIQ0uPNRYH#h7=#-O`qhtl|cD29K+qg;efcLWZX$Q1D zLMHLl4u}MPsktELwf!z`sLqo5qGRZZ^d)s7+s&Gb-e$sFE>+Uvc=u_WBRFHdQInp- zzh;j~&y6aMN_jl|NOYFhnM>3axB|@tMlltv3OtXT463hbI8RlxL!bw`H{^B$xDN;^Z-4C(1iYSLvA}t^0)H5ky|@rBd4r42oXcG z>Dp^GiBjULH9s#E2In26%$7nAKAVkI|5b)72c1y_FiHqBJcPTH^9zRrjxWN6QNK_v zF$+!ylFcSsP4)c z^^tVY{L_^TO?sJqKpp^h@b54N4$W_-6%KCD_mP1ne@U|>(VRqVqokN(%s>Yw8|r2B zoA5N{Gc=_wNX+gFx{%9k16+%_O#KFK&efKJj)_85^cTNVe1gtV8_cH6N#|qE$JRO6 zOl+L+x3CUwLxY2Ivp8^&R5b|9PHB(A4R`*P)Y`lDJc@<98c zjdlIdOA4!57EYr&=*t|1Lpj<&J=f71-VWacGs%K%n_zam7}f~|{3N!FL8!G8tLHXw zfV|7hOSC92R&Ntu9%q>osWF;Aqb~OkyFwprB*`17f<`j%;Psw6j*{W$N-ypfze&Ew zFXrDXy+AFZ06o%snfZ;K#xyrqF^#m+;FLSsy-(Od{7Y_St_w`)vVRBRcV(gz;4hSS zU^XkbUfQ1^+04?OYMsMXgio-S=es_Ey{8w4Nn$&Z5Q<5?sZ8OMF;yw=uc%6FPvh|3 zayhx+jWUTzf%BN7kuW&!x=p2KHxrMOl`J-+J@3{8~WY1in~Z0 z4_$|%nut15>&(ZL+h|W^6MeuKR5^c-Q_A^i$PMb*AHjne&zN^`sX9ccmv9%pO)RP3 zm+Bg4$WdBP?VGT}E~8t#pQ#2;+09L3=?5^an1y3D9RYg`LS zxrLn@r4%(E8L)+t%=C_Ujp1~>w8G-!A80##nERS`E__M4EIo@P5xv~yIJ>I{Stqb5 zdoVLwE{qET6`s^i8-KHpje?Gqc&V!^7U3{sAe!eY&cz8G(PZ3K-D!Mm(9>xrcgr)m z@7W%6IH+hYhMVCKBg5)HCcA2B-RYg4`|irRs@mdv!0kxk2Y}+lM=DL+#QXw^@`cpT z%szIqIz+wi$RjSeHiI?Jx4#YdWsCKU26!Gwib%|ObuhQiQ3TLA|Ec-f@ZDwn?Qo2K zm!709(XCf9Ra2MM|Ku1noT?&a#!S-caqZ-xdd1*%ZzVX>R-pM$-s-HgSJ-RnG3i`- zeQ$NH4jm5<;Nf91w`wR;oo_!9-!4pYe(+h?1=oR1kt3iD^_{OPMo+V~)e1F)V)(-swF`6J{;wg~hF|y(m@1g3%ZkVA>|W2;}gi-D0HoaZ|N)+pKifHw6?jC$MK7z&TlpArHr==`@U>pwLzd+YcEhk;wJsOL5Vh zm!Ulq@Ottj(>ZV#4Md%Szfng$EA*0SPfUt`#`D}7@QCx6|AETfA+(a6&r!lkA+C9! z0CpB*B|~!y=p!f#7hArkZ5Bhh#oJi%;F9VR@c>nj@5Qw8j)y;$LHcwe8x9N9RO%Za zh4=s_S~p%z@XTS48njJ;+| z;JJ+As+-ignc=4R7(8y(Sv9yeaDlszvLL&-%m<3=7mfG&5x%Z7L$YfFjV}(nagR9N zSXdK-uwT%0qNR*>N4jl|7B?xc@p#zD~aja z0;N^fO*&-M0B1wh%q;h0jF@STB=@l<7MIA^$98!$`(5oWFY%1_71w)1k-3svK#zrw zgGY#$xfdMEwN;KCL7f{17vhJms?>7aw#gf^DrMJl)sSP0Fp)(*7fTZMoF>Kwe3)oU zHG(ybkZ)9|QMh#Yn=uijTgE6G{}X4Lf+tV-6VW&6=Rr{`8%S_R%3{N1&t_TEO+fsTp&y`%F0f?1+Tj1hT^rhOj$6T zK1Cem25_4_yzZy38RO}~YZ1g$H588HOs_SwRSH$S+C%6=V21J$7WXjsgCV7x{)N5CRiozO^S1fgNHA7C zAIOu6daB_T*8Ip&+j_nmm6W;UW-k%h#~xEl23wlLiOOaz%_1nEA5;?B(=Iu|GmQ6> zaj31nhpQP%VU{!dtsK%Bq&rl;Sxmh^PyYcXN|+ZmnvM-U^xQeJ0d@n!nC~c4I_7LI zc&)jq>Bc;7d|odF@=}`OkGTdL-MOFQ(_g%0~Ub5&w8}Nk5a{gBjs1&o5V| zJ5tH}KnGput+=@`)A*@fwUxr{U8nS(u6Q{>eF*e+trb(PTeY}jkUBxx0%s}}wI#X$ zD}s({JE078FiD{@(iAj1GM7DoKlwLEz4)_2rx4BTAYa1O{%rKSo)CD!sQ zSuJ>n>ulLuescbtT1uQ0au&3bcx40m(k0^?^}shR_pqx7D5(kVbmEAaq9ZyDtGJX@ zj2_EdkqVID%vPh#D+ztA&iKKc#jr>?j{50)O_y!jh+&kEaarK$p-ix022Bm1tF#-V9F=l#4Qrpl^r!C;2GU9(Bpj22K!8HQ} zdm&!o4m%9$JzvOqoq5h3Yq8YK2V1Z|*vq{`2okT=J^CC$Pzb$_7z1N61@^M?SWMIt zh_`e{@_{EBo}i-mR-tWyASfSib&>S5#46ZMAMO0zA4`R_zGgD-14C6v#+g@|j#m@R zT4r|?!O>u_F+dz4cQn$iq-``_xqJDyY!%U4A zEL+z|B|Cf#z1L&;iS!-fvU&$Bu{{a+c)PX(Rb@@JF0N@DXTQMOdC61|R{#OmZs$jd z-1+V=uQ0Vz+z?@ zEGTEODOPpZRZ3-kC`qQory5aVK;7UfGO=72-Wh5NMA`+{8}a%--pXJEF_eC;IoQJL zSlisASZ}QUp+B)c^H=7bXOKROT;MMTmWpH1d|;@xqo-N6p=_8zALZUOJHQulm#5>* z+1wM5WDJ$-PzAIO?BN6z{q7y6Z!r>mmBCw(u8{bkc?2TlTy)JCw{f(9+k&;GDT``8qnM3zkx@smr5$%hUujh0`GBK zCO+|ib~Kp8_B7X$|B}_higN~Xl9}#bVN8bgX_40!pHa*+fw>*(kH2IWitLt4=v%b! zaWpkr7*99y{ieq#1H79;yR_=!J-CBB6HbED@aL#AY^u5r#5G7(i`c7rKPeWO9hpdf z4G!0BaDoOET}$&9flU%ynJLul=-K?|yfE`AAqh1y4CyVLsu$q<Ck1%iHB46115G6@JVCjT~oQ}hKl;aP=$#6n> zGQzIcx8u?1h>Wv*2Ej7MHB|r}l`__Nuy8!8CTKvP3=D$hBk^oBUlKQv6llpLm~Uzv zEMcxjGyD~pKb_N4uQF>R8T#%>u;DW9Z^;dcqlElF^J=n?{70F|jpq+m9PLiF_>`y0 z57VuVlxM&;8H>S2h7p=NCMuUKw)+c@q{n%FBK_b>?hiWE-IAX~04~Hg7c%pZ=0Wd` zd!~whi}Ysfe4B*vnNy<)JcFBNkii(^8zjgd+)DctstPv=|4vu2z2O@*|BCZ;8~z$| z9Y3WDD5Y=}9HC*;V!5wq%KS_vtNqYya}bWA�{lHZUeVhzVQ?;&QCL3%sp3QP&@S7xcS*i6YTEUZ)~857CNy6(8fACJ5O z$6Ou=jHT+|dBdrGLVeUHbX00+uJa)OX`=)?82#j9h1SGEywuz$T@T%7vZw@cA{R|G zG7ghnHxW=zuYe9PVpbi`O=2S$uWxa+bRWf$Kp|NHxFV|NwONi} z6JBGh!(@6(utcDzUW@L?ujZ>#D{Q@txMs^Nnr=$m!|<&5U9`s5@Kk~Ejy=jK^Q-5m zHqJ8V&PaR0^j7!LulaPzr~A~(`dYQIK*My?ZGPi+8)Ng@5}86%b``r#IL8h{pB;b4 zU62pK-O|j+OMWZg%+pb*W>m0W1x=_pVXd4X>JCY|jf)}P)i2yJ{M_8=$c^8~m3LKP zUeO6~jMr{5_%I)3%v97+uY`ilA5dYgJM$iM_9y&J_*SW7OJyi?3NC82X0E{sK0&PR zxvaM_JFQe%Yt-d%iE9?y#jK&t zXnNNrGfNYmTm9MZQ1s2sJ;pC!A4o-XC;JdvccE|=*Ve03QSy4KzVA0wDSQy$Hdf$& z#4piyFu9MdwX8 z+ITBeM9p%Kdrp|Ah=R5;P*%T+*Wo+cE%?In_f?j0V3tuhloZc3f5TVvY_lgZ`;Sy& z+gcsb7`TCakM=q5%h`A{-fnv>McZ0h_P-PIeP(pR5WY5&@lfVpsGDlsbGYSP_HWrM>?*Ijkqf@=tT!xE4!iY}U!Ze9lfr0tQTGD+$pp z;&Zm@@5Ce=MJ0t=M20ap4%0X8GGL8D-5pc|O0%Q&?M+n?#G6-9L@0Kvf0?^<8=6X@>f$?S*w(5`P!f z@}FU|nLRNj$Tcpzyj$qvAFf}feBm=(LTm*v&^tVTmG~JnCKt&wkpyQOV*=O7Fia7e zStf!;`AcCl5aUTuWl9p(8(p~?peWOgEhV-`?Z{068_o^w-ua&U7Mw-zVG64}GYJ+7 zeq^Jxp@APcQ?1vWAT$%5;3wX0^$9xI?m7?px58NdC_3%wA}yBcgL`CZ!0q{m{%M5N zf;kIj^dJ;_P(d_iz!S1RO7Z0XOj=7%LYrFX(5O zA!tr+lJ20_v$e$A-m^kicT3#GeUO>wZ!XpFjMt(~%2rbGYxTnS@o{v*b;>d>*p=Pj zF@4H8C4YUmRFgzlO0Nwv;j;Kc;4>!TV1vQTJ8quZ6pqn~@XHg*%74bZ(;6_>BfGe> zMpbQVc()@lX0>^ace{_^EI5@&z~h;}=^Q2peK5H^Hy6vN@|8S^*mQl^`Dc6z8w1`r zR6JkVY7C3SXa5uqbLYgf&gSr>`HFzncbY7p7sgYs*u`vRGcJQmxhXI93&FYWM~(VL z&V-ONrW zcw(7W=mt?96bi541^+%gGt^KmsuY7q-SwDCVqs+&dKt#(ID0#+u_Y-XvN&cIE|T8> zmy=G$SJTVsYqj0nR%R>n4tG_$G?{11RGGv#{_D&It~qrTpJl5B*9+^#y;i0k!Y6tc zVG}otxerQO?C?VE5S8diLE93&=z(^aJHw8J#lRPSgZDc#SSV^=h@0_+YzLTu*z$p! zcVG`|7JnJ^30@@T2z%=p=x**|^$2VLJNtVxtD*(Np5e(z1eM?rvLnp*4bUWGnMbE> zMk#z-crTrU8|6mCb>h!3VO$5zw1CC>Hx=8$Pr=&eJ9CIxl_&~}2cjeoI2rD^?>j7_ zH`TL=4Mqj6m%>;^-a(dGy(y2aWuP790}xMoQoU{4F^zxHoUD*}P#+N)1@taty4|TCrK`A51Ks3tG6gxK83IYt}zM z;x?}7E6rcU_|yirVxH%YF{fJ)PrutbumMq5P$>p^)H5bfPYvS|rcrM#wDY_UZWV_zq zu{`cLx0LA&Iwd+ipF}00h9R-t_%F&`(AYDct7LJ4j5j@eyRO1{)dk>&^G&e#|KG?= zZT(k0ksHJBvv{oHfi7ADSk*Z0E*!`rs#(-~l=w5TvJq#ctai$CQhO;kz5V#lQmn8Q zC0OZ*;^Mrt04T)9>YcepxsQw+-d5lX+dJIKpo9<{XFMdkNee9ov5D}5+{o<7$_j)7 zx>cPyO(s~(wbsjNT_Oz>s}{9q=-Hy~FC7;zOu-+tGRza9J`Tz(_Y`vZg6furvz-{i zI}m{l`cls<{{?Y~;~kwz)d){i+iFF!&#G^Ujwb652no9r?eT2po5gJ< zFoN}48IsdkMBsqGq0%C+J!;LIWI4-H`<)TAE#XyKJ-K$cwELmiSKEvylFdEuNx``# zzAB&3lsg<3QwqJ!T*lwFPu7~DrS8{49ImS}Af$eltP;dNBx70ds?`H+M!i?lxdH~s z4Z=yT-Y_f5obqHepFRB7W#qT(eOHZV);f`5tP8ZyA#$H`6XOjL41uhyP*U82eHw2bW5l z*aGAfw!Dh?=gfIH_) ziux5DGEW#~rG;9YurV^qcQ#?Rep*lFZjtS@$Jr~@x!N~$ncEaR;YV*8CDWi zwWgqn(n6zvvugHc?>~WG_}k1O@|AZnKLu5UYmKCQr%}f#Yawjt)s(8>p_VI6G0JK` zsh-Le<|e0#FZiQstI*Ee0c3sSnDLp+mY-W0on)jL?-TZ!S5cYVzG@4-TsCXvK25a| zn#630cu+A^DL*T^DKZZ4H*)L^ax3^}>{J~M>@RfwS=y&=gx!DuXRiK|}d2u_{QI>V| z7Md=lXlsr7T&}s?F^_AAfH*~$oKKjh!1CF{`D!PW7wVzg1D12$-T>@y6@hE5-(;3j zMtBU9^Ot2uU`w$M{#d_Za1$LJx~UbyQ)Ngs)<>X)`pLXh@g-YW%Ao2~7mY{hgY^}G zyUcD!-^f$&UM~rj$@6hv@gUU-)}iz8D6RthHL{qQNRE-lS=pe+sF|P|yw+M3(c;)) zY6IBbnIg|}f5L6a>T;r50$jCUgtT}|UBPwV*Hxb>4|cX!`mpZK2|64jFlfLr%SOHaT3!V{6I=p4|xVf8%g{^?iswUIH;D6kV7@| z*|@Az0%K4pOS|W)4_d=m?Yxqv_J(hw$1_7^hz3L&S=PL6dcM(~?#xc6Mrc0ah&462 ztKMUdu)W!4p?XqHc9?URyOw$`vxOfe z&Rvp_Rtc9y zrZu-f?(tw{&hK!8V`Xq&Xap`m^=#^N^){cX^Mo*(9sif9;np}>t>au+V={Z24Ee9J zJ;j0g+eoZmChm%o^n!otT%-TWH5F2b9+q#Tr0{2`BC>1K$)Gi(RE-t^r%p(tgua%g zk+b|*56~umcln9&C449( z!}f(c1_r|_>iN7{Aj;}M#OXZ~H)E4Y;oWK#^@2qZT!WNal##TP=0f7T`AJ)bIHathAAq>fB5dq#5?`~9`%^@n~hvH&v60ZM^1}YE{6Lz;u}Bo8dnTo*2iu`Tj!o zSz?-6Bl3B7DeM1JUf!W}il3q!H&!Tav!B&zxWJDELvazDOzei;f^VZ)Fw>YG+uprV z68V)JQllp z=c4uOPmSlJ^xont_DxQIje{4&VZIVVHRml@UtFVC1oPG2>KIgzbcae|(Y-;btBzx; zf-GS=^Q*;E^us-s^=d@)X)o31TpDqW`wGsJwZ-dfGP-Ix$7siMkx~A&O|z(xYHX%a z+xtxXiYiK{#gbUSruUX1^3|J8bv%t1jDd+CkA$;U8a6OiY3~5>c6AURgiDRAyn>EF z!d?EH=LYi^C;@-CAHzTL09xxB#_7x_I*mv(ed=GRqbJq#T}{>mc0GPa6m&fyyTM<> z1+CNjS9-41hrQfbC8AQO zWPY`WtZ7jSTbF$dO6OEy%S0asdxRANBof)O2|LAwT6bnqYA45ra9d%>!QECD;s-IA zYp;*fN5))4or!WrIkkncmaJo$j*f>&a|KZoE;mL-ropqWNy-rr4a#e;WXrmS=P?^4 zx7HFAGWNP?$1^0dE9rva(#9TiI+%sqN7}IQxULDo^hjHO3SZrOS*g#p&6vXEz=>*0 zYwOUGxZA`TP?O1U$5<|q<{<)X=6adMSj*B#UUY3$y5TBJF;^A9z*6Q(X1M<+-sSwL ztY=24rhbXF@s&U_ITk+ypNxv&G|?y!A9|)4LNZ?&CI?p$dmJY+x8{{4%cx^EQ+({VhhCXT7iV$ z1FPYEbxdTSJVMy2oQOY69ts<|8*Aqz3gQ6BEIKeYpe}y z(7OwSaX4<8R7+pZuT*_NcjY%@-S_!0KUT_MhDM^p^y%nTuq>d zd5VX_N9OR%9A$@iFdPsD8`K}{wVPW-EP|}RN!%B{wV%jc z57$^V;|t>zoq5@P(-na${kwMZR?>*ZoryFMw+v4X5 zvxH@2GqsSvxc)%otfbf)rLfp1?>Dtz?yq`3*8$E;q_Vu$p8myq5bOYT)r%C56U3R$ zc=INDj~;Nn_Ax?rc{1sz7lmi|8|oc}(o7$#5}TxdO6Uw|&n9b{mSH5IneYWJ5Z@I3 zM;8PhZ9db*coE%O`9^h*d_kVv&t_@(o}ZzP6JtOf2&EUsETKrV=0;U&nB$?a&XA!JaE=?1)Dq<5d%oxDO*XK#T8g*9RXDs*6fQ^{9%#FP6*6eq*9_CX?QNG5#q=0na zNvqXhM;18>R&4LMd7}=SrXeN0}=RX&_#eL5+wD+J0;lX#sdZHt>*p1E@i*P8&?j*kQQSo!sZTvFysBSv8+yG)p`Vg~4tc_+- zF&?+oOSms|5U)GhhK?i-m;TXx{-@05@F=e&6bgI_+~-NE5}mIPx473Q%1M6&baA#c z7!N2GV^BJ{M^T?0rjP@B0 zz+Zg@^oC|D_z1s_JmXJ=a?HylL%dIXZ$^WQ@RKl(?#aL8&T92wUv(qs9l8Q}MUubBLAUdVBd>x39j#mL}5Q|f(Y7dvfq50^3)QBtDU+B{Go^w~R2 zE}6fCI%Ql%sd!-d?aK4S$$ATCNAjmNy+d%VWhp#mc{K+xZSos?vzVLIGImeoin*68 z22-8WiJ7K{K44ZfihBL9yt9R}O~}yun4MjLCif-4_x4Z6JK+#~gsWjn@GqhgKUbTF z|8v&>Wz|%C)cH4m!-x}?J6@uo?95!ku5lD{9G82-XE0Y;PM)$-iS4Y<+sA)UJM-rg zlVfT$?SZ#C?%=UhEAEQsq4$Spcm}FJ*meHN={2l;o!AUU5aFOm%slHBXj%J5R<>BNu8Hz!=Nqk?c>8L==TPKtD|Q?Qkq|Fc?9r^=9AtmP^USt?t zIb$@@(Q%%C15@mo#$~h#tiWTfH4#54$&tX!(|?2SjTzdGKoPzSe3bG`dt*++oTD<( zRlzZaXr<)I_0!u&TDq3HJA?k#eN;?LvZmpc^e|K2TN6&j6P#(pjl}ZCOJ+qNn_U?k zMJ!e=zY*V^K;%XL4A9lPPBh(BvD;A$jO3rHvpvVD2kaoNn8n1jpvtQs5NB)rv5m=JL-j4w-!yl7XMN#_*i`ClrBY{FzUMDuF=7oZI;Ux6z_HL8 z^*28do%9XinWF~t0rplO>AXEwJ;`oV-&kB&6KjK2K-pq+XWpu@s1UbKi>SM}()v`_ z3xj710Vh9$abatUOxbvm6;uM%#b4R2O{dy>DP3|Wz|D%sCAvP7|3*^j%cx(KCEzK^%N@R#r$b|!YQEu}-$+x*eYW6(dW_*Zy(NlwqW@I|I*^W$O#QD!?Iv35-% zf{q0q0kSs5XOtJqx~uY!@=ah{vs5%k>1sqdMOt(O^jM)b_&_A6b@6e%5?;V0 zdwrmgPUPIym#U|Y9nsJAVdk>xmUfqK5xMDU5-@sz%=Xaxs>^hd26uIC;14@qKYyla1<s~I5;;qctJ>13WGNtf!#m_>Y z?1Rn?*x@P?d0@n=&-sI)X0Q_K8Jul#W}gCgIa-_|73D5TQQ^fAl0PR*&2;hG!k6)C zV+3DNKSfkj9C?R8Q!6h!iK(uIU9G7$dbC2>5~VK8dMg<`TDmU9=8p{Qg+E;<+;6nD z;07ASH*h0Y8Tnye4(oGF7xzp5xUM=rrdJTENY^ccK}_I%HJB9nhnMI#Hj1wn=tAy91SGvLxkU4_=LkEP zZ{;rx6NHQ42OKFhu=0%`GItx^X4CYRYN|WMc*uvXoK}F`9;}vMk*k4ku${&0u}R@1 zYy}>T)K~>e_f8q>sZGyjn&WGzlu}weBNSAQB5T(k)XnkE$zXNdGJcz0z?IDm zM!sktyI1a--!gllV=Md#lRzCKI`B7jMeS%z!Q0Vf^bZlA{wVhp^F-gR0+$icq{&u# zb_Vee+{2BF$?&yRuXzqQ`jW}n4m8l4$fs#4{!2(Z@{c%EJE^up{d|8j&b;KryF!EC zE(WYpUGK>ph?F%48l4h~v4z>ctgL2L{E}sixI7TyQ4E`@mB?FS_OY1V#mpk5hWARE zlP=G%Wd`#_AUrVBTnn!mYV;C#6P35D;3)Eigr@s)c$Q9wB1`pV}blADi z*v!Pa>MQ5OPijJVEwdEQxA=xR*5ak7?whjZ$)Vyf*qhuPeKt0o{n#{{ zNih3qh2**T5Sk&l&?zvW6$A8-vvl4oT0V2MyRSY!z$8`^HVB@?Cj9D|J} z;ji-a(A#XO>3Q}5bC;WHZM%wP26(#hkF?%kLHuIbrp8z&!!)=BH-&@Pw@&Kt05~-A z08B!$frZ|_!b!3uS)WK#?lXH)$g=28^j9%%F^1&fBEeks8c`24cdd@UBivAOaNF3c zbhcJ8=1(--Tid6H?DQ+SBiRWhE3X_~Jgup>#4LIf@r)ZuHUhhdJ?NJEC3b;*9*3u< zN;Cb?3{YE74;J7rXdYj-5te$%+4MYpBOVg_f5K^!dfET4_Y-HJXuFAK1gbHo^lW~& zFd57sTUnMM9erb(J6h=@o#nOTxUqH{7uKh+Q7i@Zto>Th+I%|Iy&wE=Wr(xg_u+Hx zFTJB=i%(FlsG0%2o8V`0Ua&Arnq$y2)QXy+X0a{Q<@`l;3*S}U=z0a4=!3&vJ(vA~ zK03ybKaam4bYUgBg=Hn6rZizCRY z>T|+kxIx{xaeg0OVC-k6ay`i>2?J#y#&KN;%LR;g@xO9SvW8jC61xGdd<_cKRPv-& zt^vOOps_KC?yLM4KQz(+b%oc3a~2^T&9}rqf@`eIc$8Qf*Wijf_Xf&ob~Z-8fJ;D8 zsM2Jtr!m|rR}0d(8#g+$7Rc5KVpUdAHYrObPuM;fgZb+g|B}w#uy>KVLL&Puxr$n> zH(-#j0yvXsr*GJo=sT>b$2YkXH#jdJJg`sn{!{|!rn#H#Y3m%@SYAh-RUQQirJ+l- zzUhCpc~o!mu;cICY-%4bP;GN`*L7)t-VsIT)s{M|cgXE|1*8&aqP_>03tndD>OGZ( zu{RIXN+kG`-E0JMj&chPfpUh*P@~xm+$sLDP*#4APv#b(`v*xQN-Kl^uwO#$`C8_L zCcSKB1V9;rramziOOMHytfm#_)*6S||M-snL^TOaXRiiMi{RZ+qxW41#2@@bwG78t)~0f3jNpsf7O}k3=$$jh zs-w0+MQ$WOrnQHR{Nb4FsNf%_M%0edn!1}@owEv}hmy~tow5!!3suuH#Whq$s-otH zx5YD{7+4hS!bPKU+#IVbG>7xiiqx81OVQ^dgBD$SlWeIhy%fJ%FBpE23%^ zjOXJkX-1$VN9iu}Dzj91&i0L@G9%20r$dHMtuA#`{jjU)w29P+2m_yKZO}0?3yfeU z=1z8YM#0>#WE)&4fk@lPKzTC%gmmEI!e)nQz1#OEg!#G3GP9-iY7cSCGhC`~zY-kG z3^Hi-K%}~rfrnOhg|`f=}r`J9lvG7S{fWTPu`3~l)U`=I!`Ugm zLGTiHQ@4k^d!JH;$QpIrs1jR2=xo^llWA5NNJdkD%jG`GlL;?<$*R@e!nLe9hgS26 zY<0`pH^^|Q!{HTmsxng#nwxw%@>Yur9jNYPw|dvZ&HQcj*g0D4WaT}Pu}^>I8*l#F z^f4%Au@u|1N8(XW5nL&MKF^BF@N0X`xVhv#@RhozYT`6RvR?k2wLSV4T56Uc3A&&z zX%BG=<%PBF@V7c9m5$FijEtBPYnctp)h?azB2pLti}u@cG@}t z7rw3^Q@_U#wFqeri;jxwALK?gQ(q8BZng@XhAIBK5l&tTV;x=iB5K1xFZx#V0_J%A z2l1HmlI57@(j4x-THQ*^SBqNgxUNnhK65+ax&Kjg7VJ?gZ5YO#;toYib+;?=Oppinp!B>sDyJILIc>GNmwMs-r48UC zVG~`i>yby~yw_(DJK-&>rok(F$brO4ED{{lesBZKF8T(~S*lIA8-pCLOf~BSxm*p2 zlR!Z;FWE_o;C^e5ng0AX!Xev~7(Ocv2mgX{dJ16~_klsA9 z9FsfSEhw3N8otdn8R`#lHNg;k z?yfrUDwQA3B0Esat-<7!>}md8)^Kw-U5PBD8cy|l&rrryIk zESoeKrdhL@`|1ZR=-cXT!L8;0@jc@jYMa7mxPQ{eYDeKPoY;N}vfg`oymBwtoL=Q9 zr491T^0wnnLPVFTa8b2cjm}GdlBR)MJY_Cao#>nS!k7tWz`vjsobQX*3&XCqWjf$y z!NEBFe=`+9V|%~!8vBRN$Ku*!VZCo9e8$Ey-PBlhlU0^4Fh!EB02 z^ewKi-iCVViy)X}!U>WpDq*KgtGA6)bc}L< zE<^`>Y3OwE)>m2`sUf>HxSUus1~ux01V&YB>)rW2;2*G1>_c`nXMwf)e0W>S1Zxis zV@NK6i^Q|sJ=q%YQ6NprQ~N8(R0k%!HhO?9L zsaKQ(rt_gB(mZLMgJY}>{E%D{O^AHPV(^mx9$3H~1CNBW)GN8MHp?hyxcz9uB#dS& zd5@y))_<|7`*t`E7&V2(_I5Se{}en5=V6+#Y1BeBMNcq_I!Ay)+-GwI+z;Lv5!Nbw zv>I!q@RN8;Kf#PLaObGl_lRovFa3{QSLPU=yQ;mEW?0R@OC&DO&T9BpnQ`uGElbGVxUH>d9jSI0~;1RPkHHbLJEP^*G zUk0DqM+sG}-D1P^IxyL{9b_>j{I~3i%0%S`8Iu&96~VvNA=nNx#D{2Fu!DI|Jx~h^ ztb7DCBOij_pa<84XMA;GNF7%rRl9`uut%U==yZN`@)Bk$2&>D%uF91He|t9RCGpo{ zi=?w^W+I8Tg>AF1s;js^nSZ0S+XXR^?#qz4U_{{nMq?6D9`t}S!jBagp2M%z0_H+| z9N6UQ$wZpJvu4uMKm)B8KS}(AALgU1v6{ykF77nnLCb4{QBqyyG5*@z18#v-V+dXF zAP);MPvI~3N>B5w-~O%=ZC`N~S`k5`4IKA6Coi&Q&h+Ql6Ri@W%go0ruAXFR<#*^| zsNwGX;;X&=po0wJN0D1%K2_BGCA5TNwO4dIdp1hJ%n#koEX7U^HHCX~SCgx>N>qZ{ zG3c=Gz#Zs@nUC6Knwf1J6n{qvQp1#HT0M3p*g!V&o<-TU&G?*zTJ#DbS6`L2)8yFS zLUa9C`fg^IKT0f%ICt+4#MAkx?r19?PYu%NEZLY*#ru^Eu@jt_HSIuW?wGJGU=V)h zH#p8!R=LYHg+;i=o=|8Rdpaz#HPr8j;ygVeAG5~+U?b}+g5B-W(|IR2;hD`g;}47d z+(e~_|`OfZ?PDn?Kf+0NjUaYHwXu#`Axus$ zBK@abPW@iN&$|wQ3m?)KWwhw~YN26QH}WN9kH_3V$KqL0*)h#?-(k zg+}bEMm`FT8Y|lHIMmJPNDuHqRa;7IL~1_gpyZH43q%}k@6aTi}4Pt$ws(h zVIhgnUCfK@bmJCWP1RC4bspu$KD)bu#nMc4OKK|bW+$uwx_?YO)!f#R7Xbue#JcD;KC*Voz8 zDp-erlXlKfx@A7qTgvCu=3*;XO>;6t+gH#+zs7dvA7>B2D|r{N9YGJi6LXu%w%+Ou zXhI)@&5$0mg`udfFdv)KK#A~qW)*xxjDUHaSGC>rKdzy%&!{2w6!14X>GoA$al>g7 z-cQbTi>}#ZkjV&m>(iQ(`nJ8E&RDhZAMj7&b?h{YX2||awCtcAv$zytRMw85cxb$L zhuId~RXY2hfIXIvje;%E+h(PHz?DbRtVz-js=7$n+w}LzyCo{?u@-Pq#APc5`{s@3 zX25G~HGHhIb2wJHA$Bzgd?3nXWvdT_X~ML0nv1pXGqtrH#!Y1gwTbjI`@M{@l%2<3 zp;z&RjjiyqsiQqSqB1hQDF?W$Z{*zO-|5rMMtX|O$+7WO%%AEHqg-$=dCmWV+J;Zw zy_v4BFSqIwt<*3br~1^UWJk89+MX+jc=)gR2EN~tM_Z?w)G_und35g#F$O#U&uynM z3}vS)`#8{5P>z5 zuBe}lzsz1{l9>&Y>g4B3=!@n4%?SuEH}St= zYrYF879xqdRHXS{+@;kBo`4UzzwPJIarja>N2wxQ1@ripW`|52dqmGhMmx2Z?Vc=i z>Kh_Cw2taRDc;oy8`K+&p>7d|!3M$oC~tU8IxPF-wpyI~U%V7oqu!gWwb>hm6;fMf z0o+EvNhn85Bc*Utyc;vcUl!kQo~AihQP4hNG~Ai+Q`@FogKd39pyuk$p`;?}@=R5< zea#Ii+!MPMP-t{#u`Us3ZX8sSWZ^*SWBrecTa5=+Mu6{Bj9^e06NxG--1}v z*ZeKvBlSS2sBaIlOnhRK-5cl{<^;s)c*_3PdZ09*SEbb2NekKjL|kAMZ+Oaqnc_ma zA#)LaMZbtFrl2^3kCG#d4GFUGgwerJtDL_dQNtU`jAq}HQP9Ar#GRvl%9X$c>6`Zl z*&eir%S*irK1MzKvaT)qzt(2UgFfAKsV$g|N$dpbqPj6N6Gkva;kKYboUreD`)K9O zX1@HCLoEf5TB+)7YOL9-{&v2QzS-SGn}xsD9w;rD0Z`C;^Y_Dp;6G^tP%iTUTTX9}%RM5U+mcjO@bKzor23%}>3sj7qBBhGczz%Fg8by~P8*!b? zzX4DNYg_b#d{s0TE-63EDTo>hE%%lRxGyiD{O+&*do_}BP`qqkNEj`w^4v1TdS&T>N<$OoXME@K!a{YC2Yw{&BR=#3p zg`QCxIRQUPt&~U5Q~91ChHk;o^hzL^*Q{f(3$t1ka*sGuGvQ*mOL#H(gLpklFg|%0IB)&ZCsJjEuFP)gdrV!bMr>@-S8XESKQv!>@46q= zoBfql8Z^iIsrkVneT*FCX|CB^EXp9?00e>Q*RaUUck1{(FMQhqMGobQjlv>{-=Qbw zIc19ak-l=+#cVRCfXT=|FdQ-T3G*S^zZat((3$vFcXa&mP)AtC804R*#tS17TPWk~ zhH;5r7e3D#>UocHxSIM9OxK4}_td_L&qzj%0bjL$=-I}Lz+jh5Re&SJ64ZbE6K*^H zL|Ug!llH}&47En&t+ntyQ-)hB{nYp9r{Pn!Hl1iBxGRf)W%tF$>kAnLxo#Iwzp4D( zu-v-FBJ~_|&ivtQ%Z=eC>uvOH&K_3066G0&%uc0YX*93t!xHdi=KpuYKeg1{XJ#Dc z6jGU3+ok`Vo}@P7F9J{Kv$2sKVn6llrXxKt^EP)z`U?(rZ{RD_d$d|iZ2A>tjPD8Q zWb@IKcOOdJY>lae*t)XbEi^g2Oa?jDzgyTCQ z=U5I(wsw?KvFA!Sah6I3&Z=M7)P1c3ROV^z1ksWcP}r)f_O2ZKHluXYf~z zPWoO`N{YzH7oKZv6TLuX3owJZU*wf=62B7>yS^Zs%VMI0d53Ewo~8=HLJ%<5s6*^f zYAoIgpJQ}GJd07GqAbT2$4jQ~;f^S+;x~aS$fGr$e24DyZG^(g3~{H{AoIQ(G3eoK zs*GOFSOunNGAu(cw}G?@R@VYxhV(u#oo&eVOjyT_3a7z8U~j$J*i%qByg0VL@0bX+ zs#fj5U(yZzoc7VoD=kdO1vRXa{sP!TPl!7hY^>+YJPht&3343K4HFO_E=c_6y0IBk zKKjovi+Bp`1%chIiC@T_UYs$70E&{fvj)+9#Zp))IF%|+zodS+R%09E3aGU#7ENA; z2$(h&rlM(n5qO)eL-s;O%kS`S=L`DGhSgWa>=7Ox^S{Hqv zFyB>0uBm$Yomq3mN9J0(9IRuvq^=ru6BcPF9qZ+$@`nTGGak4eW0PRe{LPoOr&2c$ zR&hpp7N?bv*P(vX2_;fx=qAij^kxpAC-+ZK+U_q?SsBV{mh`VwP~6oD^JBU0@R5UzPGGg> z(~J2WSYNZTdKlK=X6h)LWo`}Lb{B>#*+@^ta1Jpzlx8%ozufmYrx{g1>q@>1rQtcA z;_O>$lKs;7Oy@zp^Cr7AYJpl9#4BH;M(aK7Y z-9_l}dNtI>bxMWQr3oB1Tx&}n41c30=#BYeYBz3_(vAFxOt>rIe{$NW zN{-cA8y&eh$&>kcT5%@pfQ|J80s0Ez#gF160EJQu3#kbGyXTIvl=*5_M*pSO+(UaE zpF%Zpwr68u&Cm|d{cx&znkz3Hzz15#i0@>R9b>&I!VI{8eaCg??;G1uBX2<5c#zjh zb0^RQwg@Xr&GifN6J$4>$E78#rEa!D>93rpPi<>OpwmN$&GxfoD*GATqVd@E09qPHG7GBFv zL=B55y|TUx{jV%TyQMNY7-#T$LU5w98IGw$KY-WgPy0cR43+@N@ zfkbi(UpxEOzVE?eNtuEpI1gazFz0MZmG3eU(q`-vm}iUrZqj~FAMXMU0~UBfRdgTp zFSRODLBYve%1#iaT=DbjPNRXrP`E2gT~5!JmzY|1ZE+1-gc${s;Shg5d>+=u-3eVX zOKKm;vs`I&hw+_@vF5>w8qN#o*NLQthye*nhIR$c#&_k6MhU=%L|cL^6F1 zfmVhf8X9F!`;U`&{s&vr7# z?nv+)McLs?8*&_$gVk6Pc%rRsfZgg(%JOKXM4QaaXIJV?z6*g`qB zn`fDv*EfRhE-WJdfK3TCjhgmUeQ%b{U11Azg}@T+6FUNSRx=R6bc4Hq)!oHl18^g> z&AG^aCl`QLxQ%;cdX{!reyhKvt8in8Z`gi#TklJyFkjT4IoDktv1axF$8P0mrX&15 z_My7f*UD_9<#TuNbTUW6NWOv1)DD5;@IIbjKAC;nQ318*C$U%X2GmUHo;6aPO;hR$ zY_|#Vjplj~t6pQg^zX#~h*Q?=sy($XvCrrmba8o%D;P|X2l!V|d+1mGAFh4zD?uUn z2Ig=f_W_A!ckm6k>*_A~khd1iQICSXK_5!p{T}UbE6YO@o13-e8=>N9H^9)0-efgT zL;V9{;5l8UP>aKjLXhL@Fk8bB_G-sCa}l+J4G<5k@0n>@89k5DH*{9N3b%t#W}jed zDatwxe`~B&((VhNDC_la>`VSS*D2?RZ*gK?`l-5%?c{44w6p_MH|(g^S&MYv^gW^X zs=48+cnX{CUCIEvE0@o_1gg`C2|w&+MkV7_@Q&e8%IVi&4KRxT#OJ~PwGSlJQOl6y zjM3%+Y(c^b?p3Z!dY^4F*BF75GIp8glJZj*6f*k-Q#9P$rrkHpN$hB_&)i81pcDG` z$7SUuz6m0v=+e2h)2bLHbKaqBaohif)oc7ILS0_B-Eiwogbrl|Fh zYUH25#cNLx>1(yPgL$Agr*6@ob0X=cO1Hpx{VYmIP@OcVF`S-~!>AYK43KPhRpRAiYz5d$I6*96 zFG*{}q&SCG$v@DQqLyRZP-lZvKq4U3Gtqrf0QRrYV~LMzt8^#F#T^4Tl^sq-#f?aH zzO@f3<}WkkSz~K@6e27J;f|g8xGblDPCtHUOXXH}GPW2Y2}*E*&=7b|9mn(%*5|xo z-+~*|BnN4fmbWv-h!+ISkw$0FacKlJl{VG_PpQP%?6bj%>Jlx<*l!g8%8sl4*@5|T zKW;r^C>7Ny{^Hb5JV#wDwXa{)ELwl4`4AKqQt-x}?XgXiGU60!sjp*99kMFdCut>0 z@n5h;!^vuvSe-iVsFAaqn-BY&d#EU?JRg@cML6$j&UBH0K2K}IEax99E$r7=QQhfT zh9;=5fWTz?d!cF3B=1>z-XE9K>{fxY|o zlC=?0;;Z_vJPwS&?o`$Svhx@JzV-&{j($&GE#KizsZ{Q`@PhCWyp}nK8p8bNiZw86 zpZ16>?RzLyuwS6NK~XpKRZaY%{z@vy-;y??#OE`;92SFnInic~C{P^B;Whh~uavqQ zHbhisKrE?V7d5v{s7iP4;+3T@WcAdC+<0r`2J)EWL0

    RSE1~xURAGpJoH#Ju~{a^k!;t^a% z9+u%wmE<@EFUYOXGsW&wBUiM!!mlCU$3WHx??l<@5-ThIs#~e!$Xf9Wl;f7cRKy6_ zj^F2dQir%Y_(~y@*0#T;7fN&Diim&1XTCa`J3K59oAp&`2}LVQL8Mh{IcOzP^mk(j z*AQHl�@zd8m8LK$pVJqkFQPZGuSdb)cv-hIvay3;m5L2r+e*D2g}F^2Aqxm9}@$ z8q+O&@lr#p7uA=jA>JlV+l$O5NsWlR!g_CSbubq6+3HNiNn}$kg#SQO`5k>PZoR#6 z*I_lC?as>bV575^YR2+O)}Z(;Xp?@HCF|6&XK*$B9%de*?{%P8B@`yByI=ax0XC~R zm_j5A8|m8AW8$ekmdQhB{(iCm+z>TQ}QAz2WrmypTKljo$Hm6ZO6!E^x@J$ zYPwxfnw(P|^Wzlz3!HGik!q-vy~@6j)>nOoXXWz_UL2s{?v$DWvtR4sw zg)}Ew$G6wt)4c;KWcS<-_`>u)&I_E(Uo=m92W-0q=+5;|Y)`O2rr*Qm43ip=_hF^|>9Pav!zUOn_deka zi<|LU6}@+$pglH}?brqWp;~zQbo zJNjbcms8Kf)nP9p3`6cu`Z9HmYc^lZD#TsFUrANCW6oXvg;js^o9bU~`D zxdyx!-|wGc{>*-^#}MA6RX&flO@E>~6K3LXJ-zwunSY60^rs{aX*eF9us?=xI@0yV z^w`AnS|N5XotLBFMU(^j2Q=b(8d01r-&fb5*83`DFWWWq4`TH%RXHefWr0I)w6_VI zE6!36xe_@RF7}o;r|w>?CW-KTXa0PfsTTA*bDM-YPCC!?~_IWU3a19 z%Utu3el=KtJ%d+HD&^`1lX3!ZrSySo$b5wDeBZc*EXNdKhQ<7s^_i;4tNw5BHu5J) zN_VghWzAPnrMWg_FY^QVNR8)@nsp>Ip+g{5ylTiw9qE+i&Qd*d)Dh$wvVx)nFKI34 z%jQdS32HmG$T==Bp%MICA`f7^5-n)<}-Ldc9)H=#XO zgn9$kue-?qp{wJKTtQ)~P6u+}B&(~oM!ak03BL?1%r39a@>Q`%x~9|Zkk&py>Fasu zYKVwV5y%sJKnfX?g~7x-h9pYq=tQTj<^}dnQ6meCpRXMhSFPspQ0YSjdX{+pUX%+E=7 zG^nzJ`W|Vk=lCT;ZSVoSOnC0S?F*PWOgAhARyM9OMZmx80_$N`w67a#Zq>Bvr_Dvp z=-TiHzT7CLjy2oWKZ9*B+G#JD#s|BZ#qEh9ue%bk@Cb!dn^Ep{AZ<)yQEaHpWgsW^9p-x=~^ z$=GqG7r7>qU466ZYjFYXo;?_@Q!nJ4(-YHJW|LYz=8gXd+g8n`#;aeY?#3skkFn9H zCw8}M1}{n&eFG2DdwjhudsjK{Tae+h&1?L2W?0e(v4K3nUy68drWmzADJCx;O%Ebd z%|xoIc1OP?pCun)b74S^@JX>UBG~TG19qf&AMDNas9T~~E-lb6v>ojR8W?klyE=jK zdfPxXB>bya;VPGwlqw(5s+wob!}K2_g3V>-qE1C^^=m?gc|!ed+(eQ3UV2B^SBM0o zuu7mBJqJ|fyoeRj1#S95;xSs+;?Z-}We>FOgsQ=6-tSZ)PgOf2rWWcf*YGtBH=vFP z9r^CeVRKk!Yoa;w1Q1jyL?;R`+rwL+-yaXY*snt+;q>qWb(UI|Z>1fjX0yN9?`$DN z&=?DCeSns#^#%g)z-zD!Ii8B)*Wo)uL+IRKE%jyQYCNw!D5OD7$BplL?a+0(N8(o_ zB6~@V_U;P7^2T8cd+5gJLfY%ASduQp-I|at1Wv-%E(v@*0b?R z)yePdE%~6?TRp|Sq6dP{xn02|?=N_jY%JAc7luZISM_g7+t4q$CpR5?Nv?)9DT;{+ z=MPo||1sm4s;EBu16Fb5BI;{*PhqoI@=m>%)RAZDn)G{dy#B&tihtWR@g9c6j4;a@ zoAi2ylK9o%L+D52YQaP*vzDPMyQylvrrb!FzzX29b&B6gH-*bHCnnwjmUhtmj_>B( znJc|J;4Acb+i*kU->d&v9rfwC^{7Ec3}|N8(LSl9`j6T^y(#(3SPvJ9`Mm$_zakBV zlkk4^`$$#jy5L{+kaB~^$h*#Q%ww)QHJn*1cG-V5yQy``9!2Ch2X1Og?Fy};mV(#Z zT6(&*RNIUC+po-1>LlNDvx9I%9Rc$r5=x}~Aa^nIU3KfrO(VP)e$@EbD)cw-7!ltrFju%gy@X!_0T)qkf2(jo)Gm#PRwc45`1mz8oC)Lv42tHFc93`4Q-`u2)N1}i+yZ$Ca^mGd38Sho4|KSnGM$67?P}S>l%sl) z1jr|oVfB_LmZ?tVwdkzx%5FM;nAtnsc{Q9j{;8E;C9P%5z|bQ8hH(XSAxdyJ;5+N7 z&xgOOKUnSQzZ?i*#poXHUuF}j@xj6BXl|Ug*{bC45MIPpA}nQ>crmmJ^r0KTUH;0* zH-T7Icm$P$uUCiD?O@N?u86y;Ny8a$z-P$X!FAo3S=jLOsc2gr2S>~6V?%q_KD^bs>3 z_-S6wX|MciXw*FGm(q*>O2mZPc*cc`8pH6BD1B4U+LF1-1fYvOrA zi_|-0QPV;^&bXvh?H}=mrzKg)Q3pcY2}`FQETM8lwHpJOAqd!?K%-g zZ&U9Ec9myRc2G(aI@6X~otP56NqND24zx5Uk;Skq5F?HR*L=mmdTY9w3RLbD9AjeO zm|!R`v;zEY?jOS|l#!oXpU8@49Q;e$Dz`}LE3J#`kIb*2h0H_H+&&Z@4Q`1CZBaP^Q@!L)dc!JMgij4FaOZ^L?Bevd>^Iglcquzsy;*;t*igBF_$3vfg{`Bj z1!r?V+!fhmaVVLWJ+=)66))% zsiR^My*jNW7PY(LTZmIkFETb111pAqnuWE0y<2nlk(-q?qO;`4`~wD>-=#jzIi7-M z0j#`;ThGNrI2Lc`IT4oPCNnLt*=%QHLC^y~1^yGyd8)#r}9|&S5y&Y1T$P;MZWa+K&RVU8UGPG z)PG@vurKJ5kg1r?;Lh`R~9SChHou=>H@OcA9x-InVFZc6)b5hvU~L#?7) zyAz#$u_tm!pAqt^5%wjf4D6xxjQ*V6&RFa$%Lw!v?;3Jeur_E(Gy~(z&A1=54^0tm zhl84+eMG!I)u`_~n3Ye^Grkc#RRp%8TkM@hpJu0U^Mu_{1P5S$t1*=%q?#oT=R?eg zNb>=m1DeGw%$y8!5QU^X8LeL2&pY?A)3hSoMc4T7hS;k+?xAzWCh;q>*f)3FBksc( zB{|TEeFb{ftEpGTf5G!sQ?Y~4k;p5SP>a)>kiC2X0|En>8?O3hbM9<%4LZ>Zkv^?y zsIYr?7E0_E38d=9gjpF+|n$qtKNAR?_o&)WV z=|ix+*<1H!HPriorj#=|D&5UYw4hm#EghQc$&m<9(R|+)NN9r9xs(C55{LWVTXKYR4AN;*b)1NWt z)$^fj`97?jl;Bvw@9_T8p2+yIqs;$)X z;9H%$+E{!kvZ-!!)iZwu8j4ydiisB6u?vVaSh8kYs+TA3Py@QOn9dx;KF|aGdD)%v zE&p-vNaY~+7#UwCSt)2^_BlYra=w<@7nD)~xrNa5IWM+YKV$cIcQGa#orPxBMm9e> z7l#l%rbbqVI$H{;Cz-|FRE=raepKe7$= zXtN}H2mNLiO6${FNtM97@O-@=|IXMOswn2*(d?x_5BSYc;V0snY@6-OADU=AwlFS6 z-+@l@re z`?q+ll)2OwqtXe7E!)?sn{>tn8<*6Ewm61FLWq ztwm~Cxe>EiD=+kgcatjxo2ov&2sckl=X*lV*VanGkGYFCc+X$p?)mn~g~1Z)e&`wb z+%e0r35V4Rr-YJfl&7P&n|KDRrtjxIaa%o01C_ZbS1tEj?V`Jkk7cWI5u9z*@D9~} zi5ad6_4E1GvE!rW!Z>YU06TQo_{v_!>T4I+?p9Ug0Lh>S8WjK0xT_t@OmiQ?1#X0f zTc4@2%4hFrv9Dg1s_rNeBuy$b7hiz*Pg}&R)Q9j{BMW;bWrgoz`@(hf0rV)QrBI!! zWYx=_ZHx|>Mg{t zc7;A?7J*y!!*HXgd7WszGJRfHWE7C<2t}x4_~zKgu&qveDyWU+IK&qynwcLw#NOab zJlkP0W(i-;TbtSKvZ$*lsZ^e;B<~UzWIg9j8(HERwS`r}|1LHmHwBEb7UL7t#qtYk zuM$KVc@7>m*K-5uk%x9MRfJb|agn!nO7peV!T?weZISL_r?|Ribzy$shOdn#m>02Z zvT0HU44Gx&Xg{R$#fO7e*e>!}ZkFw)3h8zH|FIXf!`RAj4=#n~l*Mt;$kW@Jo1vXb zxD-gYZ7px^VXZ#{xr=E2mEk#@SzM`!#Uy&1FU14+Q|Xy>TpPsQ&d#N48pvrCdZC~3 zcSEg_Ca#C{7$#35uNBcxgQc$F{tId$qb0qbeQZR6dB$DG8}135l~kE)k}<&SEsX=! z;VhtJ1=J;S1$qfLGpmZLBbuszCZ-xMuy)Qn@QFFfdI>AgldZiwXJU`Y2|)8Vvep|y z*oP(hr(!3}F~awNuKlZyWm>`v?SOYTw%Q)dTHub2#BsZIic;FN=z1zlZYYx`ZB5OLscc8YPs+ z;m4`#=2h61pU-^M7SqR=HtaIKTxcN7OD3sp!}-;US|SRJo=E7Z7BPw>3HI#lUpehX zJ$JqQocTwdL0s`YqMbwvH%j|T&g5q%973Ig+;B7BVQPep(ffUw-W~idlp!}m6Y2eM ziqXZ_(DO288QTs$i}TpGU3-WaW(t7^Ix83Az8m?(9^z+@$&GQNDHcjcFW{bFH2rTz zj_Z7?tR#D?nCaSU7!$w_K;#3TSoL>wC`r5)wkamTr7OgUJEL48oVEk2K&0Y|gb z?17FjXr?a;7PDf>D^y;rT5zozGJNp2zQ;PAPzKxur@ZsPY4^lX4t`AEkL*6fP#$h7 zd)+Q%hLnQz6aN;Rr_X0U_T)ituvDcnKGYY`RD7Sgizq8EagDLg67`tI)*y3o<&Oyk zcOCZC4MpgGTx0x;oGvgO<|wta@@V>U!CcG@h0keDZQ(BwE~;KO*N_@Elr5+$AS5*n zEH*eUhs^70OvS28@iYGUY2RYoCVMmrQ6@XVS85&iL&Xzphn@hNLJj1;fp1o6Fo`Qi zeoeS;MyhS#bk8AENV=Ih{oq^feNs32j})t&CYz`&Kmn9Pc!XUd6o?y5&^xOMOSn_C zP5MEpkE@Ci;e3|Uk3Ubv=1yn-Ie1R(q&D*T(uDO$H zrKZrgxZa#CFo|`{^o}ss*2KzcA9x>@1C)S|WV6&AL zih~UYMC_HjauqDR5maQ`*?Ea;#t(6;_FCvIHs&a$jQ5gO4*X)8t3^3ZSrC2#W}~^% z$@IBqS-oLSt}$Kz8UH}LBkVEi@dA=$foT-4ab^AZX8>#%X1G9n!erxFl52c9|QH?>HYE zW32biCqvqn*g;$k?qv9Qo%^7cIgu-4UqJcu&WOc&6On#dhd_khpFSpR$s zodNiZ(dX;S{5sc^$OL(re%Xr6st!*%zw)iQ`r6V^as7e$P`pJiWdFcx*hcmNDp&yf zsWC$RK>U+A3>GzhnAOY{u$@?pE@0e2JK1USs({a$t1QC&u0zHK#;L87i%~;Bac(YM z+?ZyL;Ty;^sT=Mk+6M3?Q`T!Bl5Q`qw4lT7#0zv^`=+IX!lW0Yj6IA=U!t2)zx8!^ zRnU`IZ#$9-8~g3`PzkzY3pq+?L#2D+)98*}SPh2@(&sQIBFQxrhrs3LbmpR316=Sd z2~)um@+NF;xc$MUUfC0J9Z1K)YaHsy8f7FO6Sr?Q-{wyCJRhGCtC$ zaFe72d>UC1K5>82M(r(YekB7SOL^dD^zv-bcbsuk)j(^+3<-+_e}}6Eqo9i(?XAK^ zdMbltEX&)1smAksOnfWxfO~;f+q{A%0{KBYo55x6{H*TTzXTw160uW_^i8J*Go9F! z@P4&@u(~)b`!9BY?{D}EZTd{&IL#`HnFm>=;3j)>pll#oY?VgR)pN^>VSj)9z8Ycv zHl}$ZvohG|m}=TgeU4nlGuuV!GpIs*yCxdCp1aZuW5&8t9Er)= zri6&F$rjUV*p@xh_r{>u18NAd{`$cYH1wshbBx8zV7Qwj*wa`GM}&TcQP|JgtW0nA zLv7K43I}GeZGAsMXT5+l&3`ZEd(ceSgf`P2E!QbRxO8NRG z7_Qx-n)8vKzBzega~Y2{1Q8tO`V6*lTwajJ{3JCnnj}2~-K`B=X|1=ZQmd)oLKpLd z7~}$Kgf+y!*Hd$!CffL7vbedHdZ%ya0`_aRD~ssEN@KOK_dey%EKRSa#|s(SS#F!Q zL2C_$$nC|L7}{2>qRe|_XtfG{*33{WpXXp<;ErZxO^m=(!=KRP=fKv zKg~Zs{;Cm0uQbYn^K^6AEhmFvd{vlM{;g^{^ux03B)FK_VjR!Jh)4c5S_1nj@K$aG zx}j$IWTKee#azct5wGd*qMDN;-qS>xN-HJioG~oeom~ZTh!4Jh!+)4^AewwnZy+MD zJZQ(p>kd#|y$Rx&Z)AV{o9X3;>31m0wJ0l>SkGnpx+P{Sv&3!IR%*KNNLglfmv@o# zq`|N&;S07BGZIU}5x!)Up6usXPL;qHn(6M%!WQk0+b50n^~K}qGyFu9hH8bkHl}FJ zl|}4Q>xNp>>d7|LlJNriZ3g&%1$$UkRgA%?V~+Z2m?QDhxraca{>TwDOK9Eja@;m@ zBkY0NWe;IkEfqdTcBMDyvN%$nsa#<0xUS@ul0Iwi#GrB9QBM1)MKB*ZSskprc93RG z$uBs<->faZ?Pi&&NnCNQxX-r7!`2+u(!ZA`e1d9b6eevz*TjueNr7GH^jzi^T^S50qR!x zj?qWjtTk13Izn07RhGV~2Z(j?ogMRYu6S%xe^E|?JMOUNXO8Mk@Mp|cETC=kuC`V)$IWWoM6)C| zmRYV`vSr^fs#^9>xW=3$((os9INfPNug;?k!j`;~}Uo zEn`b_f186mW9Y-uA$FBkSJ=YLl{Pw8AwQ(cFOl!824(?kku~MeRC+tRjZBkAQZu

    {aQo62!2MLhuKz|aqAogGW;2PqYHd?)znKxL_(GK)xawYfwC^`%OsIj&Uqe`ThZxADcd4db!=Ko{WdThGo}oKnHaczSTn>A^*^H501`G)iqqA=3N4vhffz(>&3evqAP!W(>eBOo zQ;Jca>b)I9=zYpU@d8)R*^9fXF8~L!>U!pRi<$@VCweV9f)(PkgXK%77VR~TfUgGY zD9kl+P4G+&R?kwbpXd$qm6?}F7{mA!>Rv8OB-!2sqp9(jn>xx}iSNg6AqK?s$DC$? z&~U1y=SpxO*$P}2YO*Wr_$v?!F=jXFC8jIg;TNSynLYeC|B}4N!W!h@_>GbX1{{ySVQAk5VyUa3{SRJHt&X_- zBM!K@>*#zo-RO*u37)4}!!#%8Z<)uqU;iw9wYMVcDq`F__8IkWV68GLt_rf(mL&Qc z3)re;WA41t@z8H|3_S~<%=Dq1@I2NWMktSn3HVQJdT<6k)BA>kddct!#pHBY$EYE6 zVh(_XiS3Dn+{(~bVzbdKrx+-&cXRGHeo~pPh4@KWm?>hv#~q;Cf>KH=ZkKbAwwj-h zG8qk=8G3gvN8|To>V`ebLj1}!v zH&Ddc98r|}Gr8V(0&lCEc|?iOnzL0>zPm25f29!cIldA4tyI@Z<_mtopw!lQgz=VG zM~s*H)N^R}tB9%gQ=q|rMjD_fSI!5v4pAg$arR9AQ-2qGf9Im?X;^Wtoz^8> zE)W&WCNudjd9A>6`!u?~zFZwGKKHIRj-kH(A#e|68eio;1|3a2=cMsExLCEMw>qhP zMDyVS?BPR~<1w_=$mcqTQ?wW=Ki8|Aryjt$SU1sURy7JEe{mZ&18YU~0q^LZ%r&bX z|B*dryrZt09Z^EFihLotinB7@YD*W1?GC6 z1iPwJQho$Ch!IXs6r^BsAy}PznsPSrkg!^-BLi!i6+$x@iy`p5a6xX2cbw``)(feK zQ2s?|M*R@8^ILIcwGsiwGf|0@Uf_+yLdI*=@4g3aQFFBxU@WyrkMQO)lc?*|YtiJY zh!WTi%8HGxQ&bhDw0noui5;nq_YGo`!wr-S0;lc~ZS&iZ`{b$gG44M2SE>w$dafx3 z+?RI@xd)yLeYHZt3HWe*G1!!OSdLKI@|a&!2CM7YX+pK^ooqjMSA8P=gxqXf zMKo0d;w7gc98{iZC)LAw+h9>nSEhoYykLg1s;if*4LPrPhk&UwxDk-XXpqA<r#yj4(InaUt$ZGt5>wRcRu?CEb5($nRg=18I~{Q{12l}0z*eKbSw(r#JJ{0D#) ztO!pa8^v?&Pq>~Zg|_Ro=dM`O0Ok_)Vd6o0mpW8fuKAo@L5N?@-;tVv#rZAR%2?;b z<)LPLF%uWVp(jQTanRQtoy87AuOlo}ww09P8gxtCM{Uh*h@1#J};tQ}l#5VE-W|))9-4xh`tz>o*vzfBYMe3>6h`k|SWwJu!)Hj)oWs%`P8K;wt zf-4a{m!rbMGg#Q_t7R~;i5j&4F&n!m^{rY|T}tOt*@WOOw4Z3LuafKWakK;9jdJt3 z&Ur?fdWN5zG&Q_V>w}FA%rxV`>+o`F6w?q(g#l|Md>&e79+dW?r0)@}V?GB)vIBx4 zwuhCRcR>BbbcUbunmXT`_kl#!clLJDdQ&lz@`C(W2PBDS?rQ|gKBvaBhxi5bDeDiu zjU7P^$=|7`YlYF7y1pZejuiZf9KnUxvRzcdY6WjO6P0%8l$^|ehg;R7Vk5_H3gh}J zqm0I|xD+E`|x>C%>KiQ-0PE81$zQ%jkJv877;te4JMm)G-M ztIzeOhskSo1p8zqtD8|*XiM7>-|alCP;Vx^(%gsDr@AM-#AWe3T{7HDTO)oGFRFK? zVahb2g~o?1wxX z5P9KrC`kX}4s&;fJn$-}V2)rFaP(tv>k4&L8SAb>q=MUQS^pp53pEOkrx($td!?~X z*qVEm+bkCro9FPLjM|v*lf76{y@i>x@^m;oZ<~E#T1E^~@u%f%=X7Npk?y=ZJM& zXUqzrs-TTEfE*KQ5uWExGCzm=c#FxE`49LmVG0qK(mmi3{*FmiWwqaF-)7U(B)A&99p8Q#(wTL@=QUBxfQIWV7 zUtOQ7eI{jkSJo|HN6nE`;Fcay$--S)qj;{#p%?bg)O_+fJ4Kln9t`~SS;M9Ih!*lz zsv7&!Ha3q&X$KzEWS1L-YTf5I84ci?q}}kiAtiLvdI;y4bIK`ZuepdW=D@jsnUN_c z_Ty$7SYFJf_v(_PkQBk(B|F3H&|PDY`)XcESVYRw^2~YyX(fA)IghhPq&3PU>|g#; z@T;SW;~7((N)VS*{i&`Tq7?*UU?sG}yOh_O-=Z|qcz&_^BRmNiNS0f5$v6HP{>SVu z)Nw2#oK_Zt8vIS}9mNNd;nVzOW-CV3*HD9*9(q%*m^=cW<4RjLt(S5|eMwc)`ziBs zYQ)c%R*~n#zXj3X1^z+pnZe3GxhRK6KBf^RYkMEF8p{zI;7x6a4*1l+LVA%0><2W4?OQj-=mee}-oNp}&T5;rZYIadFDCuk=_NiQTrn5_s`)|2ohZJC-F zzX$up_Mj%|w~0ycqcSQlDrDHMQ5Ce}p&9IO-<;SiT(H$etvVZggSm*}>W8E{Lch=s zZxeQ?e-GbHDW#_{adK(+)HBbqAo4L|lPA&lWGN>X=dBZ^J~8Lf=BXfts2NI0XvYh~ zD-FuZB%ajh4g11Vfk?s@!=YXxD?rVWn~}1=1=v&nKUj-;t@JfE%SSAmQO`F-8zr3u zYp~TR3tiXtyEqT0OT!h3TZjJ-{si+tXL&llkuvaxj0N(IG+0?t^*#QYlox+^sOF(u zdNH$@U#1RLVuSU>2(6Q5rRs0 z^>6<{+d#dP93UNL86(-XTZ{p<0=QAZY#U->0jf0GZN;(&crPp{e)p|*S2H82_WomB zRnUO@M$MP^^O9cLO7_uICH5bDsQ0#SE!3zkp}EWi)QPzfsulY`QmCpZKyR`sc~?drwrn54v~w%j$9O0mD`PIjFeno zxk2}o?KBJG9po^LLl^&>kn(_S0Lo`?NBcbJ;)0_;AtjUI0N3SY)L z1rO(~h7H6ltO_?=KFA%zDkOHzm$ath2$L7n^5&46sjqo<7N?OBfJ%{ugm(wOD37^H zDY!U0YZqYkA;D?n8eIq$WJ_QV_{M5SdYG1z_m*jmbrSXyT~X#VN4UTia_-D|?L4Yv zvTvLX`E;YCxsWf(-RD+YSCTVKH)^9*ptn0qGtE3wRw>iJ5ZEOc^ z6q1%733p%uDN|_6^_TC0kq%CaSCr?v2S1I!CAtHW?!w(Ck3zSR$ZdhgnKZcw9(0I& zTeF20ml1WSAM7apB|FSZ(qivW@4qN-Lm&D9UU65MNbI87N=VY%#;sjNXv5EhxI8^jH)U%e zL!Ba*3q00(8U2h7>MdfV<6m`^q;pQLb#4T=4y+R{avmd1;=pq*+WT4kX)nP}q^fhn z!CJZ+oXsthQ38(3NSebAg$0z|!gIAfI0ZN2{pEjbvz>uJ4tiHrP=ibz{Q^aSnbI0= z5?_=W3+}l0b8DG5`nU!I4)h_O(WlwfY89iV#R@N#zr}v~*5GS%vs{DDgHP&0Ph|M3 zI-Jm{%|s=AH|K$!GasNGfTa!+Kam5h80&|t6?RYi!3g4BqPY;Kd{5b`4xlG0DM>?M zS>mLY?#=K}R5e(EZj63QmBfC5d7wn{A;gRQ?2dI*A}Yx?=UeYceH}K_kf2U?-B?Ta-G~UGFYz86CqG z(`q7aP*bjX$`$@N$P6Fl@5w*e@61f1M%D4=f8KKJ8)<=fg?$18#zmqUuE`(xVu<#= z-5FtyWD}ihh0&lnc;FuF4jNrtcj@-*M@R_Woui54c(%|#)W@i;^+U6P@2-=2d22fM z4)LU5(o1ZW@ySdxTMDGVqHsd_LVtiisKK5)RE8^C`Ps0E{4u#EN@3N8r`52uMal7h z3H;bMfLKXNZp8jX{?T-{B(t9z0!I>)jbv|mkulaA51HO5t65HOfec(lR7pEZwm`h8 z;r1~7!ZVy1Z$!(_m2)m$Si&~5EZ9N*Dz>IcY`l`7R>fCJO*}vG^zD5`e{RjtKKL4- z-+tzPz9Lw`*tOMK1ZuiPf^H%7v@$idEL@f5u_NTcMsNL$x@K2Vs{r1~XyLhPlnaZ1 zmz~Jcp36c|f!`=BCnK31o)I}W8IZry8p;imcSk+s{vkW#F76qOU<%ZACL)`ieuBOp zj3gIf`6g-2qqdVh>>KIdnT2GiuT+Op>C_LsuY02V8a2d`ft6FBx!x!FMysW)|MWL# zL%!L0JmlbNhF`EviRyuSunl@ML>jJ?wZTH(Lex@Rg?DXTg#w6{l!SJ3*~pOSLN=0h zcw_aAH(u0UVEYks~dwwC zStWo89hFt=LTVVZ6q}_b+gx%jJ|{aGc$5fl5;VDou0rZgG0iM2eMA)8_3SyvNurzk ziTXsi3HPF{QyDSROm#kU?uN*r%U#qasW;#{*FqgxpFPdlaeAHPEof`LhF|J$oU~jG zvM;H2WI*7(omGqN?muEdC#SZ_Gw}vuYM5bPsYi)LNjIJKLVqa>*_G@(G0WP>U5C&0 zs*VrNF(s%_&#;cG4AhYUaZ>JKW3Dlot^_tqmxOnW%C8P|cm6TPtIt6{ewB7p z=s-VETH~w1Nc3B4gPJ2_u;$9K96`#%CUE1qOu&0P!G}T-#FpL~m`%kRj8`%I!h328 z^`5GyJ`na|RfBOT;c={Canq{}V#)}WjoFC0yM^4sbji%+YWe5F74BWe2|zUTi*MX- zvhCC^HQoCelu{q*L&L8U^5JpU>9|_tQc5w0nWSSo)5KX!FAEyF&$_1?=P4{Sm(4sB z67OgCayH{~)&0~nwlOvcTyiE^I@5q1#yi8!={unaei#U0Lz1hSQ&9%`l3pTrv#Slt zU<7h)AjZsRWxRcewjVJ|kkKht{s+-V^?;=ORIWtRhN=PgXDSRXGnuFx9YAief0TIl zbN-ap%M_Hb=rGn%-#C$|r(83xQnPb5Gvo1Bj`qsSyked(xtmeQkJ=}^HM`L55#=tZI{^R0cRaR%hd&&ZORHz4lMngaoP2y@$?~N*Ebc6V$ zkK7Ag1+ioiW4ZE>I&3W@zOoIev1CoWe15LGV~l|eDo^=)_(=E?%)`pS)yOMY6Aom4 z^1b0=x~2xSMBK85dfTxfP)8{Q9&j)6iU#}By_MG?L*5R`dm$6#7Wli`w=sRGXH8evr!I)Cb(1(rJx@DBW4ImtqD|pwpwl*&u{sX*CDMB zmQ}}6Az_H1EptGU++GAM%m zm9s4{1s`oT18w+tss`4XPV#oePrGPrPw2k0rv90Uv)*#Mr0Qz<+*bxrTFMOd7QSE@ z_90?o59iAQNw2Rt>G|qGZ&~J#uf^6S=1;gVeH@$$BG5@fbOiOjd~;k<7n_}#)`55> zJ!z{omPufDvOc*@Mq^dtj$|)#u=%rdQxEzL59{!>SN3o#jRM>pAnT&OJg)@xn3zZ$ zH&1b?@Ef%dR^&u^3!BczrzGo-$OvU8Vs^fApbqB#ZBZP(1LYO>xzgP;xJld-ETA6= zxD&f1)KRjKiRc5InOQ2Py%${0%Ng*&?El8kedKCZiDbswq(~Ka(oU zCGlm%s9@pTI7L?@7)@7UilHdGz;;Ct3Sy^o0!ln+M-(>S@XJ$|&@E-z@rmB0*T(Ni z!x3qxIr~U0$Mw{|N-Mods5zW;DAtjIHv?zzInoESGSkeMOC2Y4(AvsY`rBsm&8#_~ z1W`JiFg8wD2bbEOVZVp}0%E|G`CFCvY=0G|1mN|xg&!O9OI0PR%r(oB#9m*P^ zzRJjBD3uiI&C2=8HMSeHdozHKal=Ztzq`byBv7T0kF}0##jD8eG5Esa9)=E8(+mm#f*(5CUulByf+uO3RPMOnk z&W2XHY>x5TX?`UrDGgzl=T&uv;V)6-pK2HQNW-RV53Kav;OfHO_$0kDyrV4%G=T$+ z62>awn{BEdrKg2L;f=`RR469_p3>SF7_)(1#|@_sOPJaP)Z+$cFA}1K1I+TIX2IqD z4&jN~RG|nr!-`jj`v(|`Uc@5QjiF>ll{1uES}ku4tq`-udm6ZGyjd!@je>eE)=;g1 z<3HU=919nMLH0;~IqoT*1G~BCoQc*n_jhSB6Q`7<+bCnBgt)%027wfZKNM{|31i#> z*SF+V%C^+r=ouWVfAW2&W>SY4#c&vI>R+P@EQr5P*vK}poCz1id5*Ej-_?cd>7kSl zF^xRcxMJ+8e1*Lc*DCXWW-nm}s1Ju2BZEI+J=j&afwbrgR zo)2hi|BT;m9>V6J`GAWZgN$2-Yq5C-x_w{ubYd$o;U&$2wZI{T^*5CB9ewz^VJAP8 zdc-GtE*-pz@(cz`k8-DH=rX^PZ|wV-mB7y=uZdla%bvZ`-)2^5F4`nVf>6pm zx{MuoS8%^9SzF6F@#lfA64g*A8@>`gXp}PBafOJl ziPfl1^x5FI;L|u$e4?Mj{PuJ zDfeDJ87_~`kB4b1)|C>KIoPwGT*$6Kgwyq)x!fu*hCT!P=0w5LN`fnfdP~JQBZEZJ zszd(^X7m4`GxPEk9IJ0cS(gBBm3Q1DS)+vG75+giBwj{2wQ+gbst>=z(x5ds3@`qF zqB8d`IVg7)=x)^a7Scn^|LDVXQ8+`brDS><<(C9Ya_Z=R2G+kFKawJCwNo4lvc>3sdCO}< zgsy~}lCc9An>Nfb!<7=gU{WZ5A3$Wa=VJ8XzC218=3S$X`L}t?fuQkMJfl}c-}?S3 zL&e3$4!W1tFt|sLrRG@yE|LzQw#y`Ys&_m{#07emD;f!rO7WizTV4s|bZZa(@@z+w zkSy&amY~0JRiUDo37)sq8OIm+K)s@NV@q)Dy~Xnj>xH!e;xA=GC^x*7YUzv2=)jdGU>y8NEtdC#+rEmFYVRmRSlYdSyX5F+MG1L`-JFC>J*@XB zqR?LuhDZvUW|d^l8sEqaVWX?dUYu$z7SW3F&5VEO=7*+&t@dV2S8a`J5?x+v$nVF` zgxBdklZS!3u2cTb?&;$26b9`;hr_LVM=CqTgP zpTr^6Pfw;3b`=f93bC~=I#y7O*3H>OX-rvx{6Hb4GV?gJH1ob0)Lv7=$ztBKp=ZPZ z%2XzJkIS32`u->SNRMK;=oP3Tz_4!lPzpdZ*$rs_KUrzxxSa2yDXb;m*U?C>tg`$) zcCI!9c`KL7J6o4XM6{_9$YcX(sS6A=ss9Moo|gKsv|c9=2^@A z?Ru&334&d1Ysn4 zHz7v7Wa0JcZ_HnEK~}|*NFq!UAN*4@KRA!-{p@`J;_z_S`82rA%;GIGk0FgBptQLU zR-raWH<7+c*MbkY_GogP%72tg=vBbP{44Anucp@$YdW7YZ-RwjN?@6RzO?}?xhS%} z1a0*KHI?@g>7nHl!cgP5zfdp?BaQdjPvLPPI{`O*MiqA?eVgflhz};!l6n*QQr$}2 zVB%~!{`2%UT!%PLAXiJWB3D7QZW~=n(}%5KdC<*1?uws95a-m zj%Zkd>Xmg&=>+4Xzo`&x2+w#_P#OEf>pLP=2mOVRQ6Vxz*$o`c&dSF(ptU6!b5k2H76~ zUEQNJHGAf5&<}Zwn?;lj)@N@`IEgOET(o{EB|Hy9y||Cze1L3!`c7maJ1Bedhvl3? z6tmf6QCrWL4h%uR))nQ8R>gdXpHL35U6~=t{i#i4d3EUFa%lR~NWKPhg)X7-SV#OV zy9FC-I=B*?!Bff~vZiqa4|A1->(ok6+9HvyaH3v7J&EsOJ9D>;srUl91vn$TF{=mK z(&PCB>OJ(u`M~eCJ%DaR<(`v}sM0vF_w{jFrs`&sg9WN3k?;9_R2+Vd`vQj~ZSgwj zsr)=WPKz<#Gh^&`L@r!Oft1_T2=9ql@!5Ev+#3gH%66qOw;y)K_G&*sKz%NC!`Gqz zLWXOv_*uV3eWW^t%UX7;mUKeg21XbYlD{jRwMytjJC<+lWsC|!InQ#l6~92gmAs#r z!i?48iNSntQ0@Pz`w0^r+v$sJ0-crIS1fFv@$68Gd5R`H7WZmiEF$+em`YF5i?Y?( zpvUdmKyAsfdD0+?hG@Se=u4&spbhGO*bI7-wb$Cow`5KcZv8Aj#ob#<()#lQ!(X_~ zT6OCQCC)n;CWT;NpC%|P$bwiWI4Dup4g0lVb@3l% zm-q^gQY>{T^---P93x6_S3OJELMm!($bIA+sw%b*3|4cn(Z+?~eawt0$=m6FkP+-3 zx?@~f;{{hjuV5simdbiG^WCD&XQ~Pp!3%r{f0^o$J2&RNJdg@<)4>VWO*RfApDy1) z8}CR69Aq`~Ct^V@^?$%FiAk<5$}6>myV#0*sJVb&IGFGDER*}2CAjLy(_Ty(?Q4$z)NZ33Z|(d_s4Y7(P{}x4 z=d&E6|IYfGm`(t#tI$V61|Is2r=ChN4UEUyRNM~s$!;)2xt0@0moYvfn;beTg(p*& zTyEX{4pfuwc4@qn%W0?PMd%UDh+3qJU2dU6U-KS5gSK(36P;E%d z@o)8d2Elcce*yv-V?On0!VLRy=W=d7n=QLxMY*V&M5ZDV@kM!+)QTyX zkj1~#rz!cieo4oPeV`0;I(G;$I8@hDEtG|3`mz*hOiil5-}Cs?b$VTA7td>|yc|cR z$qYOX;tZVYk5_W!vj6E@nMr7q6P14@Z)>o9_!GZ@Z{-{taB%NKoy5V~Gs8!9hC}l9 z=VwuF>nl5;N`3Q)z6DDorssb`Nr!h}NN!!V6uZq$m?O9$%otG4sD(&Oi(p0jeSsto zs!8Bt?w{DFc18B-G3-G4xjwB%4B|mM*)beAd|R9oE>F{Nw6uv%#AB7=6}Kc;sPh(UqjqPbnRcWY%nV!x z1wEsK4yqC2iY+60*<^7qwv!o0Een0dFJZlz=Hd&52Km0%a%txfa(1{=pc06;mvBs7 z>32;aJ|xUUMz$ZqV5P(M=}JRtsPhRCt5NnM@tv`S4MO;R_epVy?j&0%DNehnB5p)L z+shV(4U}avhuczRzeDSXw{ZS|*I)qk6Xqt4g+uuc!VIMsIE?i|H?LDl2jMdtq7(TH zVwQHnQ%ue_CK)@~hVpQ%7R_M=NgI08PcXL8sDhyUVn#Nckg(B6XVz0cT{+>>a(^6? zx4=|emwZ}PnA7sEP^=nn#Cm zN45`K#P(9YC%zX4lg*QVG125bqpENkxPo9ET{dnNGV+r4mf{~tqSKU9!OM6p zx;jZ(WZVkJsb4nnmR)4 zM*c8$hf6bF_^G_t%)j7>SPB2e)slINLmN8)%;Udc@gVhwq;jL@xR1v<5s!?UV#ghV53#Uy>(+X<` z`N{FP-dTvEWpb+P2$zR4qBZpm9N=%N1B2)I80If3OI+;=VX@#D7zpkN--IM1laF+C zH+nK-`Q4ztr$GEXa*=k&$tP!0R8pV73(d-|h1K+wvIf&KF^^kV{xsJ$>Xo+Cc`I%V z^UF7by6I?Y4vntEmNEC~OO3F+l|2Q3m?>Q$aTyyOck*@ejhKw?ZKO?J&04%2eCF~SI!4b z?-V@<{`31lL17mCMXg3%N6kE1LMJ7D#{NS_ab~OcopDl4V-mQ>;2gEm6Q+k__cLi& znW>`^;WIo?X29S20OwAk7|nuJU@~9Tx~cq2cMVVT&Ndf|g@IEotDn%icn@$(nB`zc z#yMf3za9)8NT%)@2h>ICaCLxcXRjI0U?=>JnLz&r$QccCZ7s=|aACST^96oLc_@3c zt}6xD7fQbOvGpQQ(N*PSSqs^fu>O)K) z+)gYcx8xSe?FD7Wp7>SBZG2B{OZNo(-9_kuaH-Z;d0-ac8LoHkc38{XC3b=6*WaM$ zelWPfP6{sJJyD&NV)O;%r9TR$a2LQb34=W`Hy*%ibFslY#NXZ;;5XdHFJP+fPvZml z7ffbatBn;8)nDJ{InF_|5c}FwMLB>S$>Bn?9H-58_*iit6N{2abN!t36d*sJ0`_h& zj>&8NK5~T7g8B%Pq$}E9tVw}55A+QH)v2P&W2q3VgR>~3`IQ<47J~|K zli0)4-9mSZLx+^PbYJ4JlA}aAOMytKJC_gr<}Ij zToC`DVY&lT3T#FHch4|axq~1~9d#^KLU1y&z?P@y_Vtyy#hGj@5Ih z3}L6##2BHMi@z1+)}5aJ`TRDLAbCw0>j@;7Uhxt88kbpEdDNtJ8p%Z%3C@NgfcU%pRj>nMvk z_^V0-d%C`fUxYd$x!^P7R3>QUjjQ@)@e006E+&@(GdvEXAG(9A$1>Rp)*8Cc!6Mmv zvGV?Cvw(I^?!cT-gQ(XLk@_#X$)sTiseeq|hAdIU67U+YVD=JAm~ZlT3x`1i&p6|4 za0xRUWMM_I-TPC z^K>1)6t$5ncz6kD1xhnnsDa`ZPpBKQL%@RtrUPA~V@uHIa_jwmrqIx>XAae_IXV#DV=UweV2j;Odjlb-#FyIt)Xuzx<4o`YJfdo#bwxM`?h^dC68WhC8I1|pPbvXMpa;t!%m=y>sruboyKblka_S;+NsJj7oE$r#BsU=EW< z;NS6fx`=e0yced_r^t1%0N&)=;GM(I=%>tS_fAmBwlaK^J3)WQapUju$MTWk%*1HN zbT}3-rDVERI;-!GRl4COcgNA!s*1SR1EgQdgupeq9dFk$tCv=sdv1;aVb^N1FL>g+ zC8uJa*K=lt+`DXB?F9ITvb@JVz10_D19mc}6D~$KcNkx^X|ngwPMp+>o6S-x;5CGM zG0AjyXJf9RewQn(_wv3^_VC@f3f5V6Jb%UbrMnYX#C^!1h3oEKjy?_vmI}3#nt2CP zq40;K2s)5;falaG_65E@kf{3AZOlG>ANOB69m6R(c9Xdsdz0QM)XEmbaRM^w1vV)g z=sx>>YD=sIJju-F28SfJGqUU4r=Mx}Lz9Ag;2i2-D+~{S4g4+GjI(2+Tt!?YTo-sW z3AG7#_>LH1yh2-C z8|4;dBg6aqi9JnIu(UlcJ0iXepetkjjql1aO{rZ#~)6qR0>sAtslPSQWPeB_WSV}uCCuAj1&gI2!6>NfTiH{73WG}Sv3Bjq;i z{J6=~QDdoC3ZD)<$XUHU;~F*>AlfttP*3m#!rNdUS)8hEjwMoa$AB{89%KSxj3{-U z_J@_kh4K%*in7CR`D&-|wp@A}NcA=bQ4eiYNp z_~=Z|ooy?aI})rhUsFH)7tBSV96Q7KZe|!s(*MF|w02%m+@fc>D$_%yb-+!3qaw8i z7NlqS-zq=yXSp95PNRbBx3wL8pQ3RPerwc&7r7|dT-|NH6RvOOouhpglgWlX53mfq z5&qVBQ^~^~yV2iucJ0Go5od|Au6po*yP!pJDR3+GpRx|bsoflJ^S=!z~E7*FJc|R`usdpmD@Nf*%HjqjccCKT) z29}~sTV&=cZY-P31l8xL*;AQl3Xkw5O}A?vnx!t)_u{E=oRCK~O^OVBR6TV3fobY% zk;Kc36gsa?;LZy-=vskkparr6>TDC(!x%%wNUgYj-V=Ed=(j?+W@}TyeCuzuwegCq zeK<#0Prnh$l6c-OP*R=hU>z--Q^mJA$Qz~1K=+8{Vj1?B`Xu*~zl%D;ieP838Q>Uo zRIN- zgiA%6&3ia7A_6|+ms(Vqq>m*|gNT&IT5WE1?j4?XBpWkPKXV4xB`}Me%4H+Vb9va8 zEp2^%i7!JpMc;`>%wa>98itDp6GF@LS~Inkd|w@9 z5SsihCvUPx8QS}kDJk_;)8H&^H>M?i4VN&hc(y4{q7C|@p27QqjZ|uYWiy$YAcfSN z_VADmvtyR}|A1V42qK(})$6&sroVRdH8j+FJ>gvR59g(`Z$n@PTo<7 zr!fyk61$Zy$rJb$V7tq28*82cbCX+$tsGyp>D*TKU;jvuu1nx=rVQ04^Ql~fP2pSM z^TAx0NbaY4+3k8+%Asdc9_79{!_|}sc%o#dR^Qnyw=vxqj0w%fq%g&OCa-`)p$0gc zG%c$ViylZSf%_j(F}5*kuT0dI8*kZ(N-1!k9?jmM80Hk|W_~gukdQmrX1H|r#lG1o zPWC03issN`xG8uK<&Crje&vsPt^WFE8CU|# z7Br=n+KjDdofKM9$L)PV1JflBG7aVzY{(6yXF+e=^!WOEZ{v>C6SQO2P#WansXQQ` zQ$|pELQQU)SXgiGE8umDKo!y7wyoniO>o5&MAo$^zOOc&JxsipuK6Jl^19@d(>oaH z?DfFV017inMY&j`IC|=$9b>pjYSq}^OdI%@5lby}Zl~)wI#5-lqHHf?XMQhvlP7-X zGOm%>rGbmuuEf|n`^MrsRF^u=jPVQwi7CsN&-x4ZY3~Da3&}GU;*>I=N4k3OexNy( zil)cAvnLX08%4K<$MmlF8OJ7ycFhz%$ayPMPBVNKDU7 zA?G_!lgZMb==V4w=P}>7-9&Bnj4+7(C>OO>@OJ7;&gXdTNo749wYAH@L-@TCc*;!qk z*OA18zHHb%G4q~xK47gQzDwM3w9kEk(ygOVyS$E3&z-XU0eftLJ>6A zUJ>l$-{^npmL$r23DU{aWXRY9>jA&|Lv#_NrRw@Lvp)Anslxm#=W&Fmm3D>t4quSh ze3e0i;1Y~w`XNi(Ub7NAnp*=re8>D$c8I&pzG?b}xL=H-cNJ%_<$PoGE1(B{Ap5=f z)F>znphxMSi9yk)m0rduITCLOI&DuCNm zf*1Fv_@8oq>6QBtTREo;+Ns9oH!~`8J-O4&gXAM{q|WGjXfGR|oJDJRC+)4;Ons?Z z$%{Q6y0R4jUaqFu4Z9nPp^vD)L5#P)^2#xj-6d4xd$QH^7<7w%78r;n<{qG3d@`=- zna(1#&V0k;>|=~2+F`wn`5u&HJ@_^HZ_g@$mFmk|Gmmqz;Dh-?Q0YOi7JGuP>4*Wh z6KmjCgw27dl+T9G-(F(GPfUnjLFFNe(Ms-V(%wT`RRCI3Gx%dk>9nBKJT#2`OG5t= zMz-;pjmSPB70|+=iox<)y>K0hR9-W`8A7=X-1)nu$@(MrEtqz=E4T0eC_3x^sIj*V zG5=5tr;i)?s0FV@t~7Z3H;%23`dX-{hf3M3radF z_o-IOdiw+>GcjsLph5aD*i$vBd~Bysk)*Z~YqWyDd!xTK3bA9_rbCiZEGdZSR3^|v*E8>ke)j=gVMhVPNxn|6!6`C7&bVVSEb z^~+T$xu~8FiewfNie3bOajmx1467o(Brh2LdSd8W_6@pRUdh02?hCz?;fNTkv3wD`>+4H5xl_V8@PmJ@ z{pBx~_&oZCVbEROy}V!OJ?RB~1#{1d>w+iQI{u2fDGlLe;WG6M-4<@;6hMWw1#%nn zpY(ssH~b}jh+=w+qBGj5U<2hVQ4loXef)hw7tiWFglGIu_~EdUM@1t2tPzyk24 zKq}p((}~!al>HP?@TdKfnZis#L07bK-Og{t0ZMkKRAhrnEwvO_0S3j5P;0vSLRE@@ zGmv43gE#b^+zx$+i^crI%0dBkrP9?r=c!CRVkUr7fHM!I&v93ys+l|&CN89<`88BM zSfD{km=mGa5QPeCB|Dek!HHKHEp6;df z4mkf|um6AZ!gmvAq_*Qc(mmmz-IzVj%*83mPec(eA9at3OJSLq$SG>PGmOqCZeN-^ zusV^6Gdq|Ug%|v=oQu>*ZHe@+`ji{4-*>GF$Y7uL10D1B0SnEp##|F7YzZuv@PsF& znNXZbC0dGNXoIXH(;NwO;-2J12jMB>W2CpVm*^Ec2@XVS)ew~wI;wQyC8j6;N8qD5 znQk`^V9nsd(#_6ks3H15M zUiUTMV`4QnVl6hay@kcz{!Ag)8Wp?P*HyY>HQw9YE|go_?C(w1d+Ns$MHC~gWhM%| z?^9@k(pT?|=kLw^BeK@8*Aqv;O@fLh#nqLju$WBcJPr4Yh0SNyZXz@5iQZfL!-db! z+9bLne^+k+CIqH45_2&)L(bAapg#H({(NK!yO8T-4UEjtg6bfSg$uz#gQkk5l!?DU zGv*7t9jiwcl7AS(B}xnPqmB0NxL{ASRord&4gR%ri#A+NqmP1|lw?{IO2-ZM7NJ`l zTcqC#cMg8F3tDx=^Hvr2InaR`&O8HCv6sG~K3biw7f+63+p3{J-zp`+#N1o_UiNr+ zDQw~2#!ME5rW7(K(p7lPA20OfNYo07%=ysr;Agg7u!Ro@VSc%OoZ3bV6F#AxzFS6r z=8U~rNi^qyO2ORt_n~T4|5t6dMBK6AF&!Rm(M#i`(0;CXONMyhe& zaKKsYNqd7@+o&sa;bVn1Xk=)&Ba+?I>H>!tB?G1KDJLKQRw@?#4AtFS6e`9R&{X=% zNfCK)j^G2?^7z=_^!+qIIbq9qiJ63V8#LgqvMExQPy;(Mda_T2YJ9ZP2)#^P5R-D` z3@q*4ZOtS<$ebqQ%tk!3JdXNQw|zG&`>+Y{cXp{=-TMQSH3vA0v$yyg#C@Zg za}toDK}hK%-KnPOO`W}&;{?Hk|v@CWG(TvbF%oW z@}F>A!&@2G3=ji6Bw^O@Qp9nX?4HP!XaC0!_VwXzbE}>ABA9i?eu|ry;S*MpyU<-D z+I-+TDSv04`6`+Fl=pODJ3cYD_BO%h_Hjq76QNmpO=CQ>P9l^OSu4FV>_SKDtK*8g z7VF*Fs*%yKCi@buPAZ+3Pszb~@Wtq~+C=NEg)yaw#m?toF36j%9_cN$Re|}=LJX}B4{yqwL49X5aZji%m#wTo zv*pskW2k6ORc}PTr!PdFU`=k8)QyT3vzhrOHl0V(xL#yYa1LGdL1v=j6aS@OIk0b- z%)w0Grto9>;q;C49OZ{GH|WjTq!#DuD*niR$ps2w4r?-Q#C{{M8S{innn%#J4)HJS zBw?WON~tElN98jB{KYQOP8q+PqCSkj%&brLmLdKw9+uoEM3m8mDY~a?dT5d+MIG=xRzkXYYl`p z{7FXT|MfMCoF)Y$4m&8S>UjdTv=Lpf^WlxQ!hJq|8P)S>Xa0#(Q981+enBY`u8KFB zLHx!V?tG_2GHVb+BZra+ut@JmiozP!g)>mO!h7LjUB+EcN-?g$M@Apv1a8HoNj(qU zK&R*c*erO!7kD*mq1KZ8S4cVf#g#5l^mV2i^QT&$osEAJIC37<4}Nl%A1S63>{;Sj%hx zF32D_K!^(T)HnF*AdV|eq#Bpifh;2#-mP$$SVR z%|Ih|5qr^Bkg-r{ungZ+MrX~TXBxxRE6OQ{qU4}GR^vebcnQ4$>$TD5Piu`}$c@6p(u?a^D+!_p)?i#O-~RCG2Mb!U=wrzx*&yO>lHIg@T=YTsG)jg z?GV_=R5D)WZuW3siIK~ufM>x(G>_7!tm(X8-$4Q_R9HNSH;jK0^$+q8{v)#* zeL)rV*J^hrM;!;JfM5J&yz%P=CuleDS9i9nt_6)evwUQ$^^yagPbiG%Q7my)xyg*< zN(z3w;j2Mj58G%s)17KUy+T(rOAz(;H_&DOWUVRagDQYt;FrLN*=$LWs@7#|nTz?y z+V}>PlM+Hn^~viguf(6slC-GE9y=YAyZZ>k#9P$U#Ii<>2C0Ec#NW6rUdI9!)o`u+ z)$}tt2@dp?K7rzM-e->XJ=BAQX?5cs?5`~yw;K6cdqw89ywp+*hvXub8^>Io%-7r| zp&xr+EF<;e#LNq5S>TQ)2M>Vu>{oMv_E1yp&BlJUjQSS1g)~==);X|p`D%(}}qVTZw9WOdY)OK_A>mwFfSX9FJJI8Nlk&@{y~Mszry(ZC{?vvgjYjTC`W%YX(Nbjz<+^3 zX>UV8{VQ(emllT`Y1U}2y1NdXWwu6z)EMq^Xa_mXXaU0jQZw)jKLb2v%R;~VI3=^U znbPKJe_>&h>NHc+FUHo0ls|Na->GjAn$Zy3wZ@tEs5xMQ|2R{JZcQ2Ryg@UWo=Zd& zvx^9TE!0KfC@`|0Fb%w6;STP<{gR$V#zay;K6i2K4QY_}S(~j#5<5Cdq#QDbhCf;t zJip)!)Gx7TWH_tnd(g4q|3n{s+S@>{FZ``l^Cr@qT*SX9WpKb*cR6@XRL{C$|CBb0 zWzG6=(dv@eh^R61nY!5a7_0YVT8abxqfonyE=-zllwDM97sU}fMUQ>cK9kok=m*>6 zJ&CHemAeDaI-j{~!&-c9uujr1*Xl;?oww5pSnUJjSvSW?ZGqAcTyw7cjwo^AVcVWt4D&~|uVg%g&ye_^5WYRwUEo#Y@QAV35A|uhKz#Zdl z(l6|HDwi;sEi65w;yu0K;m8p1*t}tPGg@k6s4~(Jv3z1*{X6R8cIV#H+fv(Y#niM$ zcrP$XyU4H5Ci+SuPFilA_sn2aUwy8yvIqU8JPYKa>L%gd8Zppv*JbORY`ZIhdwMCR zgDs$uj<+7irWiYn7vfs6K3h&NC43EiW)8w%_)b`f8tz_T zrm%_LW~iZmulYB#OZbj+WKQEPJeV{+Pz&r3QuS79)9j6T`=oN_Zm`pJOv`XMs1@$% zaxBqQ8%ng6a#2mG2zU}&5D*#~)nPG7D>C9^~sr-SRxo^yGo~>Mt&|PgIPHMzkvxO$wv9OE^Hw^kp>Tjq6 z>PoE_war>e<@&30i>3b31@jy;o4?^s^sQkhV#0S7u+MvnUjyFa8`(&$h;h-gF#CV< zZBPreP3j3btBrAJQviS;nk2FX?TqZ^kF}nZ7)xqdd^=2v+j%^nKZLwvV;m3Ny71%z0MS z2JuV=9<6W(iP{F8e%xptu&kSMkH}7qQZ3hFt-tVi?*K7P%R>plIA2NaJ`1G6dRy)V z(I3Px&(KrM;W(~$c4m7~ zt?5+buEJ%sF)N6D^7i30)ST0zJqC2FqOneS=+? z{;oOHE?0l5K1;HBHK!@Z@AU_EsMoU>~Rz*m6$jb_;{=!?G7`wq6Ob6Y7u_+u}H z2Io}ADdN8Fic+Tb6B4q_O=DKO>Vo;6)zld{7nC6Y`$0b;Z!w00bnhnJX6q}1$Zh;L zIDshY-oS+eXUvl5vGGUnwRDu3CXVHoM`+c|T&H?rBWEpnExahM7uvyvnEqWnato_5 zTWZamqU#hnu2y|>45$LUAOaT|fk;iYv)+f<&Q(b+$^OlM^LW{bc)s*b4g(QwB+`7{ zgne)<`^#LzW5}iZ&7o9S-v0zA5&j0xq}q`lLN-yr8;`%>H#h>bQJ*@hFb7Ppr(Vd0 z&CTDeFWAn}1!e~K^8>A2D(@T=-eeTT+#k{mO3koGWHS|xjN-TB6VgtW;2L2|E|)W( zs2Q5YOy@30|J1Gue#I3Ir6vn_={eadku%0Qi7$~1UE`V{vI*RI~TNIoazgH zP-I8$W$P~?RUB#F_x?pHQhwYOzriPgR)GgFf===+$(d47Qy|nHj>}vwI_#dg%fcDh3Y*_b(tYT6h>h~W6WSSdqOZAKA^bNu$v?)q z^C$W-6kzs&W`N*&68)?zS`DpYIL92|y5pUoAC3Q}ttSUZ3X3Lt0XYFj@!lr7m@e3d6|^-w!RkNG=DaevFher0*vpq(?W?t zReibrMW~92EPd>`z@&#MhxAcW8GeBjixbHctZFa|bn+h$9Rl<9Gs5py2D8$P7M$;@PqdABKkUg@T9O#iV~h4{tuAmoPrKki95f(mv{`Lc!o% z?!5R_tM2{DpGJQ$X=(z<61S^8#lh|q!rWZVTq*GSI6jwY#5Gse$S<{}FeK6-4a_0a z@fNKK>PeO*YSSfHztMqDad*f&nb;Vv%r1+rvJy6eR_1D(!!h}+nAR87<7-LNrK9#Z z=VZAyca-WXEQ@3keavxzFn!uCiu<*foqwsHT_OKYKlc-zzhzR_MlnzFLcd%XU50=>C7Tr!b<0CG?x_VJn}MJ!^N?~M9iOH&Vf`s zQ}2~~Qmf>ks91Xf?|Dn4e&)JzN&EGI4%Ms*RQlIlY7)S}pUXU4a=3Vj`3D9P?=Sq_Gcn z(VAzD<9ylzZ5th6hg*ffDo5F!)%g70#h`O$O~W`s@o3vxnqj%9(m}S z(T$vj8Mdo8H?QATU<>%ebSK zG(KcuZv`kAvCQ{qCzADjM0*yA%a}9b3wAlZ$L}E4qRFmy!SZlOkYY(u5pGKb;9A?>4rL4-SojU9p7HeuiPNwy!BB{ zI26_5uHh3*d+KnOEI-BR^CxhsP!H9|EWhmu$F!HiYo|k8rdBbB2D*|*bW8ni6~RvL z8o~`w1kYy|`5%BrU^1H=-->IQ^e@+{#uiMo`3l+)6`dD!+C*{Htg4=F+B>Tx`7b+$ zJ0Cf;vt3|BWWRd~=6zq2FDK{If}wiKJt3$c^R+-L?TXs?q#en_i4gOrdzpNjDaoAl zopkh*JDaE6UqMlJg*FcMHKtNEy}if}YHN^*?F)0rza1A*O!gb5f;fdvlB=mhs0MQ5 zq|rjNJXdIAVlG`XwxZe0*A+IBO0s1T^iLCqiJVlA>|y^jH#7CnB;#4mKD|4rC`KhG z5y@^-KW-G^2YMVNADk-E@L#%n0w5;28YUIm=|pwlKJzz61J|AO2hs+m4tnA0Qg^2L zCD0gU2Ma{DD@&9-b1YHWeFPoU9t6H@dk!?fY&cNsr#KxSrRL5@RDV=m+aaxGr-7ltzd#jFz3^l!G1vmE zlvrjPY9~Cjr{eqW-|{cVBJFEdOZ>5~?(gXx4mt|wH;ZL=@I-_+%D zLxsgBwQp9f&_8TH?z(#s=froJXtWp3@-9awchq2PW{Iw&Nfhpu8@#tO<6|EHlG#U8 z)L((;yaKL_@`M=ylWq$&X~s> ze^ba)Bj>s`pN|3C^tNCX7;3bnj(G!M5BmQF*o-C$4Uk{B(xq*SvLjC!AKBRymp{xdo=|`vXrO*R3 zBCofRkFFyMk-AU-?ZL9hui$SmUh=D1u0AQ53E$oOP%{(LCOO9aftl*gr!R+3_*m;# zU`6~Cm(w_{sQmqy6g)#OfNH4?kYavk3+ho!3*UFGeaOjr!UxG!u(GmAD8)R>d>m+q zxr;g+t81`|9%~P8m~Q?!^c!^!Y-7IgL(Om6aD5T>Z5C4o;N0p~yzgydjyKn`yUbQr zAJmMhdDLU;)DqMbjAg%blR;DBh!W!Ki9O+-L%nj=0q8kT`C+>BAD*Aw6N|w;utcLn z=BjWByr$h2s*hOnZ(*iwxQnAieiT^Xc`d|PBNE3Dw@kO?()!RD(5I!Ev*m*F6PA_E zvHip%LiywyIO}qTDW4G1YU@2CTVNwC(fvuNoJRX9DicE$(?6)yIG32AT@h~ZG30lq zkDkLc<6dZ0sU_N(q*T7Bw*lEPd7ILgIVk_f&Q;o?Iw(K0$Tv|+*AAeELJ@Hy^B}v7 zb<`@v?iU{8^OvmsBz84C$PrpSZGso_J;nZ@2TajdIG>@%`2EoUlT~HR4gKzJ%LaUB z_2E7jI^+%56LaT~Cy7h&Be#WXB(9cobvL}GT+=pK5}Fr$%0HLNJ3a;O36&bM)?pZ< z4a)3eKY^^)-`gG@0G-WPZ2nvmIOGjetN1ve@fJC==OW2D9!e)6%Ipf3}j zT_q)nzYFubi!jwAgY&fDNl;F6GIt$}Y$uMAH=)HzQK3@a9P@~L(yS@nMZ3&6Pd#QK zF_3}$V(<2ZBm7bKQ|l2qi?|<5XKHgZ7#w9ZiwWiBzCsCiT|1vxRhx+lsO^Zlv@<-h7wSC+P zcBEc`vwdB}5}=gP!dD|V)v-yMrH}R|H@d(d2wS;MPzN_f3+tQpvdU7jeOzIo7puB_ z)^6u|b5hJlS2xmjT zQyui*^ihnDYl5EU?sRp_spNi{%{Xj*o^=H$Ez05CWjA<>TcP)WX>hqP2OUl+YwhQU zL`IlbnWEdX_n+{}>`kIAx}Uj)T*wZwuz#Hw(M$0M&YKiT?5^%N6RGccKJEnFF0X#p zZ^_53ufe4tMI-sP!fDgvpnSYAoJd2(&~~^7zv7X(DsO7`a6Uf3s` zWNjg}ys2aUcAxMJd@|fZoi$(RKC_Iuj#rIMI_VvvN3-itY{ai!fM?CaL>hV7Tm)Vt z!6;_F;wR;v2z!}bVVvY)R_k8jRI)GK38fA|d_!yj zIIOtxD)JF3)44~lN?w$9Sl#u<;#v5YRfp(~HuGzU3rxlAAJ}SAC6IxZaH(2F&Ih(? zZRGjt$GQV35?*t7QdZkdm}qVpXs=r=Z5Ujf(9YFIYNOl_s>~|>q}aeUDcCzD7eX`$ zEnusK_lXxGKbc<*FIo*FOMkn>{PYx}-@;sMMV}5wuu zf$!`BzEyZc;I!3-Jji?50mftbS#lG1PQZ4&B^GiHt%~}_W4bf&Z~qRTx}PyusT-Qd z{vZ>szWPIrwN*YBEP~gKMfLB)CeWq#A|@(jjX%*2tsK5z)?u zLA~5g{2*+#>m5HIT?fO!6n(OOojO4EV$xs%E*o2y+RH=bhO*+^g4!7GwIh+@>@Ip9 zsEs%##HXWndHFn@rJ`EX+?8l<_&Yd_Le{y&E`d*UUhtHWgXfkxY->0;rvKV+|Kz_y!7#N--N?Z+)oV1g5HOg%-gsMo)f+1oQ`lBsXW* zhDN%gnXTbQ;x2o1Ui(0ixHrCQ^#4$6YaMYOv;4bzr-D-6tITS+pV)3RFxp0nc~ZoU zu%K`vV-ctC6vn}#ICGhq z*;%`a>eDyP1^OXk49u@3qY_+Ws9P-)MP?N$MebvoXgB$CVqcvjXh&V@|MSAjW!X1R zo5H@*9|uchPr`4tYKDnYZ2*Zu< z<{IBSqY3gB|(o>2lkdLhW_T)8B0UIf$>~l zZy}{w=AW(lQ%%h3j%_im7@u){pH@X4;CvBO%$H}U*3 zI>c3_%2I{20`&ZvfXX&2@`bc4a);SddZ=7BtH6oqYj*j}|i(HiY#bk?_+Dv1sFifcm z5`um3UHCVxmzbPUSN=zzqbxJ}Fc-Ktp=!v8&}`l4@t(bQF!8Z{gz2Jnv^((igB6sY zbR4q~1m$5&`JM)le{`*pWL7R>Yb(Lam|g^xqpj{*3y~byb@WjtQJm* zI$JfbGrNteX-|`$nPmgjg)m&td<&;Zt-W=%+RActBUF!Ff!~h@E%(ki;ZClM`pYhW zcZi{q{s*v!H5k+HzsUayseD_aikN9Pf%Vm9Y<{LvDD2HXm`HCzgIF^1BDfznB|m#g zVC!Wv^C5O2m>%kXaG1T`*B~p^m8S00|Nn_!AudLZ)Rk-=>`kwfoaQfVV{ejiTmBg? z7Kv4RB$ja>3qIHCGYx$2Ea0t(srbB=FVGA9N4Cxw58|a>Xe(R{x(9QWa%hQbKi(p) zOx}cIoOkqrVw_oy?xP9Xm4u;zDaODHVePjc50P< zSGa6*V(bpb3}00(Sw9(JQr_eD6p&iVN6>roqavHfA#FNgHDMr5$W&I}6O)93WPx6ZijpUQr~TqHjC8u3^k%Z-!!<@JD(Pg%wY64THyd4#mo?MDqfjcJ#;37#QOIy12c zZgMu`y{qicZetWS_IpO7AeWj~TYSs4G@j@WTvz4K#y)(19c!;4t&C2-0g*0XvAd^t zwq6LcF7BesxOp|m`zoU~9OipRjFwsn8Dd#IAnpPO_;FwZJV2Eud&!qjQ@sGguqD*b z##y)5djsC%#JC(y1FE&0pCAl&4~{U>1>9wA4!OuxHeNg7j?%oY$L1O~6F$oP>JCwN z%owvo_BzeQ6+-Wr{K8uCAL$?JvQQ7M793~}CT*7B>Sh(TJAsG#P*6vXW17P}cDXfI zTBDVsQ>f3$s{-SliTVTDB@BhD*jaH6P$_;ccQ!JMEN-t?d26Th8hM1y<6m+&a1LWU zd)McsvV?ZP_P&*B2_EsNu>&0UH4$g)L$XJk4}s5U$NwXL0}T7XysGvV7X{Wbm8?GM z3_H>2+2|NJ7I&A6QQuOno$1C1D~o(Ur}|E~(_ozZIXsm*=Glts`unkuw6Rizsem{9 zO}FG(o82X-z2-}#Q{BK8Z6H@j&rMklUh#{;w$N1ckA5@I+N6bQ=rI}&r@`y|O#3tX zDvVbpZF0VN_ze_OM>zkcyCjY=e-P*N6>=G+vlLrz68)Z>fd0V@&nbK*u^=}-l3;f9 zw^a@4By}77k6axnL=@$ZCoO~1jsB5zTV#bu<>Z3)E%cO0Q#d8;UBNzp^HF~zL3<$o z7MYu~M;>HNh-l8Ry@f1o5uB|z42+Xbpc6^^aayT8bJ2Er_ZS^PSD22@<`vODaLIhw z_T*g(9*}y1$;LvZ86?~zeXr48a|qMHLprzFGr{gacLrhS@(|&NJlFVw87mKPC-EaJ zn#*~~NIB0J(1)lFud?UFqFEW{KJ|Z~JyQr=R(j;MMj5KbpP5iW*{`vhnw&tIEI!-hz9VZS{V=1+3*cAq0&Zw$z^-Z! ziI*2)hC)wxLp~=Z$z_tBhqlY7;8QS?xo?JHA4c_bQVTe1FvH2(Isd}fl*(()!om&Y zAZzFe?5%(O~rE!ZXI5Pybx z2e|kGk%3k}^hZjjbXl{ga@h0J*V!6ggY&r!TC{bIZ|iiy+d`8-3+K7yV%$Cy!xe!2 z<&NeH?SikjK0(;vOE#A%N6|4bpLs2|^DX6iaD#{;n1xaDFap)$mYLtgU8oA%M9=Sw z!gN8Tdh7=o7A1V>~QFNv@c+5xV72B=aq_rA%&eM!~N+`HPTgD}`li@mkKKfG`r8%b*c>y;uTgB0dpK`9icESJ21K~bz2Daoepsji$aVC0W9ObvO zi}bL3t5HzDBIwp%dZ+L-yEDHIYT7BeHlJX2;eR4`0P5St-kvdh40lOt2r7|x6Jqt- zuBn)1RUoYoQGmTX{pi`;9S}4NdIS6_dO~z#X%O|2Ef+{Lm+D*j z^JEpdFzO|o(oB7|*@bD6`ybcea|NZN4Z#yhcw(+Qg@2{l)MfdWYrQ$xe5gkf-YXfA zTt_M2W+n|aCzv5X+1 zBf&C}b#jw%WqCAw0V`vYeLH~$YjK~Zo-(=NysS|S$sHiJu~!3Ytq3xhw%+q#JvgEl z)ZS;;XA<-~$`SWY)RMWH(ULw5TvB;ZR&1u7)?3iaQ645qT*uv!fkq?YFT7=3iF1Gr zy)XGykqw3i-jN4;Pp3UHd}xlhhW3WZ)K+skRRf*$G;}q=@0--5R>9Hni|xABSfx|e zS7rd{z*k{0(~@kh?FXqWugBZP?4=;hN@nD&F~(0kO}djfTX=2!ZY7Fx(s_Q6oVCBZ zSc>$Z#_SHF6MHnPfLVymO+4vJ3#3KXYw5vu)>j~#)for7QScehowG5WFl=py3+ev> zA5N=$MK;wEzK63t9kiBGbV?ENsdGbQKK+IKXm(duGR;iC(^8872fLYN>5+i(T;Y3@ z9zdxJ%~i%$y|96fHdm{Go9bxMb|es%>wtAh8AqM*s4xbz(MH>Ta(Y@#W|!-YQ4kDe z{HY6f%1UBTyZce7RqhpoqY0efN6dwwV0OWojpMhJ`DX-eOwWG z2K$gc6Z$Q5ACwiEvu+b&ZhL#89aS}wosq<~cm4;13gxNb+a;EuhdAn*kJ#bWf6%e) zD}J>!ChX+fW9`9X1*rnxQ{ zNjfdf&ir3)SG+f^s?S2pw9A+TKgvbhKBYXjoayK+ri{{~!8<8a0Kv}eaI>V?glpv( zDC@2ZD(+o#)8u^F@9jcl8_)Q>maGX5;yucBF2<;1*>k*tds%nddsU7VOf$6vmSznpsp1Tf9~6U_h{wN&=8iw8Nz#z;Zpt~B0UeYixVs5%3C z?;Cq(s@nn;^m674;LNQE6ZAh=%J<$q9djc-gL+^ss)&}^Lg0knUaf5Qq`jElw;y*@ z3!8Zj2kCY53Pkr7pNP8-P3HPrtBmPvqaffLF{Oktc5C5Aq!e5d>1|TlRL?dP46o6i zCM`D>z)nJ3Fgp5ugj5`O((zC%A0FcN24&wOxYrhOCT+0RE>u_^jR{$++}~hTmc~X| z2YUf-%A}0`Xhrx9HzKQ_&>*S~l^N57U8-(KzNlTqZSniuZ$e4;oy6@%B-oPs5v=2D z>AS^G7pjBzu2e3@exr@S>8WOfjgxz2(K7C&+{#l*eZc1vJnZzu^Fq3iNHW&Oo=>jjxJAdJQIoIkI~*UuLziRB4r#tX^?rZYG%y{-qx> zErqt4#&yEvqK)<$qBPzHe5Mw$2f=BI8L~Q<;M^(h2@vFZ{u^cw*U5R$8b*7FtF}gSoOh{jg76&mcEY&aRa}Mi3Yqty ziX0$p5n4s_nY(IIb><3lpurVv%$-ShJlNd*BD4ZL!{(3v#$sO)sg9V*zxGTs4jb`g z)ASWFd-y=;bLueWVxV!fL;e&h09PhdhKsaH+y}fJWC%6V#9XAjK5~go)b{Dsg@fR^ zcLrI5k4+rQY&R2)Pt^GA;|-9sn*rn|!FH_2mXz+bi-lieZg)EKwpNrmHm{6SSF6J; zj@*`_BX#uufsTPzXrpx%C%bV|PA_eL_l|)@xDMX>6btqVwe5FA?LaTzBc`5xoz2g` zQRb5zN4nbt%LTSrkL;J|ALCj8fZv1dgbR3gImo<^pXs7RS)r5nZ+;+>craHF|Zt`p~)&@rLN$wK9k}31yMpW7^PLAfUTQztyPy>?T zx6Z-I{@@VUpHe9|T$z{3zL#TwKQ6%;Kx@ps>dwR%BLRO$YUu5yD)MEx!}FtI$3&8T z%O!G8LxbI2LFqu6_LUa=Q<+olXX;aP4uQ<#T4QkDDx}?Voe96^8gm|bB-2GKn_QYL zW>hhwwNhpkV^#buyrV5k0KFruD3%U&B(G`pP=3z>>3~(-wKi}$ep6^gWG)AZTWm5@ zO+9j`K1kMcP{24O@Qwk5JDlNOWDcUw!v)4a#sIX3N`jSy3i<%1JC=Ms1o?ylfb;#% zWhJG@GzhJM=>f0mX5V7+)fP~ZspV~^&-G!_vi9CN6pio7!XCJWuMYCJS@O8sHBu+N5M`es*p* z56UIY!LY0Szwn!IO`*9|mHi4vabrXC*(af+p%KVq^wn=-$4!B(EX*Ez6zt+SrA3%D zU^s8-ZJoPN9iw7ys=S|j=Q>$0WbI1ptv&T!=jXG$xLYfaJC8|Rvz#oUOgKHfOGk#r zd?cGVo8x}+gUC&#j<05Ly}Bw?m8~Z9s&7R4{9i@q{Sd{qw{Z(%!G;>5iCD0KEU;zz zoHM)JXaX7wvEoAmJXoHPC;Z`=uAS2{Tb+LoC2}@y zOj!;80T)fZDO;zq>>l$^oO>>W-HGPr8|PE^L0g8SlE0?@%S_ZpsmDzt_zh|;y7UKv zU3h}du_sG6$4m)Cx`yfp%ct@U%-qoDnkh^OEHCR{z7SK6BK1qkQQnl7nz`C6FrJMP z_T&4-sd07iICfNZVV+thnw_OoIWja_e&zO1r^uzCN$6Rij<{@U@$dBB%q!H4oiAL$ z-1&ph6KrX80#S}s<{TKtyUDr07s@q$WyaW3)fspVyUDM0MwMGYQovh!*EuLSD`650 zbCm?!nD_UuP+#+3**25ed;TO-6dBC5mTDJ?Y`oUi_-+f2*%~Gv*80ML4EjhpTC7}& z--Ii~4yFs}VwAfE5-DO0`!5K@=%Ps_tMnh`ZR!Z~1jv(`=`rS3=V(kY8w-ZZJMk0< zpF6k>=%GB)SLG_>)@YZGOi+046SEBd&HiRvV{5^T)h~FDYC+I6G!Z@Ly1+-AUCHrp zM6a29#16x>(V<(w8eF3nz-|?*6$jlCjRwQ{z5KtiRbh*{3cLPtxtY{wU^nJ(c2&x` zao8=iFX*Jhn$+shf_;@;a)2QbfnL_S1m~2aTTHs}WxH|z)R~2p;?5;A$ zKGK_%J}p=hyWUbq>D+Im7j_Q5LERMSxWf)&VsJB_cQt}k^^&@(6rlO=B^a0-BmM;M zY7RQxa$8$wn}xPhL}0wwj)H7U%=lZ*GzR*z-zk}xvpZJ33v$_|&d$->f)Gi=}DCS|OR8IbbL1!5u=Q6ny7KH<5m}Y(NprGdY$0%rONUKHJNfGmG2| ztuMDEvi1I2x8qPg*E%#`Xo~j}we;oCveI9$L;Q77jy%{{5GwJ5bVG;;?x`fKQij*YCOUBSejZ{!C_SMitp zskU2e3LAjpTS=xm@9PbqUVhEpbY^)wplfP-^oyqzooZe+kF!0onfRxq-12mJt=Z3k!M~^}Ms1)#KT6Y~m3kH%EhD#$-HvW^VZq^;)ju3os{!LC zZIN6%0#FOVCgP$3m99)4tAW{o@(d$RNQ;bu*k;A=isOI5*K==*PX*Ps+~VJlig*2LI;$l9J36;1sJ-H3|!O6XW!2 z_Avbi|NF!#G!o}~5IiXkAiAMyb*^)?`k|Gg-;b-&k0s^7j^gyP$6BT`x7}vZdhLTEg0u zmx1rqRCf$#RocSw_>3n6)0sgqN*%*r3FLtP7jtmCG=lwu5h})(6HeJUW~moLwBZX^ z?wk3-*9uAVS}UCB8$xan58@=o4A`+WE~%35Osr!j@@e2hr3>ccZes08kYE6d!ejUCc*^mtnZRETPvqpc4I8C$C>%c0`?O7xu=X1 za3?LtdqTO$&ZRlRicB?L$pAOasQ+c1rA>UkdBGK@RPr;JkdofY=FG9tx^?NxA_h`T;5^z0c7BQdef3f3?0=&l? z>nz|7z<2)h{AqQCx~O0>f4%JB}y#&Hyq6sqP|+E&~<-y#Z&%!u8=K~H)`*UAN04x zE=k6-voiSz{i$)@OqTZeXvL9K$XzIS!z^akh8ECmxb^a2x|yMYyJ{O4&4zIx z_}Df|2^W{}k3zsULtU@mBU#&>;6`Gq+M7JWq!E+7SJUwp1Fz%=E8T z&Sdk!MsGaviSY!^>9dr>*H=lwP9L;RM8T??|IwdJib#dYz(X zOVLKp-kkL#;tunkLr`bh-heFBv9!TTXw9f8=!9z;?1}ariwBnzDrrAX^;Loi_6oSw zU{D`Sft!Tw%C8u!s|r;a`P_D`+Guh348U_-das71pB(ta4J!Jz zNfHM7xGs7xwyeB`IE;Mhtf$M3vd|uLA6?>H4*Sx_Lk~fVyl zi1iQ?%o^J?+$`%~cAl?fabn#$h5SNtR4Bm?aJTYYEeYh~uT)R|6o1KX(I3Kv{%^qo zI>r1l@LIX8kN9Ay^t&=JzlV2EqK(R+E2V^xs$Br7#v#^*sw`#RI4+S6?2a|E*g|$2 z*<0~DQcRj@&)@V$OAOv2E&y9#ESRciqY2<0o|H}0cL_z7R-l)@)A*ZjZCK$*{S>=b z7ag%gy_2%tU{|nblAJ`abO^aqnF$Ke0p5yo04W`1KX%44leiart$&UY2C;M05rd9U zhYZ2K9qRH~YqF53Co$i^C+^(5sl*;;KJ%lQ%M2AdTA@RUTXk@QeH*(Lv?Wr!ZOTu8 zpX?FDBkz2}Me*uXr300NuG%{&+k|Rin!6JTJ#&E{k%>#){|pGYyZo^juQx{r^jG4i zI*;m@dF~;^QO9R$G9LfaxV7E3q6~LS4<}X85yDWW(aa)R7H2B0Y-gCkg|Gaz+D22N z?`R_&)B6rFzVLpc2#Ht;%Bn@I--)ZR%ixIsSQpRFEO_UVtwk&gEv5qeHPK6H+QBm10^#Gd3&QVUe|WO3)r=9SF~*=H z{5`6be@oZ4TuZ)(@&n&d{7`RRF%AsT_ba_&tvQ3g(tQ@wS-fjlTsw>;vz5M1I%_V$ zIL;V*PQ8ZTozrSeL{|r9=Foc#N&mX*g912w}6`tM0STa zedo&$AJ{^V(u$P-6K^p2)OT8%{xcbFwvuwh_TZK?WKGk{)J{m2ny|FE-ps^w#caL3 zyq_qs*RzxKakxJ>0sbMp@{SFRmh#YH5QFIMjXTZ@m-Xk^7up;(X9tHGnRNCc`KxYP zN%NvVlIn+zU(etgt;)Sr`QG0F@9U?dv4MpZ;b4MMOQx6&`nMv_-C!0go!F6jAMO-K z(+h}?^jKn<<_wPLzaQ@*a*|Sf9YepO-o`AUkm?f1FIW>oOe%QG93oxVz+Izl5x*+C$7LrLd)63D+1=ExkdM2p zU60S_b=7aaI2LENAZr3!3sX&(xXML}x?*GZdvcsMUgXZR+u>j#+Hf)V4T~5m&tu=1 z3UagClTXT&h2hS_F5u^Q9;am^$|u6->{0x%o0o&zQG%2QUxOizPf=pv6Um}CP;Hov z38jX^pwTTkMGj+B;xBA-TcKY^qsSd3PU@Y&-oWolni7`&b_k5oViz_Y>LBe*t4FkfQ(q literal 0 HcmV?d00001