博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
图像超分之——寻找两张图差异的区域
阅读量:508 次
发布时间:2019-03-07

本文共 23976 字,大约阅读时间需要 79 分钟。

本代码是超分或者复原任务中,想找出PSNR差距较大的区域的代码

import osimport mathimport numpy as npimport cv2import globfrom skimage import transformfrom skimage import measurefrom collections import OrderedDictimport matplotlib.pyplot as pltimport matplotlib.patches as patchesdef bgr2ycbcr(img, only_y=True):    '''same as matlab rgb2ycbcr    only_y: only return Y channel    Input:        uint8, [0, 255]        float, [0, 1]    '''    in_img_type = img.dtype    img.astype(np.float32)    if in_img_type != np.uint8:        img *= 255.    # convert    if only_y:        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0    else:        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]    if in_img_type == np.uint8:        rlt = rlt.round()    else:        rlt /= 255.    return rlt.astype(in_img_type)def calculate_psnr(img1, img2):    # img1 and img2 have range [0, 255]    img1 = img1.astype(np.float64)    img2 = img2.astype(np.float64)    mse = np.mean((img1 - img2)**2)    if mse == 0:        return float('inf')    return 20 * math.log10(255.0 / math.sqrt(mse))def mse2psnr(mse):    if mse == 0:        return float('inf')    return 20 * math.log10(1.0 / math.sqrt(mse))def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on',                  dpi=80, verbose=False):    height = image.shape[0]    width = image.shape[1]    # resize heat map    heat_map_resized = transform.resize(heat_map, (height, width))    # normalize heat map    max_value = np.max(heat_map_resized)    min_value = np.min(heat_map_resized)    normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)    if display:        # display        plt.imshow(image)        plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)        plt.axis(axis)        plt.show()    if save is not None:        if verbose:            print('save image: ' + save)                    H, W, C = image.shape        figsize = W / float(dpi), H / float(dpi)        fig = plt.figure(figsize=figsize)        ax = fig.add_axes([0, 0, 1, 1])        ax.axis('off')                ax.imshow(image)        ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)        fig.savefig(save, dpi=dpi, transparent=True)        def to_bin(img, lower, upper):    return (lower < img) & (img < upper)def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,                  save=None, cmap='viridis', axis='on', dpi=80, verbose=False):    height, width, _ = im_BSL.shape    # resize heat map    heatmap_resized = transform.resize(heatmap, (height, width))    # normalize heat map    max_value = np.max(heatmap_resized)    min_value = np.min(heatmap_resized)    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)    # capture regions    bin_map = to_bin(normalized_heatmap, thres, 1.0)    label_map = measure.label(bin_map, connectivity=2)    props = measure.regionprops(label_map)    plot_im = im_BSL.copy()    plot_im[~bin_map] = 0    if save is not None:        if verbose:            print('save image: ' + save)                    H, W, C = im_BSL.shape        figsize = W / float(dpi), H / float(dpi)        fig = plt.figure(figsize=figsize)        ax = fig.add_axes([0, 0, 1, 1])        ax.axis('off')                ax.imshow(im_BSL)        ax.imshow(normalized_heatmap, alpha=alpha)#         ax.imshow(plot_im, alpha=alpha)        ax.axis(axis)        for i in range(len(props)):            if props[i].bbox_area >= 100:                bbox_coord = props[i].bbox                ax.add_patch(                    patches.Rectangle(                        (bbox_coord[1], bbox_coord[0]),                        bbox_coord[3] - bbox_coord[1],                        bbox_coord[2] - bbox_coord[0],                        edgecolor='y',                        linewidth = 6,                        fill=False                    ))                psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \                       calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)                h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'                if bbox_coord[0] < 20:                    ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r',                             verticalalignment='top', horizontalalignment=h_aln, fontsize=26)                else:                    ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',                            verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)                ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)        fig.savefig(save, dpi=dpi, transparent=True)#     plt.show()def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,                  save=None, cmap='viridis', axis='on', dpi=80, verbose=False):    H, W, C = im_BSL.shape    # resize heat map    heatmap_resized = transform.resize(heatmap, (H, W))    # normalize heat map    max_value = np.max(heatmap_resized)    min_value = np.min(heatmap_resized)    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)    # capture regions    bin_map = to_bin(normalized_heatmap, 0.4, 1.0)    label_map = measure.label(bin_map, connectivity=2)    props = measure.regionprops(label_map)    bbox_err = []    for i in range(len(props)):        if props[i].bbox_area >= 100:            bbox_coord = props[i].bbox            err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])            psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \                   calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)            bbox_err.append((i, err, psnr))                bbox_err.sort(key=lambda x:x[1], reverse=True)    im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)    save_dir20= '/data1/cropimage/diff6_cvpr'    save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')    im_diff20=im_diff*255    cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])    num_bbox = min(len(bbox_err), 5)    # Plot patches    fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))    if axes.ndim == 1:        axes = [axes]    for i in range(num_bbox):        ind, err, psnr = bbox_err[i]        bbox_coord = props[ind].bbox        axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])# ####################################################################################################         im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         axes[i][0].imshow(im_GT1)#         axes[i][1].imshow(im_BSL1)#         axes[i][2].imshow(im_OCT1)#         axes[i][3].imshow(im_diff1)#         save_dir1= '/data1/cropimage/diff/im_GT1'#         save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         im_GT1=cv2.resize(im_GT1*255,(100, 100))#         cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])#         save_dir2= '/data1/cropimage/diff/im_BSL1'#         save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_BSL1=im_BSL1*255#         im_BSL1=cv2.resize(im_BSL1*255,(100, 100))#         cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])#         save_dir3= '/data1/cropimage/diff/im_OCT1'#         save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_OCT1=im_OCT1*255#         im_OCT1=cv2.resize(im_OCT1*255,(100, 100))#         cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])#         save_dir4= '/data1/cropimage/diff/im_diff1'#         save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_diff1=im_diff1*255#         im_diff1=cv2.resize(im_diff1*255,(100, 100))#         cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])        axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \                "{:+.2f}".format(psnr), color='r', fontsize=16)        axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \                "{}".format(bbox_coord), color='r', fontsize=16)    fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)    # plt.show()folder_BSL = "/data1/results10.09/(v1)Layer_HRLR_withoutconnection_SRResNet_16B64C_alpha=0.5/DIV2K_VAL/"folder_OCT = "/data1/results/(ture)_1X1_directshare_SRResNet_44B64C_alpha=0.5/DIV2K_VAL/"folder_GT = '/data1/data/DIV2K_VAL/DIV2K_valid_HR/'crop_border = 4suffix = ''  # suffix for Gen imagestest_Y = False  # True: test Y channel only; False: test RGB channelsPSNR_all = []SSIM_all = []img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]if test_Y:    print('Testing Y channel.')else:    print('Testing RGB channels.')patch_size = 32stride = 10for img_path in img_list:    base_name = os.path.splitext(os.path.basename(img_path))[0]    im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.    im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.    im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '').replace('_bicLRx4', '') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.    H, W, C = im_OCT.shape    H_axis = np.arange(0, H - patch_size, stride)    W_axis = np.arange(0, W - patch_size, stride)    err_map = np.zeros((len(H_axis), len(W_axis)))    inv_map = np.zeros((len(H_axis), len(W_axis)))    total_err = np.mean((im_OCT - im_BSL)**2)    for i, h in enumerate(H_axis):        for j, w in enumerate(W_axis):            patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]            patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]            patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)            err_map[i, j] = mse2psnr(patch_err)            inv_map[i, j] = mse2psnr(total_err- patch_err)        save_dir = '/data1/cropimage/diff_cvpr/'    save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')    save_dir6 = '/data1/cropimage/heatdiff_cvpr/'    save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')    #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)    plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)    plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)

改进版:

import osimport mathimport numpy as npimport cv2import globfrom skimage import transformfrom skimage import measurefrom collections import OrderedDictimport matplotlib.pyplot as pltimport matplotlib.patches as patchesdef bgr2ycbcr(img, only_y=True):    '''same as matlab rgb2ycbcr    only_y: only return Y channel    Input:        uint8, [0, 255]        float, [0, 1]    '''    in_img_type = img.dtype    img.astype(np.float32)    if in_img_type != np.uint8:        img *= 255.    # convert    if only_y:        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0    else:        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]    if in_img_type == np.uint8:        rlt = rlt.round()    else:        rlt /= 255.    return rlt.astype(in_img_type)def calculate_psnr(img1, img2):    # img1 and img2 have range [0, 255]    img1 = img1.astype(np.float64)    img2 = img2.astype(np.float64)    mse = np.mean((img1 - img2)**2)    if mse == 0:        return float('inf')    return 20 * math.log10(255.0 / math.sqrt(mse))def mse2psnr(mse):    if mse == 0:        return float('inf')    return 20 * math.log10(1.0 / math.sqrt(mse))def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on',                  dpi=80, verbose=False):    height = image.shape[0]    width = image.shape[1]    # resize heat map    heat_map_resized = transform.resize(heat_map, (height, width))    # normalize heat map    max_value = np.max(heat_map_resized)    min_value = np.min(heat_map_resized)    normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)    if display:        # display        plt.imshow(image)        plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)        plt.axis(axis)        plt.show()    if save is not None:        if verbose:            print('save image: ' + save)                    H, W, C = image.shape        figsize = W / float(dpi), H / float(dpi)        fig = plt.figure(figsize=figsize)        ax = fig.add_axes([0, 0, 1, 1])        ax.axis('off')                ax.imshow(image)        ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)        fig.savefig(save, dpi=dpi, transparent=True)        def to_bin(img, lower, upper):    return (lower < img) & (img < upper)def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,                  save=None, cmap='viridis', axis='on', dpi=80, verbose=False):    height, width, _ = im_BSL.shape    # resize heat map    heatmap_resized = transform.resize(heatmap, (height, width))    # normalize heat map    max_value = np.max(heatmap_resized)    min_value = np.min(heatmap_resized)    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)    # capture regions    bin_map = to_bin(normalized_heatmap, thres, 1.0)    label_map = measure.label(bin_map, connectivity=2)    props = measure.regionprops(label_map)    plot_im = im_BSL.copy()    plot_im[~bin_map] = 0    if save is not None:        if verbose:            print('save image: ' + save)                    H, W, C = im_BSL.shape        figsize = W / float(dpi), H / float(dpi)        fig = plt.figure(figsize=figsize)        ax = fig.add_axes([0, 0, 1, 1])        ax.axis('off')                ax.imshow(im_BSL)        ax.imshow(normalized_heatmap, alpha=alpha)#         ax.imshow(plot_im, alpha=alpha)        ax.axis(axis)        for i in range(len(props)):            if props[i].bbox_area >= 100:                bbox_coord = props[i].bbox                ax.add_patch(                    patches.Rectangle(                        (bbox_coord[1], bbox_coord[0]),                        bbox_coord[3] - bbox_coord[1],                        bbox_coord[2] - bbox_coord[0],                        edgecolor='y',                        linewidth = 6,                        fill=False                    ))                psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \                       calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)                h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'                if bbox_coord[0] < 20:                    ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r',                             verticalalignment='top', horizontalalignment=h_aln, fontsize=26)                else:                    ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',                            verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)                ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)        fig.savefig(save, dpi=dpi, transparent=True)#     plt.show()def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False,                  save=None, cmap='viridis', axis='on', dpi=80, verbose=False):    H, W, C = im_BSL.shape    # resize heat map    heatmap_resized = transform.resize(heatmap, (H, W))    # normalize heat map    max_value = np.max(heatmap_resized)    min_value = np.min(heatmap_resized)    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)    # capture regions    bin_map = to_bin(normalized_heatmap, 0.4, 1.0)    label_map = measure.label(bin_map, connectivity=2)    props = measure.regionprops(label_map)    bbox_err = []    for i in range(len(props)):        if props[i].bbox_area >= 100:            bbox_coord = props[i].bbox            err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])            psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \                   calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)            bbox_err.append((i, err, psnr))                bbox_err.sort(key=lambda x:x[1], reverse=True)    im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)    save_dir20= '/data1/cropimage/diff6_cvpr'    save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')    im_diff20=im_diff*255    cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])    num_bbox = min(len(bbox_err), 5)    # Plot patches    fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))    if axes.ndim == 1:        axes = [axes]    for i in range(num_bbox):        ind, err, psnr = bbox_err[i]        bbox_coord = props[ind].bbox        axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])        axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])# ####################################################################################################         im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]#         axes[i][0].imshow(im_GT1)#         axes[i][1].imshow(im_BSL1)#         axes[i][2].imshow(im_OCT1)#         axes[i][3].imshow(im_diff1)#         save_dir1= '/data1/cropimage/diff/im_GT1'#         save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         im_GT1=cv2.resize(im_GT1*255,(100, 100))#         cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])#         save_dir2= '/data1/cropimage/diff/im_BSL1'#         save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_BSL1=im_BSL1*255#         im_BSL1=cv2.resize(im_BSL1*255,(100, 100))#         cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])#         save_dir3= '/data1/cropimage/diff/im_OCT1'#         save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_OCT1=im_OCT1*255#         im_OCT1=cv2.resize(im_OCT1*255,(100, 100))#         cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])#         save_dir4= '/data1/cropimage/diff/im_diff1'#         save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')#         #im_diff1=im_diff1*255#         im_diff1=cv2.resize(im_diff1*255,(100, 100))#         cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])        axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \                "{:+.2f}".format(psnr), color='r', fontsize=16)        axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \                "{}".format(bbox_coord), color='r', fontsize=16)    fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)    # plt.show()folder_BSL = "/data1/results10.09/(multi_scale)_SRResNet_16B64C/DIV2K_VAL0.8/"folder_OCT ="/data1/results/(multi)1X1_directshare_SRResNet_48B64C_alpha=0.5/DIV2K_VAL0.8/"folder_GT =   "/data1/data/multiscale_dataset/DIV2K_valid_HR_0.8/"crop_border = 4suffix = ''  # suffix for Gen imagestest_Y = False  # True: test Y channel only; False: test RGB channelsPSNR_all = []SSIM_all = []img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]if test_Y:    print('Testing Y channel.')else:    print('Testing RGB channels.')patch_size = 32stride = 10for img_path in img_list:    base_name = os.path.splitext(os.path.basename(img_path))[0]    im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.    im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.    im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('_bicLRx4', '_bicLRx0.6') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.    H, W, C = im_OCT.shape    H_axis = np.arange(0, H - patch_size, stride)    W_axis = np.arange(0, W - patch_size, stride)    err_map = np.zeros((len(H_axis), len(W_axis)))    inv_map = np.zeros((len(H_axis), len(W_axis)))    total_err = np.mean((im_OCT - im_BSL)**2)    for i, h in enumerate(H_axis):        for j, w in enumerate(W_axis):            patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]            patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]            patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)            err_map[i, j] = mse2psnr(patch_err)            inv_map[i, j] = mse2psnr(total_err- patch_err)    save_dir = '/data1/cropimage/heatmap/diff_DIV2K_VAL0.8/'    if not os.path.exists(save_dir):        os.mkdir(save_dir)    save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')    save_dir6 = '/data1/cropimage/heatmap/heatdiff_DIV2K_VAL0.8/'    if not os.path.exists(save_dir6):        os.mkdir(save_dir6)    save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')    #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)    plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)    plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)

 

 

转载地址:http://myajz.baihongyu.com/

你可能感兴趣的文章