我正在将字符()输入到此链接x_train
示例 13 中定义的 RNN 模型中。以下是与模型定义、输入预处理和训练相对应的代码。
def char_rnn_model(features, target):
"""Character level recurrent neural network model to predict classes."""
target = tf.one_hot(target, 15, 1, 0)
#byte_list = tf.one_hot(features, 256, 1, 0)
byte_list = tf.cast(tf.one_hot(features, 256, 1, 0), dtype=tf.float32)
byte_list = tf.unstack(byte_list, axis=1)
cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
_, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)
logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
#loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
loss = tf.contrib.losses.softmax_cross_entropy(logits=logits, onehot_labels=target)
train_op = tf.contrib.layers.optimize_loss(
loss,
tf.contrib.framework.get_global_step(),
optimizer='Adam',
learning_rate=0.001)
return ({
'class': tf.argmax(logits, 1),
'prob': tf.nn.softmax(logits)
}, loss, train_op)
# pre-process
char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
x_train = np.array(list(char_processor.fit_transform(x_train)))
x_test = np.array(list(char_processor.transform(x_test)))
# train
model_dir = "model"
classifier = learn.Estimator(model_fn=char_rnn_model,model_dir=model_dir)
count=0
n_epoch = 20
while count<n_epoch:
print("\nEPOCH " + str(count))
classifier.fit(x_train, y_train, steps=1000,batch_size=10)
y_predicted = [
p['class'] for p in classifier.predict(
x_test, as_iterable=True,batch_size=10)
]
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy: {0:f}'.format(score))
count+=1
print(metrics.classification_report(y_test, predicted))
训练结束后,目录model_dir
中将填充以下文件:
它保存了模型的权重和图表。我想用它们进行*推理*。
我设法使用以下代码加载它们:
new_saver = tf.train.import_meta_graph(meta_file)
sess = tf.Session()
new_saver.restore(sess, tf.train.latest_checkpoint(model_dir))
meta_file
model.ckpt-??????.meta 文件之一的路径在哪里。
我想将训练好的模型应用于新的字符序列。因此我输入:
new_input = ["Some Sequence of character"]
new_input_processed = np.array(list(char_processor.transform(new_input)))
output = sess.run(new_input_processed)
但我收到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-13-982f2b9b18b3> in <module>()
----> 1 output = sess.run(new_input_processed)
/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
898 try:
899 result = self._run(None, fetches, feed_dict, options_ptr,
--> 900 run_metadata_ptr)
901 if run_metadata:
902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 # Create a fetch handler to take care of the structure of fetches.
1119 fetch_handler = _FetchHandler(
-> 1120 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1121
1122 # Run request and get response.
/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
251 if isinstance(fetch, tensor_type):
252 fetches, contraction_fn = fetch_fn(fetch)
--> 253 return _ElementFetchMapper(fetches, contraction_fn)
254 # Did not find anything.
255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches, contraction_fn)
284 raise TypeError('Fetch argument %r has invalid type %r, '
285 'must be a string or Tensor. (%s)' %
--> 286 (fetch, type(fetch), str(e)))
287 except ValueError as e:
288 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument array([[ 83, 111, 109, 101, 32, 83, 101, 113, 117, 101, 110, 99, 101,
32, 111, 102, 32, 99, 104, 97, 114, 97, 99, 116, 101, 114,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=uint8) has invalid type <type 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)
我正在使用 Tensorflow 1.8.0 和 python 2.7.14。
您遇到的错误是由于 sess.run()
方法的输入参数不符合要求。在 TensorFlow 中,sess.run()
需要一个或多个张量(Tensors)作为 fetches
,而您将 new_input_processed
(一个 numpy.ndarray
)传递给了 sess.run()
。这个输入不是张量,因此发生了错误。
您的代码试图将预处理的字符序列 new_input_processed
直接传递给 sess.run()
,而 sess.run()
期望的是 TensorFlow 的张量类型,而不是 NumPy 数组。
您需要确保 new_input_processed
是一个 TensorFlow 张量,而不是一个 NumPy 数组。可以通过将它转换为 TensorFlow 张量来解决这个问题。
在这种情况下,您应该使用 feed_dict
来传递输入数据,并确保将 NumPy 数组传递给张量。
features
的输入张量,我们将通过 graph.get_tensor_by_name()
获取它。feed_dict
: 使用 feed_dict
来将输入数据传递给模型。sess.run()
来执行推理。假设您知道输入张量的名称(如 features
),可以按如下方式修改代码:
import numpy as np
import tensorflow as tf
# 加载模型
model_dir = "model"
meta_file = "model/model.ckpt-??????.meta" # 替换为实际的 meta 文件路径
new_saver = tf.train.import_meta_graph(meta_file)
sess = tf.Session()
new_saver.restore(sess, tf.train.latest_checkpoint(model_dir))
# 获取输入张量和输出张量
graph = tf.get_default_graph()
input_tensor = graph.get_tensor_by_name('features:0') # 假设输入张量的名称是 'features'
output_tensor = graph.get_tensor_by_name('class:0') # 假设输出张量的名称是 'class'
# 输入数据预处理
new_input = ["Some Sequence of character"]
new_input_processed = np.array(list(char_processor.transform(new_input)))
# 使用 feed_dict 将数据传递给模型
feed_dict = {input_tensor: new_input_processed}
# 运行模型推理
output = sess.run(output_tensor, feed_dict=feed_dict)
print(output)
input_tensor = graph.get_tensor_by_name('features:0')
:获取图中名为 features
的输入张量。注意,您需要根据实际的模型来修改张量名称。output_tensor = graph.get_tensor_by_name('class:0')
:获取模型的输出张量。
预处理输入数据:
new_input_processed = np.array(list(char_processor.transform(new_input)))
:将输入字符序列进行预处理,并将其转换为 NumPy 数组。
使用 feed_dict
传递数据:
feed_dict = {input_tensor: new_input_processed}
:将预处理后的输入数据传递给图中的输入张量。
运行模型推理:
output = sess.run(output_tensor, feed_dict=feed_dict)
:通过 sess.run()
运行推理,得到输出结果。features:0
和 class:0
是张量的名称,在模型定义时可能有所不同。您可以通过 graph.get_operations()
来查看图中的所有操作和张量,确定它们的确切名称。char_processor
的定义,可能需要将输入数据进行适当的填充或调整。这应该解决您遇到的 TypeError
,并使您能够在加载的模型上进行推理。