Skip to content

Commit

Permalink
fix: add input validations to monitor fields before creating/updating…
Browse files Browse the repository at this point in the history
… a monitor. add tests to ensure that functionality of existing monitors do not break.

Signed-off-by: vikhy-aws <[email protected]>
  • Loading branch information
vikhy-aws committed Jan 23, 2025
1 parent 2e1cc91 commit c97aa53
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import org.opensearch.commons.alerting.model.ScheduledJob
import org.opensearch.commons.alerting.util.AlertingException
import org.opensearch.commons.alerting.util.isMonitorOfStandardType
import org.opensearch.commons.utils.getInvalidNameChars
import org.opensearch.commons.utils.isValidId
import org.opensearch.commons.utils.isValidName
import org.opensearch.commons.utils.isValidQueryName
import org.opensearch.core.rest.RestStatus
import org.opensearch.core.xcontent.ToXContent
import org.opensearch.core.xcontent.XContentParser.Token
Expand Down Expand Up @@ -86,6 +88,14 @@ class RestIndexMonitorAction : BaseRestHandler() {
throw AlertingException.wrap(IllegalArgumentException("Missing monitor ID"))
}

// Check if the ID is valid
if (request.method() == PUT && !isValidId(id)) {
throw IllegalArgumentException(
"Invalid monitor ID [$id]. " +
"Monitor ID should be alphanumeric string with +, /, _, or - characters only"
)
}

// Validate request by parsing JSON to Monitor
val xcp = request.contentParser()
ensureExpectedToken(Token.START_OBJECT, xcp.nextToken(), xcp)
Expand All @@ -95,6 +105,14 @@ class RestIndexMonitorAction : BaseRestHandler() {
try {
monitor = Monitor.parse(xcp, id).copy(lastUpdateTime = Instant.now())

// Validate if the monitor name is valid
if (!isValidName(monitor.name)) {
throw IllegalArgumentException(
"Invalid monitor name [${monitor.name}]. " +
"Monitor Name should be alphanumeric (4-50 chars) starting with letter or underscore"
)
}

rbacRoles = request.contentParser().map()["rbac_roles"] as List<String>?

validateDataSources(monitor)
Expand All @@ -108,6 +126,21 @@ class RestIndexMonitorAction : BaseRestHandler() {
if (it !is QueryLevelTrigger) {
throw (IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for query level monitor"))
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore"
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only"
)
}
}
}
}

Expand All @@ -116,6 +149,21 @@ class RestIndexMonitorAction : BaseRestHandler() {
if (it !is BucketLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for bucket level monitor")
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore"
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only"
)
}
}
}
}

Expand All @@ -124,6 +172,21 @@ class RestIndexMonitorAction : BaseRestHandler() {
if (it !is QueryLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for cluster metrics monitor")
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore"
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only"
)
}
}
}
}

Expand All @@ -133,6 +196,21 @@ class RestIndexMonitorAction : BaseRestHandler() {
if (it !is DocumentLevelTrigger) {
throw IllegalArgumentException("Illegal trigger type, ${it.javaClass.name}, for document level monitor")
}
if (!isValidName(it.name)) {
throw IllegalArgumentException(
"Invalid trigger name [${it.name}]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore"
)
}
it.actions.forEach { action ->
val destinationId = action.destinationId
if (!isValidId(destinationId)) {
throw IllegalArgumentException(
"Invalid destination ID [$destinationId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only"
)
}
}
}
}
}
Expand All @@ -158,7 +236,7 @@ class RestIndexMonitorAction : BaseRestHandler() {
private fun validateDocLevelQueryName(monitor: Monitor) {
monitor.inputs.filterIsInstance<DocLevelMonitorInput>().forEach { docLevelMonitorInput ->
docLevelMonitorInput.queries.forEach { dlq ->
if (!isValidName(dlq.name)) {
if (!isValidQueryName(dlq.name)) {
throw IllegalArgumentException(
"Doc level query name may not start with [_, +, -], contain '..', or contain: " +
getInvalidNameChars().replace("\\", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,35 @@ import org.opensearch.alerting.toJsonString
import org.opensearch.alerting.util.DestinationType
import org.opensearch.client.ResponseException
import org.opensearch.client.WarningFailureException
import org.opensearch.common.UUIDs
import org.opensearch.common.unit.TimeValue
import org.opensearch.common.xcontent.LoggingDeprecationHandler
import org.opensearch.common.xcontent.XContentFactory.jsonBuilder
import org.opensearch.common.xcontent.XContentType
import org.opensearch.common.xcontent.json.JsonXContent.jsonXContent
import org.opensearch.commons.alerting.model.Alert
import org.opensearch.commons.alerting.model.CronSchedule
import org.opensearch.commons.alerting.model.DataSources
import org.opensearch.commons.alerting.model.DocLevelMonitorInput
import org.opensearch.commons.alerting.model.DocLevelQuery
import org.opensearch.commons.alerting.model.DocumentLevelTrigger
import org.opensearch.commons.alerting.model.IntervalSchedule
import org.opensearch.commons.alerting.model.Monitor
import org.opensearch.commons.alerting.model.Monitor.Companion.NO_ID
import org.opensearch.commons.alerting.model.QueryLevelTrigger
import org.opensearch.commons.alerting.model.ScheduledJob
import org.opensearch.commons.alerting.model.SearchInput
import org.opensearch.commons.alerting.util.IndexUtils.Companion.NO_SCHEMA_VERSION
import org.opensearch.commons.alerting.util.string
import org.opensearch.commons.utils.getInvalidNameChars
import org.opensearch.core.common.bytes.BytesReference
import org.opensearch.core.rest.RestStatus
import org.opensearch.core.xcontent.NamedXContentRegistry
import org.opensearch.core.xcontent.ToXContent
import org.opensearch.core.xcontent.XContentBuilder
import org.opensearch.index.query.BoolQueryBuilder
import org.opensearch.index.query.QueryBuilders
import org.opensearch.index.query.RangeQueryBuilder
import org.opensearch.script.Script
import org.opensearch.search.aggregations.AggregationBuilders
import org.opensearch.search.builder.SearchSourceBuilder
Expand All @@ -66,6 +78,8 @@ import org.opensearch.test.rest.OpenSearchRestTestCase
import java.time.Instant
import java.time.ZoneId
import java.time.temporal.ChronoUnit
import java.time.temporal.ChronoUnit.MINUTES
import java.util.*
import java.util.concurrent.TimeUnit

@TestLogging("level:DEBUG", reason = "Debug for tests.")
Expand Down Expand Up @@ -249,6 +263,192 @@ class MonitorRestApiIT : AlertingRestTestCase() {
}
}

fun `test creating a monitor with invalid monitor name`() {
val invalidName = """1~`!@#$%^&*()_+-=[]/<>?;':\""""
val exception = assertThrows(ResponseException::class.java) {
createMonitor(randomQueryLevelMonitor(name = invalidName), refresh = true)
}
val errorResponse = createParser(XContentType.JSON.xContent(), exception.response.entity.content).map()
// Expected error
val expectedError = mapOf(
"error" to mapOf(
"reason" to "Invalid monitor name [$invalidName]. " +
"Monitor Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"caused_by" to mapOf(
"reason" to "java.lang.IllegalArgumentException: Invalid monitor name [$invalidName]. " +
"Monitor Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"type" to "exception"
),
"type" to "alerting_exception",
"root_cause" to listOf(
mapOf(
"reason" to "Invalid monitor name [$invalidName]. " +
"Monitor Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"type" to "alerting_exception",
)
)
),
"status" to 400
)
assertEquals(expectedError, errorResponse)
}

fun `test creating a monitor with invalid trigger name`() {
val invalidName = """1~`!@#$%^&*()_+-=[]/<>?;':\""""
val trigger = randomQueryLevelTrigger(name = invalidName)
val exception = assertThrows(ResponseException::class.java) {
createMonitor(randomQueryLevelMonitor(triggers = listOf(trigger)), refresh = true)
}
val errorResponse = createParser(XContentType.JSON.xContent(), exception.response.entity.content).map()
// Expected error
val expectedError = mapOf(
"error" to mapOf(
"reason" to "Invalid trigger name [$invalidName]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"caused_by" to mapOf(
"reason" to "java.lang.IllegalArgumentException: Invalid trigger name [$invalidName]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"type" to "exception"
),
"type" to "alerting_exception",
"root_cause" to listOf(
mapOf(
"reason" to "Invalid trigger name [$invalidName]. " +
"Trigger Name should be alphanumeric (4-50 chars) starting with letter or underscore",
"type" to "alerting_exception",
)
)
),
"status" to 400
)
assertEquals(expectedError, errorResponse)
}

fun `test creating a monitor with invalid destination id`() {
val invalidId = """1~`!@#$%^&*()_+-=[]/<>?;':\""""
val trigger = randomQueryLevelTrigger(destinationId = invalidId)
val exception = assertThrows(ResponseException::class.java) {
createMonitor(randomQueryLevelMonitor(triggers = listOf(trigger)), refresh = true)
}
val errorResponse = createParser(XContentType.JSON.xContent(), exception.response.entity.content).map()
// Expected error
val expectedError = mapOf(
"error" to mapOf(
"reason" to "Invalid destination ID [$invalidId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only",
"caused_by" to mapOf(
"reason" to "java.lang.IllegalArgumentException: Invalid destination ID [$invalidId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only",
"type" to "exception"
),
"type" to "alerting_exception",
"root_cause" to listOf(
mapOf(
"reason" to "Invalid destination ID [$invalidId]. " +
"Destination ID should be alphanumeric string with +, /, _, or - characters only",
"type" to "alerting_exception",
)
)
),
"status" to 400
)
assertEquals(expectedError, errorResponse)
}

protected fun Monitor.toJsonStringWithType(): String {
val builder = jsonBuilder()
return shuffleXContent(
toXContent(builder, ToXContent.MapParams(mapOf("with_type" to "true")))
).string()
}

protected fun createMonitorUsingAdminClient(monitor: Monitor, refresh: Boolean = true): Monitor {
createAlertingConfigIndex()

val response = indexDocWithAdminClient(
ScheduledJob.SCHEDULED_JOBS_INDEX,
UUIDs.base64UUID(),
monitor.toJsonStringWithType(),
refresh

)

val monitorJson = jsonXContent.createParser(
NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE,
response.entity.content
).map()

return monitor.copy(
id = monitorJson["_id"] as String,
version = (monitorJson["_version"] as Int).toLong()
)
}

fun `test existing monitors do not break with new validations`() {
val invalidValue = "~'`!@#$%^&*()_+-=[]/<>?;':\""
val monitor = Monitor(
id = NO_ID,
version = 1L,
name = invalidValue,
enabled = true,
schedule = IntervalSchedule(
interval = 1,
unit = MINUTES
),
lastUpdateTime = Instant.now(),
enabledTime = Instant.now(),
monitorType = Monitor.MonitorType.QUERY_LEVEL_MONITOR.value,
user = null,
schemaVersion = NO_SCHEMA_VERSION,
inputs = listOf(
SearchInput(
indices = listOf(invalidValue),
query = SearchSourceBuilder().apply {
size(2147483647)
query(
BoolQueryBuilder().filter(
RangeQueryBuilder("order_date")
.gte(invalidValue)
.lte(invalidValue)
)
)
aggregations()
}
)
),
triggers = listOf(
QueryLevelTrigger(
id = UUID.randomUUID().toString(),
name = invalidValue,
severity = invalidValue,
condition = ALWAYS_RUN,
actions = listOf()
)
),
uiMetadata = mapOf(),
dataSources = DataSources(),
deleteQueryIndexInEveryRun = false,
shouldCreateSingleAlertForFindings = false,
owner = "alerting"
)

// Monitor should be created
val createdMonitor = createMonitorUsingAdminClient(monitor)
assertNotNull("Created monitor should have an ID", createdMonitor.id)

// getMonitor should work
val retrievedMonitor = getMonitor(createdMonitor.id)
assertEquals("Retrieved monitor should have the same name", retrievedMonitor.name, invalidValue)

// executeMonitor should work
val executedMonitor = executeMonitor(createdMonitor.id)
assertEquals("Monitor execution should return OK status", RestStatus.OK, executedMonitor.restStatus())

// searchAlerts should work
val alerts = searchAlerts(createdMonitor, ScheduledJob.SCHEDULED_JOBS_INDEX)
assertEquals("No alerts raised, but searchAlerts must work", 0, alerts.size)
}

/*
fun `test creating an AD monitor with detector has monitor backend role`() {
createAnomalyDetectorIndex()
Expand Down

0 comments on commit c97aa53

Please sign in to comment.