Skip to content

Commit

Permalink
Merge pull request #91 from abdelaziz-mahdy/load-labels-from-absoulte
Browse files Browse the repository at this point in the history
feat: add LabelsLocation enum and update labels loading methods
  • Loading branch information
abdelaziz-mahdy authored Jan 9, 2025
2 parents e215703 + 6a39d74 commit af69bae
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 4.3.2

- Add LabelsLocation enum and update labels loading methods

## 4.3.1+1

- fixed example
Expand Down
26 changes: 25 additions & 1 deletion lib/enums/model_type.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,28 @@ enum CameraPreProcessingMethod {
byteList,
}

enum ModelLocation { asset, path }
/// Enum representing the location of the model.
///
/// `asset` indicates that the model is located in the application's assets.
///
/// `path` indicates that the model is located at a specific file path.
enum ModelLocation {
/// `asset` indicates that the model is located in the application's assets.
asset,

/// `path` indicates that the model is located at a specific file path.
path
}

/// Enum representing the location of the labels.
///
/// `asset` indicates that the labels are located in the application's assets.
///
/// `path` indicates that the labels are located at a specific file path.
enum LabelsLocation {
/// `asset` indicates that the model is located in the application's assets.
asset,

/// `path` indicates that the model is located at a specific file path.
path
}
44 changes: 32 additions & 12 deletions lib/pytorch_lite.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,22 @@ class PytorchLite {
String path, int imageWidth, int imageHeight, int numberOfClasses,
{String? labelPath,
bool ensureMatchingNumberOfClasses = true,
ModelLocation modelLocation = ModelLocation.asset}) async {
ModelLocation modelLocation = ModelLocation.asset,
LabelsLocation labelsLocation = LabelsLocation.asset}) async {
if (modelLocation == ModelLocation.asset) {
path = await _getAbsolutePath(path);
}

int index =
await ModelApi().loadModel(path, null, imageWidth, imageHeight, null);
List<String> labels = [];
if (labelPath != null) {
String labelData =
await _loadLabelsFile(labelPath, labelsLocation: labelsLocation);
if (labelPath.endsWith(".txt")) {
labels = await _getLabelsTxt(labelPath);
labels = await _getLabelsTxt(labelData);
} else {
labels = await _getLabelsCsv(labelPath);
labels = await _getLabelsCsv(labelData);
}
if (ensureMatchingNumberOfClasses) {
if (labels.length != numberOfClasses) {
Expand All @@ -67,7 +71,8 @@ class PytorchLite {
{String? labelPath,
ObjectDetectionModelType objectDetectionModelType =
ObjectDetectionModelType.yolov5,
ModelLocation modelLocation = ModelLocation.asset}) async {
ModelLocation modelLocation = ModelLocation.asset,
LabelsLocation labelsLocation = LabelsLocation.asset}) async {
if (modelLocation == ModelLocation.asset) {
path = await _getAbsolutePath(path);
}
Expand All @@ -76,10 +81,12 @@ class PytorchLite {
imageHeight, objectDetectionModelType.index);
List<String> labels = [];
if (labelPath != null) {
String labelData =
await _loadLabelsFile(labelPath, labelsLocation: labelsLocation);
if (labelPath.endsWith(".txt")) {
labels = await _getLabelsTxt(labelPath);
labels = await _getLabelsTxt(labelData);
} else {
labels = await _getLabelsCsv(labelPath);
labels = await _getLabelsCsv(labelData);
}
}
return ModelObjectDetection(index, imageWidth, imageHeight, labels,
Expand Down Expand Up @@ -110,18 +117,31 @@ class PytorchLite {
}
}

Future<String> _loadLabelsFile(String labelPath,
{LabelsLocation labelsLocation = LabelsLocation.asset}) async {
String labelsData;
if (labelsLocation == LabelsLocation.asset) {
labelsData = await rootBundle.loadString(labelPath);
} else {
labelsData = await File(labelPath).readAsString();
}
return labelsData;
}

///get labels in csv format
///labels are separated by commas
Future<List<String>> _getLabelsCsv(String labelPath) async {
String labelsData = await rootBundle.loadString(labelPath);
return labelsData.split(",");
Future<List<String>> _getLabelsCsv(
String fileContent,
) async {
return fileContent.split(",");
}

///get labels in txt format
///each line is a label
Future<List<String>> _getLabelsTxt(String labelPath) async {
String labelsData = await rootBundle.loadString(labelPath);
return labelsData.split("\n");
Future<List<String>> _getLabelsTxt(
String fileContent,
) async {
return fileContent.split("\n");
}

/*
Expand Down
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: pytorch_lite
description: Flutter package to help run pytorch lite models classification and yolov5 and yolov8
version: 4.3.1+1
version: 4.3.2
homepage: https://github.com/abdelaziz-mahdy/pytorch_lite

environment:
Expand Down

0 comments on commit af69bae

Please sign in to comment.