import os
from PIL import ImageFilter
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import transforms


class Countgwhd(Dataset):
    def __init__(self, img_path, ann_path, resize_shape, pretrain=False):
        self.img_path = img_path
        self.ann_path = pd.read_csv(ann_path)
        self.shape = resize_shape
        if pretrain:
            self.transform = transforms.Resize([self.shape, self.shape])
        else:
            self.transform = transforms.Compose([ transforms.RandomHorizontalFlip(0.5), transforms.Resize([self.shape, self.shape])])
    def __getitem__(self, idex):
        # 拼接图片
        img_path = os.path.join(self.img_path + self.ann_path.iloc[idex, 0])
        # tensor类型
        image = Image.open(img_path)
        image = image.convert("RGB")
        image = image.filter(ImageFilter.SHARPEN)
        image = image.filter(ImageFilter.SHARPEN)
        TOtensor = transforms.ToTensor()
        image = TOtensor(image) * 255
        label = self.ann_path.iloc[idex, 1]
        image = self.transform(image)
        #filename = self.filenames[idex]
        return image, label

    def __len__(self):
            return len(self.ann_path)



