From 51596456c7ff08db0a4728a30d89dd5e8f6952ed Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Mon, 15 Jan 2024 18:39:41 +0200 Subject: [PATCH] New rails * Fixed parameter validation * Introduced TransformerContext that united parameters that were passed to transformer during initialization procedure and initialized transformer * Updated RandomDate transformer implementation, adapted to the dynamic parametrization * Fixed static and dynamic parameters initialization * Deprecated value-parameters methods in ParameterDefinition --- internal/db/postgres/context/table.go | 2 +- internal/db/postgres/context/transformers.go | 2 +- internal/db/postgres/dump/table.go | 6 +- internal/db/postgres/dumpers/table.go | 2 +- .../dumpers/transformation_pipeline.go | 43 ++++--- .../db/postgres/transformers/random_date.go | 120 +++++++++++------- .../postgres/transformers/utils/definition.go | 27 +++- .../transformers/utils/schema_validation.go | 27 ++-- internal/db/postgres/validate.go | 6 +- pkg/toolkit/driver.go | 4 + pkg/toolkit/dynamic_parameter.go | 26 +++- pkg/toolkit/parameter.go | 42 +++--- pkg/toolkit/static_parameter.go | 30 +++-- pkg/toolkit/types.go | 33 +++-- 14 files changed, 238 insertions(+), 132 deletions(-) diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 0b827c7b..f353b934 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -108,7 +108,7 @@ func validateAndBuildTablesConfig( return nil, warnings, err } warnings = append(warnings, initWarnings...) - table.Transformers = append(table.Transformers, transformer) + table.TransformersContext = append(table.TransformersContext, transformer) } } diff --git a/internal/db/postgres/context/transformers.go b/internal/db/postgres/context/transformers.go index 5430ffab..0c611ad5 100644 --- a/internal/db/postgres/context/transformers.go +++ b/internal/db/postgres/context/transformers.go @@ -28,7 +28,7 @@ func initTransformer( c *domains.TransformerConfig, r *transformersUtils.TransformerRegistry, types []*toolkit.Type, -) (transformersUtils.Transformer, toolkit.ValidationWarnings, error) { +) (*transformersUtils.TransformerContext, toolkit.ValidationWarnings, error) { var totalWarnings toolkit.ValidationWarnings td, ok := r.Get(c.Name) if !ok { diff --git a/internal/db/postgres/dump/table.go b/internal/db/postgres/dump/table.go index de1358a9..7aacd256 100644 --- a/internal/db/postgres/dump/table.go +++ b/internal/db/postgres/dump/table.go @@ -35,7 +35,7 @@ type Table struct { RootPtName string LoadViaPartitionRoot bool RootOid toolkit.Oid - Transformers []utils.Transformer + TransformersContext []*utils.TransformerContext Dependencies []int32 DumpId int32 OriginalSize int64 @@ -47,8 +47,8 @@ type Table struct { } func (t *Table) HasCustomTransformer() bool { - return slices.ContainsFunc(t.Transformers, func(transformer utils.Transformer) bool { - _, ok := transformer.(*custom.CmdTransformer) + return slices.ContainsFunc(t.TransformersContext, func(transformer *utils.TransformerContext) bool { + _, ok := transformer.Transformer.(*custom.CmdTransformer) return ok }) } diff --git a/internal/db/postgres/dumpers/table.go b/internal/db/postgres/dumpers/table.go index 5dd69233..75508015 100644 --- a/internal/db/postgres/dumpers/table.go +++ b/internal/db/postgres/dumpers/table.go @@ -71,7 +71,7 @@ func (td *TableDumper) Execute(ctx context.Context, tx pgx.Tx, st storages.Stora func() error { var pipeline Pipeliner var err error - if len(td.table.Transformers) > 0 { + if len(td.table.TransformersContext) > 0 { if td.validate { pipeline, err = NewValidationPipeline(gtx, eg, td.table, w, td.validateWithOriginal) if err != nil { diff --git a/internal/db/postgres/dumpers/transformation_pipeline.go b/internal/db/postgres/dumpers/transformation_pipeline.go index 71c0db4c..48dffd3b 100644 --- a/internal/db/postgres/dumpers/transformation_pipeline.go +++ b/internal/db/postgres/dumpers/transformation_pipeline.go @@ -53,32 +53,39 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *d // TODO: Fix this hint. Async execution cannot be performed with template record because it is unsafe. // For overcoming it - implement sequence transformer wrapper - that wraps internal (non CMD) transformers - hasTemplateRecordTransformer := slices.ContainsFunc(table.Transformers, func(transformer utils.Transformer) bool { - _, ok := transformer.(*transformers.TemplateRecordTransformer) + hasTemplateRecordTransformer := slices.ContainsFunc(table.TransformersContext, func(transformer *utils.TransformerContext) bool { + _, ok := transformer.Transformer.(*transformers.TemplateRecordTransformer) return ok }) - if !hasTemplateRecordTransformer && table.HasCustomTransformer() && len(table.Transformers) > 1 { + if !hasTemplateRecordTransformer && table.HasCustomTransformer() && len(table.TransformersContext) > 1 { isAsync = true tw := NewTransformationWindow(ctx, eg) tws = append(tws, tw) - for _, t := range table.Transformers { - if !tw.TryAdd(table, t) { + for _, t := range table.TransformersContext { + if !tw.TryAdd(table, t.Transformer) { tw = NewTransformationWindow(ctx, eg) tws = append(tws, tw) - tw.TryAdd(table, t) + tw.TryAdd(table, t.Transformer) } } } + record := toolkit.NewRecord(table.Driver) + + for _, tc := range table.TransformersContext { + for _, dp := range tc.DynamicParameters { + dp.SetRecord(record) + } + } + tp := &TransformationPipeline{ - table: table, - //buf: bytes.NewBuffer(nil), + table: table, w: w, row: pgcopy.NewRow(len(table.Columns)), transformationWindows: tws, isAsync: true, - record: toolkit.NewRecord(table.Driver), + record: record, } var tf TransformationFunc = tp.TransformSync @@ -93,9 +100,9 @@ func NewTransformationPipeline(ctx context.Context, eg *errgroup.Group, table *d func (tp *TransformationPipeline) Init(ctx context.Context) error { var lastInitErr error var idx int - var t utils.Transformer - for idx, t = range tp.table.Transformers { - if err := t.Init(ctx); err != nil { + var t *utils.TransformerContext + for idx, t = range tp.table.TransformersContext { + if err := t.Transformer.Init(ctx); err != nil { lastInitErr = err log.Warn().Err(err).Msg("error initializing transformer") } @@ -103,8 +110,8 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { if lastInitErr != nil { lastInitialized := idx - for _, t = range tp.table.Transformers[:lastInitialized] { - if err := t.Done(ctx); err != nil { + for _, t = range tp.table.TransformersContext[:lastInitialized] { + if err := t.Transformer.Done(ctx); err != nil { log.Warn().Err(err).Msg("error terminating previously initialized transformer") } } @@ -123,8 +130,8 @@ func (tp *TransformationPipeline) Init(ctx context.Context) error { func (tp *TransformationPipeline) TransformSync(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { var err error - for _, t := range tp.table.Transformers { - _, err = t.Transform(ctx, r) + for _, t := range tp.table.TransformersContext { + _, err = t.Transformer.Transform(ctx, r) if err != nil { return nil, NewDumpError(tp.table.Schema, tp.table.Name, tp.line, err) } @@ -187,8 +194,8 @@ func (tp *TransformationPipeline) CompleteDump() (err error) { func (tp *TransformationPipeline) Done(ctx context.Context) error { var lastErr error - for _, t := range tp.table.Transformers { - if err := t.Done(ctx); err != nil { + for _, t := range tp.table.TransformersContext { + if err := t.Transformer.Done(ctx); err != nil { lastErr = err log.Warn().Err(err).Msg("error terminating initialized transformer") } diff --git a/internal/db/postgres/transformers/random_date.go b/internal/db/postgres/transformers/random_date.go index ec952ad0..7d31ece6 100644 --- a/internal/db/postgres/transformers/random_date.go +++ b/internal/db/postgres/transformers/random_date.go @@ -16,7 +16,6 @@ package transformers import ( "context" - "errors" "fmt" "math/rand" "strings" @@ -48,13 +47,15 @@ var RandomDateTransformerDefinition = utils.NewTransformerDefinition( "min", "min threshold date (and/or time) of random value", ).SetRequired(true). - SetLinkParameter("column"), + SetLinkParameter("column"). + SetDynamicModeSupport(true), toolkit.MustNewParameterDefinition( "max", "max threshold date (and/or time) of random value", ).SetRequired(true). - SetLinkParameter("column"), + SetLinkParameter("column"). + SetDynamicModeSupport(true), toolkit.MustNewParameterDefinition( "truncate", @@ -82,22 +83,30 @@ type RandomDateTransformer struct { columnIdx int rand *rand.Rand generate dateGeneratorFunc - min *time.Time - max *time.Time truncate string keepNull bool - delta *int64 affectedColumns map[int]string + + columnParam toolkit.Parameterizer + maxParam toolkit.Parameterizer + minParam toolkit.Parameterizer + truncateParam toolkit.Parameterizer + keepNullParam toolkit.Parameterizer } func NewRandomDateTransformer(ctx context.Context, driver *toolkit.Driver, parameters map[string]toolkit.Parameterizer) (utils.Transformer, toolkit.ValidationWarnings, error) { + + columnParam := parameters["column"] + maxParam := parameters["max"] + minParam := parameters["min"] + truncateParam := parameters["truncate"] + keepNullParam := parameters["keep_null"] + var columnName, truncate string - var minTime, maxTime time.Time var generator dateGeneratorFunc = generateRandomTime var keepNull bool - p := parameters["column"] - if _, err := p.Scan(&columnName); err != nil { + if _, err := columnParam.Scan(&columnName); err != nil { return nil, nil, fmt.Errorf(`unable to scan "column" param: %w`, err) } @@ -108,34 +117,32 @@ func NewRandomDateTransformer(ctx context.Context, driver *toolkit.Driver, param affectedColumns := make(map[int]string) affectedColumns[idx] = columnName - p = parameters["min"] - v, err := p.Value() - if err != nil { - return nil, nil, fmt.Errorf(`error parsing "min" parameter: %w`, err) - } - minTime, ok = v.(time.Time) - if !ok { - return nil, nil, errors.New(`unexpected type for "min" parameter`) - } - - p = parameters["max"] - v, err = p.Value() - if err != nil { - return nil, nil, fmt.Errorf(`error parsing "max" parameter: %w`, err) - } - - maxTime, ok = v.(time.Time) - if !ok { - return nil, nil, errors.New(`unexpected type for "max" parameter`) - } - - p = parameters["keep_null"] - if _, err := p.Scan(&keepNull); err != nil { + //p = parameters["min"] + //v, err := p.Value() + //if err != nil { + // return nil, nil, fmt.Errorf(`error parsing "min" parameter: %w`, err) + //} + //minTime, ok = v.(time.Time) + //if !ok { + // return nil, nil, errors.New(`unexpected type for "min" parameter`) + //} + // + //p = parameters["max"] + //v, err = p.Value() + //if err != nil { + // return nil, nil, fmt.Errorf(`error parsing "max" parameter: %w`, err) + //} + // + //maxTime, ok = v.(time.Time) + //if !ok { + // return nil, nil, errors.New(`unexpected type for "max" parameter`) + //} + + if _, err := keepNullParam.Scan(&keepNull); err != nil { return nil, nil, fmt.Errorf(`unable to scan "keep_null" param: %w`, err) } - p = parameters["truncate"] - if _, err := p.Scan(&truncate); err != nil { + if _, err := truncateParam.Scan(&truncate); err != nil { return nil, nil, fmt.Errorf(`unable to scan "truncate" param: %w`, err) } @@ -143,26 +150,20 @@ func NewRandomDateTransformer(ctx context.Context, driver *toolkit.Driver, param generator = generateRandomTimeTruncate } - if minTime.After(maxTime) { - return nil, toolkit.ValidationWarnings{ - toolkit.NewValidationWarning(). - AddMeta("max", maxTime). - AddMeta("min", minTime). - SetMsg("max value must be greater than min"), - }, nil - } - delta := int64(maxTime.Sub(minTime)) return &RandomDateTransformer{ keepNull: keepNull, truncate: truncate, columnName: columnName, columnIdx: idx, - min: &minTime, - max: &maxTime, generate: generator, rand: rand.New(rand.NewSource(time.Now().UnixMicro())), affectedColumns: affectedColumns, - delta: &delta, + + columnParam: columnParam, + minParam: minParam, + maxParam: maxParam, + truncateParam: truncateParam, + keepNullParam: keepNullParam, }, nil, nil } @@ -180,6 +181,29 @@ func (rdt *RandomDateTransformer) Done(ctx context.Context) error { } func (rdt *RandomDateTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { + + minTime := &time.Time{} + empty, err := rdt.minParam.Scan(minTime) + if err != nil { + return nil, fmt.Errorf(`error getting "min" parameter value: %w`, err) + } + if empty { + return nil, fmt.Errorf("parameter \"min\" cannot be empty") + } + + maxTime := &time.Time{} + empty, err = rdt.maxParam.Scan(maxTime) + if err != nil { + return nil, fmt.Errorf(`error getting "max" parameter value: %w`, err) + } + if empty { + return nil, fmt.Errorf("parameter \"max\" cannot be empty") + } + + if minTime.After(*maxTime) { + return nil, fmt.Errorf("max value must be greater than min: got min = %s max = %s", minTime.String(), maxTime.String()) + } + valAny, err := r.GetRawColumnValueByIdx(rdt.columnIdx) if err != nil { return nil, fmt.Errorf("unable to scan value: %w", err) @@ -188,7 +212,9 @@ func (rdt *RandomDateTransformer) Transform(ctx context.Context, r *toolkit.Reco return r, nil } - res := rdt.generate(rdt.rand, rdt.min, rdt.delta, &rdt.truncate) + delta := int64(maxTime.Sub(*minTime)) + + res := rdt.generate(rdt.rand, minTime, &delta, &rdt.truncate) if err := r.SetColumnValueByIdx(rdt.columnIdx, res); err != nil { return nil, fmt.Errorf("unable to set new value: %w", err) } diff --git a/internal/db/postgres/transformers/utils/definition.go b/internal/db/postgres/transformers/utils/definition.go index ed419480..d0f9bfea 100644 --- a/internal/db/postgres/transformers/utils/definition.go +++ b/internal/db/postgres/transformers/utils/definition.go @@ -110,9 +110,15 @@ func (d *TransformerDefinition) SetSchemaValidator(v SchemaValidationFunc) *Tran // return totalWarnings, params, nil //} +type TransformerContext struct { + Transformer Transformer + StaticParameters map[string]*toolkit.StaticParameter + DynamicParameters map[string]*toolkit.DynamicParameter +} + func (d *TransformerDefinition) Instance( ctx context.Context, driver *toolkit.Driver, rawParams map[string]toolkit.ParamsValue, dynamicParameters map[string]*toolkit.DynamicParamValue, -) (Transformer, toolkit.ValidationWarnings, error) { +) (*TransformerContext, toolkit.ValidationWarnings, error) { // Decode parameters and get the pgcopy of parsed params, parametersWarnings, err := toolkit.InitParametersV2(driver, d.Parameters, rawParams, dynamicParameters) if err != nil { @@ -123,12 +129,23 @@ func (d *TransformerDefinition) Instance( return nil, parametersWarnings, nil } + dynamicParams := make(map[string]*toolkit.DynamicParameter) + staticParams := make(map[string]*toolkit.StaticParameter) + for name, p := range params { + switch v := p.(type) { + case *toolkit.StaticParameter: + staticParams[name] = v + case *toolkit.DynamicParameter: + dynamicParams[name] = v + } + } + paramDefs := make(map[string]*toolkit.ParameterDefinition, len(d.Parameters)) for _, pd := range d.Parameters { paramDefs[pd.Name] = pd } // Validate schema - schemaWarnings, err := d.SchemaValidator(ctx, driver, d.Properties, paramDefs) + schemaWarnings, err := d.SchemaValidator(ctx, driver, d.Properties, staticParams) if err != nil { return nil, nil, fmt.Errorf("schema validation error: %w", err) } @@ -144,5 +161,9 @@ func (d *TransformerDefinition) Instance( res = append(res, schemaWarnings...) res = append(res, transformerWarnings...) - return t, res, nil + return &TransformerContext{ + Transformer: t, + StaticParameters: staticParams, + DynamicParameters: dynamicParams, + }, res, nil } diff --git a/internal/db/postgres/transformers/utils/schema_validation.go b/internal/db/postgres/transformers/utils/schema_validation.go index 7ba55fba..7f139f0c 100644 --- a/internal/db/postgres/transformers/utils/schema_validation.go +++ b/internal/db/postgres/transformers/utils/schema_validation.go @@ -21,7 +21,7 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -type SchemaValidationFunc func(ctx context.Context, table *toolkit.Driver, properties *TransformerProperties, parameters map[string]*toolkit.ParameterDefinition) (toolkit.ValidationWarnings, error) +type SchemaValidationFunc func(ctx context.Context, table *toolkit.Driver, properties *TransformerProperties, parameters map[string]*toolkit.StaticParameter) (toolkit.ValidationWarnings, error) func ValidateSchema( table *toolkit.Table, column *toolkit.Column, columnProperties *toolkit.ColumnProperties, @@ -37,7 +37,7 @@ func ValidateSchema( func DefaultSchemaValidator( ctx context.Context, driver *toolkit.Driver, properties *TransformerProperties, - parameters map[string]*toolkit.ParameterDefinition) (toolkit.ValidationWarnings, error) { + parameters map[string]*toolkit.StaticParameter) (toolkit.ValidationWarnings, error) { var warnings toolkit.ValidationWarnings if parameters == nil { @@ -45,43 +45,44 @@ func DefaultSchemaValidator( } for _, p := range parameters { - if !p.IsColumn || p.IsColumn && !p.ColumnProperties.Affected { + if !p.GetDefinition().IsColumn || p.GetDefinition().IsColumn && !p.GetDefinition().ColumnProperties.Affected { // We assume that if parameter is not a column or is a column but not affected - it should not // violate constraints continue } // Checking is transformer can produce NULL value - if p.ColumnProperties.Nullable && p.Column.NotNull { + if p.GetDefinition().ColumnProperties.Nullable && p.Column.NotNull { warnings = append(warnings, toolkit.NewValidationWarning(). SetMsg("transformer may produce NULL values but column has NOT NULL constraint"). SetSeverity(toolkit.WarningValidationSeverity). AddMeta("ConstraintType", toolkit.NotNullConstraintType). - AddMeta("ParameterName", p.Name). - AddMeta("ColumnName", p.Column.Name), + AddMeta("ParameterName", p.GetDefinition().Name). + AddMeta("ColumnName", p.GetDefinition().Column.Name), ) } // Checking transformed value will not exceed the column length - if p.ColumnProperties.MaxLength != toolkit.WithoutMaxLength && - p.Column.Length < p.ColumnProperties.MaxLength { + if p.GetDefinition().ColumnProperties.MaxLength != toolkit.WithoutMaxLength && + p.GetDefinition().Column.Length < p.GetDefinition().ColumnProperties.MaxLength { warnings = append(warnings, toolkit.NewValidationWarning(). SetMsg("transformer value might be out of length range: column has a length"). SetSeverity(toolkit.WarningValidationSeverity). AddMeta("ConstraintType", toolkit.LengthConstraintType). - AddMeta("ParameterName", p.Name). + AddMeta("ParameterName", p.GetDefinition().Name). AddMeta("ColumnName", p.Column.Name). AddMeta("ColumnMaxLength", p.Column.Length). - AddMeta("TransformerMaxLength", p.ColumnProperties.MaxLength), + AddMeta("TransformerMaxLength", p.GetDefinition().ColumnProperties.MaxLength), ) } // Performing checks constraint checks with the affected column for _, c := range driver.Table.Constraints { - if p.IsColumn && (p.ColumnProperties == nil || p.ColumnProperties != nil && p.ColumnProperties.Affected) { - if warns := c.IsAffected(p.Column, p.ColumnProperties); len(warns) > 0 { + if p.GetDefinition().IsColumn && (p.GetDefinition().ColumnProperties == nil || + p.GetDefinition().ColumnProperties != nil && p.GetDefinition().ColumnProperties.Affected) { + if warns := c.IsAffected(p.Column, p.GetDefinition().ColumnProperties); len(warns) > 0 { for _, w := range warns { - w.AddMeta("ParameterName", p.Name) + w.AddMeta("ParameterName", p.GetDefinition().Name) } warnings = append(warnings, warns...) } diff --git a/internal/db/postgres/validate.go b/internal/db/postgres/validate.go index 2fb3efec..14df5180 100644 --- a/internal/db/postgres/validate.go +++ b/internal/db/postgres/validate.go @@ -184,7 +184,7 @@ func (v *Validate) Run(ctx context.Context) error { var tablesWithTransformers []dump.Entry for _, item := range v.context.DataSectionObjects { - if t, ok := item.(*dump.Table); ok && len(t.Transformers) > 0 { + if t, ok := item.(*dump.Table); ok && len(t.TransformersContext) > 0 { t.ValidateLimitedRecords = v.config.Validate.RowsLimit tablesWithTransformers = append(tablesWithTransformers, t) } @@ -249,8 +249,8 @@ func (v *Validate) getVerticalRowColors(affectedColumns map[int]struct{}, column func (v *Validate) getAffectedColumns(t *dump.Table) map[int]struct{} { affectedColumns := make(map[int]struct{}) - for _, tr := range t.Transformers { - ac := tr.GetAffectedColumns() + for _, tr := range t.TransformersContext { + ac := tr.Transformer.GetAffectedColumns() for idx := range ac { affectedColumns[idx] = struct{}{} } diff --git a/pkg/toolkit/driver.go b/pkg/toolkit/driver.go index 3cbdc8c9..c55dfbdf 100644 --- a/pkg/toolkit/driver.go +++ b/pkg/toolkit/driver.go @@ -117,6 +117,10 @@ func NewDriver(table *Table, customTypes []*Type) (*Driver, ValidationWarnings, return d, warnings, nil } +func (d *Driver) GetTypeMap() *pgtype.Map { + return d.SharedTypeMap +} + func (d *Driver) EncodeValueByColumnIdx(idx int, src any, buf []byte) ([]byte, error) { if typeName, ok := d.unsupportedColumnIdxs[idx]; ok { return nil, fmt.Errorf("encode-decode operation is not supported for column %d with type %s", idx, typeName) diff --git a/pkg/toolkit/dynamic_parameter.go b/pkg/toolkit/dynamic_parameter.go index 63d5ca1d..db9cd1b2 100644 --- a/pkg/toolkit/dynamic_parameter.go +++ b/pkg/toolkit/dynamic_parameter.go @@ -40,6 +40,16 @@ func (dp *DynamicParameter) Init(columnParameters map[string]*StaticParameter, d // 1. If it has CastDbType check that type is the same as in CastDbType iof not - raise warning // 2. If it has linked parameter check that it has the same types otherwise raise validation error + if !dp.definition.DynamicModeSupport { + warnings = append( + warnings, + NewValidationWarning(). + SetSeverity(ErrorValidationSeverity). + SetMsg("parameter does not support dynamic mode"), + ) + return warnings, nil + } + if dynamicValue == nil { panic("DynamicValue is nil") } @@ -134,10 +144,11 @@ func (dp *DynamicParameter) Init(columnParameters map[string]*StaticParameter, d } if dp.definition.CastDbType != "" && - !IsTypeAllowed( + !IsTypeAllowedWithTypeMap( + dp.driver, []string{dp.definition.CastDbType}, - dp.driver.CustomTypes, - column.Name, + column.TypeName, + column.TypeOid, true, ) { warnings = append(warnings, NewValidationWarning(). @@ -156,6 +167,9 @@ func (dp *DynamicParameter) Init(columnParameters map[string]*StaticParameter, d } func (dp *DynamicParameter) Value() (value any, err error) { + if dp.record == nil { + return nil, fmt.Errorf("check transformer implementation: dynamic parameter usage during initialization stage is prohibited") + } // TODO: Add logic for using cst template and null behaviour v, err := dp.record.GetColumnValueByIdx(dp.columnIdx) if err != nil { @@ -165,6 +179,9 @@ func (dp *DynamicParameter) Value() (value any, err error) { } func (dp *DynamicParameter) RawValue() (rawValue ParamsValue, err error) { + if dp.record == nil { + return nil, fmt.Errorf("check transformer implementation: dynamic parameter usage during initialization stage is prohibited") + } // TODO: Add logic for using cst template and null behaviour v, err := dp.record.GetRawColumnValueByIdx(dp.columnIdx) if err != nil { @@ -174,6 +191,9 @@ func (dp *DynamicParameter) RawValue() (rawValue ParamsValue, err error) { } func (dp *DynamicParameter) Scan(dest any) (bool, error) { + if dp.record == nil { + return false, fmt.Errorf("check transformer implementation: dynamic parameter usage during initialization stage is prohibited") + } // TODO: Add logic for using cst template and null behaviour empty, err := dp.record.ScanColumnValueByIdx(dp.columnIdx, dest) if err != nil { diff --git a/pkg/toolkit/parameter.go b/pkg/toolkit/parameter.go index c8937f10..b0b69fb5 100644 --- a/pkg/toolkit/parameter.go +++ b/pkg/toolkit/parameter.go @@ -133,15 +133,20 @@ type ParameterDefinition struct { RawValueValidator RawValueValidator `json:"-"` // LinkedParameter - column-like parameter that has been linked during parsing procedure. Warning, do not // assign it manually, if you don't know the consequences + // Deprecated LinkedColumnParameter *ParameterDefinition `json:"-"` - // Column - column of the table that was assigned in the parsing procedure according to provided column name in + // Column - column of the table that was assigned in the parsing procedure according to provided Column name in // parameter value. In this case value has textual column name + // Deprecated Column *Column `json:"-"` // Driver - initialized used for decoding raw value to database type mentioned in CastDbType + // Deprecated Driver *Driver `mapstructure:"-" json:"-"` // value - cached parsed value after Scan or Value + // Deprecated value any // rawValue - original raw value received from config + // Deprecated rawValue ParamsValue } @@ -334,6 +339,11 @@ func (p *ParameterDefinition) SetDefaultValue(v ParamsValue) *ParameterDefinitio return p } +func (p *ParameterDefinition) SetDynamicModeSupport(v bool) *ParameterDefinition { + p.DynamicModeSupport = v + return p +} + // Deprecated func (p *ParameterDefinition) Copy() *ParameterDefinition { cp := *p @@ -643,20 +653,6 @@ func InitParametersV2( } for _, pd := range otherParamsDef { - staticValue, ok := staticValues[pd.Name] - var p Parameterizer - if ok { - sp := NewStaticParameter(pd, driver) - initWarns, err := sp.Init(columnParams, staticValue) - for _, w := range initWarns { - w.AddMeta("ParameterName", pd.Name) - } - warnings = append(warnings, initWarns...) - if err != nil { - return nil, warnings, fmt.Errorf("error initializing static parameter \"%s\": %w", pd.Name, err) - } - p = sp - } dynamicValue, ok := dynamicValues[pd.Name] if ok { dp := NewDynamicParameter(pd, driver) @@ -668,9 +664,21 @@ func InitParametersV2( if err != nil { return nil, warnings, fmt.Errorf("error initializing static parameter \"%s\": %w", pd.Name, err) } - p = dp + params[pd.Name] = dp + continue + } + + staticValue := staticValues[pd.Name] + sp := NewStaticParameter(pd, driver) + initWarns, err := sp.Init(columnParams, staticValue) + for _, w := range initWarns { + w.AddMeta("ParameterName", pd.Name) + } + warnings = append(warnings, initWarns...) + if err != nil { + return nil, warnings, fmt.Errorf("error initializing static parameter \"%s\": %w", pd.Name, err) } - params[pd.Name] = p + params[pd.Name] = sp } return params, warnings, nil diff --git a/pkg/toolkit/static_parameter.go b/pkg/toolkit/static_parameter.go index cb01b196..49b56b80 100644 --- a/pkg/toolkit/static_parameter.go +++ b/pkg/toolkit/static_parameter.go @@ -13,7 +13,7 @@ type StaticParameter struct { driver *Driver linkedColumnParameter *StaticParameter rawValue ParamsValue - column *Column + Column *Column value any } @@ -88,18 +88,26 @@ func (p *StaticParameter) Init(columnParams map[string]*StaticParameter, rawValu ) return warnings, nil } - p.column = column + p.Column = column - columnTypeName := p.column.TypeName - if p.column.OverriddenTypeName != "" { - columnTypeName = p.column.OverriddenTypeName + columnTypeName := p.Column.TypeName + columnTypeOid := p.Column.TypeOid + if p.Column.OverriddenTypeName != "" { + columnTypeName = p.Column.OverriddenTypeName + columnTypeOid = 0 } if p.definition.ColumnProperties != nil { if len(p.definition.ColumnProperties.AllowedTypes) > 0 { - if !IsTypeAllowed(p.definition.ColumnProperties.AllowedTypes, p.driver.CustomTypes, columnTypeName, true) { + if !IsTypeAllowedWithTypeMap( + p.driver, + p.definition.ColumnProperties.AllowedTypes, + columnTypeName, + columnTypeOid, + true, + ) { warnings = append(warnings, NewValidationWarning(). SetSeverity(ErrorValidationSeverity). SetMsg("unsupported column type"). @@ -149,7 +157,7 @@ func (p *StaticParameter) Value() (any, error) { } else if p.definition.LinkedColumnParameter != nil { // Parsing dynamically - default value and type are unknown // TODO: Be careful - this may cause an error in Scan func if the the returning value is not a pointer - val, err := p.driver.DecodeValueByTypeOid(uint32(p.linkedColumnParameter.column.TypeOid), p.rawValue) + val, err := p.driver.DecodeValueByTypeOid(uint32(p.linkedColumnParameter.Column.TypeOid), p.rawValue) if err != nil { return nil, fmt.Errorf("unable to scan parameter via Driver: %w", err) } @@ -209,20 +217,20 @@ func (p *StaticParameter) Scan(dest any) (bool, error) { } else if p.linkedColumnParameter != nil { // Try to scan value using pgx Driver and pgtype defined in the linked column - if p.linkedColumnParameter.column == nil { - return false, fmt.Errorf("parameter is linked but column was not assigned") + if p.linkedColumnParameter.Column == nil { + return false, fmt.Errorf("parameter is linked but Column was not assigned") } switch p.value.(type) { case *time.Time: - val, err := p.driver.DecodeValueByTypeOid(uint32(p.linkedColumnParameter.column.TypeOid), p.rawValue) + val, err := p.driver.DecodeValueByTypeOid(uint32(p.linkedColumnParameter.Column.TypeOid), p.rawValue) if err != nil { return false, fmt.Errorf("unable to scan parameter via Driver: %w", err) } valTime := val.(time.Time) p.value = &valTime default: - if err := p.driver.ScanValueByTypeOid(uint32(p.linkedColumnParameter.column.TypeOid), p.rawValue, p.value); err != nil { + if err := p.driver.ScanValueByTypeOid(uint32(p.linkedColumnParameter.Column.TypeOid), p.rawValue, p.value); err != nil { return false, fmt.Errorf("unable to scan parameter via Driver: %w", err) } } diff --git a/pkg/toolkit/types.go b/pkg/toolkit/types.go index 659be618..b0dd7c1a 100644 --- a/pkg/toolkit/types.go +++ b/pkg/toolkit/types.go @@ -70,32 +70,32 @@ type Type struct { RootBuiltInTypeName string `json:"root_built_in_type_name,omitempty"` } -func (t *Type) IsAffected(p *ParameterDefinition) (w ValidationWarnings) { +func (t *Type) IsAffected(p *StaticParameter) (w ValidationWarnings) { if p.Column == nil { panic("parameter Column must not be nil") } - if p.Column == nil { + if p.GetDefinition().ColumnProperties == nil { panic("parameter ColumnProperties must not be nil") } - if !p.ColumnProperties.Affected { + if !p.GetDefinition().ColumnProperties.Affected { return } if p.Column.TypeOid != t.Oid { return } - if p.ColumnProperties.Nullable && p.Column.NotNull { + if p.GetDefinition().ColumnProperties.Nullable && p.GetDefinition().Column.NotNull { w = append(w, NewValidationWarning(). SetSeverity(WarningValidationSeverity). - AddMeta("ParameterName", p.Name). + AddMeta("ParameterName", p.GetDefinition().Name). AddMeta("ColumnName", p.Column.Name). - AddMeta("TypeName", p.Name). + AddMeta("TypeName", p.GetDefinition().Name). SetMsg("transformer may produce NULL values but column type has NOT NULL constraint"), ) } if t.Check != nil { w = append(w, NewValidationWarning(). SetSeverity(WarningValidationSeverity). - AddMeta("ParameterName", p.Name). + AddMeta("ParameterName", p.GetDefinition().Name). AddMeta("ColumnName", p.Column.Name). AddMeta("TypeSchema", t.Schema). AddMeta("TypeName", t.Name). @@ -105,11 +105,11 @@ func (t *Type) IsAffected(p *ParameterDefinition) (w ValidationWarnings) { SetMsg("possible check constraint violation: column has domain type with constraint"), ) } - if t.Length != WithoutMaxLength && t.Length < p.ColumnProperties.MaxLength { + if t.Length != WithoutMaxLength && t.Length < p.GetDefinition().ColumnProperties.MaxLength { w = append(w, NewValidationWarning(). SetSeverity(WarningValidationSeverity). SetMsg("transformer value might be out of length range: domain has length higher than column"). - AddMeta("ParameterName", p.Name). + AddMeta("ParameterName", p.GetDefinition().Name). AddMeta("ColumnName", p.Column.Name). AddMeta("TypeSchema", t.Schema). AddMeta("TypeName", t.Name). @@ -166,8 +166,19 @@ func TryRegisterCustomTypes(typeMap *pgtype.Map, types []*Type, silent bool) { } } +func IsTypeAllowedWithTypeMap( + driver *Driver, allowedTypes []string, typeName string, typeOid Oid, checkInherited bool, +) bool { + // Get canonical type name by type Oid if exists otherwise use provided name + pgType, ok := driver.GetTypeMap().TypeForOID(uint32(typeOid)) + if ok { + typeName = pgType.Name + } + return IsTypeAllowed(driver, allowedTypes, typeName, checkInherited) +} + func IsTypeAllowed( - allowedTypes []string, customTypes []*Type, typeName string, checkInherited bool, + driver *Driver, allowedTypes []string, typeName string, checkInherited bool, ) bool { if slices.Contains(allowedTypes, typeName) { @@ -179,7 +190,7 @@ func IsTypeAllowed( } // If custom type is found check that the root type is allowed - pgCustomRootType := GetCustomType(customTypes, typeName) + pgCustomRootType := GetCustomType(driver.CustomTypes, typeName) if pgCustomRootType == nil { return false }