diff --git a/README.md b/README.md index bad7c86e..e9fa79df 100644 --- a/README.md +++ b/README.md @@ -966,6 +966,16 @@ Each `ghw.AcceleratorDevice` struct contains the following fields: describing the processing accelerator card. This may be `nil` if no PCI device information could be determined for the card. +#### filters +The `ghw.Accelerator()` function accepts a slice of filters, of type string, as parameter +in format `[]:[][:]`, (same is the _lspci_ command). + +Some filter examples: +* `::0302`. Select 3D controller cards. +* `10de::0302`. Select Nvidia (`10de`) 3D controller cards (`0302`). +* `1da3:1060:1200`. Select Habana Labs (`1da3`) Gaudi3 (`1060`) processing accelerator cards (`1200`). +* `1002::`. Select AMD ATI hardware. + ```go package main @@ -976,7 +986,11 @@ import ( ) func main() { - accel, err := ghw.Accelerator() + filter := make([]string, 0) + // example of a filter to detect 3D controllers + // filter = append(filter, "::0302") + + accel, err := ghw.Accelerator(filter) if err != nil { fmt.Printf("Error getting processing accelerator info: %v", err) } diff --git a/cmd/ghwc/commands/accelerator.go b/cmd/ghwc/commands/accelerator.go index 24cf28dc..ca993cc2 100644 --- a/cmd/ghwc/commands/accelerator.go +++ b/cmd/ghwc/commands/accelerator.go @@ -23,7 +23,9 @@ var acceleratorCmd = &cobra.Command{ // showAccelerator show processing accelerators information for the host system. func showAccelerator(cmd *cobra.Command, args []string) error { - accel, err := ghw.Accelerator() + filter := make([]string, 0) + + accel, err := ghw.Accelerator(filter) if err != nil { return errors.Wrap(err, "error getting Accelerator info") } diff --git a/host.go b/host.go index 89b1ad27..c96d385a 100644 --- a/host.go +++ b/host.go @@ -73,7 +73,7 @@ func Host(opts ...*WithOption) (*HostInfo, error) { if err != nil { return nil, err } - acceleratorInfo, err := accelerator.New(opts...) + acceleratorInfo, err := accelerator.New([]string{}, opts...) if err != nil { return nil, err } diff --git a/pkg/accelerator/accelerator.go b/pkg/accelerator/accelerator.go index b51ef2e2..7be9c8b0 100644 --- a/pkg/accelerator/accelerator.go +++ b/pkg/accelerator/accelerator.go @@ -37,15 +37,19 @@ func (dev *AcceleratorDevice) String() string { } type Info struct { - ctx *context.Context - Devices []*AcceleratorDevice `json:"devices"` + ctx *context.Context + Devices []*AcceleratorDevice `json:"devices"` + DiscoveryFilters []string } // New returns a pointer to an Info struct that contains information about the // accelerator devices on the host system -func New(opts ...*option.Option) (*Info, error) { +func New(filter []string, opts ...*option.Option) (*Info, error) { ctx := context.New(opts...) - info := &Info{ctx: ctx} + info := &Info{ + ctx: ctx, + DiscoveryFilters: filter, + } if err := ctx.Do(info.load); err != nil { return nil, err diff --git a/pkg/accelerator/accelerator_linux.go b/pkg/accelerator/accelerator_linux.go index 67b9aa3a..fb4f4a57 100644 --- a/pkg/accelerator/accelerator_linux.go +++ b/pkg/accelerator/accelerator_linux.go @@ -6,10 +6,12 @@ package accelerator import ( - "github.com/samber/lo" + "fmt" + "strings" "github.com/jaypipes/ghw/pkg/context" "github.com/jaypipes/ghw/pkg/pci" + "github.com/samber/lo" ) // PCI IDs list available at https://admin.pci-ids.ucw.cz/read/PD @@ -60,13 +62,69 @@ func (i *Info) load() error { if !isAccelerator(device) { continue } - accelDev := &AcceleratorDevice{ - Address: device.Address, - PCIDevice: device, + if len(i.DiscoveryFilters) > 0 { + for _, filter := range i.DiscoveryFilters { + if validate(filter, device) { + accelDev := &AcceleratorDevice{ + Address: device.Address, + PCIDevice: device, + } + accelDevices = append(accelDevices, accelDev) + break + } + } + } else { + accelDev := &AcceleratorDevice{ + Address: device.Address, + PCIDevice: device, + } + accelDevices = append(accelDevices, accelDev) } - accelDevices = append(accelDevices, accelDev) } i.Devices = accelDevices return nil } + +// validate checks if a given PCI device matches the provided filter string. +// +// The filter string is expected to be in the format "VendorID:ProductID:Class+Subclass". +// Each part of the filter (VendorID, ProductID, Class+Subclass) is optional and can be +// left empty, in which case the corresponding attribute is ignored during validation. +// +// Parameters: +// - filter: A string in the form "VendorID:ProductID:Class+Subclass", where +// any part of the string may be empty to represent a wildcard match. +// - device: A pointer to a `pci.Device` structure. +// +// Returns: +// - true: If the device matches the filter criteria (wildcards are supported). +// - false: If the device does not match the filter criteria. +// +// Matching criteria: +// - VendorID must match `device.Vendor.ID` if provided. +// - ProductID must match `device.Product.ID` if provided. +// - Class and Subclass must match the concatenated result of `device.Class.ID` and `device.Subclass.ID` if provided. +// +// Example: +// +// filter := "8086:1234:1200" +// device := pci.Device{Vendor: Vendor{ID: "8086"}, Product: Product{ID: "1234"}, Class: Class{ID: "12"}, Subclass: Subclass{ID: "00"}} +// isValid := validate(filter, &device) // returns true +// +// filter := "8086::1200" // Wildcard for ProductID +// isValid := validate(filter, &device) // returns true +// +// filter := "::1200" // Wildcard for ProductID and VendorID +// isValid := validate(filter, &device) // returns true +func validate(filter string, device *pci.Device) bool { + ids := strings.Split(filter, ":") + + if (ids[0] == "" || ids[0] == device.Vendor.ID) && + (len(ids) < 2 || (ids[1] == "" || ids[1] == device.Product.ID)) && + (len(ids) < 3 || (ids[2] == "" || ids[2] == fmt.Sprintf("%s%s", device.Class.ID, device.Subclass.ID))) { + return true + } + + return false +} diff --git a/pkg/accelerator/accelerator_linux_test.go b/pkg/accelerator/accelerator_linux_test.go index 6a42c487..76a0b863 100644 --- a/pkg/accelerator/accelerator_linux_test.go +++ b/pkg/accelerator/accelerator_linux_test.go @@ -17,7 +17,7 @@ import ( "github.com/jaypipes/ghw/testdata" ) -func testScenario(t *testing.T, filename string, expectedDevs int) { +func testScenario(t *testing.T, filename string, hwFilter []string, expectedDevs int) { testdataPath, err := testdata.SnapshotsDirectory() if err != nil { t.Fatalf("Expected nil err, but got %v", err) @@ -41,7 +41,7 @@ func testScenario(t *testing.T, filename string, expectedDevs int) { _ = snapshot.Cleanup(tmpRoot) }() - info, err := accelerator.New(option.WithChroot(tmpRoot)) + info, err := accelerator.New(hwFilter, option.WithChroot(tmpRoot)) if err != nil { t.Fatalf("Expected nil err, but got %v", err) } @@ -59,7 +59,7 @@ func TestAcceleratorDefault(t *testing.T) { } // In this scenario we have 1 processing accelerator device - testScenario(t, "linux-amd64-accel.tar.gz", 1) + testScenario(t, "linux-amd64-accel.tar.gz", []string{}, 1) } @@ -69,5 +69,18 @@ func TestAcceleratorNvidia(t *testing.T) { } // In this scenario we have 1 Nvidia 3D controller device - testScenario(t, "linux-amd64-accel-nvidia.tar.gz", 1) + testScenario(t, "linux-amd64-accel-nvidia.tar.gz", []string{}, 1) +} + +func TestAcceleratorFilter(t *testing.T) { + if _, ok := os.LookupEnv("GHW_TESTING_SKIP_ACCELERATOR"); ok { + t.Skip("Skipping PCI tests.") + } + + // Set the filter to detect only processing accelerators (Nvidia not included) + discoveryFilter := make([]string, 0) + discoveryFilter = append(discoveryFilter, "::1200") + + // In this scenario we have 1 Nvidia 3D controller device + testScenario(t, "linux-amd64-accel-nvidia.tar.gz", discoveryFilter, 0) }