From 2f535323f9b7506dc4cbfe47cc2b1f8067739ef1 Mon Sep 17 00:00:00 2001 From: Piotr Heilman Date: Thu, 23 Jan 2025 12:51:13 +0100 Subject: [PATCH] Add start-from-s3 command. --- go.mod | 19 +++++ go.sum | 38 ++++++++++ main.go | 72 ++++++++++++++++++- .../poseidon_tree.go | 2 +- prover/marshal.go | 51 +++++++++++++ 5 files changed, 179 insertions(+), 3 deletions(-) rename test_tree.go => poseidon_tree/poseidon_tree.go (99%) diff --git a/go.mod b/go.mod index 9128e3e..83e8416 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,10 @@ go 1.23 toolchain go1.23.3 require ( + github.com/aws/aws-sdk-go-v2 v1.33.0 + github.com/aws/aws-sdk-go-v2/config v1.29.1 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.53 + github.com/aws/aws-sdk-go-v2/service/s3 v1.74.0 github.com/consensys/gnark v0.8.0 github.com/iden3/go-iden3-crypto v0.0.13 github.com/prometheus/client_golang v1.14.0 @@ -22,6 +26,21 @@ require ( github.com/DataDog/go-tuf v1.1.0-0.5.2 // indirect github.com/DataDog/sketches-go v1.4.5 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.54 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.28 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect + github.com/aws/smithy-go v1.22.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index b0aae08..379568c 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,44 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aws/aws-sdk-go-v2 v1.33.0 h1:Evgm4DI9imD81V0WwD+TN4DCwjUMdc94TrduMLbgZJs= +github.com/aws/aws-sdk-go-v2 v1.33.0/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= +github.com/aws/aws-sdk-go-v2/config v1.29.1 h1:JZhGawAyZ/EuJeBtbQYnaoftczcb2drR2Iq36Wgz4sQ= +github.com/aws/aws-sdk-go-v2/config v1.29.1/go.mod h1:7bR2YD5euaxBhzt2y/oDkt3uNRb6tjFp98GlTFueRwk= +github.com/aws/aws-sdk-go-v2/credentials v1.17.54 h1:4UmqeOqJPvdvASZWrKlhzpRahAulBfyTJQUaYy4+hEI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.54/go.mod h1:RTdfo0P0hbbTxIhmQrOsC/PquBZGabEPnCaxxKRPSnI= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 h1:5grmdTdMsovn9kPZPI23Hhvp0ZyNm5cRO+IZFIYiAfw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24/go.mod h1:zqi7TVKTswH3Ozq28PkmBmgzG1tona7mo9G2IJg4Cis= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.53 h1:3jYpOndmkKtmlPOhMNIV7Q92GD61x/KNjmxUcB95btw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.53/go.mod h1:+s7tPUl4uy7FMpT5qnjkY5YJNuKU2HZL6trkYxQNtb4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 h1:igORFSiH3bfq4lxKFkTSYDhJEUCYo6C8VKiWJjYwQuQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28/go.mod h1:3So8EA/aAYm36L7XIvCVwLa0s5N0P7o2b1oqnx/2R4g= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 h1:1mOW9zAUMhTSrMDssEHS/ajx8JcAj/IcftzcmNlmVLI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28/go.mod h1:kGlXVIWDfvt2Ox5zEaNglmq0hXPHgQFNMix33Tw22jA= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.28 h1:7kpeALOUeThs2kEjlAxlADAVfxKmkYAedlpZ3kdoSJ4= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.28/go.mod h1:pyaOYEdp1MJWgtXLy6q80r3DhsVdOIOZNB9hdTcJIvI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.2 h1:e6um6+DWYQP1XCa+E9YVtG/9v1qk5lyAOelMOVwSyO8= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.2/go.mod h1:dIW8puxSbYLSPv/ju0d9A3CpwXdtqvJtYKDMVmPLOWE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 h1:TQmKDyETFGiXVhZfQ/I0cCFziqqX58pi4tKJGYGFSz0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9/go.mod h1:HVLPK2iHQBUx7HfZeOQSEu3v2ubZaAY2YPbAm5/WUyY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.9 h1:2aInXbh02XsbO0KobPGMNXyv2QP73VDKsWPNJARj/+4= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.9/go.mod h1:dgXS1i+HgWnYkPXqNoPIPKeUsUUYHaUbThC90aDnNiE= +github.com/aws/aws-sdk-go-v2/service/s3 v1.74.0 h1:ncCHiFU9Eq4qnKCNlzMZXfFmvb9R8OVNfU8SFOskxdI= +github.com/aws/aws-sdk-go-v2/service/s3 v1.74.0/go.mod h1:jGJ/v7FIi7Ys9t54tmEFnrxuaWeJLpwNgKp2DXAVhOU= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 h1:kuIyu4fTT38Kj7YCC7ouNbVZSSpqkZ+LzIfhCr6Dg+I= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.11/go.mod h1:Ro744S4fKiCCuZECXgOi760TiYylUM8ZBf6OGiZzJtY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 h1:l+dgv/64iVlQ3WsBbnn+JSbkj01jIi+SM0wYsj3y/hY= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10/go.mod h1:Fzsj6lZEb8AkTE5S68OhcbBqeWPsR8RnGuKPr8Todl8= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 h1:BRVDbewN6VZcwr+FBOszDKvYeXY1kJ+GGMCcpghlw0U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.9/go.mod h1:f6vjfZER1M17Fokn0IzssOTMT2N8ZSq+7jnNF0tArvw= +github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro= +github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/main.go b/main.go index 863b994..5f0eeb1 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "time" "worldcoin/gnark-mbu/logging" + "worldcoin/gnark-mbu/poseidon_tree" "worldcoin/gnark-mbu/prover" "worldcoin/gnark-mbu/server" @@ -226,7 +227,7 @@ func main() { if mode == server.InsertionMode { params := prover.InsertionParameters{} - tree := NewTree(treeDepth) + tree := poseidon_tree.NewTree(treeDepth) params.StartIndex = 0 params.PreRoot = tree.Root() @@ -241,7 +242,7 @@ func main() { r, err = json.Marshal(¶ms) } else if mode == server.DeletionMode { params := prover.DeletionParameters{} - tree := NewTree(treeDepth) + tree := poseidon_tree.NewTree(treeDepth) params.DeletionIndices = make([]uint32, batchSize) params.IdComms = make([]big.Int, batchSize) @@ -320,6 +321,73 @@ func main() { return nil }, }, + { + Name: "start-from-s3", + Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", EnvVars: []string{"MTB_MODE"}, DefaultText: "insertion"}, + &cli.BoolFlag{Name: "json-logging", Usage: "enable JSON logging", Required: false}, + &cli.StringFlag{Name: "prover-address", Usage: "address for the prover server", Value: "localhost:3001", Required: false}, + &cli.StringFlag{Name: "metrics-address", Usage: "address for the metrics server", Value: "localhost:9998", Required: false}, + &cli.StringFlag{Name: "s3-region", Usage: "s3 region of bucket", EnvVars: []string{"S3_REGION"}, DefaultText: "us-east1"}, + &cli.StringFlag{Name: "s3-bucket", Usage: "s3 bucket name", EnvVars: []string{"S3_BUCKET"}, Required: true}, + &cli.StringFlag{Name: "s3-object-key", Usage: "s3 object key (path)", EnvVars: []string{"S3_OBJECT_KEY"}, Required: true}, + &cli.IntFlag{Name: "s3-concurrency", Usage: "number of concurrent connections to download from s3", EnvVars: []string{"S3_CONCURRENCY"}, DefaultText: "8"}, + &cli.Int64Flag{Name: "s3-part-mibs", Usage: "size of part to download from s3", EnvVars: []string{"S3_PART_MIBS"}, DefaultText: "64"}, + }, + Action: func(context *cli.Context) error { + if context.Bool("json-logging") { + logging.SetJSONOutput() + } + region := context.String("s3-region") + bucket := context.String("s3-bucket") + objectKey := context.String("s3-object-key") + concurrency := context.Int("s3-concurrency") + partMibs := context.Int64("s3-part-mibs") + mode := context.String("mode") + + if mode != server.DeletionMode && mode != server.InsertionMode { + return fmt.Errorf("invalid mode: %s", mode) + } + + logging.Logger(). + Info(). + Str("region", region). + Str("bucket", bucket). + Str("objectKey", objectKey). + Str("objectKey", objectKey). + Int("concurrency", concurrency). + Int64("partMibs", partMibs). + Msg("Loading proving system from S3") + + start := time.Now() + ps, err := prover.ReadSystemFromS3(region, bucket, objectKey, concurrency, partMibs) + if err != nil { + return err + } + duration := time.Since(start) + logging.Logger(). + Info(). + Uint32("treeDepth", ps.TreeDepth). + Uint32("batchSize", ps.BatchSize). + Dur("duration", duration). + Msg("Proving system loaded") + + config := server.Config{ + ProverAddress: context.String("prover-address"), + MetricsAddress: context.String("metrics-address"), + Mode: mode, + } + instance := server.Run(&config, ps) + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt) + <-sigint + logging.Logger().Info().Msg("Received sigint, shutting down") + instance.RequestStop() + logging.Logger().Info().Msg("Waiting for server to close") + instance.AwaitStop() + return nil + }, + }, { Name: "prove", Flags: []cli.Flag{ diff --git a/test_tree.go b/poseidon_tree/poseidon_tree.go similarity index 99% rename from test_tree.go rename to poseidon_tree/poseidon_tree.go index 2b01fb3..f54b73f 100644 --- a/test_tree.go +++ b/poseidon_tree/poseidon_tree.go @@ -1,4 +1,4 @@ -package main +package poseidon_tree import ( "github.com/iden3/go-iden3-crypto/poseidon" diff --git a/prover/marshal.go b/prover/marshal.go index 2562c28..ed063ba 100644 --- a/prover/marshal.go +++ b/prover/marshal.go @@ -3,15 +3,20 @@ package prover import ( "bufio" "bytes" + "context" "encoding/binary" "encoding/json" "fmt" + "github.com/aws/aws-sdk-go-v2/config" "io" "math/big" "os" "worldcoin/gnark-mbu/logging" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend/groth16" ) @@ -392,3 +397,49 @@ func ReadSystemFromFile(path string) (ps *ProvingSystem, err error) { } return } + +func ReadSystemFromS3(region, bucket, objectKey string, concurrency int, partMiBs int64) (ps *ProvingSystem, err error) { + ps = new(ProvingSystem) + + ctx := context.TODO() + + awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return + } + + client := s3.NewFromConfig(awsConfig) + if err != nil { + return + } + + hObj, err := client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + }) + if err != nil { + return + } + + downloader := manager.NewDownloader(client, func(d *manager.Downloader) { + d.PartSize = partMiBs * 1024 * 1024 + d.Concurrency = concurrency + }) + + buff := make([]byte, *hObj.ContentLength) + w := manager.NewWriteAtBuffer(buff) + _, err = downloader.Download(context.TODO(), w, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(objectKey), + }) + if err != nil { + return + } + + bufferedReader := bytes.NewReader(w.Bytes()) + _, err = ps.UnsafeReadFrom(bufferedReader) + if err != nil { + return + } + return +}