Skip to content

Commit 6457261

Browse files
committed
Minor fixes to unit tests for orchestration/portable
PiperOrigin-RevId: 327320340
1 parent ca8eddc commit 6457261

12 files changed

+77
-58
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
python (2.7/3.5/3.6) is not available anymore in `tensorflow/tfx` images
6767
on docker hub. Virtualenv is not used anymore.
6868
* Depends on `pyarrow>=0.17,<0.18`.
69+
* Depends on `attrs>=19.3.0,<20`.
6970

7071
## Breaking changes
7172
* Changed the URIs of the value artifacts to point to files.

tfx/dependencies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def make_required_install_packages():
2424
# LINT.IfChange
2525
'apache-beam[gcp]>=2.22,<3',
2626
# LINT.ThenChange(examples/chicago_taxi_pipeline/setup/setup_beam.sh)
27+
'attrs>=19.3.0,<20',
2728
'click>=7,<8',
2829
'docker>=4.1,<5',
2930
'google-api-python-client>=1.7.8,<2',

tfx/orchestration/portable/base_executor_operator.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Base class to define how to operator an executor."""
1515

1616
import abc
17-
from typing import Any, Dict, List, Optional, Text
17+
from typing import Any, Dict, List, Optional
1818

1919
import attr
2020
import six
@@ -30,30 +30,31 @@
3030
# TODO(b/150979622): We should introduce an id that is not changed across
3131
# retires of the same component run and pass it to executor operators for
3232
# human-readability purpose.
33-
@attr.s(auto_attribs=True)
33+
# TODO(b/165359991): Restore 'auto_attribs=True' once we drop Python3.5 support.
34+
@attr.s
3435
class ExecutionInfo:
3536
"""A struct to store information for an execution."""
3637
# The metadata of this execution that is registered in MLMD.
37-
execution_metadata: metadata_store_pb2.Execution = None
38+
execution_metadata = attr.ib(type=metadata_store_pb2.Execution, default=None)
3839
# The input map to feed to executor
39-
input_dict: Dict[Text, List[types.Artifact]] = None
40+
input_dict = attr.ib(type=Dict[str, List[types.Artifact]], default=None)
4041
# The output map to feed to executor
41-
output_dict: Dict[Text, List[types.Artifact]] = None
42+
output_dict = attr.ib(type=Dict[str, List[types.Artifact]], default=None)
4243
# The exec_properties to feed to executor
43-
exec_properties: Dict[Text, Any] = None
44+
exec_properties = attr.ib(type=Dict[str, Any], default=None)
4445
# The uri to executor result, note that Executors and Launchers may not run
4546
# in the same process, so executors should use this uri to "return"
4647
# ExecutorOutput to the launcher.
47-
executor_output_uri: Text = None
48+
executor_output_uri = attr.ib(type=str, default=None)
4849
# Stateful working dir will be deterministic given pipeline, node and run_id.
4950
# The typical usecase is to restore long running executor's state after
5051
# eviction. For examples, a Trainer can use this directory to store
5152
# checkpoints.
52-
stateful_working_dir: Text = None
53+
stateful_working_dir = attr.ib(type=str, default=None)
5354
# The config of this Node.
54-
pipeline_node: pipeline_pb2.PipelineNode = None
55+
pipeline_node = attr.ib(type=pipeline_pb2.PipelineNode, default=None)
5556
# The config of the pipeline that this node is running in.
56-
pipeline_info: pipeline_pb2.PipelineInfo = None
57+
pipeline_info = attr.ib(type=pipeline_pb2.PipelineInfo, default=None)
5758

5859

5960
class BaseExecutorOperator(six.with_metaclass(abc.ABCMeta, object)):

tfx/orchestration/portable/beam_dag_runner_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for tfx.orchestration.portable.beam_dag_runner."""
15-
15+
import os
1616
import mock
1717
import tensorflow as tf
18+
1819
from tfx.orchestration import metadata
1920
from tfx.orchestration.portable import beam_dag_runner
2021
from tfx.orchestration.portable import test_utils
@@ -46,8 +47,10 @@ def setUp(self):
4647
super(BeamDagRunnerTest, self).setUp()
4748
# Setup pipelines
4849
self._pipeline = pipeline_pb2.Pipeline()
49-
self.load_proto_from_text('pipeline_for_launcher_test.pbtxt',
50-
self._pipeline)
50+
self.load_proto_from_text(
51+
os.path.join(
52+
os.path.dirname(__file__), 'testdata',
53+
'pipeline_for_launcher_test.pbtxt'), self._pipeline)
5154

5255
@mock.patch.multiple(
5356
beam_dag_runner,

tfx/orchestration/portable/inputs_utils_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for tfx.orchestration.portable.inputs_utils."""
15+
import os
1516
import tensorflow as tf
1617

1718
from tfx import types
@@ -28,6 +29,10 @@
2829

2930
class InputsUtilsTest(test_utils.TfxTest):
3031

32+
def setUp(self):
33+
super().setUp()
34+
self._testdata_dir = os.path.join(os.path.dirname(__file__), 'testdata')
35+
3136
def testResolveParameters(self):
3237
parameters = pipeline_pb2.NodeParameters()
3338
text_format.Parse(
@@ -65,8 +70,9 @@ def testResolveParametersFail(self):
6570

6671
def testResolverInputsArtifacts(self):
6772
pipeline = pipeline_pb2.Pipeline()
68-
self.load_proto_from_text('pipeline_for_input_resolver_test.pbtxt',
69-
pipeline)
73+
self.load_proto_from_text(
74+
os.path.join(self._testdata_dir,
75+
'pipeline_for_input_resolver_test.pbtxt'), pipeline)
7076
my_example_gen = pipeline.nodes[0].pipeline_node
7177
another_example_gen = pipeline.nodes[1].pipeline_node
7278
my_transform = pipeline.nodes[2].pipeline_node

tfx/orchestration/portable/launcher.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,23 @@
4545
}
4646

4747

48-
@attr.s(auto_attribs=True)
48+
# TODO(b/165359991): Restore 'auto_attribs=True' once we drop Python3.5 support.
49+
@attr.s
4950
class _PrepareExecutionResult:
5051
"""A wrapper class using as the return value of _prepare_execution()."""
5152

5253
# The information used by executor operators.
53-
execution_info: Optional[base_executor_operator.ExecutionInfo]
54+
execution_info = attr.ib(
55+
type=base_executor_operator.ExecutionInfo, default=None)
5456
# Contexts of the execution, usually used by Publisher.
55-
contexts: List[metadata_store_pb2.Context]
57+
contexts = attr.ib(type=List[metadata_store_pb2.Context], default=None)
5658
# TODO(b/156126088): Update the following documentation when this bug is
5759
# closed.
5860
# Whether an execution is needed. An execution is not needed when:
5961
# 1) Not all the required input are ready.
6062
# 2) The input value doesn't meet the driver's requirement.
6163
# 3) Cache result is used.
62-
is_execution_needed: bool = False
64+
is_execution_needed = attr.ib(type=bool, default=False)
6365

6466

6567
class Launcher(object):
@@ -241,7 +243,7 @@ def _publish_failed_execution(
241243
metadata_handler=m, execution_id=execution_id, contexts=contexts)
242244

243245
def _clean_up(self, execution_info: base_executor_operator.ExecutionInfo):
244-
tf.io.gfile.remove(execution_info.stateful_working_dir)
246+
tf.io.gfile.rmtree(execution_info.stateful_working_dir)
245247

246248
def launch(self) -> Optional[metadata_store_pb2.Execution]:
247249
"""Executes the component, includes driver, executor and publisher.

tfx/orchestration/portable/launcher_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,14 @@ def setUp(self):
6767
connection_config.sqlite.SetInParent()
6868
self._mlmd_connection = metadata.Metadata(
6969
connection_config=connection_config)
70+
self._testdata_dir = os.path.join(os.path.dirname(__file__), 'testdata')
7071

7172
# Sets up pipelines
7273
pipeline = pipeline_pb2.Pipeline()
73-
self.load_proto_from_text('pipeline_for_launcher_test.pbtxt', pipeline)
74+
self.load_proto_from_text(
75+
os.path.join(
76+
os.path.dirname(__file__), 'testdata',
77+
'pipeline_for_launcher_test.pbtxt'), pipeline)
7478
self._pipeline_info = pipeline.pipeline_info
7579
self._pipeline_runtime_spec = pipeline.runtime_spec
7680
self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (

tfx/orchestration/portable/mlmd/context_lib_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for tfx.orchestration.portable.mlmd.context_lib."""
15+
import os
1516
import tensorflow as tf
17+
1618
from tfx.orchestration import metadata
1719
from tfx.orchestration.portable import test_utils
1820
from tfx.orchestration.portable.mlmd import context_lib
@@ -26,10 +28,14 @@ def setUp(self):
2628
super().setUp()
2729
self._connection_config = metadata_store_pb2.ConnectionConfig()
2830
self._connection_config.sqlite.SetInParent()
31+
self._testdata_dir = os.path.join(
32+
os.path.dirname(os.path.dirname(__file__)), 'testdata')
2933

3034
def testRegisterContexts(self):
3135
node_contexts = pipeline_pb2.NodeContexts()
32-
self.load_proto_from_text('node_context_spec.pbtxt', node_contexts)
36+
self.load_proto_from_text(
37+
os.path.join(self._testdata_dir, 'node_context_spec.pbtxt'),
38+
node_contexts)
3339
with metadata.Metadata(connection_config=self._connection_config) as m:
3440
context_lib.register_contexts_if_not_exists(
3541
metadata_handler=m, node_contexts=node_contexts)

tfx/orchestration/portable/mlmd/execution_lib_test.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -96,35 +96,17 @@ def testPrepareExecution(self):
9696
""", result)
9797

9898
def testArtifactAndEventPairs(self):
99-
model = standard_artifacts.Model()
100-
model.uri = 'model'
10199
example = standard_artifacts.Examples()
102100
example.uri = 'example'
103101
example.id = 1
104102

105-
expected_artifact_one = metadata_store_pb2.Artifact()
106-
expected_artifact_two = metadata_store_pb2.Artifact()
107-
text_format.Parse("""
108-
type_id: 1
109-
uri: 'model'""", expected_artifact_one)
103+
expected_artifact = metadata_store_pb2.Artifact()
110104
text_format.Parse(
111105
"""
112106
id: 1
113-
type_id: 2
114-
uri: 'example'""", expected_artifact_two)
115-
expected_event_one = metadata_store_pb2.Event()
116-
expected_event_two = metadata_store_pb2.Event()
117-
text_format.Parse(
118-
"""
119-
path {
120-
steps {
121-
key: 'model'
122-
}
123-
steps {
124-
index: 0
125-
}
126-
}
127-
type: INPUT""", expected_event_one)
107+
type_id: 1
108+
uri: 'example'""", expected_artifact)
109+
expected_event = metadata_store_pb2.Event()
128110
text_format.Parse(
129111
"""
130112
path {
@@ -135,18 +117,15 @@ def testArtifactAndEventPairs(self):
135117
index: 0
136118
}
137119
}
138-
type: INPUT""", expected_event_two)
120+
type: INPUT""", expected_event)
139121

140122
with metadata.Metadata(connection_config=self._connection_config) as m:
141123
result = execution_lib._create_artifact_and_event_pairs(
142124
m, {
143-
'model': [model],
144125
'example': [example],
145126
}, metadata_store_pb2.Event.INPUT)
146127

147-
self.assertListEqual([(expected_artifact_one, expected_event_one),
148-
(expected_artifact_two, expected_event_two)],
149-
result)
128+
self.assertCountEqual([(expected_artifact, expected_event)], result)
150129

151130
def testPutExecutionGraph(self):
152131
with metadata.Metadata(connection_config=self._connection_config) as m:

tfx/orchestration/portable/outputs_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def remove_output_dirs(output_dict: Dict[Text, List[types.Artifact]]) -> None:
3838
"""Remove dirs of output artifacts' URI."""
3939
for _, artifact_list in output_dict.items():
4040
for artifact in artifact_list:
41-
tf.io.gfile.remove(artifact.uri)
41+
if tf.io.gfile.isdir(artifact.uri):
42+
tf.io.gfile.rmtree(artifact.uri)
43+
else:
44+
tf.io.gfile.remove(artifact.uri)
4245

4346

4447
class OutputsResolver:

tfx/orchestration/portable/runtime_parameter_utils_test.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for tfx.orchestration.portable.runtime_parameter_utils."""
15+
import os
16+
1517
import tensorflow as tf
1618

17-
from ml_metadata.proto import metadata_store_pb2
1819
from tfx.orchestration.portable import runtime_parameter_utils
1920
from tfx.orchestration.portable import test_utils
2021
from tfx.proto.orchestration import pipeline_pb2
22+
from ml_metadata.proto import metadata_store_pb2
2123

2224

2325
class RuntimeParameterUtilsTest(test_utils.TfxTest):
@@ -26,13 +28,18 @@ def setUp(self):
2628
super().setUp()
2729
self._connection_config = metadata_store_pb2.ConnectionConfig()
2830
self._connection_config.sqlite.SetInParent()
31+
self._testdata_dir = os.path.join(os.path.dirname(__file__), 'testdata')
2932

3033
def testFullySubstituteRuntimeParameter(self):
3134
pipeline = pipeline_pb2.Pipeline()
3235
expected = pipeline_pb2.Pipeline()
33-
self.load_proto_from_text('pipeline_with_runtime_parameter.pbtxt', pipeline)
3436
self.load_proto_from_text(
35-
'pipeline_with_runtime_parameter_substituted.pbtxt', expected)
37+
os.path.join(self._testdata_dir,
38+
'pipeline_with_runtime_parameter.pbtxt'), pipeline)
39+
self.load_proto_from_text(
40+
os.path.join(self._testdata_dir,
41+
'pipeline_with_runtime_parameter_substituted.pbtxt'),
42+
expected)
3643
runtime_parameter_utils.substitute_runtime_parameter(
3744
pipeline, {
3845
'context_name_rp': 'my_context',
@@ -44,9 +51,14 @@ def testFullySubstituteRuntimeParameter(self):
4451
def testPartiallySubstituteRuntimeParameter(self):
4552
pipeline = pipeline_pb2.Pipeline()
4653
expected = pipeline_pb2.Pipeline()
47-
self.load_proto_from_text('pipeline_with_runtime_parameter.pbtxt', pipeline)
4854
self.load_proto_from_text(
49-
'pipeline_with_runtime_parameter_partially_substituted.pbtxt', expected)
55+
os.path.join(self._testdata_dir,
56+
'pipeline_with_runtime_parameter.pbtxt'), pipeline)
57+
self.load_proto_from_text(
58+
os.path.join(
59+
self._testdata_dir,
60+
'pipeline_with_runtime_parameter_partially_substituted.pbtxt'),
61+
expected)
5062
runtime_parameter_utils.substitute_runtime_parameter(
5163
pipeline, {
5264
'context_name_rp': 'my_context',
@@ -55,7 +67,9 @@ def testPartiallySubstituteRuntimeParameter(self):
5567

5668
def testSubstituteRuntimeParameterFail(self):
5769
pipeline = pipeline_pb2.Pipeline()
58-
self.load_proto_from_text('pipeline_with_runtime_parameter.pbtxt', pipeline)
70+
self.load_proto_from_text(
71+
os.path.join(self._testdata_dir,
72+
'pipeline_with_runtime_parameter.pbtxt'), pipeline)
5973
with self.assertRaisesRegex(RuntimeError, 'Runtime parameter type'):
6074
runtime_parameter_utils.substitute_runtime_parameter(
6175
pipeline,

tfx/orchestration/portable/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ def setUp(self):
3434
self._testMethodName)
3535
tf.io.gfile.makedirs(self.tmp_dir)
3636

37-
def load_proto_from_text(self, file_name: Text,
37+
def load_proto_from_text(self, path: Text,
3838
proto_message: message.Message) -> message.Message:
3939
"""Loads proto message from serialized text."""
40-
path = os.path.join(os.path.dirname(__file__), 'testdata', file_name)
4140
return io_utils.parse_pbtxt_file(path, proto_message)
4241

4342
def assertProtoPartiallyEquals(

0 commit comments

Comments
 (0)