Skip to content

Commit

Permalink
Merge pull request #13 from GwonsooLee/fix_snapshot_bug
Browse files Browse the repository at this point in the history
fix snapshot bug
  • Loading branch information
GwonsooLee authored Jul 1, 2021
2 parents 719e4bc + ad56b5f commit f35657b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 81 deletions.
8 changes: 5 additions & 3 deletions internal/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import "os"
const (
DefaultRegion = "ap-northeast-2"
EmptyString = ""
GlacierType = "GLACIER"
)

const (
Expand All @@ -36,9 +37,10 @@ const (
)

var (
ConfigDirectoryPath = HomeDir() + "/.escli"
BaseFilePath = ConfigDirectoryPath + "/config.yaml"
ValidRestoreTier = []string{"Standard", "Bulk", "Expedited"}
ConfigDirectoryPath = HomeDir() + "/.escli"
BaseFilePath = ConfigDirectoryPath + "/config.yaml"
ValidRestoreTier = []string{"Standard", "Bulk", "Expedited"}
SupportedRepositoryType = []string{"s3"}
)

// Get Home Directory
Expand Down
192 changes: 114 additions & 78 deletions internal/runner/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ func (r Runner) ArchiveSnapshot(out io.Writer, args []string) error {
for {
for _, item := range objs.Contents {
fmt.Fprintf(out, "%s\n", *item.Key)
if *item.StorageClass != "GLACIER" {
if *item.StorageClass != constants.GlacierType {
if r.Flag.Force {
color.Green("Change Storage Class to %s -> GLACIER", *item.StorageClass)
wait.Add(1)
go func(key string) {
defer wait.Done()
_, err := r.Client.TransitObject(aws.String(repository.Settings.Bucket), aws.String(key), "GLACIER")
_, err := r.Client.TransitObject(aws.String(repository.Settings.Bucket), aws.String(key), constants.GlacierType)
if err != nil {
panic(err)
}
Expand All @@ -212,7 +212,7 @@ func (r Runner) ArchiveSnapshot(out io.Writer, args []string) error {

color.Green("Change Storage Class to %s -> GLACIER", *item.StorageClass)
wait.Add(1)
_, err := r.Client.TransitObject(aws.String(repository.Settings.Bucket), item.Key, "GLACIER")
_, err := r.Client.TransitObject(aws.String(repository.Settings.Bucket), item.Key, constants.GlacierType)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -266,25 +266,24 @@ func (r Runner) getIndexIDFromS3(bucket *string, prefix *string) (*snapshotSchem
return &snapshotsIndicesS3, nil
}

func (r Runner) RestoreSnapshot(out io.Writer, args []string) error {
maxSemaphore := r.Flag.MaxConcurrentJob
if maxSemaphore == 0 {
maxSemaphore = constants.DefaultMaxConcurrentJob
}

// RestoreSnapshot is main function of restoring snapshot process
func (r *Runner) RestoreSnapshot(out io.Writer, args []string) error {
repositoryID := args[0]
snapshotID := args[1]
indexName := args[2]
areAllObjectsStandard := true

repository := r.Client.GetRepository(repositoryID)

if err := checkRepositoryType(repository.Type); err != nil {
return err
}

if repository.Type == "s3" {
fmt.Fprintf(out, "bucket name : %s\n", util.StringWithColor(repository.Settings.Bucket))
fmt.Fprintf(out, "base path : %s\n", util.StringWithColor(repository.Settings.BasePath))

var basePath string
var result []SegmentError

if repository.Settings.BasePath == constants.EmptyString {
basePath = constants.EmptyString
Expand All @@ -304,78 +303,33 @@ func (r Runner) RestoreSnapshot(out io.Writer, args []string) error {
metaData := snapshotsIndicesS3.Indices[indexName]
prefix := repository.Settings.BasePath + "/indices/" + metaData.ID + "/"

// add all objects in segments
segments := r.AddObjectSegments(repository.Settings.Bucket, prefix, nil)
bar := pb.New(len(segments))
bar.SetRefreshRate(time.Second)
bar.SetWriter(out)

var wg sync.WaitGroup
semaphore := make(chan int, maxSemaphore)
output := make(chan []SegmentError)
input := make(chan SegmentError)
defer close(output)

go func(input chan SegmentError, output chan []SegmentError, wg *sync.WaitGroup, bar *pb.ProgressBar) {
var ret []SegmentError
for se := range input {
ret = append(ret, se)
bar.Add(1)
wg.Done()
}
output <- ret
}(input, output, &wg, bar)

f := func(out io.Writer, bucket string, segment SnapshotSegment, force bool, ch chan SegmentError, sem chan int) {
sem <- 1
time.Sleep(1 * time.Second)
if force {
//color.Green("Restore Storage Class to %s -> STANDARD", segment.StorageClass)
err := r.restoreObject(out, aws.String(bucket), aws.String(segment.Key))
ch <- SegmentError{
Key: segment.Key,
Error: err,
}
} else {
reader := bufio.NewReader(os.Stdin)

color.Blue("Change Storage Class to STANDARD [y/n]: ")

resp, _ := reader.ReadString('\n')
if strings.ToLower(strings.TrimSpace(resp)) == "y" {
color.Green("Change Storage Class to %s -> STANDARD", segment.StorageClass)
err := r.restoreObject(out, aws.String(bucket), aws.String(segment.Key))
ch <- SegmentError{
Key: segment.Key,
Error: err,
}
} else {
color.Red("Don't change storage class %s", segment.Key)
}
}
<-sem
}

bar.Start()
for _, s := range segments {
if s.StorageClass == "GLACIER" {
areAllObjectsStandard = false
wg.Add(1)
go f(out, repository.Settings.Bucket, s, r.Flag.Force, input, semaphore)
} else {
bar.Add(1)
if r.Flag.Force {
// if --force is enabled by user, then run it concurrently
areAllObjectsStandard, err = r.RunConcurrentRestore(out, repository.Settings.Bucket, segments, r.Flag.MaxConcurrentJob)
if err != nil {
return err
}
}
wg.Wait()
close(input)

bar.Finish()
} else {
for _, s := range segments {
if s.StorageClass == "GLACIER" {
areAllObjectsStandard = false
reader := bufio.NewReader(os.Stdin)

result = <-output
color.Blue("Change Storage Class to STANDARD [y/n]: ")

if len(result) > 0 {
for _, s := range result {
if s.Error != nil {
s.PrintError()
resp, _ := reader.ReadString('\n')
if strings.ToLower(strings.TrimSpace(resp)) == "y" {
color.Green("Change Storage Class to %s -> STANDARD", s.StorageClass)
err := r.restoreObject(out, aws.String(repository.Settings.Bucket), aws.String(s.Key))
if err != nil {
return err
}
} else {
color.Red("Don't change storage class %s", s.Key)
}
}
}
}
Expand All @@ -395,7 +349,8 @@ func (r Runner) RestoreSnapshot(out io.Writer, args []string) error {
return nil
}

func (r Runner) restoreObject(_ io.Writer, bucket *string, key *string) error {
// restoreObject restores object
func (r *Runner) restoreObject(_ io.Writer, bucket *string, key *string) error {
resp, err := r.Client.HeadObject(bucket, key)

if err != nil {
Expand Down Expand Up @@ -435,3 +390,84 @@ func (r *Runner) AddObjectSegments(bucket, prefix string, token *string) []Snaps

return segments
}

// RunConcurrentRestore runs restore process concurrently
// This only runs if user uses --force option
func (r *Runner) RunConcurrentRestore(out io.Writer, bucket string, segments []SnapshotSegment, maxSemaphore int64) (bool, error) {
var result []SegmentError
var areAllObjectsStandard bool

// maxSemaphore limits the number of go routines
if maxSemaphore == 0 {
maxSemaphore = constants.DefaultMaxConcurrentJob
}

bar := pb.New(len(segments))
bar.SetRefreshRate(time.Second)
bar.SetWriter(out)

var wg sync.WaitGroup
semaphore := make(chan int, maxSemaphore)
output := make(chan []SegmentError)
input := make(chan SegmentError)
defer close(output)

go func(input chan SegmentError, output chan []SegmentError, wg *sync.WaitGroup, bar *pb.ProgressBar) {
var ret []SegmentError
for se := range input {
ret = append(ret, se)
bar.Add(1)
wg.Done()
}
output <- ret
}(input, output, &wg, bar)

f := func(out io.Writer, bucket string, segment SnapshotSegment, ch chan SegmentError, sem chan int) {
sem <- 1
time.Sleep(1 * time.Second)

//color.Green("Restore Storage Class to %s -> STANDARD", segment.StorageClass)
err := r.restoreObject(out, aws.String(bucket), aws.String(segment.Key))
ch <- SegmentError{
Key: segment.Key,
Error: err,
}

<-sem
}

bar.Start()
for _, s := range segments {
if s.StorageClass == "GLACIER" {
areAllObjectsStandard = false
wg.Add(1)
go f(out, bucket, s, input, semaphore)
} else {
bar.Add(1)
}
}
wg.Wait()
close(input)

bar.Finish()

result = <-output

if len(result) > 0 {
for _, s := range result {
if s.Error != nil {
s.PrintError()
}
}
}

return areAllObjectsStandard, nil
}

// checkRepositoryType checks if repository type is supported by escli
func checkRepositoryType(repositoryType string) error {
if !util.IsStringInArray(strings.ToLower(repositoryType), constants.SupportedRepositoryType) {
return fmt.Errorf("unsupported repository type: %s", repositoryType)
}
return nil
}

0 comments on commit f35657b

Please sign in to comment.