我需要帮助我想拟合模型但是它提供了这个错误

问题描述 投票:0回答:1

model.fit 不工作......我需要帮助__我正在使用 mobilenet 模型 那么如何解决这个问题? history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = [checkpoint, lr_decay, early_stopping])

Training the model...
Epoch 1/15
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-54-6b9880db1065> in <cell line: 13>()
     11 print("Training the model...")
     12 
---> 13 history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = callbacksX)
     14 print("Done!")

1 frames
/usr/local/lib/python3.9/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     50 
     51   Raises:
---> 52     An exception on error.
     53   """
     54   device_name = ctx.device_name

InvalidArgumentError: Graph execution error:

Detected at node 'categorical_crossentropy/softmax_cross_entropy_with_logits' defined at (most recent call last):
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.9/dist-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelapp.py", line 619, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.9/dist-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.9/dist-packages/tornado/ioloop.py", line 687, in <lambda>
      lambda f: self._run_callback(functools.partial(callback, future))
    File "/usr/local/lib/python3.9/dist-packages/tornado/ioloop.py", line 740, in _run_callback
      ret = callback()
    File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 821, in inner
      self.ctx_run(self.run)
    File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 782, in run
      yielded = self.gen.send(value)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
      yield gen.maybe_future(dispatch(*args))
    File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
      yield gen.maybe_future(handler(stream, idents, msg))
    File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
      self.do_execute(
    File "/usr/local/lib/python3.9/dist-packages/tornado/gen.py", line 234, in wrapper
      yielded = ctx_run(next, result)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.9/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
      result = self._run_cell(
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
      return runner(coro)
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "/usr/local/lib/python3.9/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-52-6b9880db1065>", line 13, in <cell line: 13>
      history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = callbacksX)
    File "/usr/local/lib/python3.9/dist-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      except Exception as e:  # pylint: disable=broad-except
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1685, in fit
      Ignored with the default value of `None`. If x is a `tf.data`
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1284, in train_function
      This function should contain the mathematical logic for one step of
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1268, in step_function
      training_logs = epoch_logs
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1249, in run_step
      use_multiprocessing=use_multiprocessing,
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1051, in train_step
      When passing an infinitely repeating dataset, you must specify the
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/training.py", line 1109, in compute_loss
      `namedtuple("example_tuple", ["y", "x"])`
    File "/usr/local/lib/python3.9/dist-packages/keras/engine/compile_utils.py", line 265, in __call__
      loss: A string, function, or `Loss` object.
    File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 142, in __call__
      return losses_utils.compute_weighted_loss(
    File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 268, in call
      0.5
    File "/usr/local/lib/python3.9/dist-packages/keras/losses.py", line 1984, in categorical_crossentropy
    File "/usr/local/lib/python3.9/dist-packages/keras/backend.py", line 5565, in categorical_crossentropy
      Output tensor.
Node: 'categorical_crossentropy/softmax_cross_entropy_with_logits'
logits and labels must be broadcastable: logits_size=[32,7] labels_size=[32,9]
     [[{{node categorical_crossentropy/softmax_cross_entropy_with_logits}}]] [Op:__inference_train_function_63960]

我尝试安装旧版本的张量,但也没有用

下面是我用来训练模型的代码

dropout_dense = 0.1
mobilenet_model = MobileNet(input_shape=IMAGE_SHAPE, include_top=False, pooling="max")

model = Sequential()
model.add(mobilenet_model)
model.add(Dropout(dropout_dense))
model.add(BatchNormalization())
model.add(Dense(256, activation="relu"))
model.add(Dropout(dropout_dense))
model.add(BatchNormalization())
model.add(Dense(7, activation="softmax"))

def top_2_acc(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=2)

def top_3_acc(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)

model.compile(Adam(0.01), loss="categorical_crossentropy", metrics=[categorical_accuracy, top_2_acc, top_3_acc])
filepath = "model.h5"

checkpoint_param = {
    "filepath": filepath,
    "monitor": "val_categorical_accuracy",
    "verbose": 1,
    "save_best_only": True,
    "mode": "max"
}
checkpoint = ModelCheckpoint(**checkpoint_param)

lr_decay_params = {
    "monitor": "val_loss",
    "factor": 0.5,
    "patience": 2,
    "min_lr": 1e-5
}
lr_decay = ReduceLROnPlateau(**lr_decay_params)
early_stopping = EarlyStopping(monitor="val_loss", patience=4, verbose=1)
print("Training the model...")

history = model.fit(train_flow,steps_per_epoch=32,epochs=15,verbose=1,validation_data = val_flow,validation_steps=32,callbacks = [checkpoint, lr_decay, early_stopping])
print("Done!")
python deep-learning error-handling conv-neural-network google-colaboratory
1个回答
0
投票

节点:'categorical_crossentropy/softmax_cross_entropy_with_logits' logits 和 labels 必须是可广播的:logits_size=[32,7] labels_size=[32,9]

你的最后一层有 7 个节点:

model.add(Dense(7, activation="softmax"))
,而(大概)你有 9 个可能的标签。您必须将最后一行更改为
model.add(Dense(9, activation="softmax"))
.

© www.soinside.com 2019 - 2024. All rights reserved.