Skip to content

Commit

Permalink
conditionally show correlation rule output
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrichie5 committed Dec 6, 2023
1 parent 68904e8 commit 0ecf0f0
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 82 deletions.
28 changes: 28 additions & 0 deletions panther_analysis_tool/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class BulkUploadResponse:
data_models: BulkUploadStatistics
lookup_tables: BulkUploadStatistics
global_helpers: BulkUploadStatistics
correlation_rules: BulkUploadStatistics


@dataclass(frozen=True)
Expand Down Expand Up @@ -437,6 +438,28 @@ class GenerateEnrichedEventResponse:
enriched_event: Dict[str, Any] # json


@dataclass(frozen=True)
class FeatureFlagWithDefault:
flag: str
defaultTreatment: Optional[bool] = None


@dataclass(frozen=True)
class FeatureFlagTreatment:
flag: str
treatment: bool


@dataclass(frozen=True)
class FeatureFlagsParams:
flags: List[FeatureFlagWithDefault]


@dataclass(frozen=True)
class FeatureFlagsResponse:
flags: List[FeatureFlagTreatment]


class Client(ABC):
@abstractmethod
def check(self) -> BackendCheckResponse:
Expand Down Expand Up @@ -516,6 +539,10 @@ def generate_enriched_event_input(
) -> BackendResponse[GenerateEnrichedEventResponse]:
pass

@abstractmethod
def feature_flags(self, params: FeatureFlagsParams) -> BackendResponse[FeatureFlagsResponse]:
pass


def backend_response_failed(resp: BackendResponse) -> bool:
return resp.status_code >= 400 or resp.data.get("statusCode", 0) >= 400
Expand All @@ -532,6 +559,7 @@ def to_bulk_upload_response(data: Any) -> BackendResponse[BulkUploadResponse]:
data_models=BulkUploadStatistics(**data.get("dataModels", default_stats)),
lookup_tables=BulkUploadStatistics(**data.get("lookupTables", default_stats)),
global_helpers=BulkUploadStatistics(**data.get("globalHelpers", default_stats)),
correlation_rules=BulkUploadStatistics(**data.get("correlationRules", default_stats)),
),
)

Expand Down
8 changes: 8 additions & 0 deletions panther_analysis_tool/backend/graphql/feature_flags.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
query GetFeatureFlags($input: GetFeatureFlagsInput!) {
featureFlags(input: $input) {
flags {
flag
treatment
}
}
}
26 changes: 24 additions & 2 deletions panther_analysis_tool/backend/public_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
Expand All @@ -31,7 +31,6 @@
from gql.transport.exceptions import TransportQueryError
from graphql import DocumentNode, ExecutionResult

from ..constants import VERSION_STRING, ReplayStatus
from .client import (
BackendCheckResponse,
BackendError,
Expand Down Expand Up @@ -63,8 +62,12 @@
UpdateSchemaParams,
UpdateSchemaResponse,
to_bulk_upload_response,
FeatureFlagsParams,
FeatureFlagsResponse,
FeatureFlagTreatment,
)
from .errors import is_retryable_error, is_retryable_error_str
from ..constants import VERSION_STRING, ReplayStatus


@dataclass(frozen=True)
Expand Down Expand Up @@ -134,6 +137,9 @@ def stop_replay_mutation(self) -> DocumentNode:
def generate_enriched_event_query(self) -> DocumentNode:
return self._load("generate_enriched_event")

def feature_flags_query(self) -> DocumentNode:
return self._load("feature_flags")

def _load(self, name: str) -> DocumentNode:
if name not in self._cache:
self._cache[name] = Path(_get_graphql_content_filepath(name)).read_text()
Expand Down Expand Up @@ -517,6 +523,22 @@ def generate_enriched_event_input(
),
)

def feature_flags(self, params: FeatureFlagsParams) -> BackendResponse[FeatureFlagsResponse]:
query = self._requests.feature_flags_query()
query_input = {"input": asdict(params)}
res = self._safe_execute(query, variable_values=query_input)
data = res.data.get("featureFlags", {}) # type: ignore

return BackendResponse(
status_code=200,
data=FeatureFlagsResponse(
flags=[
FeatureFlagTreatment(flag=flag.get("flag"), treatment=flag.get("treatment"))
for flag in data.get("flags") or []
]
),
)

def _execute(
self,
request: DocumentNode,
Expand Down
3 changes: 3 additions & 0 deletions panther_analysis_tool/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ class ReplayStatus:
ERROR_COMPUTATION = "ERROR_COMPUTATION"
EVALUATION_IN_PROGRESS = "EVALUATION_IN_PROGRESS"
COMPUTATION_IN_PROGRESS = "COMPUTATION_IN_PROGRESS"


ENABLE_CORRELATION_RULES_FLAG = "EnableCorrelationRules"
Loading

0 comments on commit 0ecf0f0

Please sign in to comment.