221 lines
7.3 KiB
Python
221 lines
7.3 KiB
Python
import random
|
|
import os
|
|
from efficientnet_pytorch import EfficientNet
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tqdm import tqdm
|
|
import time
|
|
import torch.nn as nn
|
|
import matplotlib
|
|
import cv2
|
|
from torch.utils.data import WeightedRandomSampler
|
|
import numpy as np, torch
|
|
import albumentations as A
|
|
from albumentations.pytorch import ToTensorV2
|
|
matplotlib.use('TkAgg')
|
|
import matplotlib.pyplot as plt
|
|
|
|
#随机种子
|
|
def seed_everything(seed):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
seed_everything(0)
|
|
|
|
HW = 260
|
|
|
|
train_transform = A.Compose([
|
|
A.LongestMaxSize(max_size=260, interpolation=cv2.INTER_AREA),
|
|
A.PadIfNeeded(min_height=260, min_width=260, border_mode=cv2.BORDER_CONSTANT, value=(114, 114, 114)),
|
|
A.ShiftScaleRotate(
|
|
shift_limit=0.02,
|
|
scale_limit=0.05,
|
|
rotate_limit=15,
|
|
interpolation=cv2.INTER_LINEAR,
|
|
border_mode=cv2.BORDER_CONSTANT, value=(114, 114, 114),
|
|
p=0.5
|
|
),
|
|
|
|
A.ColorJitter(
|
|
brightness=0.3,
|
|
contrast=0.3,
|
|
saturation=0.15,
|
|
hue=0.05,
|
|
p=0.5
|
|
),
|
|
|
|
A.ImageCompression(quality_lower=75, quality_upper=100, p=0.2),
|
|
|
|
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
ToTensorV2(),
|
|
])
|
|
|
|
val_transform = A.Compose([
|
|
A.LongestMaxSize(max_size=260, interpolation=cv2.INTER_AREA),
|
|
A.PadIfNeeded(260, 260, border_mode=cv2.BORDER_CONSTANT, value=(114,114,114)),
|
|
A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
|
|
ToTensorV2(),
|
|
])
|
|
|
|
def read_file(path):
|
|
label_map = {
|
|
"front": 0,
|
|
"rear": 1,
|
|
"side": 2,
|
|
"front-left": 3,
|
|
"front-right": 4,
|
|
"rear-left": 5,
|
|
"rear-right": 6,
|
|
"others": 7
|
|
}
|
|
img_paths, labels = [], []
|
|
for folder_name in tqdm(os.listdir(path)):
|
|
folder_path = os.path.join(path, folder_name)
|
|
if not os.path.isdir(folder_path) or folder_name not in label_map:
|
|
continue
|
|
label = label_map[folder_name]
|
|
for img_name in os.listdir(folder_path):
|
|
img_paths.append(os.path.join(folder_path, img_name))
|
|
labels.append(label)
|
|
print("读到了%d个训练数据" % len(labels))
|
|
return img_paths, torch.LongTensor(labels)
|
|
|
|
class Carview_Dataset(Dataset):
|
|
def __init__(self, path, mode="train"):
|
|
self.paths, self.Y = read_file(path)
|
|
self.transform = train_transform if mode == "train" else val_transform
|
|
|
|
def __getitem__(self, idx):
|
|
img = np.array(Image.open(self.paths[idx]).convert("RGB"))
|
|
if self.transform:
|
|
img = self.transform(image=img)['image']
|
|
label = self.Y[idx]
|
|
return img, label
|
|
|
|
def __len__(self):
|
|
return len(self.Y)
|
|
|
|
|
|
train_path=r"D:\jili\car_viewpoint_classification\dataset\train"
|
|
val_path=r"D:\jili\car_viewpoint_classification\dataset\val"
|
|
train_set = Carview_Dataset(train_path, "train")
|
|
val_set = Carview_Dataset(val_path, "val")
|
|
|
|
|
|
labels = train_set.Y.cpu().numpy()
|
|
class_counts = np.bincount(labels, minlength=8)
|
|
class_weights = 1.0 / (class_counts + 1e-6)
|
|
sample_weights = class_weights[labels]
|
|
|
|
sampler = WeightedRandomSampler(
|
|
weights=torch.from_numpy(sample_weights).double(),
|
|
num_samples=len(sample_weights), # 每个 epoch 采样这么多样本
|
|
replacement=True # 支持少数类过采样
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_set, batch_size=16, sampler=sampler,
|
|
num_workers=0, pin_memory=True
|
|
)
|
|
|
|
# 验证集保持顺序评测
|
|
val_loader = DataLoader(
|
|
val_set, batch_size=16, shuffle=False,
|
|
num_workers=0, pin_memory=True
|
|
)
|
|
|
|
|
|
def build_efficientnet_model(num_classes=8):
|
|
model = EfficientNet.from_pretrained('efficientnet-b2')
|
|
in_features = model._fc.in_features
|
|
model._fc = nn.Linear(in_features, num_classes)
|
|
return model
|
|
|
|
|
|
def train_val(model, train_loader, val_loader, device, epochs, optimizer, loss, scheduler, save_path, patience=10):
|
|
model= model.to(device)
|
|
plt_train_loss = []
|
|
plt_val_loss = []
|
|
plt_train_acc = []
|
|
plt_val_acc = []
|
|
max_acc=0.0
|
|
|
|
for epoch in range(epochs):
|
|
train_loss = 0.0
|
|
val_loss = 0.0
|
|
train_acc = 0.0
|
|
val_acc = 0.0
|
|
start_time = time.time()
|
|
#模型开始训练
|
|
model.train()
|
|
for batch_x, batch_y in train_loader:
|
|
x, target = batch_x.to(device), batch_y.to(device)
|
|
pred = model(x) #向前传播
|
|
train_bat_loss = loss(pred, target) #计算预测与真实标签之间的损失
|
|
pred_labels = torch.argmax(pred, dim=1)
|
|
true_labels = target
|
|
optimizer.zero_grad() # 清除旧梯度
|
|
train_bat_loss.backward() # 反向传播计算当前梯度
|
|
optimizer.step() # 更新参数
|
|
train_loss += train_bat_loss.cpu().item() #计算batch损失累加值
|
|
train_acc += (pred_labels == true_labels).sum().item()
|
|
plt_train_loss.append(train_loss / train_loader.__len__())
|
|
plt_train_acc.append(train_acc / train_loader.dataset.__len__())
|
|
# 模型开始验证
|
|
model.eval()
|
|
with torch.no_grad(): #验证不用梯度计算
|
|
for batch_x, batch_y in val_loader:
|
|
x, target = batch_x.to(device), batch_y.to(device)
|
|
pred = model(x)
|
|
val_bat_loss = loss(pred, target)
|
|
pred_labels = torch.argmax(pred, dim=1)
|
|
true_labels = target
|
|
val_loss += val_bat_loss.cpu().item()
|
|
val_acc += (pred_labels == true_labels).sum().item()
|
|
plt_val_loss.append(val_loss / val_loader.__len__())
|
|
plt_val_acc.append(val_acc / val_loader.dataset.__len__())
|
|
val_acc_avg=val_acc/len(val_loader.dataset)
|
|
if val_acc_avg > max_acc:
|
|
max_acc = val_acc_avg
|
|
torch.save(model, save_path)
|
|
print('[%03d/%03d] %2.2f sec(s) TrainLoss : %.6f | valLoss: %.6f Trainacc : %.6f | valacc: %.6f' % \
|
|
(epoch, epochs, time.time() - start_time, plt_train_loss[-1], plt_val_loss[-1], plt_train_acc[-1], plt_val_acc[-1])
|
|
)
|
|
scheduler.step()
|
|
plt.plot(plt_train_loss)
|
|
plt.plot(plt_val_loss)
|
|
plt.title("loss")
|
|
plt.legend(["train", "val"])
|
|
plt.show()
|
|
|
|
plt.plot(plt_train_acc)
|
|
plt.plot(plt_val_acc)
|
|
plt.title("acc")
|
|
plt.legend(["train", "val"])
|
|
plt.show()
|
|
|
|
|
|
model = build_efficientnet_model(num_classes=8)
|
|
|
|
lr=0.0003
|
|
|
|
loss = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
|
|
optimizer=torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.5)
|
|
device="cuda" if torch.cuda.is_available() else "cpu"
|
|
save_path = "model_save/best_model.pth"
|
|
epochs = 15
|
|
os.makedirs("model_save", exist_ok=True)
|
|
train_val(model, train_loader, val_loader, device, epochs, optimizer, loss, scheduler, save_path)
|
|
|
|
|
|
|