Skip to content

Commit

Permalink
Support for ID generation by sequence.
Browse files Browse the repository at this point in the history
Ids can be annotated with @sequence to specify a sequence to pull id values from.

Closes #1923
Original pull request #1955

Signed-off-by: mipo256 <[email protected]>

Some accidential changes removed.
Signed-off-by: schauder <[email protected]>
  • Loading branch information
mipo256 authored and schauder committed Feb 4, 2025
1 parent b51c77b commit d1c9960
Show file tree
Hide file tree
Showing 33 changed files with 716 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domai

Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");
SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream()
.map(insertSubject -> sqlParametersFactory.forInsert(insertSubject.getInstance(), domainType,
insertSubject.getIdentifier(), idValueSource))
.map(insertSubject -> sqlParametersFactory.forInsert( //
insertSubject.getInstance(), //
domainType, //
insertSubject.getIdentifier(), //
idValueSource //
) //
) //
.toArray(SqlIdentifierParameterSource[]::new);

String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers());
Expand Down Expand Up @@ -280,7 +285,8 @@ public <T> List<T> findAll(Class<T> domainType) {

@Override
public <T> Stream<T> streamAll(Class<T> domainType) {
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(), getEntityRowMapper(domainType));
return operations.queryForStream(sql(domainType).getFindAll(), new MapSqlParameterSource(),
getEntityRowMapper(domainType));
}

@Override
Expand Down Expand Up @@ -364,7 +370,8 @@ public <T> List<T> findAll(Class<T> domainType, Sort sort) {

@Override
public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(), getEntityRowMapper(domainType));
return operations.queryForStream(sql(domainType).getFindAll(sort), new MapSqlParameterSource(),
getEntityRowMapper(domainType));
}

@Override
Expand Down Expand Up @@ -479,5 +486,4 @@ private Class<?> getBaseType(PersistentPropertyPath<RelationalPersistentProperty

return baseProperty.getOwner().getType();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.springframework.data.jdbc.core.mapping;

import java.util.Map;
import java.util.Optional;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.data.jdbc.repository.config.AbstractJdbcConfiguration;
import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.relational.core.conversion.MutableAggregateChange;
import org.springframework.data.relational.core.dialect.Dialect;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
import org.springframework.data.relational.core.mapping.event.BeforeSaveCallback;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.util.Assert;

/**
* Callback for generating ID via the database sequence. By default, it is registered as a
* bean in {@link AbstractJdbcConfiguration}
*
* @author Mikhail Polivakha
*/
public class IdGeneratingBeforeSaveCallback implements BeforeSaveCallback<Object> {

private static final Log LOG = LogFactory.getLog(IdGeneratingBeforeSaveCallback.class);

private final RelationalMappingContext relationalMappingContext;
private final Dialect dialect;
private final NamedParameterJdbcOperations operations;

public IdGeneratingBeforeSaveCallback(
RelationalMappingContext relationalMappingContext,
Dialect dialect,
NamedParameterJdbcOperations namedParameterJdbcOperations
) {
this.relationalMappingContext = relationalMappingContext;
this.dialect = dialect;
this.operations = namedParameterJdbcOperations;
}

@Override
public Object onBeforeSave(Object aggregate, MutableAggregateChange<Object> aggregateChange) {
Assert.notNull(aggregate, "The aggregate cannot be null at this point");
RelationalPersistentEntity<?> persistentEntity = relationalMappingContext.getPersistentEntity(aggregate.getClass());
Optional<String> idTargetSequence = persistentEntity.getIdTargetSequence();

if (dialect.getIdGeneration().sequencesSupported()) {

if (persistentEntity.getIdProperty() != null) {
idTargetSequence
.map(s -> dialect.getIdGeneration().nextValueFromSequenceSelect(s))
.ifPresent(sql -> {
Long idValue = operations.queryForObject(sql, Map.of(), (rs, rowNum) -> rs.getLong(1));
PersistentPropertyAccessor<Object> propertyAccessor = persistentEntity.getPropertyAccessor(aggregate);
propertyAccessor.setProperty(persistentEntity.getRequiredIdProperty(), idValue);
});
}
} else {
if (idTargetSequence.isPresent()) {
LOG.warn("""
It seems you're trying to insert an aggregate of type '%s' annotated with @TargetSequence, but the problem is RDBMS you're
working with does not support sequences as such. Falling back to identity columns
"""
.formatted(aggregate.getClass().getName())
);
}
}

return aggregate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.data.jdbc.core.JdbcAggregateTemplate;
import org.springframework.data.jdbc.core.convert.*;
import org.springframework.data.jdbc.core.dialect.JdbcDialect;
import org.springframework.data.jdbc.core.mapping.IdGeneratingBeforeSaveCallback;
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
import org.springframework.data.jdbc.core.mapping.JdbcSimpleTypes;
import org.springframework.data.mapping.model.SimpleTypeHolder;
Expand Down Expand Up @@ -119,6 +120,22 @@ public JdbcMappingContext jdbcMappingContext(Optional<NamingStrategy> namingStra
return mappingContext;
}

/**
* Creates a {@link IdGeneratingBeforeSaveCallback} bean using the configured
* {@link #jdbcMappingContext(Optional, JdbcCustomConversions, RelationalManagedTypes)} and
* {@link #jdbcDialect(NamedParameterJdbcOperations)}.
*
* @return must not be {@literal null}.
*/
@Bean
public IdGeneratingBeforeSaveCallback idGeneratingBeforeSaveCallback(
JdbcMappingContext mappingContext,
NamedParameterJdbcOperations operations,
Dialect dialect
) {
return new IdGeneratingBeforeSaveCallback(mappingContext, dialect, operations);
}

/**
* Creates a {@link RelationalConverter} using the configured
* {@link #jdbcMappingContext(Optional, JdbcCustomConversions, RelationalManagedTypes)}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.springframework.data.convert.WritingConverter;
import org.springframework.data.jdbc.core.mapping.JdbcMappingContext;
import org.springframework.data.relational.core.conversion.IdValueSource;
import org.springframework.data.relational.core.dialect.AnsiDialect;
import org.springframework.data.relational.core.mapping.Column;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.sql.SqlIdentifier;
Expand All @@ -49,7 +48,6 @@ class SqlParametersFactoryTest {
RelationalMappingContext context = new JdbcMappingContext();
RelationResolver relationResolver = mock(RelationResolver.class);
MappingJdbcConverter converter = new MappingJdbcConverter(context, relationResolver);
AnsiDialect dialect = AnsiDialect.INSTANCE;
SqlParametersFactory sqlParametersFactory = new SqlParametersFactory(context, converter);

@Test // DATAJDBC-412
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package org.springframework.data.jdbc.core.mapping;

import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.data.annotation.Id;
import org.springframework.data.relational.core.conversion.MutableAggregateChange;
import org.springframework.data.relational.core.dialect.MySqlDialect;
import org.springframework.data.relational.core.dialect.PostgresDialect;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.core.mapping.Table;
import org.springframework.data.relational.core.mapping.TargetSequence;
import org.springframework.data.relational.core.sql.IdentifierProcessing;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;

/**
* Unit tests for {@link IdGeneratingBeforeSaveCallback}
*
* @author Mikhail Polivakha
*/
class IdGeneratingBeforeSaveCallbackTest {

@Test
void test_mySqlDialect_sequenceGenerationIsNotSupported() {
// given
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
MySqlDialect mySqlDialect = new MySqlDialect(IdentifierProcessing.NONE);
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);

// and
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);

NoSequenceEntity entity = new NoSequenceEntity();

// when
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));

// then
Assertions.assertThat(processed).isSameAs(entity);
Assertions.assertThat(processed).usingRecursiveComparison().isEqualTo(entity);
}

@Test
void test_EntityIsNotMarkedWithTargetSequence() {
// given
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
PostgresDialect mySqlDialect = PostgresDialect.INSTANCE;
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);

// and
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);

NoSequenceEntity entity = new NoSequenceEntity();

// when
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));

// then
Assertions.assertThat(processed).isSameAs(entity);
Assertions.assertThat(processed).usingRecursiveComparison().isEqualTo(entity);
}

@Test
void test_EntityIdIsPopulatedFromSequence() {
// given
RelationalMappingContext relationalMappingContext = new RelationalMappingContext();
relationalMappingContext.getRequiredPersistentEntity(EntityWithSequence.class);

PostgresDialect mySqlDialect = PostgresDialect.INSTANCE;
NamedParameterJdbcOperations operations = mock(NamedParameterJdbcOperations.class);

// and
long generatedId = 112L;
when(operations.queryForObject(anyString(), anyMap(), any(RowMapper.class))).thenReturn(generatedId);

// and
IdGeneratingBeforeSaveCallback subject = new IdGeneratingBeforeSaveCallback(relationalMappingContext, mySqlDialect, operations);

EntityWithSequence entity = new EntityWithSequence();

// when
Object processed = subject.onBeforeSave(entity, MutableAggregateChange.forSave(entity));

// then
Assertions.assertThat(processed).isSameAs(entity);
Assertions
.assertThat(processed)
.usingRecursiveComparison()
.ignoringFields("id")
.isEqualTo(entity);
Assertions.assertThat(entity.getId()).isEqualTo(generatedId);
}

@Table
static class NoSequenceEntity {

@Id
private Long id;
private Long name;
}

@Table
static class EntityWithSequence {

@Id
@TargetSequence(value = "id_seq", schema = "public")
private Long id;

private Long name;

public Long getId() {
return id;
}
}
}
Loading

0 comments on commit d1c9960

Please sign in to comment.