Skip to content
Snippets Groups Projects
Commit c770bb46 authored by haonan2's avatar haonan2
Browse files

Upload New File

parent 8ab4ec49
Branches main
No related tags found
No related merge requests found
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from utils.lane_detector import LaneDetector
from models.enet import ENet
import os
# Define dataset and checkpoint paths
DATASET_PATH = "/opt/data/TUSimple/test_set"
CHECKPOINT_PATH = "checkpoints/enet_checkpoint_epoch_best.pth" # Path to the trained model checkpoint
# Function to load the ENet model
def load_enet_model(checkpoint_path, device="cuda"):
enet_model = ENet(binary_seg=2, embedding_dim=4).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
enet_model.load_state_dict(checkpoint['model_state_dict'])
enet_model.eval()
return enet_model
def perspective_transform(image):
"""
Transform an image into a bird's eye view.
1. Calculate the image height and width.
2. Define source points on the original image and corresponding destination points.
3. Compute the perspective transform matrix using cv2.getPerspectiveTransform.
4. Warp the original image using cv2.warpPerspective to get the transformed output.
"""
####################### TODO: Your code starts Here #######################
####################### TODO: Your code ends Here #######################
return transformed_image
# Function to visualize lane predictions for multiple images in a single row
def visualize_lanes_row(images, instances_maps, alpha=0.7):
"""
Visualize lane predictions for multiple images in a single row
For each image:
1. Resize it to 512 x 256 for consistent visualization.
2. Apply perspective transform to both the original image and its instance map.
3. Overlay the instance map to a plot with the corresponding original image using a specified alpha value.
"""
num_images = len(images)
fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
####################### TODO: Your code starts Here #######################
####################### TODO: Your code ends Here #######################
plt.tight_layout()
plt.show()
def main():
# Initialize device and model
device = "cuda" if torch.cuda.is_available() else "cpu"
enet_model = load_enet_model(CHECKPOINT_PATH, device)
lane_predictor = LaneDetector(enet_model, device=device)
# List of test image paths
sub_paths = [
"clips/0530/1492626047222176976_0/20.jpg",
"clips/0530/1492626286076989589_0/20.jpg",
"clips/0531/1492626674406553912/20.jpg",
"clips/0601/1494452381594376146/20.jpg",
"clips/0601/1494452431571697487/20.jpg"
]
test_image_paths = [os.path.join(DATASET_PATH, sub_path) for sub_path in sub_paths]
# Load and process images
images = []
instances_maps = []
for path in test_image_paths:
image = cv2.imread(path)
if image is None:
print(f"Error: Unable to load image at {path}")
continue
print(f"Processing image: {path}")
instances_map = lane_predictor(image)
images.append(image)
instances_maps.append(instances_map)
# Visualize all lane predictions in a single row
if images and instances_maps:
visualize_lanes_row(images, instances_maps)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment