Skip to content

Commit

Permalink
Streaming Join work
Browse files Browse the repository at this point in the history
  • Loading branch information
sb2k16 committed Aug 5, 2024
1 parent 0ca6cf0 commit 8cd96ed
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.opensearch.dataprepper.plugins.processor.aggregate.actions;

import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.plugins.hasher.IdentificationKeysHasher;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateAction;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionInput;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionOutput;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionResponse;
import org.opensearch.dataprepper.plugins.processor.aggregate.GroupState;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;

@DataPrepperPlugin(name = "groupby", pluginType = AggregateAction.class, pluginConfigurationType = GroupByAggregateActionConfig.class)
public class GroupByAggregateAction implements AggregateAction {
static final String LAST_RECEIVED_TIME_KEY = "last_received_time";
static final String SHOULD_CONCLUDE_CHECK_SET_KEY = "should_conclude_check_set";
static final String EVENTS_KEY = "events";
static final String ERROR_STATUS_KEY = "error_status";
private final Random random;
private final IdentificationKeysHasher identificationKeysHasher;

@DataPrepperPluginConstructor
public GroupByAggregateAction(final GroupByAggregateActionConfig groupByAggregateActionConfig) {
this.identificationKeysHasher = new IdentificationKeysHasher(groupByAggregateActionConfig.getIdentificationKeys());
this.random = new Random();
}

@Override
public AggregateActionResponse handleEvent(final Event event, final AggregateActionInput aggregateActionInput) {
final GroupState groupState = aggregateActionInput.getGroupState();
List<Event> events = (List)groupState.getOrDefault(EVENTS_KEY, new ArrayList<>());
events.add(event);
final IdentificationKeysHasher.IdentificationKeysMap identificationKeysMap = identificationKeysHasher.createIdentificationKeysMapFromEvent(event);
event.getMetadata().setAttribute("partition_key", identificationKeysMap.hashCode());
groupState.put(EVENTS_KEY, events);
groupState.put(LAST_RECEIVED_TIME_KEY, Instant.now());
return AggregateActionResponse.nullEventResponse();
}

@Override
public AggregateActionOutput concludeGroup(final AggregateActionInput aggregateActionInput) {
GroupState groupState = aggregateActionInput.getGroupState();
return new AggregateActionOutput((List)groupState.getOrDefault(EVENTS_KEY, List.of()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.opensearch.dataprepper.plugins.processor.aggregate.actions;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.NotEmpty;

import java.util.List;

public class GroupByAggregateActionConfig {
@JsonProperty("identification_keys")
@NotEmpty
private List<String> identificationKeys;

public List<String> getIdentificationKeys() {
return identificationKeys;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package org.opensearch.dataprepper.plugins.processor.aggregate.actions;

import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin;
import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventType;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.plugins.hasher.IdentificationKeysHasher;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateAction;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionInput;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionOutput;
import org.opensearch.dataprepper.plugins.processor.aggregate.AggregateActionResponse;
import org.opensearch.dataprepper.plugins.processor.aggregate.GroupState;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;

@DataPrepperPlugin(name = "join", pluginType = AggregateAction.class, pluginConfigurationType = JoinAggregateActionConfig.class)
public class JoinAggregateAction implements AggregateAction {
private static final String JOINED_KEY = "JOINED_KEY";
private static final String STREAMS_KEY = "STREAMS_KEY";
private final IdentificationKeysHasher identificationKeysHasher;
private final Random random;
private final JoinAggregateActionConfig joinAggregateActionConfig;

@DataPrepperPluginConstructor
public JoinAggregateAction(final JoinAggregateActionConfig joinAggregateActionConfig) {
this.identificationKeysHasher = new IdentificationKeysHasher(joinAggregateActionConfig.getIdentificationKeys());
this.joinAggregateActionConfig = joinAggregateActionConfig;
this.random = new Random();
}

@Override
public AggregateActionResponse handleEvent(final Event event, final AggregateActionInput aggregateActionInput) {
final GroupState groupState = aggregateActionInput.getGroupState();

String stream = event.get("stream", String.class);
Map<String, List<Object>> streamMap = (Map<String, List<Object>>) groupState.getOrDefault(stream, new HashMap<>());

IdentificationKeysHasher.IdentificationKeysMap keysMap = identificationKeysHasher.createIdentificationKeysMapFromEvent(event);
streamMap.computeIfAbsent(keysMap.toString(), k -> new ArrayList<>()).add(event);

groupState.put(stream, streamMap);

return AggregateActionResponse.nullEventResponse();
}

@Override
public AggregateActionOutput concludeGroup(final AggregateActionInput aggregateActionInput) {
GroupState groupState = aggregateActionInput.getGroupState();

List<Map<String, List<Object>>> listOfAllStreams = new ArrayList<>();
for (String stream: joinAggregateActionConfig.getStreams()) {
Map<String, List<Object>> streamMap = (Map<String, List<Object>>) groupState.getOrDefault(stream, new HashMap<>());
if (streamMap.isEmpty()) {
return new AggregateActionOutput(List.of());
}
listOfAllStreams.add(streamMap);
}


List<Event> events = new ArrayList<>();
// Iterate over all streams
Map<String, List<List<Object>>> aggregatedMap = new HashMap<>();
for (Map<String, List<Object>> map : listOfAllStreams) {
for (Map.Entry<String, List<Object>> entry : map.entrySet()) {
aggregatedMap.computeIfAbsent(entry.getKey(), k -> new ArrayList<>()).add(entry.getValue());
}
}

Map<String, List<Object>> mergedJoinedMap = aggregatedMap.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
entry -> mergeAndJoin(entry.getValue())
));

for (Map.Entry<String, List<Object>> key: mergedJoinedMap.entrySet()) {
List<Object> records = key.getValue();
records.forEach(record -> events.add((Event) record));
}

return new AggregateActionOutput(events);
}

private static List<Object> mergeAndJoin(List<List<Object>> lists) {
if (lists.isEmpty()) {
return Collections.emptyList();
}

// Start with the first list
List<Object> result = new ArrayList<>(lists.get(0));

// Compute the Cartesian product for all lists
for (int i = 1; i < lists.size(); i++) {
List<Object> currentList = lists.get(i);
result = result.stream()
.flatMap(v1 -> currentList.stream()
.map(v2 -> {
Event e1 = (Event) v1;
Event e2 = (Event) v2;
Map<String, Object> m1 = e1.toMap();
Map<String, Object> m2 = e2.toMap();
m1.putAll(m2);
return JacksonEvent.builder().withEventType(EventType.DOCUMENT.toString()).withData(m1).build();
}))
.collect(Collectors.toList());
}

return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.opensearch.dataprepper.plugins.processor.aggregate.actions;

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.constraints.NotEmpty;

import java.util.List;

public class JoinAggregateActionConfig {
@JsonProperty("identification_keys")
@NotEmpty
private List<String> identificationKeys;

@JsonProperty("streams")
@NotEmpty
private List<String> streams;

@JsonProperty("primary_key")
@NotEmpty
private String primary_key;

public List<String> getIdentificationKeys() {
return identificationKeys;
}

public List<String> getStreams() {
return streams;
}

public String getPrimary_key() {
return primary_key;
}
}
8 changes: 7 additions & 1 deletion data-prepper-plugins/kinesis-source/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ dependencies {
implementation 'software.amazon.kinesis:amazon-kinesis-client:2.6.0'
compileOnly 'org.projectlombok:lombok:1.18.20'
annotationProcessor 'org.projectlombok:lombok:1.18.20'

implementation("com.amazonaws:aws-encryption-sdk-java:3.0.0")
implementation("software.amazon.cryptography:aws-cryptographic-material-providers:1.0.2")
implementation(platform("software.amazon.awssdk:bom:2.20.91"))
implementation("software.amazon.awssdk:kms")
implementation("software.amazon.awssdk:dynamodb")
// The following are optional:
implementation("com.amazonaws:aws-java-sdk:1.12.394")
testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ public class KinesisStreamConfig {
@JsonProperty("stream_arn")
private String arn;

@JsonProperty("kms_key")
private String kmsKey;

@JsonProperty("initial_position")
private InitialPositionInStream initialPosition = InitialPositionInStream.LATEST;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.event.DefaultEventMetadata;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.log.JacksonLog;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.source.kinesis.KinesisSource;
Expand All @@ -14,6 +16,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Map;
import java.util.function.Consumer;

public class DefaultCodec implements InputCodec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@
import com.amazonaws.services.schemaregistry.deserializers.GlueSchemaRegistryDeserializer;
import com.amazonaws.services.schemaregistry.deserializers.GlueSchemaRegistryDeserializerFactory;
import com.amazonaws.services.schemaregistry.deserializers.GlueSchemaRegistryDeserializerImpl;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.core.instrument.Counter;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.event.DefaultEventMetadata;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventType;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.source.kinesis.KinesisSource;
import org.opensearch.dataprepper.plugins.source.kinesis.KinesisSourceConfig;
import org.opensearch.dataprepper.plugins.source.kinesis.KinesisStreamConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.glue.model.DataFormat;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.kinesis.common.StreamIdentifier;
import software.amazon.kinesis.exceptions.InvalidStateException;
import software.amazon.kinesis.exceptions.ShutdownException;
import software.amazon.kinesis.exceptions.ThrottlingException;
Expand All @@ -31,20 +40,25 @@
import software.amazon.kinesis.processor.ShardRecordProcessor;
import software.amazon.kinesis.retrieval.KinesisClientRecord;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public class KinesisRecordProcessor implements ShardRecordProcessor {
private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class);

private static final ObjectMapper mapper = new ObjectMapper();
// Checkpointing interval
private static final int MINIMAL_CHECKPOINT_INTERVAL_MILLIS = 2 * 60 * 1000; // 2 minute
private final boolean acknowledgementsEnabled;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
private final StreamIdentifier streamIdentifier;
private String kmsKeyId;
private String kinesisShardId;
private final Buffer<Record<Event>> buffer;
private final InputCodec codec;
Expand All @@ -58,6 +72,7 @@ public class KinesisRecordProcessor implements ShardRecordProcessor {
static final Duration ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofSeconds(20);
private final Counter acknowledgementSetCallbackCounter;
static final String ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME = "acknowledgementSetCallbackCounter";
private KmsClient kmsClient;

public KinesisRecordProcessor(Buffer<Record<Event>> buffer, InputCodec codec, KinesisSourceConfig kinesisSourceConfig, final AcknowledgementSetManager acknowledgementSetManager, final PluginMetrics pluginMetrics) {
this.buffer = buffer;
Expand All @@ -72,6 +87,37 @@ public KinesisRecordProcessor(Buffer<Record<Event>> buffer, InputCodec codec, Ki
this.acknowledgementsEnabled = kinesisSourceConfig.isAcknowledgments();
this.bufferAccumulator = BufferAccumulator.create(buffer, 1, Duration.ofSeconds(1));
acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME);
this.streamIdentifier = null;
this.kmsClient = null;
this.kmsKeyId = null;
}

public KinesisRecordProcessor(Buffer<Record<Event>> buffer, InputCodec codec, KinesisSourceConfig kinesisSourceConfig, final AcknowledgementSetManager acknowledgementSetManager, final PluginMetrics pluginMetrics, final StreamIdentifier streamIdentifier) {
this.buffer = buffer;
this.codec = codec;
this.enableCheckpoint = kinesisSourceConfig.isEnableCheckPoint();
this.bufferTimeoutMillis = (int) kinesisSourceConfig.getBufferTimeout().toMillis();
this.acknowledgementSetManager = acknowledgementSetManager;
this.streamIdentifier = streamIdentifier;
GlueSchemaRegistryConfiguration gsrConfig = new GlueSchemaRegistryConfiguration("us-east-1");
glueSchemaRegistryDeserializer = new GlueSchemaRegistryDeserializerImpl(kinesisSourceConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration(), gsrConfig);
GlueSchemaRegistryDeserializerFactory glueSchemaRegistryDeserializerFactory = new GlueSchemaRegistryDeserializerFactory();
gsrDataFormatDeserializer = glueSchemaRegistryDeserializerFactory.getInstance(dataFormat, gsrConfig);
this.acknowledgementsEnabled = kinesisSourceConfig.isAcknowledgments();
this.bufferAccumulator = BufferAccumulator.create(buffer, 1, Duration.ofSeconds(1));
acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME);
this.kmsClient = null;
this.kmsKeyId = null;
for (KinesisStreamConfig streamConfig: kinesisSourceConfig.getStreams()) {
if (streamConfig.getName().equals(streamIdentifier.streamName())) {
this.kmsClient = KmsClient.builder()
.credentialsProvider(kinesisSourceConfig.getAwsAuthenticationOptions().authenticateAwsConfiguration())
.region(kinesisSourceConfig.getAwsAuthenticationOptions().getAwsRegion())
.build();
this.kmsKeyId = streamConfig.getKmsKey();
break;
}
}
}


Expand Down Expand Up @@ -146,7 +192,23 @@ private void processRecord(KinesisClientRecord record, Consumer<Record<Event>> e
record.data().get(arr);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arr);
// Invoke codec
codec.parse(byteArrayInputStream, eventConsumer);
//codec.parse(byteArrayInputStream, eventConsumer);

final BufferedReader reader = new BufferedReader(new InputStreamReader(byteArrayInputStream));
String line;
while ((line = reader.readLine()) != null) {
LOG.debug("Codec to parse line message: " + line);
JsonNode jsonNode = mapper.readTree(line);
Map<String, Object> rec = mapper.convertValue(jsonNode, new TypeReference<Map<String, Object>>(){});
rec.put("stream", streamIdentifier.streamName());
Record<Event> eventRecord = new Record<>(JacksonEvent.builder()
.withEventMetadata(DefaultEventMetadata.builder()
.withEventType(EventType.DOCUMENT.toString())
.withAttributes(Map.of("stream", streamIdentifier.streamName()))
.build())
.withData(rec).build());
eventConsumer.accept(eventRecord);
}
}

@Override
Expand Down
Loading

0 comments on commit 8cd96ed

Please sign in to comment.