728x90
Image.open()으로 하니까 PIL resize라고 검색ㅎㅐ서 resize하는 방법을 data_lowlight = data_lowlight.resize(self.resize), data_highlight = data_highlight.resize(self.resize) 을 if self.mode == 'train': 과 elif self.mode == 'test': 에 각각 추가
def __getitem__(self, index):
data_lowlight_path = self.data_list[index]
if self.mode == 'train':
data_lowlight = Image.open(data_lowlight_path)
data_lowlight = data_lowlight.resize(self.resize)
data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low','Normal'))
data_highlight = data_highlight.resize(self.resize)
data_lowlight, data_highlight = self.FLIP_LR(data_lowlight, data_highlight)
data_lowlight, data_highlight = self.FLIP_UD(data_lowlight, data_highlight)
data_lowlight, data_highlight = self.Random_Crop(data_lowlight, data_highlight)
# print(self.w, self.h)
#print(data_lowlight.size, data_highlight.size)
data_lowlight = data_lowlight.resize((self.w, self.h), Image.ANTIALIAS)
data_highlight = data_highlight.resize((self.w, self.h), Image.ANTIALIAS)
data_lowlight, data_highlight = (np.asarray(data_lowlight) / 255.0), (np.asarray(data_highlight) / 255.0)
if self.normalize:
#data_lowlight, data_highlight = torch.from_numpy(data_lowlight).float(), torch.from_numpy(data_highlight).float()
transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ConvertImageDtype(torch.float), ])
transform_gt = Compose([ToTensor(), ConvertImageDtype(torch.float), ])
#return transform_input(data_lowlight).permute(2, 0, 1), transform_gt(data_highlight).permute(2, 0, 1)
return transform_input(data_lowlight), transform_gt(data_highlight)
else:
data_lowlight, data_highlight = torch.from_numpy(data_lowlight).float(), torch.from_numpy(data_highlight).float()
return data_lowlight.permute(2,0,1), data_highlight.permute(2,0,1)
elif self.mode == 'test':
data_lowlight = Image.open(data_lowlight_path)
data_lowlight = data_lowlight.resize(self.resize)
data_highlight = Image.open(data_lowlight_path.replace('low', 'normal').replace('Low','Normal'))
data_highlight = data_highlight.resize(self.resize)
data_lowlight, data_highlight = (np.asarray(data_lowlight) / 255.0), (np.asarray(data_highlight) / 255.0)
#data_lowlight, data_highlight = torch.from_numpy(data_lowlight).float(), torch.from_numpy(data_highlight).float()
if self.normalize:
#data_lowlight, data_highlight = torch.from_numpy(data_lowlight).float(), torch.from_numpy(data_highlight).float()
transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ConvertImageDtype(torch.float), ])
transform_gt = Compose([ToTensor(), ConvertImageDtype(torch.float), ])
#return transform_input(data_lowlight).permute(2, 0, 1), transform_gt(data_highlight).permute(2, 0, 1)
return transform_input(data_lowlight), transform_gt(data_highlight)
else:
data_lowlight, data_highlight = torch.from_numpy(data_lowlight).float(), torch.from_numpy(data_highlight).float()
return data_lowlight.permute(2,0,1), data_highlight.permute(2,0,1)
data_lowlight.resize(self.resize)
class lowlight_loader(data.Dataset):
def __init__(self, images_path, mode='train', normalize=True, resize=(512,512)):
self.train_list = populate_train_list(images_path, mode)
#self.h, self.w = int(img_size[0]), int(img_size[1])
# train or test
self.mode = mode
self.data_list = self.train_list
self.normalize = normalize
self.resize = resize
print("Total examples:", len(self.train_list))
self.resize = resize 추가함, resize=(512,512) 디폴트 값도 추가.
'AI > Deep Learning' 카테고리의 다른 글
IAT 에러 해결법 (0) | 2023.03.08 |
---|---|
[펌] 사전 훈련된 CNN 사용하기(ImageNet 데이터셋, VGG16 모델), 미세 조정 (0) | 2021.10.22 |
[펌] 신경망 기계 번역 (0) | 2021.04.12 |