小能豆

TensorFlow 模型的 tf.data 管道中存在问题

python

我有一个应用程序,需要使用 tf.data 设置管道。我所拥有的数据存储在Matlab中创建的.mat文件中,包含三个变量“s_matrix”,它是一个224x224x3双精度数组,一个“框架”,它是1024x1复数双精度,最后是一个数字标签。管道将这样加载,以便我可以将数据提供给 model.fit 函数。到目前为止,我一直在使用的用于加载和处理数据的代码添加在下面,但我不断收到几个类型错误和意外的字节错误。

更新2:进行一些修改

通过包含 Giorgos 建议的更改以及更改数据集生成器并使用 tf.data.Dataset.from_generator 函数来更新代码。有一些明显的改进,但现在的问题是只有两个输入之一被传递。

# Define the shape of the input image
input_shape = (224, 224, 3)

# Define the shape of the complex vector after conversion
complex_shape = (1024, 2, 1)

# Define a function to load and preprocess each sample
def load_and_preprocess_sample(sample_path):
    # Load the sample from the mat file
    sample = scipy.io.loadmat(sample_path)
    matrix = sample['s_matrix']
    complex_vector = sample['frame']
    label = sample['numeric_label']

    # Preprocess the matrix, complex vector, and label as needed
    real = tf.reshape(tf.math.real(complex_vector), [1024, 1])
    imag = tf.reshape(tf.math.imag(complex_vector), [1024, 1])
    signal_tensor = tf.concat([real, imag], axis=-1)
    signal_tensor = tf.reshape(signal_tensor, [1024, 2, 1])
    signal = signal_tensor
    # Normalize the matrix values between 0 and 1
    matrix = matrix / 255.0

    return matrix, signal, label


# Define a generator function to generate the samples
def sample_generator(file_paths):
    for file_path in file_paths:
        #yield load_and_preprocess_sample(file_path)
        matrix, complex_vector, label = load_and_preprocess_sample(file_path)
        yield (matrix, complex_vector), label

# Modify the create_dataset() function to use from_generator
def create_dataset(file_paths):
    dataset = tf.data.Dataset.from_generator(
        generator=lambda: sample_generator(file_paths),
        output_signature=(
            tf.TensorSpec(shape=input_shape, dtype=tf.float32),
            tf.TensorSpec(shape=complex_shape, dtype=tf.float32),
            tf.TensorSpec(shape=(1,), dtype=tf.float32)
        )
    )

    dataset = dataset.shuffle(buffer_size=len(file_paths))
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

# Get a list of all file paths in the data folder
file_paths = [os.path.join(data_path, f) for f in os.listdir(data_path) if f.endswith('.mat')]

# Split file paths into training and validation sets
train_file_paths = file_paths[:-num_val_samples]
val_file_paths = file_paths[-num_val_samples:]

生成的数据和模型调用如下:

# Create datasets for training and validation sets
train_dataset = create_dataset(train_file_paths)
val_dataset = create_dataset(val_file_paths)
...
...
...
model = tf.keras.Model(inputs=[input1, input2], outputs=output)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train your model
model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=num_train_samples // batch_size,
          validation_data=val_dataset,
          validation_steps=num_val_samples // batch_size)

当前的错误输出共享如下:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[7], line 90
     87 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
     89 # Train your model
---> 90 model.fit(train_dataset,
     91           epochs=5,
     92           steps_per_epoch=num_train_samples // batch_size,
     93           validation_data=val_dataset,
     94           validation_steps=num_val_samples // batch_size)

File ~\miniconda3\envs\tf2\lib\site-packages\keras\utils\traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File ~\AppData\Local\Temp\__autograph_generated_filea4_9b7hv.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__train_function(iterator)
     13 try:
     14     do_return = True
---> 15     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16 except:
     17     do_return = False

ValueError: in user code:

    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\engine\training.py", line 1160, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\engine\training.py", line 1146, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\engine\training.py", line 1135, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\engine\training.py", line 993, in train_step
        y_pred = self(x, training=True)
    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "C:\Users\Admin\miniconda3\envs\tf2\lib\site-packages\keras\engine\input_spec.py", line 216, in assert_input_compatibility
        raise ValueError(

    ValueError: Layer "model" expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 224, 224, 3) dtype=float32>]

阅读 153

收藏
2023-06-28

共1个答案

小能豆

根据您提供的更新后的代码,以下是一些需要修改的地方:

  1. 您定义的输入张量的名称应与模型中的输入名称相匹配。在您的代码中,您将输入张量命名为 input1input2,但是在模型定义中,您没有使用这些名称。请确保输入张量的名称匹配。

input1 = tf.keras.layers.Input(shape=input_shape) input2 = tf.keras.layers.Input(shape=complex_shape)

然后,将这些输入传递给模型的 inputs 参数。

model = tf.keras.Model(inputs=[input1, input2], outputs=output)

  1. 在训练模型之前,确保您的数据集生成器生成正确的输入。您的 sample_generator 函数应该返回一个元组 (input, label),其中 input 是一个包含两个元素的元组 (matrix, complex_vector)

yield (matrix, complex_vector), label

  1. 在创建数据集时,根据模型的输入定义更新 output_signature 参数,确保与输入张量的形状和数据类型匹配。

output_signature=( ( tf.TensorSpec(shape=input_shape, dtype=tf.float32), tf.TensorSpec(shape=complex_shape, dtype=tf.float32) ), tf.TensorSpec(shape=(1,), dtype=tf.float32) )

这样可以确保数据集的元素是正确的形状和数据类型。

请根据上述修改尝试更新您的代码,并确保输入名称匹配,数据生成器返回正确的输入形状,以及数据集的 output_signature 与模型的输入匹配。这样应该可以解决您遇到的问题。

2023-06-28