diff --git a/models/app.py b/models/app.py index cfacb75..db271c0 100644 --- a/models/app.py +++ b/models/app.py @@ -1,10 +1,17 @@ +# For executing: python3 app.py /path/to/file # Import libraries from PIL import Image, UnidentifiedImageError from transformers import ViTImageProcessor, ViTForImageClassification +import sys # Specify the local directory where the model files are stored local_model_path = '/home/overnion/Status200/models/pretrained' +# Check if the image path is provided +if len(sys.argv) < 2: + print("Error: No image path provided. Please provide the path to the image as an argument.") + exit() + # Load the image processor and model from the local directory image_processor = ViTImageProcessor.from_pretrained(local_model_path) model = ViTForImageClassification.from_pretrained( @@ -13,10 +20,11 @@ model = ViTForImageClassification.from_pretrained( ) # Load image +image_path = sys.argv[1] # Get the image path from command line arguments try: - image = Image.open('/home/overnion/Status200/models/samples/apple.png') + image = Image.open(image_path) # Convert the image to RGB if it's not already - if (image.mode != 'RGB'): + if image.mode != 'RGB': image = image.convert('RGB') except FileNotFoundError: print("Error: Image file not found.") @@ -37,5 +45,5 @@ logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Print the predicted class -print("Predicted class:", model.config.id2label[predicted_class_idx]) +print(model.config.id2label[predicted_class_idx])