model.fit 不工作......我需要帮助__我正在使用 mobilenet 模型 那么如何解决这个问题? history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = [checkpoint, lr_decay, early_stopping])
Training the model...
Epoch 1/15
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-54-6b9880db1065> in <cell line: 13>()
11 print("Training the model...")
12
---> 13 history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = callbacksX)
14 print("Done!")
1 frames
/usr/local/lib/python3.9/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
50
51 Raises:
---> 52 An exception on error.
53 """
54 device_name = ctx.device_name
InvalidArgumentError: Graph execution error:
Detected at node 'categorical_crossentropy/softmax_cross_entropy_with_logits' defined at (most recent call last):
File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.9/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.9/dist-packages/traitlets/config/application.py", line 992, in launch_instance
app.start()
File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelapp.py", line 619, in start
self.io_loop.start()
File "/usr/local/lib/python3.9/dist-packages/tornado/platform/asyncio.py", line 215, in start
self.asyncio_loop.run_forever()
File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
self._run_once()
File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
handle._run()
File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/usr/local/lib/python3.9/dist-packages/tornado/ioloop.py", line 687, in <lambda>
lambda f: self._run_callback(functools.partial(callback, future))
File "/usr/local/lib/python3.9/dist-packages/tornado/ioloop.py", line 740, in _run_callback
ret = callback()
File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 821, in inner
self.ctx_run(self.run)
File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 782, in run
yielded = self.gen.send(value)
File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
yield gen.maybe_future(dispatch(*args))
File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
yield gen.maybe_future(handler(stream, idents, msg))
File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
self.do_execute(
File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
yielded = ctx_run(next, result)
File "/usr/local/lib/python3.9/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.9/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
result = self._run_cell(
File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
return runner(coro)
File "/usr/local/lib/python3.9/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
coro.send(None)
File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-52-6b9880db1065>", line 13, in <cell line: 13>
history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = callbacksX)
File "/usr/local/lib/python3.9/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
except Exception as e: # pylint: disable=broad-except
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1685, in fit
Ignored with the default value of `None`. If x is a `tf.data`
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1284, in train_function
This function should contain the mathematical logic for one step of
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1268, in step_function
training_logs = epoch_logs
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1249, in run_step
use_multiprocessing=use_multiprocessing,
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1051, in train_step
When passing an infinitely repeating dataset, you must specify the
File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1109, in compute_loss
`namedtuple("example_tuple", ["y", "x"])`
File "/usr/local/lib/python3.9/dist-packages/keras/engine/compile_utils.py", line 265, in __call__
loss: A string, function, or `Loss` object.
File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 142, in __call__
return losses_utils.compute_weighted_loss(
File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 268, in call
0.5
File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 1984, in categorical_crossentropy
File "/usr/local/lib/python3.9/dist-packages/keras/backend.py", line 5565, in categorical_crossentropy
Output tensor.
Node: 'categorical_crossentropy/softmax_cross_entropy_with_logits'
logits and labels must be broadcastable: logits_size=[32,7] labels_size=[32,9]
[[{{node categorical_crossentropy/softmax_cross_entropy_with_logits}}]] [Op:__inference_train_function_63960]
我尝试安装旧版本的张量,但也没有用
下面是我用来训练模型的代码
dropout_dense = 0.1
mobilenet_model = MobileNet(input_shape=IMAGE_SHAPE, include_top=False, pooling="max")
model = Sequential()
model.add(mobilenet_model)
model.add(Dropout(dropout_dense))
model.add(BatchNormalization())
model.add(Dense(256, activation="relu"))
model.add(Dropout(dropout_dense))
model.add(BatchNormalization())
model.add(Dense(7, activation="softmax"))
def top_2_acc(y_true, y_pred):
return top_k_categorical_accuracy(y_true, y_pred, k=2)
def top_3_acc(y_true, y_pred):
return top_k_categorical_accuracy(y_true, y_pred, k=3)
model.compile(Adam(0.01), loss="categorical_crossentropy", metrics=[categorical_accuracy, top_2_acc, top_3_acc])
filepath = "model.h5"
checkpoint_param = {
"filepath": filepath,
"monitor": "val_categorical_accuracy",
"verbose": 1,
"save_best_only": True,
"mode": "max"
}
checkpoint = ModelCheckpoint(**checkpoint_param)
lr_decay_params = {
"monitor": "val_loss",
"factor": 0.5,
"patience": 2,
"min_lr": 1e-5
}
lr_decay = ReduceLROnPlateau(**lr_decay_params)
early_stopping = EarlyStopping(monitor="val_loss", patience=4, verbose=1)
print("Training the model...")
history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = [checkpoint, lr_decay, early_stopping])
print("Done!")
节点:'categorical_crossentropy/softmax_cross_entropy_with_logits' logits 和 labels 必须是可广播的:logits_size=[32,7] labels_size=[32,9]
你的最后一层有 7 个节点:
model.add(Dense(7, activation="softmax"))
,而(大概)你有 9 个可能的标签。您必须将最后一行更改为model.add(Dense(9, activation="softmax"))
.