Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom filters for processing accelerators #395

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,16 @@ Each `ghw.AcceleratorDevice` struct contains the following fields:
describing the processing accelerator card. This may be `nil` if no PCI device
information could be determined for the card.

#### filters
The `ghw.Accelerator()` function accepts a slice of filters, of type string, as parameter
in format `[<vendor>]:[<device>][:<class>]`, (same is the _lspci_ command).

Some filter examples:
* `::0302`. Select 3D controller cards.
* `10de::0302`. Select Nvidia (`10de`) 3D controller cards (`0302`).
* `1da3:1060:1200`. Select Habana Labs (`1da3`) Gaudi3 (`1060`) processing accelerator cards (`1200`).
* `1002::`. Select AMD ATI hardware.

```go
package main

Expand All @@ -976,7 +986,11 @@ import (
)

func main() {
accel, err := ghw.Accelerator()
filter := make([]string, 0)
// example of a filter to detect 3D controllers
// filter = append(filter, "::0302")

accel, err := ghw.Accelerator(filter)
if err != nil {
fmt.Printf("Error getting processing accelerator info: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion cmd/ghwc/commands/accelerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ var acceleratorCmd = &cobra.Command{

// showAccelerator show processing accelerators information for the host system.
func showAccelerator(cmd *cobra.Command, args []string) error {
accel, err := ghw.Accelerator()
filter := make([]string, 0)

accel, err := ghw.Accelerator(filter)
if err != nil {
return errors.Wrap(err, "error getting Accelerator info")
}
Expand Down
2 changes: 1 addition & 1 deletion host.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Host(opts ...*WithOption) (*HostInfo, error) {
if err != nil {
return nil, err
}
acceleratorInfo, err := accelerator.New(opts...)
acceleratorInfo, err := accelerator.New([]string{}, opts...)
if err != nil {
return nil, err
}
Expand Down
12 changes: 8 additions & 4 deletions pkg/accelerator/accelerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,19 @@ func (dev *AcceleratorDevice) String() string {
}

type Info struct {
ctx *context.Context
Devices []*AcceleratorDevice `json:"devices"`
ctx *context.Context
Devices []*AcceleratorDevice `json:"devices"`
DiscoveryFilters []string
}

// New returns a pointer to an Info struct that contains information about the
// accelerator devices on the host system
func New(opts ...*option.Option) (*Info, error) {
func New(filter []string, opts ...*option.Option) (*Info, error) {
ctx := context.New(opts...)
info := &Info{ctx: ctx}
info := &Info{
ctx: ctx,
DiscoveryFilters: filter,
}

if err := ctx.Do(info.load); err != nil {
return nil, err
Expand Down
68 changes: 63 additions & 5 deletions pkg/accelerator/accelerator_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package accelerator

import (
"github.com/samber/lo"
"fmt"
"strings"

"github.com/jaypipes/ghw/pkg/context"
"github.com/jaypipes/ghw/pkg/pci"
"github.com/samber/lo"
)

// PCI IDs list available at https://admin.pci-ids.ucw.cz/read/PD
Expand Down Expand Up @@ -60,13 +62,69 @@ func (i *Info) load() error {
if !isAccelerator(device) {
continue
}
accelDev := &AcceleratorDevice{
Address: device.Address,
PCIDevice: device,
if len(i.DiscoveryFilters) > 0 {
for _, filter := range i.DiscoveryFilters {
if validate(filter, device) {
accelDev := &AcceleratorDevice{
Address: device.Address,
PCIDevice: device,
}
accelDevices = append(accelDevices, accelDev)
break
}
}
} else {
accelDev := &AcceleratorDevice{
Address: device.Address,
PCIDevice: device,
}
accelDevices = append(accelDevices, accelDev)
}
accelDevices = append(accelDevices, accelDev)
}

i.Devices = accelDevices
return nil
}

// validate checks if a given PCI device matches the provided filter string.
//
// The filter string is expected to be in the format "VendorID:ProductID:Class+Subclass".
// Each part of the filter (VendorID, ProductID, Class+Subclass) is optional and can be
// left empty, in which case the corresponding attribute is ignored during validation.
//
// Parameters:
// - filter: A string in the form "VendorID:ProductID:Class+Subclass", where
// any part of the string may be empty to represent a wildcard match.
// - device: A pointer to a `pci.Device` structure.
//
// Returns:
// - true: If the device matches the filter criteria (wildcards are supported).
// - false: If the device does not match the filter criteria.
//
// Matching criteria:
// - VendorID must match `device.Vendor.ID` if provided.
// - ProductID must match `device.Product.ID` if provided.
// - Class and Subclass must match the concatenated result of `device.Class.ID` and `device.Subclass.ID` if provided.
//
// Example:
//
// filter := "8086:1234:1200"
// device := pci.Device{Vendor: Vendor{ID: "8086"}, Product: Product{ID: "1234"}, Class: Class{ID: "12"}, Subclass: Subclass{ID: "00"}}
// isValid := validate(filter, &device) // returns true
//
// filter := "8086::1200" // Wildcard for ProductID
// isValid := validate(filter, &device) // returns true
//
// filter := "::1200" // Wildcard for ProductID and VendorID
// isValid := validate(filter, &device) // returns true
func validate(filter string, device *pci.Device) bool {
ids := strings.Split(filter, ":")

if (ids[0] == "" || ids[0] == device.Vendor.ID) &&
(len(ids) < 2 || (ids[1] == "" || ids[1] == device.Product.ID)) &&
(len(ids) < 3 || (ids[2] == "" || ids[2] == fmt.Sprintf("%s%s", device.Class.ID, device.Subclass.ID))) {
return true
}

return false
}
21 changes: 17 additions & 4 deletions pkg/accelerator/accelerator_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/jaypipes/ghw/testdata"
)

func testScenario(t *testing.T, filename string, expectedDevs int) {
func testScenario(t *testing.T, filename string, hwFilter []string, expectedDevs int) {
testdataPath, err := testdata.SnapshotsDirectory()
if err != nil {
t.Fatalf("Expected nil err, but got %v", err)
Expand All @@ -41,7 +41,7 @@ func testScenario(t *testing.T, filename string, expectedDevs int) {
_ = snapshot.Cleanup(tmpRoot)
}()

info, err := accelerator.New(option.WithChroot(tmpRoot))
info, err := accelerator.New(hwFilter, option.WithChroot(tmpRoot))
if err != nil {
t.Fatalf("Expected nil err, but got %v", err)
}
Expand All @@ -59,7 +59,7 @@ func TestAcceleratorDefault(t *testing.T) {
}

// In this scenario we have 1 processing accelerator device
testScenario(t, "linux-amd64-accel.tar.gz", 1)
testScenario(t, "linux-amd64-accel.tar.gz", []string{}, 1)

}

Expand All @@ -69,5 +69,18 @@ func TestAcceleratorNvidia(t *testing.T) {
}

// In this scenario we have 1 Nvidia 3D controller device
testScenario(t, "linux-amd64-accel-nvidia.tar.gz", 1)
testScenario(t, "linux-amd64-accel-nvidia.tar.gz", []string{}, 1)
}

func TestAcceleratorFilter(t *testing.T) {
if _, ok := os.LookupEnv("GHW_TESTING_SKIP_ACCELERATOR"); ok {
t.Skip("Skipping PCI tests.")
}

// Set the filter to detect only processing accelerators (Nvidia not included)
discoveryFilter := make([]string, 0)
discoveryFilter = append(discoveryFilter, "::1200")

// In this scenario we have 1 Nvidia 3D controller device
testScenario(t, "linux-amd64-accel-nvidia.tar.gz", discoveryFilter, 0)
}
Loading