Kaydet (Commit) f4017fdf authored tarafından Hans Gaiser's avatar Hans Gaiser

Fix unit test for FilterDetections.

üst 1886f22e
......@@ -17,6 +17,12 @@ limitations under the License.
import tensorflow
def ones(*args, **kwargs):
""" See https://www.tensorflow.org/versions/master/api_docs/python/tf/ones .
"""
return tensorflow.ones(*args, **kwargs)
def transpose(*args, **kwargs):
""" See https://www.tensorflow.org/versions/master/api_docs/python/tf/transpose .
"""
......
......@@ -73,7 +73,7 @@ def filter_detections(
# perform per class filtering
for c in range(int(classification.shape[1])):
scores = classification[:, c]
labels = c * keras.backend.ones((keras.backend.shape(scores)[0],), dtype='int64')
labels = c * backend.ones((keras.backend.shape(scores)[0],), dtype='int64')
all_indices.append(_filter_detections(scores, labels))
# concatenate indices to single tensor
......
......@@ -18,10 +18,8 @@ import keras
import keras_retinanet.layers
import numpy as np
import pytest
@pytest.mark.skip()
class TestFilterDetections(object):
def test_simple(self):
# create simple FilterDetections layer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment