Skip to content

Commit

Permalink
Merge pull request #3 from dev6699/dev
Browse files Browse the repository at this point in the history
Extract config, add MinProbability, MaxIOU flag
  • Loading branch information
dev6699 authored Sep 19, 2023
2 parents bdb0aa0 + f338c40 commit ba93240
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 45 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ go run cmd/main.go --help
Name of model being served (Required) (default "yolonas")
-n int
Number of benchmark run. (default 1)
-o float
Intersection over Union (IoU) (default 0.7)
-p float
Minimum probability (default 0.5)
-t string
Type of model. Available options: [yolonas, yolonasint8, yolov8] (default "yolonas")
-u string
Expand Down
2 changes: 1 addition & 1 deletion class.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package yolotriton

var yoloClasses = []string{
var YoloClasses = []string{
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
"sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
Expand Down
22 changes: 19 additions & 3 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type Flags struct {
ModelName string
ModelVersion string
ModelType string
MinProbability float64
MaxIOU float64
URL string
Image string
Benchmark bool
Expand All @@ -25,6 +27,8 @@ func parseFlags() Flags {
flag.StringVar(&flags.ModelName, "m", "yolonas", "Name of model being served (Required)")
flag.StringVar(&flags.ModelVersion, "x", "", "Version of model. Default: Latest Version")
flag.StringVar(&flags.ModelType, "t", "yolonas", "Type of model. Available options: [yolonas, yolonasint8, yolov8]")
flag.Float64Var(&flags.MinProbability, "p", 0.5, "Minimum probability")
flag.Float64Var(&flags.MaxIOU, "o", 0.7, "Intersection over Union (IoU)")
flag.StringVar(&flags.URL, "u", "tritonserver:8001", "Inference Server URL.")
flag.StringVar(&flags.Image, "i", "images/1.jpg", "Inference Image.")
flag.BoolVar(&flags.Benchmark, "b", false, "Run benchmark.")
Expand All @@ -37,14 +41,26 @@ func main() {
FLAGS := parseFlags()
fmt.Println("FLAGS:", FLAGS)

cfg := yolotriton.YoloTritonConfig{
ModelName: FLAGS.ModelName,
ModelVersion: FLAGS.ModelVersion,
MinProbability: float32(FLAGS.MinProbability),
MaxIOU: FLAGS.MaxIOU,
Classes: yolotriton.YoloClasses,
}

var model yolotriton.Model
switch yolotriton.ModelType(FLAGS.ModelType) {
case yolotriton.ModelTypeYoloV8:
model = yolotriton.NewYoloV8(FLAGS.ModelName, FLAGS.ModelVersion)
cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloV8(cfg)
case yolotriton.ModelTypeYoloNAS:
model = yolotriton.NewYoloNAS(FLAGS.ModelName, FLAGS.ModelVersion)
cfg.NumClasses = 80
cfg.NumObjects = 8400
model = yolotriton.NewYoloNAS(cfg)
case yolotriton.ModelTypeYoloNASInt8:
model = yolotriton.NewYoloNASInt8(FLAGS.ModelName, FLAGS.ModelVersion)
model = yolotriton.NewYoloNASInt8(cfg)
default:
log.Fatalf("Unsupported model: %s. Available options: [yolonas, yolonasint8, yolov8]", FLAGS.ModelType)
}
Expand Down
2 changes: 1 addition & 1 deletion yolo.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type Model interface {
GetConfig() YoloTritonConfig
PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error)
PostProcess(rawOutputContents [][]byte) ([]Box, error)
GetClass(index int) string
}

type YoloTritonConfig struct {
Expand All @@ -32,6 +31,7 @@ type YoloTritonConfig struct {
ModelVersion string
MinProbability float32
MaxIOU float64
Classes []string
}

func New(url string, model Model) (*YoloTriton, error) {
Expand Down
17 changes: 3 additions & 14 deletions yolonas.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,9 @@ type YoloNAS struct {
}
}

func NewYoloNAS(modelName string, modelVersion string) Model {
func NewYoloNAS(cfg YoloTritonConfig) Model {
return &YoloNAS{
YoloTritonConfig: YoloTritonConfig{
NumClasses: 80,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
YoloTritonConfig: cfg,
}
}

Expand All @@ -35,10 +28,6 @@ func (y *YoloNAS) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}

func (y *YoloNAS) GetClass(index int) string {
return yoloClasses[index]
}

func (y *YoloNAS) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
height := img.Bounds().Dy()
width := img.Bounds().Dx()
Expand Down Expand Up @@ -94,7 +83,7 @@ func (y *YoloNAS) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
continue
}

label := y.GetClass(classID)
label := y.Classes[classID]
idx := (index * 4)
x1raw := predBoxes[idx]
y1raw := predBoxes[idx+1]
Expand Down
15 changes: 3 additions & 12 deletions yolonasint8.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@ type YoloNASInt8 struct {
}
}

func NewYoloNASInt8(modelName string, modelVersion string) Model {
func NewYoloNASInt8(cfg YoloTritonConfig) Model {
return &YoloNASInt8{
YoloTritonConfig: YoloTritonConfig{
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
YoloTritonConfig: cfg,
}
}

Expand All @@ -33,10 +28,6 @@ func (y *YoloNASInt8) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}

func (y *YoloNASInt8) GetClass(index int) string {
return yoloClasses[index]
}

func (y *YoloNASInt8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
height := img.Bounds().Dy()
width := img.Bounds().Dx()
Expand Down Expand Up @@ -89,7 +80,7 @@ func (y *YoloNASInt8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
}

classID := predClasses[index]
label := y.GetClass(int(classID))
label := y.Classes[classID]
idx := (index * 4)
x1raw := predBoxes[idx]
y1raw := predBoxes[idx+1]
Expand Down
17 changes: 3 additions & 14 deletions yolov8.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,9 @@ type YoloV8 struct {
}
}

func NewYoloV8(modelName string, modelVersion string) Model {
func NewYoloV8(cfg YoloTritonConfig) Model {
return &YoloV8{
YoloTritonConfig: YoloTritonConfig{
NumClasses: 80,
NumObjects: 8400,
MinProbability: 0.5,
MaxIOU: 0.7,
ModelName: modelName,
ModelVersion: modelVersion,
},
YoloTritonConfig: cfg,
}
}

Expand All @@ -33,10 +26,6 @@ func (y *YoloV8) GetConfig() YoloTritonConfig {
return y.YoloTritonConfig
}

func (y *YoloV8) GetClass(index int) string {
return yoloClasses[index]
}

func (y *YoloV8) PreProcess(img image.Image, targetWidth uint, targetHeight uint) (*triton.InferTensorContents, error) {
width := img.Bounds().Dx()
height := img.Bounds().Dy()
Expand Down Expand Up @@ -81,7 +70,7 @@ func (y *YoloV8) PostProcess(rawOutputContents [][]byte) ([]Box, error) {
continue
}

label := y.GetClass(classID)
label := y.Classes[classID]
x1raw := output[index]
y1raw := output[numObjects+index]
w := output[2*numObjects+index]
Expand Down

0 comments on commit ba93240

Please sign in to comment.