Source code for tensorlayerx.vision.transforms.utils

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import cv2
import os
import numpy as np
import threading


__all__ = [
    'load_image',
    'save_image',
    'load_images',
    'save_images',
]

[docs]def load_image(path): '''Load an image Parameters ---------- path : str path of the image. Returns : numpy.ndarray ------- a numpy RGB image Examples ---------- With TensorLayerX >>> import tensorlayerx as tlx >>> path = './data/1.png' >>> image = tlx.vision.load_image(path) >>> print(image) ''' image = cv2.imread(path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image
[docs]def save_image(image, file_name, path): '''Save an image Parameters ---------- image : numpy.ndarray The image to save file_name : str image name to save path : str path to save image Examples ---------- With TensorLayerX >>> import tensorlayerx as tlx >>> load_path = './data/1.png' >>> save_path = './test/' >>> image = tlx.vision.load_image(path) >>> tlx.vision.save_image(image, file_name='1.png',path=save_path) ''' image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(path, file_name), image)
[docs]def load_images(path, n_threads = 10): '''Load images from file Parameters ---------- path : str path of the images. n_threads : int The number of threads to read image. Returns : list ------- a list of numpy RGB images Examples ---------- With TensorLayerX >>> import tensorlayerx as tlx >>> load_path = './data/' >>> image = tlx.vision.load_images(path) ''' images = [] files = sorted(os.listdir(path)) if n_threads == 0: for file in files: file = os.path.join(path, file) image = load_image(file) images.append(image) else: for file in range(0, len(files), n_threads): image_list = files[file:file + n_threads] image = threading_data(image_list, fn=load_image, path = path) images.extend(image) return images
[docs]def save_images(images, file_names, path): '''Save images Parameters ---------- images : list a list of numpy RGB images file_names : list a list of image names to save path : str path to save images Examples ---------- With TensorLayerX >>> import tensorlayerx as tlx >>> load_path = './data/' >>> save_path = './test/' >>> images = tlx.vision.load_images(path) >>> name_list = user_define >>> tlx.vision.save_images(images, file_names=name_list,path=save_path) ''' if len(images) != len(file_names): raise ValueError(" The number of images should be equal to the number of file names.") for i in range(len(file_names)): images[i] = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(path, str(file_names[i])), images[i])
def threading_data(data=None, fn=None, thread_count=None, path = None): """Process a batch of data by given function by threading. Usually be used for data augmentation. Parameters ----------- data : numpy.array or others The data to be processed. thread_count : int The number of threads to use. fn : function The function for data processing. more args : the args for `fn` Ssee Examples below. Returns ------- list or numpyarray The processed results. References ---------- - `python queue <https://pymotw.com/2/Queue/index.html#module-Queue>`__ - `run with limited queue <http://effbot.org/librarybook/queue.htm>`__ """ def apply_fn(results, i, data, path): path = os.path.join(path, data) results[i] = fn(path) if thread_count is None: results = [None] * len(data) threads = [] # for i in range(len(data)): # t = threading.Thread(name='threading_and_return', target=apply_fn, args=(results, i, data[i], kwargs)) for i, d in enumerate(data): t = threading.Thread(name='threading_and_return', target=apply_fn, args=(results, i, d, path)) t.start() threads.append(t) else: divs = np.linspace(0, len(data), thread_count + 1) divs = np.round(divs).astype(int) results = [None] * thread_count threads = [] for i in range(thread_count): t = threading.Thread( name='threading_and_return', target=apply_fn, args=(results, i, data[divs[i]:divs[i + 1]], path) ) t.start() threads.append(t) for t in threads: t.join() if thread_count is None: try: return np.asarray(results, dtype=object) except Exception: return results else: return np.concatenate(results)