Changed the way we're specifying the path for image in app.py. Directly mention it while execting now.
This commit is contained in:
+11
-3
@@ -1,10 +1,17 @@
|
|||||||
|
# For executing: python3 app.py /path/to/file
|
||||||
# Import libraries
|
# Import libraries
|
||||||
from PIL import Image, UnidentifiedImageError
|
from PIL import Image, UnidentifiedImageError
|
||||||
from transformers import ViTImageProcessor, ViTForImageClassification
|
from transformers import ViTImageProcessor, ViTForImageClassification
|
||||||
|
import sys
|
||||||
|
|
||||||
# Specify the local directory where the model files are stored
|
# Specify the local directory where the model files are stored
|
||||||
local_model_path = '/home/overnion/Status200/models/pretrained'
|
local_model_path = '/home/overnion/Status200/models/pretrained'
|
||||||
|
|
||||||
|
# Check if the image path is provided
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Error: No image path provided. Please provide the path to the image as an argument.")
|
||||||
|
exit()
|
||||||
|
|
||||||
# Load the image processor and model from the local directory
|
# Load the image processor and model from the local directory
|
||||||
image_processor = ViTImageProcessor.from_pretrained(local_model_path)
|
image_processor = ViTImageProcessor.from_pretrained(local_model_path)
|
||||||
model = ViTForImageClassification.from_pretrained(
|
model = ViTForImageClassification.from_pretrained(
|
||||||
@@ -13,10 +20,11 @@ model = ViTForImageClassification.from_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load image
|
# Load image
|
||||||
|
image_path = sys.argv[1] # Get the image path from command line arguments
|
||||||
try:
|
try:
|
||||||
image = Image.open('/home/overnion/Status200/models/samples/apple.png')
|
image = Image.open(image_path)
|
||||||
# Convert the image to RGB if it's not already
|
# Convert the image to RGB if it's not already
|
||||||
if (image.mode != 'RGB'):
|
if image.mode != 'RGB':
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("Error: Image file not found.")
|
print("Error: Image file not found.")
|
||||||
@@ -37,5 +45,5 @@ logits = outputs.logits
|
|||||||
predicted_class_idx = logits.argmax(-1).item()
|
predicted_class_idx = logits.argmax(-1).item()
|
||||||
|
|
||||||
# Print the predicted class
|
# Print the predicted class
|
||||||
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
print(model.config.id2label[predicted_class_idx])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user