本文共 23976 字,大约阅读时间需要 79 分钟。
本代码是超分或者复原任务中,想找出PSNR差距较大的区域的代码
import osimport mathimport numpy as npimport cv2import globfrom skimage import transformfrom skimage import measurefrom collections import OrderedDictimport matplotlib.pyplot as pltimport matplotlib.patches as patchesdef bgr2ycbcr(img, only_y=True): '''same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] ''' in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255. # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255. return rlt.astype(in_img_type)def calculate_psnr(img1, img2): # img1 and img2 have range [0, 255] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2)**2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse))def mse2psnr(mse): if mse == 0: return float('inf') return 20 * math.log10(1.0 / math.sqrt(mse))def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): height = image.shape[0] width = image.shape[1] # resize heat map heat_map_resized = transform.resize(heat_map, (height, width)) # normalize heat map max_value = np.max(heat_map_resized) min_value = np.min(heat_map_resized) normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value) if display: # display plt.imshow(image) plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap) plt.axis(axis) plt.show() if save is not None: if verbose: print('save image: ' + save) H, W, C = image.shape figsize = W / float(dpi), H / float(dpi) fig = plt.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') ax.imshow(image) ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap) ax.set(xlim=[0, W], ylim=[H, 0], aspect=1) fig.savefig(save, dpi=dpi, transparent=True) def to_bin(img, lower, upper): return (lower < img) & (img < upper)def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): height, width, _ = im_BSL.shape # resize heat map heatmap_resized = transform.resize(heatmap, (height, width)) # normalize heat map max_value = np.max(heatmap_resized) min_value = np.min(heatmap_resized) normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value) # capture regions bin_map = to_bin(normalized_heatmap, thres, 1.0) label_map = measure.label(bin_map, connectivity=2) props = measure.regionprops(label_map) plot_im = im_BSL.copy() plot_im[~bin_map] = 0 if save is not None: if verbose: print('save image: ' + save) H, W, C = im_BSL.shape figsize = W / float(dpi), H / float(dpi) fig = plt.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') ax.imshow(im_BSL) ax.imshow(normalized_heatmap, alpha=alpha)# ax.imshow(plot_im, alpha=alpha) ax.axis(axis) for i in range(len(props)): if props[i].bbox_area >= 100: bbox_coord = props[i].bbox ax.add_patch( patches.Rectangle( (bbox_coord[1], bbox_coord[0]), bbox_coord[3] - bbox_coord[1], bbox_coord[2] - bbox_coord[0], edgecolor='y', linewidth = 6, fill=False )) psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \ calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) h_aln = 'right' if W - bbox_coord[1] < 50 else 'left' if bbox_coord[0] < 20: ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r', verticalalignment='top', horizontalalignment=h_aln, fontsize=26) else: ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r', verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26) ax.set(xlim=[0, W], ylim=[H, 0], aspect=1) fig.savefig(save, dpi=dpi, transparent=True)# plt.show()def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): H, W, C = im_BSL.shape # resize heat map heatmap_resized = transform.resize(heatmap, (H, W)) # normalize heat map max_value = np.max(heatmap_resized) min_value = np.min(heatmap_resized) normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value) # capture regions bin_map = to_bin(normalized_heatmap, 0.4, 1.0) label_map = measure.label(bin_map, connectivity=2) props = measure.regionprops(label_map) bbox_err = [] for i in range(len(props)): if props[i].bbox_area >= 100: bbox_coord = props[i].bbox err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]]) psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \ calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) bbox_err.append((i, err, psnr)) bbox_err.sort(key=lambda x:x[1], reverse=True) im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0) save_dir20= '/data1/cropimage/diff6_cvpr' save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png') im_diff20=im_diff*255 cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]]) num_bbox = min(len(bbox_err), 5) # Plot patches fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15)) if axes.ndim == 1: axes = [axes] for i in range(num_bbox): ind, err, psnr = bbox_err[i] bbox_coord = props[ind].bbox axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])# #################################################################################################### im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# axes[i][0].imshow(im_GT1)# axes[i][1].imshow(im_BSL1)# axes[i][2].imshow(im_OCT1)# axes[i][3].imshow(im_diff1)# save_dir1= '/data1/cropimage/diff/im_GT1'# save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# im_GT1=cv2.resize(im_GT1*255,(100, 100))# cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])# save_dir2= '/data1/cropimage/diff/im_BSL1'# save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_BSL1=im_BSL1*255# im_BSL1=cv2.resize(im_BSL1*255,(100, 100))# cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])# save_dir3= '/data1/cropimage/diff/im_OCT1'# save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_OCT1=im_OCT1*255# im_OCT1=cv2.resize(im_OCT1*255,(100, 100))# cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])# save_dir4= '/data1/cropimage/diff/im_diff1'# save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_diff1=im_diff1*255# im_diff1=cv2.resize(im_diff1*255,(100, 100))# cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]]) axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \ "{:+.2f}".format(psnr), color='r', fontsize=16) axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \ "{}".format(bbox_coord), color='r', fontsize=16) fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False) # plt.show()folder_BSL = "/data1/results10.09/(v1)Layer_HRLR_withoutconnection_SRResNet_16B64C_alpha=0.5/DIV2K_VAL/"folder_OCT = "/data1/results/(ture)_1X1_directshare_SRResNet_44B64C_alpha=0.5/DIV2K_VAL/"folder_GT = '/data1/data/DIV2K_VAL/DIV2K_valid_HR/'crop_border = 4suffix = '' # suffix for Gen imagestest_Y = False # True: test Y channel only; False: test RGB channelsPSNR_all = []SSIM_all = []img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]if test_Y: print('Testing Y channel.')else: print('Testing RGB channels.')patch_size = 32stride = 10for img_path in img_list: base_name = os.path.splitext(os.path.basename(img_path))[0] im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255. im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255. im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '').replace('_bicLRx4', '') + suffix + '.png'))[:, :, [2, 1, 0]] / 255. H, W, C = im_OCT.shape H_axis = np.arange(0, H - patch_size, stride) W_axis = np.arange(0, W - patch_size, stride) err_map = np.zeros((len(H_axis), len(W_axis))) inv_map = np.zeros((len(H_axis), len(W_axis))) total_err = np.mean((im_OCT - im_BSL)**2) for i, h in enumerate(H_axis): for j, w in enumerate(W_axis): patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :] patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :] patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C) err_map[i, j] = mse2psnr(patch_err) inv_map[i, j] = mse2psnr(total_err- patch_err) save_dir = '/data1/cropimage/diff_cvpr/' save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png') save_dir6 = '/data1/cropimage/heatdiff_cvpr/' save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png') #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False) plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False) plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)
改进版:
import osimport mathimport numpy as npimport cv2import globfrom skimage import transformfrom skimage import measurefrom collections import OrderedDictimport matplotlib.pyplot as pltimport matplotlib.patches as patchesdef bgr2ycbcr(img, only_y=True): '''same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] ''' in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255. # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255. return rlt.astype(in_img_type)def calculate_psnr(img1, img2): # img1 and img2 have range [0, 255] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2)**2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse))def mse2psnr(mse): if mse == 0: return float('inf') return 20 * math.log10(1.0 / math.sqrt(mse))def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): height = image.shape[0] width = image.shape[1] # resize heat map heat_map_resized = transform.resize(heat_map, (height, width)) # normalize heat map max_value = np.max(heat_map_resized) min_value = np.min(heat_map_resized) normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value) if display: # display plt.imshow(image) plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap) plt.axis(axis) plt.show() if save is not None: if verbose: print('save image: ' + save) H, W, C = image.shape figsize = W / float(dpi), H / float(dpi) fig = plt.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') ax.imshow(image) ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap) ax.set(xlim=[0, W], ylim=[H, 0], aspect=1) fig.savefig(save, dpi=dpi, transparent=True) def to_bin(img, lower, upper): return (lower < img) & (img < upper)def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): height, width, _ = im_BSL.shape # resize heat map heatmap_resized = transform.resize(heatmap, (height, width)) # normalize heat map max_value = np.max(heatmap_resized) min_value = np.min(heatmap_resized) normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value) # capture regions bin_map = to_bin(normalized_heatmap, thres, 1.0) label_map = measure.label(bin_map, connectivity=2) props = measure.regionprops(label_map) plot_im = im_BSL.copy() plot_im[~bin_map] = 0 if save is not None: if verbose: print('save image: ' + save) H, W, C = im_BSL.shape figsize = W / float(dpi), H / float(dpi) fig = plt.figure(figsize=figsize) ax = fig.add_axes([0, 0, 1, 1]) ax.axis('off') ax.imshow(im_BSL) ax.imshow(normalized_heatmap, alpha=alpha)# ax.imshow(plot_im, alpha=alpha) ax.axis(axis) for i in range(len(props)): if props[i].bbox_area >= 100: bbox_coord = props[i].bbox ax.add_patch( patches.Rectangle( (bbox_coord[1], bbox_coord[0]), bbox_coord[3] - bbox_coord[1], bbox_coord[2] - bbox_coord[0], edgecolor='y', linewidth = 6, fill=False )) psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \ calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) h_aln = 'right' if W - bbox_coord[1] < 50 else 'left' if bbox_coord[0] < 20: ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r', verticalalignment='top', horizontalalignment=h_aln, fontsize=26) else: ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r', verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26) ax.set(xlim=[0, W], ylim=[H, 0], aspect=1) fig.savefig(save, dpi=dpi, transparent=True)# plt.show()def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', dpi=80, verbose=False): H, W, C = im_BSL.shape # resize heat map heatmap_resized = transform.resize(heatmap, (H, W)) # normalize heat map max_value = np.max(heatmap_resized) min_value = np.min(heatmap_resized) normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value) # capture regions bin_map = to_bin(normalized_heatmap, 0.4, 1.0) label_map = measure.label(bin_map, connectivity=2) props = measure.regionprops(label_map) bbox_err = [] for i in range(len(props)): if props[i].bbox_area >= 100: bbox_coord = props[i].bbox err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]]) psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \ calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \ im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) bbox_err.append((i, err, psnr)) bbox_err.sort(key=lambda x:x[1], reverse=True) im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0) save_dir20= '/data1/cropimage/diff6_cvpr' save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png') im_diff20=im_diff*255 cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]]) num_bbox = min(len(bbox_err), 5) # Plot patches fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15)) if axes.ndim == 1: axes = [axes] for i in range(num_bbox): ind, err, psnr = bbox_err[i] bbox_coord = props[ind].bbox axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]) axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])# #################################################################################################### im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]# axes[i][0].imshow(im_GT1)# axes[i][1].imshow(im_BSL1)# axes[i][2].imshow(im_OCT1)# axes[i][3].imshow(im_diff1)# save_dir1= '/data1/cropimage/diff/im_GT1'# save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# im_GT1=cv2.resize(im_GT1*255,(100, 100))# cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])# save_dir2= '/data1/cropimage/diff/im_BSL1'# save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_BSL1=im_BSL1*255# im_BSL1=cv2.resize(im_BSL1*255,(100, 100))# cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])# save_dir3= '/data1/cropimage/diff/im_OCT1'# save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_OCT1=im_OCT1*255# im_OCT1=cv2.resize(im_OCT1*255,(100, 100))# cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])# save_dir4= '/data1/cropimage/diff/im_diff1'# save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')# #im_diff1=im_diff1*255# im_diff1=cv2.resize(im_diff1*255,(100, 100))# cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]]) axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \ "{:+.2f}".format(psnr), color='r', fontsize=16) axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \ "{}".format(bbox_coord), color='r', fontsize=16) fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False) # plt.show()folder_BSL = "/data1/results10.09/(multi_scale)_SRResNet_16B64C/DIV2K_VAL0.8/"folder_OCT ="/data1/results/(multi)1X1_directshare_SRResNet_48B64C_alpha=0.5/DIV2K_VAL0.8/"folder_GT = "/data1/data/multiscale_dataset/DIV2K_valid_HR_0.8/"crop_border = 4suffix = '' # suffix for Gen imagestest_Y = False # True: test Y channel only; False: test RGB channelsPSNR_all = []SSIM_all = []img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]if test_Y: print('Testing Y channel.')else: print('Testing RGB channels.')patch_size = 32stride = 10for img_path in img_list: base_name = os.path.splitext(os.path.basename(img_path))[0] im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255. im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255. im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('_bicLRx4', '_bicLRx0.6') + suffix + '.png'))[:, :, [2, 1, 0]] / 255. H, W, C = im_OCT.shape H_axis = np.arange(0, H - patch_size, stride) W_axis = np.arange(0, W - patch_size, stride) err_map = np.zeros((len(H_axis), len(W_axis))) inv_map = np.zeros((len(H_axis), len(W_axis))) total_err = np.mean((im_OCT - im_BSL)**2) for i, h in enumerate(H_axis): for j, w in enumerate(W_axis): patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :] patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :] patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C) err_map[i, j] = mse2psnr(patch_err) inv_map[i, j] = mse2psnr(total_err- patch_err) save_dir = '/data1/cropimage/heatmap/diff_DIV2K_VAL0.8/' if not os.path.exists(save_dir): os.mkdir(save_dir) save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png') save_dir6 = '/data1/cropimage/heatmap/heatdiff_DIV2K_VAL0.8/' if not os.path.exists(save_dir6): os.mkdir(save_dir6) save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png') #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False) plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False) plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)
转载地址:http://myajz.baihongyu.com/