编辑代码

      
import shutil
import argparse
import os
import time
import random

import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

def load_image(data_path: str, mode="rgb") -> np.ndarray:
    img = Image.open(data_path)
    if mode == "rgb":
        img = img.convert("RGB")
    return np.array(img)

'''
def cat_images(image_list: list[np.ndarray], axis=1, pad=20) -> np.ndarray:
    shape_list = [image.shape for image in image_list]
    max_h = max([shape[0] for shape in shape_list]) + pad * 2
    max_w = max([shape[1] for shape in shape_list]) + pad * 2

    for i, image in enumerate(image_list):
        canvas = np.zeros((max_h, max_w, 3), dtype=np.uint8)
        h, w, _ = image.shape
        crop_y = (max_h - h) // 2
        crop_x = (max_w - w) // 2
        canvas[crop_y : crop_y + h, crop_x : crop_x + w] = image
        image_list[i] = canvas

    image = np.concatenate(image_list, axis=axis)
    return image
'''


def show_anns(anns) -> None:
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]["segmentation"].shape[0], sorted_anns[0]["segmentation"].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann["segmentation"]
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

'''
def draw_binary_mask(raw_image: np.ndarray, binary_mask: np.ndarray, mask_color=(0, 0, 255)) -> np.ndarray:
    color_mask = np.zeros_like(raw_image, dtype=np.uint8)
    color_mask[binary_mask == 1] = mask_color
    mix = color_mask * 0.5 + raw_image * (1 - 0.5)
    binary_mask = np.expand_dims(binary_mask, axis=2)
    canvas = binary_mask * mix + (1 - binary_mask) * raw_image
    canvas = np.asarray(canvas, dtype=np.uint8)
    return canvas


def draw_binary_masks(raw_image: np.ndarray, binary_masks: list[np.ndarray]) -> np.ndarray:
    color_mask = np.zeros_like(raw_image, dtype=np.uint8)
    # import pdb; pdb.set_trace()
    for binary_mask in binary_masks:
        mask_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        color_mask[binary_mask == 1] = mask_color
    mix = color_mask * 0.5 + raw_image * (1 - 0.5)
    print(mix.shape)

    print(np.array(binary_masks).shape)


    vis_mask = np.array(binary_masks).sum(axis=0).astype('bool')
    print(vis_mask.shape)


    vis_mask = np.expand_dims(vis_mask, axis=2)
    canvas = vis_mask * mix + (1 - vis_mask) * raw_image
    canvas = np.asarray(canvas, dtype=np.uint8)
    return canvas


def draw_bbox(
    image: np.ndarray,
    bbox: list[list[int]],
    color: str or list[str] = "g",
    linewidth=1,
    tmp_name=".tmp.png",
) -> np.ndarray:
    dpi = 300
    oh, ow, _ = image.shape
    plt.close()
    plt.figure(1, figsize=(oh / dpi, ow / dpi))
    plt.imshow(image)
    if isinstance(color, str):
        color = [color for _ in bbox]
    for (x0, y0, x1, y1), c in zip(bbox, color):
        plt.gca().add_patch(Rectangle((x0, y0), x1 - x0, y1 - y0, lw=linewidth, edgecolor=c, facecolor=(0, 0, 0, 0)))
    plt.axis("off")
    plt.savefig(tmp_name, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0.0)
    image = cv2.resize(load_image(tmp_name), dsize=(ow, oh))
    os.remove(tmp_name)
    plt.close()
    return image


def draw_scatter(
    image: np.ndarray,
    points: list[list[int]],
    color: str or list[str] = "g",
    marker="*",
    s=10,
    ew=0.25,
    tmp_name=".tmp.png",
) -> np.ndarray:
    dpi = 300
    oh, ow, _ = image.shape
    plt.close()
    plt.figure(1, figsize=(oh / dpi, ow / dpi))
    plt.imshow(image)
    if isinstance(color, str):
        color = [color for _ in points]
    for (x, y), c in zip(points, color):
        plt.scatter(x, y, color=c, marker=marker, s=s, edgecolors="white", linewidths=ew)
    plt.axis("off")
    plt.savefig(tmp_name, format="png", dpi=dpi, bbox_inches="tight", pad_inches=0.0)
    image = cv2.resize(load_image(tmp_name), dsize=(ow, oh))
    os.remove(tmp_name)
    plt.close()
    return image
'''
def compute_polygon_dist(poly1, poly2):
    min_point_poly1_idx, min_point_poly2_idx = -1, -1
    min_dist = 1e10
    # print('           ************************8       :  ',len(poly1), ' ' , len(poly2))
    for idx1_ in range(len(poly1)):
        point = poly1[idx1_]
        # if is_in_poly(point, poly2):
        #     continue
        point_2_poly2_dists = np.linalg.norm(point - poly2, axis=1)
        # print(point_2_poly2_dists.shape)

        idx_  = np.argmin(point_2_poly2_dists)

        if point_2_poly2_dists[idx_] < min_dist :
            min_dist = point_2_poly2_dists[idx_]
            min_point_poly1_idx = idx1_
            min_point_poly2_idx = idx_
    # print('min dist : ' , min_dist,  '  idx : ', [min_point_poly1_idx, min_point_poly2_idx])


    return [min_point_poly1_idx, min_point_poly2_idx],  min_dist


def merge_two_polygons(poly1, poly2 , nearest_point_idx = []) :




    def find_next_point(poly, idx):
        if idx < len(poly)-1:
            return poly[idx+1]
        else:
            return poly[0]

    idx1, idx2 = nearest_point_idx[0], nearest_point_idx[1]
    # print('idx1, idx2 ',idx1, idx2)
    # 找出poly1[idx1] 的下一个点
    # if idx1 < len(poly1) -1 :
    #     idx1_next_point = poly1[idx1+1]
    # else:
    #     idx1_next_point = poly1[0]
    idx1_next_point = find_next_point(poly1, idx1)

    # print('idx1_next_point ', idx1_next_point)
    # 在poly1[idx1]附近找一个很近的点,沿着poly1_norm_vec方向
    idx1_point = poly1[idx1]

    pixel_offset = 1
    if np.linalg.norm(idx1_next_point - idx1_point, axis=0) <= pixel_offset:
        idx1_near_point = idx1_next_point
    else:

        poly1_norm_vec = (idx1_next_point - idx1_point) / np.linalg.norm(idx1_next_point - idx1_point, axis=0)



        # print(nearest_point_idx)
        idx1_near_point = idx1_point + poly1_norm_vec * pixel_offset
    # print('idx1_point ', idx1_point)
    # print('idx1_near_point ', idx1_near_point)

    # 在poly2内部找一个点,
    idx2_next_point = find_next_point(poly2, idx2)


    idx2_point = poly2[idx2]

    if np.linalg.norm(idx2_next_point - idx2_point, axis=0) <= pixel_offset:
        idx2_near_point = idx2_next_point
    else:

        poly2_norm_vec = (idx2_next_point - idx2_point) / np.linalg.norm(idx2_next_point - idx2_point, axis=0)
        # poly2_vec = idx2_next_point - idx2_point

        idx2_near_point = idx2_point + poly2_norm_vec * pixel_offset


    # # print('idx2_near_point ', idx2_near_point)
    # point_center =  (idx2_point+ idx2_next_point) / 2


    poly_final = []
    poly_final.append(idx1_near_point)
    for i in range(idx1+1, len(poly1)):
        poly_final.append(poly1[i])

    for i in range(idx1+1):
        poly_final.append(poly1[i])

    poly_final.append(idx2_near_point)

    for i in range(idx2+1, len(poly2)):
        poly_final.append(poly2[i])

    for i in range(idx2+1):
        poly_final.append(poly2[i])


    return poly_final

def save_yolo_txt(yolo_txt_dir, basename, yolo_info, img_shape):
    os.makedirs(yolo_txt_dir, exist_ok=True)

    print(basename , ' ', yolo_txt_dir)

    if len(yolo_info) == 0:
        return
    h, w = img_shape[:2]
    print(h, w)
    with open(os.path.join(yolo_txt_dir, basename + '.txt'), 'w') as f:

        for item in yolo_info:
            class_id, polygon = item

            f.write(str(class_id) + ' ')

            for i in range(len(polygon)):
                point = polygon[i]
                p_x_ratio, p_y_ratio = point[0] / w, point[1] / h
                f.write(str(p_x_ratio) + ' ' + str(p_y_ratio) + ' ')

            f.write('\n')
    f.close()


def mkdir(path):
    isExist = os.path.exists(path)
    if not isExist:
        # Create a new directory because it does not exist
        os.makedirs(path)
        print("The new directory is created!")

def mask2polygon(mask):
    # pass
    contours = cv2.findContours(mask,cv2.RETR_TREE , cv2.CHAIN_APPROX_SIMPLE)[0]
    final_polygon = []
    # print('mask.shape ',mask.shape)
    if len(contours) == 0:
        return False, []
    # print(' ############ len(contours[0]) : ',len(contours))
    for p in range(len(contours[0])):
        # print(contours[0][p][0])
        final_polygon.append(contours[0][p][0])

    # seg_mask = np.zeros((mask.shape), dtype=np.uint8)

    for i in range(1, len(contours)):
        # print(contours[i][0])

        polygon = []
        for p in range(len(contours[i])):
            # print(contours[i][p][0])
            polygon.append(contours[i][p][0])
        polygon = polygon[: :-1]

        nearest_point_idx, dist = compute_polygon_dist(final_polygon, polygon)

        final_polygon = merge_two_polygons(final_polygon, polygon, nearest_point_idx)


        # print(nearest_point_idx, dist)
        seg_mask = np.zeros((mask.shape), dtype=np.uint8)
        cv2.fillPoly(seg_mask, [np.array(final_polygon, dtype=np.int32)], color=255)

        # draw_img = cv2.addWeighted(draw_img, 1, seg_mask, 0.8, 0)
        # cv2.drawContours(image,contours[i],-1,(0,0,255),3,lineType=cv2.LINE_AA)

        # cv2.imshow('dasd', seg_mask)
        # cv2.waitKey(0)
    # for i in range(len(final_polygon)):

    # print('len(final_polygon) ',len(final_polygon))
    return True, final_polygon


def find_images(root_dir):
    images_dir = os.path.join(root_dir, 'images')
    return [f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif'))]


def process_masks(root_dir, image_name):
    soft_masks_dir = os.path.join(root_dir, 'soft_masks')
    binary_masks = []
    mask_ids = []
    for subdir in os.listdir(soft_masks_dir):
        subdir_path = os.path.join(soft_masks_dir, subdir)
        if os.path.isdir(subdir_path):
            mask_path = os.path.join(subdir_path, image_name)
            if os.path.exists(mask_path):
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                binary_mask = (mask > 128).astype(np.uint8) * 255
                binary_masks.append(binary_mask)
                mask_ids.append(subdir)

    return binary_masks, mask_ids

def remove_extension(filename):
    return os.path.splitext(filename)[0]

def add_png_extension(filename):
    return remove_extension(filename) + '.png'

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--class_id", type=int, default=0)
    parser.add_argument("--root_dir", type=str, default=r"C:\Users\ubtech\Cutie\workspace\发动机前悬置软垫总成01") #改成自己本地的路径(截至到images、masks、softmasks、visualization文件夹的上级目录)
    args, opt = parser.parse_known_args()
    root_dir = args.root_dir
    image_names = find_images(root_dir)
    pred_yolo_txt_dir = os.path.join(args.root_dir, 'pred_yolo_txt')
    if os.path.exists(pred_yolo_txt_dir):
        shutil.rmtree(pred_yolo_txt_dir)
    # Create a new directory
    os.makedirs(pred_yolo_txt_dir)
    for image_name in image_names:
        basename = remove_extension(image_name)
        png_name = add_png_extension(image_name)
        # print("Processing image {}".format(basename))
        binary_masks, mask_ids = process_masks(root_dir, png_name)
        if len(binary_masks) == 0:
            print('No masks found')
            continue
        yolo_info = []
        mask_shape = None
        for i in range(len(binary_masks)):
            mask = binary_masks[i]

            #对东风项目来说,只需修改cls_id数值,对应标注文档,比如:发动机前悬置软垫总成对应cls_id=0)
            cls_id = '0'
            # if cls_id == '1' or cls_id == '2' or cls_id == '3' or cls_id == '4'or cls_id == '5' or cls_id == '6':
            #      cls_id = '0'
            # elif cls_id == '4' or cls_id == '5':
            #     cls_id = '14'
            # elif cls_id == '6' or cls_id == '7':
            #     cls_id = '12'
            # elif cls_id == '8' or cls_id == '9':
            #     cls_id = '7'
            # elif cls_id == '4' or cls_id == '5' or cls_id == '6' or cls_id == '7':
            #     cls_id = '4'
            # elif cls_id == '8' or cls_id == '9' or cls_id == '10' or cls_id == '11':
            #     cls_id = '2'
            # elif cls_id == '12' or cls_id == '13' or cls_id == '14' or cls_id == '15':
            #     cls_id = '0'
            # elif cls_id == '2':
            #     cls_id = '2' 
            # elif cls_id == '3' or cls_id == '4':
            # 	cls_id = '0'
            # elif cls_id == '5' or cls_id == '6':
            #     cls_id = '4'
            # elif cls_id == '3' or cls_id == '4' or cls_id == '5' or cls_id == '6':
            # 	cls_id = '14'
            # elif cls_id == '7' or cls_id == '8' or cls_id == '9' or cls_id == '10':
            # 	cls_id = '7'
            # elif cls_id == '11' or cls_id == '12' or cls_id == '13' or cls_id == '14':
            # 	cls_id = '12'
            # else:
            #     cls_id = cls_id
            tmp = mask
            # tmp = tmp.astype(int) * 255
            # tmp = tmp.transpose(1, 2, 0)
            # tmp = np.squeeze(tmp)
            # tmp = np.uint8(tmp)
            if_success, polygon = mask2polygon(tmp)
            mask_shape = tmp.shape
            if if_success:
                yolo_info.append([cls_id, polygon])
        # if len(yolo_info) == 0:
        #     print('No masks found')
        #     continue
        print(len(yolo_info))
        save_yolo_txt(pred_yolo_txt_dir, basename, yolo_info, mask_shape)
    print(f"Found {len(image_names)} images.")
    print(f"Created {len(binary_masks)} binary masks.")


if __name__ == "__main__":
    main()