Skip to content

Commit

Permalink
[worker] Limit concurrent messages processing
Browse files Browse the repository at this point in the history
  • Loading branch information
sbocahu committed Jan 23, 2025
1 parent 7d5bb14 commit f55e541
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
3 changes: 2 additions & 1 deletion opencti-worker/src/config.yml.sample
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ opencti:

worker:
log_level: 'info'
telemetry_enabled: false
telemetry_enabled: false
max_concurrent_processing: 2 # default=4
25 changes: 15 additions & 10 deletions opencti-worker/src/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from dataclasses import dataclass, field
from threading import Thread
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Optional, Union

import pika
Expand All @@ -28,9 +29,6 @@
get_config_variable,
)

PROCESSING_COUNT: int = 4
MAX_PROCESSING_COUNT: int = 60

# Telemetry variables definition
meter = metrics.get_meter(__name__)
resource = Resource(attributes={SERVICE_NAME: "opencti-worker"})
Expand Down Expand Up @@ -83,6 +81,7 @@ def stop(self) -> None:

@dataclass(unsafe_hash=True)
class Consumer(Thread): # pylint: disable=too-many-instance-attributes
pool: ThreadPool
connector: Dict[str, Any] = field(hash=False)
config: Dict[str, Any] = field(hash=False)
opencti_url: str
Expand Down Expand Up @@ -191,13 +190,7 @@ def _process_message(
"Processing a new message, launching a thread...",
{"tag": method.delivery_tag},
)
thread = Thread(
target=self.data_handler,
args=[self.pika_connection, channel, method.delivery_tag, data],
)
thread.start()
while thread.is_alive(): # Loop while the thread is processing
self.pika_connection.sleep(0.05)
self.pool.apply(self.data_handler, [self.pika_connection, channel, method.delivery_tag, data])
self.worker_logger.info("Message processed, thread terminated")

# Data handling
Expand Down Expand Up @@ -395,6 +388,13 @@ def __post_init__(self) -> None:
self.log_level = get_config_variable(
"WORKER_LOG_LEVEL", ["worker", "log_level"], config
)
self.max_concurrent_processing = get_config_variable(
"WORKER_MAX_CONCURRENT_PROCESSING",
["worker", "max_concurrent_processing"],
config,
True,
4,
)
# Telemetry
self.telemetry_enabled = get_config_variable(
"WORKER_TELEMETRY_ENABLED",
Expand Down Expand Up @@ -443,6 +443,7 @@ def __post_init__(self) -> None:

# Start the main loop
def start(self) -> None:
self.pool = ThreadPool(self.max_concurrent_processing)
sleep_delay = 60
while True:
try:
Expand All @@ -461,6 +462,7 @@ def start(self) -> None:
{"queue": queue},
)
self.consumer_threads[queue] = Consumer(
self.pool,
connector,
self.config,
self.opencti_url,
Expand All @@ -472,6 +474,7 @@ def start(self) -> None:
self.consumer_threads[queue].start()
else:
self.consumer_threads[queue] = Consumer(
self.pool,
connector,
self.config,
self.opencti_url,
Expand Down Expand Up @@ -500,6 +503,8 @@ def start(self) -> None:
time.sleep(sleep_delay)
except KeyboardInterrupt:
# Graceful stop
pool.close()
pool.join()
for thread in self.consumer_threads:
if thread not in self.queues:
self.consumer_threads[thread].terminate()
Expand Down

0 comments on commit f55e541

Please sign in to comment.