Mehyaar's picture
Upload 2 files
debf278 verified
Raw
History Blame Contribute Delete
2.98 kB
import torch
import gradio as gr
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class_names = [
'Auto Rickshaws', 'Bikes', 'Cars', 'Motorcycles',
'Planes', 'Ships', 'Trains'
]
class VehicleClassifier(nn.Module):
def __init__(self):
super(VehicleClassifier, self).__init__()
# Convolutional Layers
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
# Pooling Layer
self.pool = nn.MaxPool2d(2, 2)
# FC Layers
self.fc1 = nn.Linear(256 * 14 * 14, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 7) # 7 classes for the 7 vehicle categories
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Apply Convolutional Layers with ReLU activation and Pooling
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
# Flatten the tensor before feeding into the FCL
x = x.view(-1, 256 * 14 * 14)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
model = VehicleClassifier().to('cpu')
model.load_state_dict(torch.load('vehicle_classifier.pth', map_location=torch.device('cpu')))
model.eval()
def predict(image):
try:
image = Image.open(image).convert('RGB')
image = transform_test(image).unsqueeze(0) # Add batch dimension
print(f"Image shape after transformation: {image.shape}")
with torch.no_grad():
outputs = model(image)
print(f"Model output: {outputs}")
_, predicted = torch.max(outputs, 1)
prediction = class_names[predicted.item()]
print(f"Predicted class: {prediction}")
return prediction
except Exception as e:
print(f"Error during prediction: {e}")
traceback.print_exc()
return "An error occurred during prediction."
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type='filepath'),
outputs=gr.Label(num_top_classes=1),
title="Vehicle Classification",
description="Upload an image of a vehicle, and the model will predict its type."
)
# Launch the interface
interface.launch(share=True)