Python tensorflow.python.platform.gfile 模块,Glob() 实例源码
我们从Python开源项目中,提取了以下47个代码示例,用于说明如何使用tensorflow.python.platform.gfile.Glob()。
def create_data_list(image_dir):
if not gfile.Exists(image_dir):
print("Image director '" + image_dir + "' not found.")
return None
extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
print("Looking for images in '" + image_dir + "'")
file_list = []
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
print("No files found in '" + image_dir + "'")
return None
images = []
labels = []
for file_name in file_list:
image = Image.open(file_name)
image_gray = image.convert('L')
image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
input_img = np.array(image_resize, dtype='int16')
image.close()
label_name = os.path.basename(file_name).split('_')[0]
images.append(input_img)
labels.append(label_name)
return zip(images, labels)
def test_read_text_lines_multifile(self):
gfile.Glob = self._orig_glob
filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
filenames, batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
def test_batch_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("A\nB\nC\nD\nE\n")
batch_size = 3
queue_capacity = 10
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
[filename], batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
read_batch_size=10, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
def restore(self, sess, save_path):
"""Restores previously saved variables.
This method runs the ops added by the constructor for restoring variables.
It requires a session in which the graph was launched. The variables to
restore do not have to have been initialized, as restoring is itself a way
to initialize variables.
The `save_path` argument is typically a value previously returned from a
`save()` call, or a call to `latest_checkpoint()`.
Args:
sess: A `Session` to use to restore the parameters.
save_path: Path where parameters were previously saved.
Raises:
ValueError: If the given `save_path` does not point to a file.
"""
if not gfile.Glob(save_path):
raise ValueError("Restore called with invalid save path %s" % save_path)
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
if gfile.Glob(ckpt.model_checkpoint_path):
return ckpt.model_checkpoint_path
return None
def create_data_list(image_dir):
if not gfile.Exists(image_dir):
print("Image director '" + image_dir + "' not found.")
return None
extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
print("Looking for images in '" + image_dir + "'")
file_list = []
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
print("No files found in '" + image_dir + "'")
return None
images = []
labels = []
for file_name in file_list:
image = Image.open(file_name)
image_gray = image.convert('L')
image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
input_img = np.array(image_resize, dtype='int16')
image.close()
label_name = os.path.basename(file_name).split('_')[0]
images.append(input_img)
labels.append(label_name)
return zip(images, labels)
def input_data(image_dir):
if not gfile.Exists(image_dir):
print(">> Image director '" + image_dir + "' not found.")
return None
extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
print(">> Looking for images in '" + image_dir + "'")
file_list = []
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
print(">> No files found in '" + image_dir + "'")
return None
batch_size = len(file_list)
images = np.zeros([batch_size, IMAGE_HEIGHT*IMAGE_WIDTH], dtype='float32')
files = []
i = 0
for file_name in file_list:
image = Image.open(file_name)
image_gray = image.convert('L')
image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
image.close()
input_img = np.array(image_resize, dtype='float32')
input_img = np.multiply(input_img.flatten(), 1./255) - 0.5
images[i,:] = input_img
base_name = os.path.basename(file_name)
files.append(base_name)
i += 1
return images, files
def get_data(data_path,
data_usedfor,
data_lvl,
feature_type="rgb",
preprocess=None,
shuffle=True,
num_epochs=1):
files_pattern = data_usedfor+"*.tfrecord"
data_files = gfile.Glob(data_path + files_pattern)
filename_queue = tf.train.string_input_producer(data_files, num_epochs=num_epochs, shuffle=shuffle)
tfrecord_list = tfrecord_reader(filename_queue, data_lvl)
vids = np.array([tfrecord_list[i][GLOBAL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)])
labels = np.array([tfrecord_list[i][GLOBAL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)])
if data_lvl == "video":
if feature_type == "rgb":
X = [tfrecord_list[i][VID_LVL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)]
elif feature_type == "audio":
X = [tfrecord_list[i][VID_LVL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)]
elif data_lvl == "frame":
if feature_type == "rgb":
X = [tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]] for i, _ in enumerate(tfrecord_list)]
#X = [np.concatenate((tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]],
# get_framediff(tfrecord_list[i][FRM_LVL_FEAT_NAMES[0]])))
# for i, _ in enumerate(tfrecord_list)]
elif feature_type == "audio":
X = [tfrecord_list[i][FRM_LVL_FEAT_NAMES[1]] for i, _ in enumerate(tfrecord_list)]
Y = to_multi_categorical(labels, NUM_CLASSES)
print "get_data done."
return X, Y
def read_batch_record_features(file_pattern, batch_size, features,
randomize_input=True, num_epochs=None,
queue_capacity=10000, reader_num_threads=1,
parser_num_threads=1,
name='dequeue_record_examples'):
"""Reads TFRecord, queues, batches and parses `Example` proto.
See more detailed description in `read_examples`.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_local_variables() as shown in the tests.
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
name: Name of resulting op.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
Raises:
ValueError: for invalid inputs.
"""
return read_batch_features(
file_pattern=file_pattern, batch_size=batch_size, features=features,
reader=io_ops.TFRecordReader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads, name=name)
def setUp(self):
super(GraphIOTest, self).setUp()
random.seed(42)
self._orig_glob = gfile.Glob
gfile.Glob = self._mock_glob
def tearDown(self):
gfile.Glob = self._orig_glob
super(GraphIOTest, self).tearDown()
def test_keyed_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":2"], [b"DEF"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":3"], [b"GHK"]])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
def test_keyed_parse_json(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file(
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n'
)
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
dtypes = {"age": tf.FixedLenFeature([1], tf.int64)}
parse_fn = lambda example: tf.parse_single_example( # pylint: disable=g-long-lambda
tf.decode_json_example(example), dtypes)
keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity,
parse_fn=parse_fn, name=name)
session.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[0]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[1]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[2]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
def _expand_file_names(filepatterns):
"""Takes a list of file patterns and returns a list of resolved file names."""
if not isinstance(filepatterns, (list, tuple, set)):
filepatterns = [filepatterns]
filenames = set()
for filepattern in filepatterns:
names = set(gfile.Glob(filepattern))
filenames |= names
return list(filenames)
def _get_file_names(file_pattern, randomize_input):
"""Parse list of file names from pattern, optionally shuffled.
Args:
file_pattern: File glob pattern, or list of strings.
randomize_input: Whether to shuffle the order of file names.
Returns:
List of file names matching `file_pattern`.
Raises:
ValueError: If `file_pattern` is empty, or pattern matches no files.
"""
if isinstance(file_pattern, list):
file_names = file_pattern
if not file_names:
raise ValueError('No files given to dequeue_examples.')
else:
file_names = list(gfile.Glob(file_pattern))
if not file_names:
raise ValueError('No files match %s.' % file_pattern)
# Sort files so it will be deterministic for unit tests. They'll be shuffled
# in `string_input_producer` if `randomize_input` is enabled.
if not randomize_input:
file_names = sorted(file_names)
return file_names
def read_batch_record_features(file_pattern, batch_size, features,
randomize_input=True, num_epochs=None,
queue_capacity=10000, reader_num_threads=1,
name='dequeue_record_examples'):
"""Reads TFRecord, queues, batches and parses `Example` proto.
See more detailed description in `read_examples`.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.local_variables_initializer() as shown in the tests.
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
name: Name of resulting op.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
Raises:
ValueError: for invalid inputs.
"""
return read_batch_features(
file_pattern=file_pattern,
batch_size=batch_size,
features=features,
reader=io_ops.TFRecordReader,
randomize_input=randomize_input,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
reader_num_threads=reader_num_threads,
name=name)
def tearDown(self):
gfile.Glob = self._orig_glob
super(GraphIOTest, self).tearDown()
def test_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
filename, batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_read_text_lines_multifile(self):
gfile.Glob = self._orig_glob
filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
filenames, batch_size, reader=tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name
test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TextLineReader" % name: "TextLineReader",
example_queue_name: "FIFOQueue",
name: "QueueDequeueUpTo"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":2"], [b"DEF"]])
self.assertAllEqual(session.run([keys, inputs]),
[[filename.encode("utf-8") + b":3"], [b"GHK"]])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def _expand_file_names(filepatterns):
"""Takes a list of file patterns and returns a list of resolved file names."""
if not isinstance(filepatterns, (list, tuple, set)):
filepatterns = [filepatterns]
filenames = set()
for filepattern in filepatterns:
names = set(gfile.Glob(filepattern))
filenames |= names
return list(filenames)
def input_data(image_dir):
if not gfile.Exists(image_dir):
print(">> Image director '" + image_dir + "' not found.")
return None
extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
print(">> Looking for images in '" + image_dir + "'")
file_list = []
for extension in extensions:
file_glob = os.path.join(image_dir, '*.' + extension)
file_list.extend(gfile.Glob(file_glob))
if not file_list:
print(">> No files found in '" + image_dir + "'")
return None
batch_size = len(file_list)
images = np.zeros([batch_size, IMAGE_HEIGHT*IMAGE_WIDTH], dtype='float32')
files = []
i = 0
for file_name in file_list:
image = Image.open(file_name)
image_gray = image.convert('L')
image_resize = image_gray.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT))
image.close()
input_img = np.array(image_resize, dtype='float32')
input_img = np.multiply(input_img.flatten(), 1./255) - 0.5
images[i,:] = input_img
base_name = os.path.basename(file_name)
files.append(base_name)
i += 1
return images, files
def _get_file_names(file_pattern, randomize_input):
"""Parse list of file names from pattern, optionally shuffled.
Args:
file_pattern: File glob pattern, or list of strings.
randomize_input: Whether to shuffle the order of file names.
Returns:
List of file names matching `file_pattern`.
Raises:
ValueError: If `file_pattern` is empty, or pattern matches no files.
"""
if isinstance(file_pattern, list):
file_names = file_pattern
if not file_names:
raise ValueError('No files given to dequeue_examples.')
else:
file_names = list(gfile.Glob(file_pattern))
if not file_names:
raise ValueError('No files match %s.' % file_pattern)
# Sort files so it will be deterministic for unit tests. They'll be shuffled
# in `string_input_producer` if `randomize_input` is enabled.
if not randomize_input:
file_names = sorted(file_names)
return file_names
def setUp(self):
super(GraphIOTest, self).setUp()
random.seed(42)
self._orig_glob = gfile.Glob
gfile.Glob = self._mock_glob
def tearDown(self):
gfile.Glob = self._orig_glob
super(GraphIOTest, self).tearDown()
def test_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filename,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_batch_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("A\nB\nC\nD\nE\n")
batch_size = 3
queue_capacity = 10
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = graph_io.read_batch_examples(
[filename],
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
read_batch_size=10,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_read_text_lines(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples(
filename,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
self.assertAllEqual(
session.run([keys, inputs]),
[[filename.encode("utf-8") + b":1"], [b"ABC"]])
self.assertAllEqual(
session.run([keys, inputs]),
[[filename.encode("utf-8") + b":2"], [b"DEF"]])
self.assertAllEqual(
session.run([keys, inputs]),
[[filename.encode("utf-8") + b":3"], [b"GHK"]])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def _expand_file_names(filepatterns):
"""Takes a list of file patterns and returns a list of resolved file names."""
if not isinstance(filepatterns, (list, tuple, set)):
filepatterns = [filepatterns]
filenames = set()
for filepattern in filepatterns:
names = set(gfile.Glob(filepattern))
filenames |= names
return list(filenames)
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def read_batch_examples(file_pattern, batch_size, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, num_threads=1,
read_batch_size=1, parse_fn=None,
name=None):
"""Adds operations to read, queue, batch `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
read `Example` proto using provided `reader`, use batch queue to create
batches of examples of size `batch_size`.
All queue runners are added to the queue runners collection, and may be
started via `start_queue_runners`.
All ops are added to the default graph.
Use `parse_fn` if you need to do parsing / processing on single examples.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.initialize_all_variables()` as shown in the tests.
queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
Returns:
String `Tensor` of batched `Example` proto.
Raises:
ValueError: for invalid inputs.
"""
_, examples = read_keyed_batch_examples(
file_pattern=file_pattern, batch_size=batch_size, reader=reader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity, num_threads=num_threads,
read_batch_size=read_batch_size, parse_fn=parse_fn, name=name)
return examples
def read_batch_features(file_pattern, batch_size, features, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, feature_queue_capacity=100,
reader_num_threads=1, parser_num_threads=1,
parse_fn=None, name=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
read `Example` proto using provided `reader`, use batch queue to create
batches of examples of size `batch_size` and parse example given `features`
specification.
All queue runners are added to the queue runners collection, and may be
started via `start_queue_runners`.
All ops are added to the default graph.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_local_variables() as shown in the tests.
queue_capacity: Capacity for input queue.
feature_queue_capacity: Capacity of the parsed features queue. Set this
value to a small number, for example 5 if the parsed features are large.
reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
records to read at once
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
Raises:
ValueError: for invalid inputs.
"""
_, features = read_keyed_batch_features(
file_pattern, batch_size, features, reader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity,
feature_queue_capacity=feature_queue_capacity,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads,
parse_fn=parse_fn, name=name)
return features
def read_batch_examples(file_pattern, batch_size, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, num_threads=1,
read_batch_size=1, parse_fn=None,
name=None):
"""Adds operations to read, queue, batch `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
read `Example` proto using provided `reader`, use batch queue to create
batches of examples of size `batch_size`.
All queue runners are added to the queue runners collection, and may be
started via `start_queue_runners`.
All ops are added to the default graph.
Use `parse_fn` if you need to do parsing / processing on single examples.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.global_variables_initializer()` as shown in the tests.
queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
Returns:
String `Tensor` of batched `Example` proto.
Raises:
ValueError: for invalid inputs.
"""
_, examples = read_keyed_batch_examples(
file_pattern=file_pattern, batch_size=batch_size, reader=reader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity, num_threads=num_threads,
read_batch_size=read_batch_size, parse_fn=parse_fn, name=name)
return examples
def read_keyed_batch_examples(
file_pattern, batch_size, reader,
randomize_input=True, num_epochs=None,
queue_capacity=10000, num_threads=1,
read_batch_size=1, parse_fn=None,
name=None):
"""Adds operations to read, queue, batch `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
read `Example` proto using provided `reader`, use batch queue to create
batches of examples of size `batch_size`.
All queue runners are added to the queue runners collection, and may be
started via `start_queue_runners`.
All ops are added to the default graph.
Use `parse_fn` if you need to do parsing / processing on single examples.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.global_variables_initializer()` as shown in the tests.
queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
Returns:
Returns tuple of:
- `Tensor` of string keys.
- String `Tensor` of batched `Example` proto.
Raises:
ValueError: for invalid inputs.
"""
return _read_keyed_batch_examples_helper(
file_pattern,
batch_size,
reader,
randomize_input,
num_epochs,
queue_capacity,
num_threads,
read_batch_size,
parse_fn,
setup_shared_queue=False,
name=name)
def read_batch_features(file_pattern,
batch_size,
features,
reader,
randomize_input=True,
num_epochs=None,
queue_capacity=10000,
feature_queue_capacity=100,
reader_num_threads=1,
parse_fn=None,
name=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
read `Example` proto using provided `reader`, use batch queue to create
batches of examples of size `batch_size` and parse example given `features`
specification.
All queue runners are added to the queue runners collection, and may be
started via `start_queue_runners`.
All ops are added to the default graph.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.local_variables_initializer() as shown in the tests.
queue_capacity: Capacity for input queue.
feature_queue_capacity: Capacity of the parsed features queue. Set this
value to a small number, for example 5 if the parsed features are large.
reader_num_threads: The number of threads to read examples.
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
Raises:
ValueError: for invalid inputs.
"""
_, features = read_keyed_batch_features(
file_pattern, batch_size, features, reader,
randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity,
feature_queue_capacity=feature_queue_capacity,
reader_num_threads=reader_num_threads,
parse_fn=parse_fn, name=name)
return features
def test_read_text_lines_large(self):
gfile.Glob = self._orig_glob
sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789"
num_records = 49999
lines = ["".join([sequence_prefix, str(l)]).encode("ascii")
for l in xrange(num_records)]
json_lines = ["".join(['{"features": { "feature": { "sequence": {',
'"bytes_list": { "value": ["',
base64.b64encode(l).decode("ascii"),
'"]}}}}}\n']) for l in lines]
filename = self._create_temp_file("".join(json_lines))
batch_size = 10000
queue_capacity = 10000
name = "my_large_batch"
features = {"sequence": tf.FixedLenFeature([], tf.string)}
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, result = tf.contrib.learn.read_keyed_batch_features(
filename, batch_size, features, tf.TextLineReader,
randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
num_enqueue_threads=2, parse_fn=tf.decode_json_example, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertEqual(1, len(result))
self.assertAllEqual((None,), result["sequence"].get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
data = []
try:
while not coord.should_stop():
data.append(session.run(result))
except errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)
parsed_records = [item for sublist in [d["sequence"] for d in data]
for item in sublist]
# Check that the number of records matches expected and all records
# are present.
self.assertEqual(len(parsed_records), num_records)
self.assertEqual(set(parsed_records), set(lines))
def test_read_text_lines_multifile_with_shared_queue(self):
gfile.Glob = self._orig_glob
filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = _read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
reader=tf.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
shared_file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % shared_file_name_queue_name
example_queue_name = "%s/fifo_queue" % name
worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
test_util.assert_ops_in_graph({
file_names_name: "Const",
shared_file_name_queue_name: "FIFOQueue",
"%s/read/TextLineReader" % name: "TextLineReader",
example_queue_name: "FIFOQueue",
worker_file_name_queue_name: "FIFOQueue",
name: "QueueDequeueUpTo"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_parse_json(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file(
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n'
)
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
dtypes = {"age": tf.FixedLenFeature([1], tf.int64)}
parse_fn = lambda example: tf.parse_single_example( # pylint: disable=g-long-lambda
tf.decode_json_example(example), dtypes)
keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity,
parse_fn=parse_fn, name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertEqual(1, len(inputs))
self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
session.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(session, coord=coord)
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[0]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[1]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[2]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def _MaybeDeleteOldCheckpoints(self, latest_save_path,
meta_graph_suffix="meta"):
"""Deletes old checkpoints if necessary.
Always keep the last `max_to_keep` checkpoints. If
`keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
kept for every 0.5 hours of training; if `N` is 10, an additional
checkpoint is kept for every 10 hours of training.
Args:
latest_save_path: Name including path of checkpoint file to save.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
"""
if not self.saver_def.max_to_keep:
return
# Remove first from list if the same name was used before.
for p in self._last_checkpoints:
if latest_save_path == self._CheckpointFilename(p):
self._last_checkpoints.remove(p)
# Append new path to list
self._last_checkpoints.append((latest_save_path, time.time()))
# If more than max_to_keep, remove oldest.
if len(self._last_checkpoints) > self.saver_def.max_to_keep:
p = self._last_checkpoints.pop(0)
# Do not delete the file if we keep_checkpoint_every_n_hours is set and we
# have reached N hours of training.
should_keep = p[1] > self._next_checkpoint_time
if should_keep:
self._next_checkpoint_time += (
self.saver_def.keep_checkpoint_every_n_hours * 3600)
return
# Otherwise delete the files.
for f in gfile.Glob(self._CheckpointFilename(p)):
try:
gfile.Remove(f)
meta_graph_filename = self._MetaGraphFilename(
f, meta_graph_suffix=meta_graph_suffix)
if gfile.Exists(meta_graph_filename):
gfile.Remove(meta_graph_filename)
except OSError as e:
logging.warning("Ignoring: %s", str(e))
def read_batch_record_features(file_pattern,
batch_size,
features,
randomize_input=True,
num_epochs=None,
queue_capacity=10000,
reader_num_threads=1,
name='dequeue_record_examples'):
"""Reads TFRecord, queues, batches and parses `Example` proto.
See more detailed description in `read_examples`.
Args:
file_pattern: List of files or pattern of file paths containing
`Example` records. See `tf.gfile.Glob` for pattern rules.
batch_size: An int or scalar `Tensor` specifying the batch size to use.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values.
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.local_variables_initializer() and run the op in a session.
queue_capacity: Capacity for input queue.
reader_num_threads: The number of threads to read examples.
name: Name of resulting op.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
Raises:
ValueError: for invalid inputs.
"""
return read_batch_features(
file_pattern=file_pattern,
batch_size=batch_size,
features=features,
reader=io_ops.TFRecordReader,
randomize_input=randomize_input,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
reader_num_threads=reader_num_threads,
name=name)
def test_read_text_lines_multifile(self):
gfile.Glob = self._orig_glob
filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
batch_size = 1
queue_capacity = 5
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filenames,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name
test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueueV2",
"%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
example_queue_name: "FIFOQueueV2",
name: "QueueDequeueUpToV2"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_read_text_lines_multifile_with_shared_queue(self):
gfile.Glob = self._orig_glob
filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"])
batch_size = 1
queue_capacity = 5
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = _read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run([
variables.local_variables_initializer(),
variables.global_variables_initializer()
])
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
self.assertEqual("%s:1" % name, inputs.name)
example_queue_name = "%s/fifo_queue" % name
worker_file_name_queue_name = "%s/file_name_queue/fifo_queue" % name
test_util.assert_ops_in_graph({
"%s/read/TextLineReaderV2" % name: "TextLineReaderV2",
example_queue_name: "FIFOQueueV2",
worker_file_name_queue_name: "FIFOQueueV2",
name: "QueueDequeueUpToV2"
}, g)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)
def test_keyed_parse_json(self):
gfile.Glob = self._orig_glob
filename = self._create_temp_file(
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}\n'
'{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}\n')
batch_size = 1
queue_capacity = 5
name = "my_batch"
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda
parsing_ops.decode_json_example(example), dtypes)
keys, inputs = graph_io.read_keyed_batch_examples(
filename,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
queue_capacity=queue_capacity,
parse_fn=parse_fn,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertEqual(1, len(inputs))
self.assertAllEqual((None, 1), inputs["age"].get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[0]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":1"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[1]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":2"])
key, age = session.run([keys, inputs["age"]])
self.assertAllEqual(age, [[2]])
self.assertAllEqual(key, [filename.encode("utf-8") + b":3"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
coord.join(threads)