본문 바로가기
Study/머신러닝

[tensorflow] tf.data를 사용해 이미지 데이터 학습시키기

by 투말치 2021. 1. 25.

목차

    반응형

    tf.data는 datagenerator를 사용하지 않는다.

     

    이미지 가져오기

    - 읽어온 이미지를 batch size만큼 묶어줘야 한다.

    - shuffle을 사용해 읽어온 이미지들의 순서를 섞어준다. 

    dataset=tf.data.Dataset.from_tensor_slices(data_paths)
    dataset=dataset.map(read_image)
    dataset = dataset.batch(batch_size)
    dataset=dataset.shuffle(buffer_size=len(data_paths))  

     

    이미지의 레이블 값 가져오기

     def get_label(path):
        fname=tf.strings.split(path,'_')[-1]
        cls_name=tf.strings.regex_replace(fname,'.png','')
        onehot_encoding=tf.cast(class_names==cls_name, tf.uint8)
        return onehot_encoding

     

    Image Augmentation 작업

    - 이전에 이미지를 회전하고 움직이는 변형작업을 했던 것처럼 이미지를 변형하는 작업을 해준다.

    def image_preprocess(image, label):
        image=tf.image.random_flip_left_right(image)
        image=tf.image.random_flip_up_down(image)
        return image, label

     

     

    학습 진행

    - epoch 값은 batch size를 사용해서 직접 계산해야 한다.

    - fit_generator를 사용해서 학습을 진행한다.

    steps_per_epoch=len(train_paths)//batch_size
    validation_steps=len(test_paths)//batch_size  #둘다 직접 계산해야 함
    
    model.fit_generator(
        train_dataset,
        steps_per_epoch=steps_per_epoch,
        validation_data=test_dataset,
        validation_steps=validation_steps,
        epochs=num_epochs
    
    
    )

     

     

     

     

     

     

     

     

     

    반응형