Skip to content

Commit

Permalink
Merge pull request #26 from GreenmaskIO/partial_restoration_fixes
Browse files Browse the repository at this point in the history
Restore command fixes
  • Loading branch information
wwoytenko authored Mar 15, 2024
2 parents e549611 + 52a675e commit ad470ac
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 70 deletions.
8 changes: 1 addition & 7 deletions cmd/greenmask/cmd/restore/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,13 @@ func getDumpId(ctx context.Context, st storages.Storager, dumpId string) (string
}

// TODO: Options currently are not implemented:
// * data-only
// * exit-on-error
// * use-list
// * schema
// * exclude-schema
// * schema-only
// * table
// * single-transaction
// * disable-triggers
// * enable-row-security
// * no-data-for-failed-tables
// * section
// * strict-names
// * use-set-session-authorization

func init() {
// General options:
Expand Down
216 changes: 153 additions & 63 deletions internal/db/postgres/cmd/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os"
"path"
"regexp"
"slices"
"strconv"
"time"

Expand All @@ -41,14 +42,26 @@ import (
)

const (
ScriptPreDataSection = "pre-data"
ScriptDataSection = "data"
ScriptPostDataSection = "post-data"
scriptPreDataSection = "pre-data"
scriptDataSection = "data"
scriptPostDataSection = "post-data"
)

const (
ScriptExecuteBefore = "before"
ScriptExecuteAfter = "after"
preDataSection = "pre-data"
dataSection = "data"
postDataSection = "post-data"
)

const (
scriptExecuteBefore = "before"
scriptExecuteAfter = "after"
)

const (
jsonListFormat = "json"
yamlListFormat = "yaml"
textListFormat = "text"
)

type Restore struct {
Expand All @@ -58,7 +71,7 @@ type Restore struct {
pgRestore *pgrestore.PgRestore
restoreOpt *pgrestore.Options
st storages.Storager
dumpIdList map[int32]bool
dumpIdList []int32
tocObj *toc.Toc
tmpDir string
}
Expand All @@ -78,8 +91,8 @@ func NewRestore(
}

func (r *Restore) RunScripts(ctx context.Context, conn *pgx.Conn, section, when string) error {
if section != ScriptPreDataSection &&
section != ScriptDataSection && section != ScriptPostDataSection {
if section != scriptPreDataSection &&
section != scriptDataSection && section != scriptPostDataSection {
return fmt.Errorf(`unknown "section" value: %s`, section)
}

Expand Down Expand Up @@ -120,6 +133,10 @@ func (r *Restore) RunScripts(ctx context.Context, conn *pgx.Conn, section, when
}

func (r *Restore) prepare() error {
if err := os.Mkdir(r.tmpDir, 0700); err != nil {
return fmt.Errorf("error creating temp dir: %w", err)
}

if r.restoreOpt.UseList != "" {
// TODO: Implement toc entries ordering according to use-list
log.Warn().Msgf("FIXME: Implement toc entries ordering according to use-list")
Expand All @@ -137,10 +154,6 @@ func (r *Restore) prepare() error {

func (r *Restore) preFlightRestore(ctx context.Context, conn *pgx.Conn) error {

if err := os.Mkdir(r.tmpDir, 0700); err != nil {
return fmt.Errorf("error creating temp dir: %w", err)
}

tocFile, err := r.st.GetObject(ctx, "toc.dat")
if err != nil {
return fmt.Errorf("cannot open toc file: %w", err)
Expand All @@ -165,18 +178,58 @@ func (r *Restore) preFlightRestore(ctx context.Context, conn *pgx.Conn) error {
return fmt.Errorf("unable to read toc file: %w", err)
}

if r.dumpIdList != nil {
if err = r.sortAndFilterEntriesByRestoreList(); err != nil {
return fmt.Errorf("unable to sort entries by the provided list: %w", err)
}
}

if len(r.restoreOpt.Schema) > 0 {
for idx, name := range r.restoreOpt.Schema {
r.restoreOpt.Schema[idx] = removeEscapeQuotes(name)
}
}

if len(r.restoreOpt.Table) > 0 {
for idx, name := range r.restoreOpt.Table {
r.restoreOpt.Table[idx] = removeEscapeQuotes(name)
}
}

if len(r.restoreOpt.ExcludeSchema) > 0 {
for idx, name := range r.restoreOpt.ExcludeSchema {
r.restoreOpt.ExcludeSchema[idx] = removeEscapeQuotes(name)
}
}

return nil
}

func (r *Restore) sortAndFilterEntriesByRestoreList() error {
sortedEntries := make([]*toc.Entry, len(r.dumpIdList))

for idx, dumpId := range r.dumpIdList {
foundIdx := slices.IndexFunc(r.tocObj.Entries, func(entry *toc.Entry) bool {
return entry.DumpId == dumpId
})
if foundIdx == -1 {
return fmt.Errorf("entry from provided list with dump id %d is not found", dumpId)
}
sortedEntries[idx] = r.tocObj.Entries[foundIdx]
}
r.tocObj.Entries = sortedEntries
return nil
}

func (r *Restore) preDataRestore(ctx context.Context, conn *pgx.Conn) error {
// Check restore options
// Do not restore this section if implicitly provided
if r.restoreOpt.DataOnly ||
r.restoreOpt.Section != "" && r.restoreOpt.Section != ScriptPreDataSection {
r.restoreOpt.Section != "" && r.restoreOpt.Section != preDataSection {
return nil
}

// Execute PreData Before scripts
if err := r.RunScripts(ctx, conn, ScriptPreDataSection, ScriptExecuteBefore); err != nil {
if err := r.RunScripts(ctx, conn, scriptPreDataSection, scriptExecuteBefore); err != nil {
return err
}

Expand All @@ -189,7 +242,7 @@ func (r *Restore) preDataRestore(ctx context.Context, conn *pgx.Conn) error {
}

// Execute PreData After scripts
if err := r.RunScripts(ctx, conn, ScriptPreDataSection, ScriptExecuteAfter); err != nil {
if err := r.RunScripts(ctx, conn, scriptPreDataSection, scriptExecuteAfter); err != nil {
return err
}

Expand All @@ -199,12 +252,13 @@ func (r *Restore) preDataRestore(ctx context.Context, conn *pgx.Conn) error {
func (r *Restore) dataRestore(ctx context.Context, conn *pgx.Conn) error {
// Execute Data Before scripts

// Do not restore this section if implicitly provided
if r.restoreOpt.SchemaOnly ||
(r.restoreOpt.Section != "" && r.restoreOpt.Section != ScriptDataSection) {
(r.restoreOpt.Section != "" && r.restoreOpt.Section != dataSection) {
return nil
}

if err := r.RunScripts(ctx, conn, ScriptDataSection, ScriptExecuteBefore); err != nil {
if err := r.RunScripts(ctx, conn, scriptDataSection, scriptExecuteBefore); err != nil {
return err
}

Expand All @@ -229,18 +283,8 @@ func (r *Restore) dataRestore(ctx context.Context, conn *pgx.Conn) error {

if entry.Section == toc.SectionData {

if r.restoreOpt.UseList != "" {
_, apply := r.dumpIdList[entry.DumpId]
if !apply {
log.Info().
Int32("dumpId", entry.DumpId).
Str("section", toc.SectionMap[entry.Section]).
Str("type", *entry.Desc).
Str("name", *entry.Tag).
Str("schema", *entry.Namespace).
Msg("toc entry was skipped")
continue
}
if !r.isNeedRestore(entry) {
continue
}

var task restorers.RestoreTask
Expand Down Expand Up @@ -271,22 +315,51 @@ func (r *Restore) dataRestore(ctx context.Context, conn *pgx.Conn) error {
}

// Execute Data After scripts
if err := r.RunScripts(ctx, conn, ScriptDataSection, ScriptExecuteAfter); err != nil {
if err := r.RunScripts(ctx, conn, scriptDataSection, scriptExecuteAfter); err != nil {
return err
}

return nil
}

func (r *Restore) isNeedRestore(e *toc.Entry) bool {

if *e.Desc == toc.TableDataDesc || *e.Desc == toc.SequenceSetDesc {

if len(r.restoreOpt.ExcludeSchema) > 0 &&
slices.Contains(r.restoreOpt.ExcludeSchema, removeEscapeQuotes(*e.Namespace)) {

return true
}

if len(r.restoreOpt.Schema) > 0 &&
!slices.Contains(r.restoreOpt.Schema, removeEscapeQuotes(*e.Namespace)) {

return false
}

if len(r.restoreOpt.Table) > 0 &&
!slices.Contains(r.restoreOpt.Table, removeEscapeQuotes(*e.Tag)) {

return false
}

return true
}

return true
}

func (r *Restore) postDataRestore(ctx context.Context, conn *pgx.Conn) error {
// Execute Post Data Before scripts

// Do not restore this section if implicitly provided
if r.restoreOpt.DataOnly ||
(r.restoreOpt.Section != "" && r.restoreOpt.Section != ScriptPostDataSection) {
(r.restoreOpt.Section != "" && r.restoreOpt.Section != postDataSection) {
return nil
}

if err := r.RunScripts(ctx, conn, ScriptPostDataSection, ScriptExecuteBefore); err != nil {
if err := r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteBefore); err != nil {
return err
}

Expand All @@ -297,7 +370,7 @@ func (r *Restore) postDataRestore(ctx context.Context, conn *pgx.Conn) error {
return fmt.Errorf("cannot restore post-data section using pg_restore: %w", err)
}

if err := r.RunScripts(ctx, conn, ScriptPostDataSection, ScriptExecuteAfter); err != nil {
if err := r.RunScripts(ctx, conn, scriptPostDataSection, scriptExecuteAfter); err != nil {
return err
}

Expand Down Expand Up @@ -407,22 +480,16 @@ func (r *Restore) setRestoreList(fileName string, format string) (err error) {
return fmt.Errorf("unable to open list file: %w", err)
}
defer f.Close()
var res map[int32]bool
switch format {
case "text":
res, err = r.parseTextList(f)
case "yaml":
res, err = r.parseYamlList(f)
case "json":
res, err = r.parseJsonList(f)
}
if err != nil {
r.dumpIdList = res
case textListFormat:
r.dumpIdList, err = r.parseTextList(f)
case yamlListFormat, jsonListFormat:
r.dumpIdList, err = r.parseStructuredList(f, format)
}
return err
}

func (r *Restore) parseTextList(f *os.File) (map[int32]bool, error) {
func (r *Restore) parseTextList(f *os.File) ([]int32, error) {
const dumpIdGroup = 1
var lineNumber int
var lineBuf = make([]byte, 0, 1024)
Expand All @@ -432,7 +499,9 @@ func (r *Restore) parseTextList(f *os.File) (map[int32]bool, error) {
if err != nil {
return nil, fmt.Errorf("cannot compile regexp: %s", err)
}
res := make(map[int32]bool)
//res := make(map[int32]bool)
var res []int32
idx := 0
for {
line, isPrefix, err := lr.ReadLine()
if err != nil {
Expand All @@ -458,37 +527,58 @@ func (r *Restore) parseTextList(f *os.File) (map[int32]bool, error) {
if err != nil {
return nil, fmt.Errorf("cannot parse dumpId at line %d", lineNumber)
}
res[int32(dumpId)] = true
res = append(res, int32(dumpId))
buf.Reset()
idx++
}
}

func (r *Restore) parseYamlList(f *os.File) (map[int32]bool, error) {
func (r *Restore) parseStructuredList(f *os.File, format string) ([]int32, error) {
meta := &storage.Metadata{}
if err := yaml.NewDecoder(f).Decode(meta); err != nil {
return nil, fmt.Errorf("metadata parsing error: %w", err)

switch format {
case jsonListFormat:
if err := json.NewDecoder(f).Decode(meta); err != nil {
return nil, fmt.Errorf("metadata parsing error in json format: %w", err)
}
case yamlListFormat:
if err := yaml.NewDecoder(f).Decode(meta); err != nil {
return nil, fmt.Errorf("metadata parsing error in yaml format: %w", err)
}
default:
return nil, fmt.Errorf("unknown format %s", format)
}
res := make(map[int32]bool)

// Build entries by the provided list and create temporal file for pg_restore call

r.restoreOpt.UseList = path.Join(r.tmpDir, "restoration.list")
tmpListFile, err := os.Create(r.restoreOpt.UseList)
if err != nil {
return nil, fmt.Errorf("unable to create temporal use-list file: %w", err)
}
defer tmpListFile.Close()

var res []int32
for idx, entry := range meta.Entries {
if entry.DumpId == 0 {
return nil, fmt.Errorf("broken list file dumpId: must not be 0: entry number %d", idx)
}
res[entry.DumpId] = true
res = append(res, entry.DumpId)
_, err = tmpListFile.Write([]byte(fmt.Sprintf("%d; \n", entry.DumpId)))
if err != nil {
return nil, fmt.Errorf("unable to write line into list file: %w", err)
}
}

return res, nil
}

func (r *Restore) parseJsonList(f *os.File) (map[int32]bool, error) {
meta := &storage.Metadata{}
if err := json.NewDecoder(f).Decode(meta); err != nil {
return nil, fmt.Errorf("metadata parsing error: %w", err)
func removeEscapeQuotes(v string) string {
if v[0] == '"' {
v = v[1:]
}
res := make(map[int32]bool)
for idx, entry := range meta.Entries {
if entry.DumpId == 0 {
return nil, fmt.Errorf("broken list file dumpId: must not be 0: entry number %d", idx)
}
res[entry.DumpId] = true
if v[len(v)-1] == '"' {
v = v[:len(v)-1]
}
return res, nil
return v
}

0 comments on commit ad470ac

Please sign in to comment.