使用带有样本权重的tf数据集的TF 2.3.0训练keras模型不适用于指标

使用带有样本权重的tf数据集的TF 2.3.0训练keras模型不适用于指标

问题描述:

我将sample_weight作为tf.data.Dataset中的第三个元组传递(在mask的上下文中使用它,因此我的sample_weight为0或1.问题是此sample_weight似乎没有被应用(请参阅: https://www.tensorflow.org/guide/keras /train_and_evaluate#sample_weights )

I am passing in sample_weight as the 3rd tuple in tf.data.Dataset (using it in the context of mask, so my sample_weight are either 0, or 1. The problem is that this sample_weight doesn't seem to get applied to metrics calculation. (Ref: https://www.tensorflow.org/guide/keras/train_and_evaluate#sample_weights)

这是代码段:

train_ds = tf.data.Dataset.from_tensor_slices((imgs, labels, masks))
train_ds = train_ds.shuffle(1024).repeat().batch(32).prefetch(buffer_size=AUTO)

model.compile(optimizer = Adam(learning_rate=1e-4),
             loss = SparseCategoricalCrossentropy(),
             metrics = ['sparse_categorical_accuracy'])

model.fit(train_ds, steps_per_epoch = len(imgs)//32, epochs = 20)

训练后的损失非常接近零,但sparse_categorical_accuracy却不(大约0.89).因此,我非常怀疑在训练期间报告指标时传入的用于构造tf.dataset的任何sample_weight(掩码)都不会应用,而丢失似乎是正确的.我通过对未单独屏蔽的子集运行预测来进一步确认,并确认准确性为1.0

The loss after training is very close to zero, but sparse_categorical_accuracy is not (about 0.89). So I highly suspect whatever sample_weight (masks) that's passed in to construct the tf.dataset, does NOT get applied when the metrics is reported during training, while loss seems to be correct. I further confirmed by running prediction on the subset that are not masked separately, and confirmed the accuracy is 1.0

此外,根据文档:

https://www.tensorflow.org/api_docs/python /tf/keras/metrics/SparseCategoricalAccuracy

该指标具有3个参数:y_true,y_pred,sample_weight

the metric has 3 args: y_true, y_pred, sample_weight

那么在度量标准计算中如何传递sample_weight?这是keras框架内的model.fit(...)的责任吗?到目前为止,我找不到任何示例.

So how does one pass the sample_weight during metric computation? Is this the responsibility of model.fit(...) within the keras framework? I can't find any example googling around so far.

在进行一些调试和文档阅读后,我发现.compile中有weighted_metrics参数,我应该使用该参数代替metrics =.我确认这已解决了我在共享colab中的测试用例.

Upon some debugging and doc reading, i found there's weighted_metrics argument in .compile, which i should use instead of metrics=. I confirmed this fixed my test case in the shared colab.

model.compile(optimizer = Adam(learning_rate=1e-4),
             loss = SparseCategoricalCrossentropy(),
             weighted_metrics = [SparseCategoricalAccuracy()])