非常初学者的问题,希望很好
[我正在尝试使用MAPS数据集从GitHub训练此model,并使用此代码为训练集创建了新的.tfrecords。它基于代码here,但我更改了一些东西以让路用于其他输入(另一个MIDI文件,我只是称其为“速度MIDI”)。
def create_train_set(tempopath, train_list, outdir, min_length, max_length):
# train_list = list of wav paths selected for
train_file_pairs = []
# find matching midi files
for wav_path in train_list:
midi_file = ''
tempo_midi_file = ''
if os.path.isfile(wav_path + '.mid'):
midi_file = wav_path + '.mid'
if os.path.isfile(wav_path + '.midi'):
midi_file = wav_path + '.midi'
if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.mid'):
tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.mid'
if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.midi'):
tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.midi'
wav_file = wav_path + '.wav'
train_file_pairs.append((wav_file, midi_file, tempo_midi_file))
train_output_name = os.path.join(outdir, 'train.tfrecord')
with tf.python_io.TFRecordWriter(train_output_name) as writer:
for idx, pair in enumerate(train_file_pairs):
print('{} of {}: {}'.format(idx, len(train_file_pairs), pair[0]))
# load the wav data
wav_data = tf.gfile.Open(pair[0], 'rb').read()
# load the midi data and convert to a notesequence
ns = midi_io.midi_file_to_note_sequence(pair[1])
tempo = midi_io.midi_file_to_note_sequence(pair[2])
# aldu = audio_label_data_utils.py
for example in aldu.process_record(
wav_data, ns, tempo, pair[0], min_length, max_length,
sample_rate):
writer.write(example.SerializeToString())
使用tf。示例如下:
example = tf.train.Example(
features=tf.train.Features(
feature={
'id':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[example_id.encode('utf-8')])),
'sequence':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[ns.SerializeToString()])),
'audio':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[wav_data])),
'tempo':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[velocity_range.SerializeToString()])),
'velocity_range':
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[velocity_range.SerializeToString()])),
}))
但是,当我尝试训练模型时,收到此错误消息(我用打印行标记了py脚本,所以我知道一切在哪里进行了:]
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
E0611 07:56:55.419340 8436 error_handling.py:70] Error recorded from training_loop: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
I0611 07:56:55.420338 8436 error_handling.py:96] training_loop marked as finished
W0611 07:56:55.421335 8436 error_handling.py:130] Reraising captured error
Traceback (most recent call last):
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
return fn(*args)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "onsets_frames_transcription_train.py", line 128, in <module>
console_entry_point()
File "onsets_frames_transcription_train.py", line 124, in console_entry_point
tf.app.run(main)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 300, in run
_run_main(main, args)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 251, in _run_main
sys.exit(main(argv))
File "onsets_frames_transcription_train.py", line 120, in main
additional_trial_info=additional_trial_info)
File "onsets_frames_transcription_train.py", line 95, in run
num_steps=FLAGS.num_steps)
File "C:\Users\User\magenta\magenta\models\onsets_frames_transcription\train_util.py", line 134, in train
estimator.train(input_fn=transcription_data, max_steps=num_steps)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2876, in train
rendezvous.raise_errors()
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\error_handling.py", line 131, in raise_errors
six.reraise(typ, value, traceback)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
raise value
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2871, in train
saving_listeners=saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 367, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1158, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1192, in _train_model_default
saving_listeners)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1484, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 754, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1252, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1353, in run
raise six.reraise(*original_exc_info)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
raise value
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1338, in run
return self._sess.run(*args, **kwargs)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1411, in run
run_metadata=run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1169, in run
return self._sess.run(*args, **kwargs)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
run_metadata_ptr)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
run_metadata)
File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
[[{{node Reshape_8}}]]
[[IteratorGetNext]]
由此,我发现问题出在wav_to_num_frames,但这是唯一的代码。
def wav_to_num_frames(wav_audio, frames_per_second):
"""Transforms a wav-encoded audio string into number of frames."""
print("Running wav_to_num_frames from data")
w = wave.open(six.BytesIO(wav_audio))
return np.int32(w.getnframes() / w.getframerate() * frames_per_second)
当我尝试使用由原始代码创建的tfrecords训练模型时,我没有遇到这个问题,所以我不知道出了什么问题。
事实证明,问题不是创建的.tfrecords本身,而是我为新添加的数据分配的张量的大小。对此没有具体的答案,因为它非常适合这种情况。