You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bh...@apache.org on 2022/04/12 20:55:16 UTC

[beam] branch master updated: [BEAM-13982] A base class for run inference (#16970)

This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new dffa7c160fb [BEAM-13982] A base class for run inference (#16970)
dffa7c160fb is described below

commit dffa7c160fb19728b2ed4d9462459915a9bf737a
Author: Ryan Thompson <ry...@gmail.com>
AuthorDate: Tue Apr 12 16:55:07 2022 -0400

    [BEAM-13982] A base class for run inference (#16970)
    
    * added initial commit
    
    * removed modified file
    
    * removed params that dont exist
    
    * added clock, removed generics that were causing pickle error, fixed metrics name
    
    * fixed class names removed class that goes in apis
    
    * added base test file
    
    * Added unit tests
    
    * reordered imports
    
    * replied to comments
    
    * apis to api
    
    * added license
    
    * added mock clock test for metrics, realized our metric wouldn't work right with a generator
    
    * Minor changes from Andys comments. Push metric namespace decision to modleLoader class
    
    * Update sdks/python/apache_beam/ml/inference/base.py
    
    typo fix valentyn's suggestion
    
    Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
    
    * updated with changes from valentyns comments
    
    * merged from tfx version
    
    * added comment
    
    * linted
    
    * changed import order for jenkins linter
    
    * added a bug to track a todo
    
    * fixed for Roberts comments
    
    * make clock and metrics collector private
    
    * make shared second parameter
    
    * mark RunInferenceDoFn private
    
    * moved initialization of shared.Shared into constructor
    
    * added todo
    
    * Update sdks/python/apache_beam/ml/inference/base.py
    
    Co-authored-by: Brian Hulette <hu...@gmail.com>
    
    * Update sdks/python/apache_beam/ml/inference/base.py
    
    Co-authored-by: Brian Hulette <hu...@gmail.com>
    
    * Update sdks/python/apache_beam/ml/inference/base.py
    
    Co-authored-by: Brian Hulette <hu...@gmail.com>
    
    * updated to correct variable names
    
    * udpated variable names
    
    * added typevar
    
    * remove unbatch
    
    * added note that users should expect changes
    
    * Update python container version
    
    * Add --dataflowServiceOptions=enable_prime to useUnifiedWorker conditions (#17213)
    
    * Add self-descriptive message for expected errors.
    
    Ideally we would not log these in the first place, but this is an easy hack.
    
    * [BEAM-10529] nullable xlang coder (#16923)
    
    * [BEAM-10529] add java and generic components of nullable xlang tests
    
    * [BEAM-10529] fix test case
    
    * [BEAM-10529] add coders and typehints to support nullable xlang coders
    
    * [BEAM-10529] update external builder to support nullable coder
    
    * [BEAM-10529] clean up coders.py
    
    * [BEAM-10529] add coder translation test
    
    * [BEAM-10529] add additional check to typecoder to not accidentally misidentify coders as nullable
    
    * [BEAM-10529] add test to retrieve nullable coder from typehint
    
    * [BEAM-10529] run spotless
    
    * [BEAM-10529] add go nullable coder
    
    * [BEAM-10529] cleanup extra println
    
    * [BEAM-10529] improve comments, clean up python
    
    * [BEAM-10529] remove changes to kafkaIO to simplify pr
    
    * [BEAM-10529] add coders to go exec, add asf license text
    
    * [BEAM-10529] clean up error handlign
    
    * [BEAM-10529] update go fromyaml to handle nullable cases
    
    * [BEAM-10529] add unit test, register nullable coder in dataflow.go
    
    * [BEAM-10529] remove mistaken commit
    
    * [BEAM-10529] add argument check to CoderTranslators
    
    * [BEAM-10529] Address python comments & cleanup
    
    * [BEAM-10529] address go comments
    
    * [BEAM-10529] remove extra check that was added in error
    
    * [BEAM-10529] fix typo
    
    * [BEAM-10529] re-order check for nonetype to prevent attribute errors
    
    * [BEAM-10529] change isinstance to ==
    
    * Fix go fmt break in core/typex/special.go (#17266)
    
    * [BEAM-8970] Add docs to run wordcount example on portable Spark Runner
    
    * [BEAM-8970] Add period to end of sentence
    
    * [BEAM-5436] Add doc page on Go cross compilation. (#17256)
    
    * Pr-bot Don't count all reviews as approvals (#17269)
    
    * Fix postcommits (#17263)
    
    * [BEAM-14241] Address staticcheck warnings in boot.go (#17264)
    
    * [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check (#17191)
    
    * [BEAM-14157] GrpcWindmillServer: Use stream specific boolean to do client closed check
    
    This is a follow up to #17162. An AbstractWindmillStream can have more than one grpc stream during its lifetime, new streams can be created after client closed for sending pending requests. So it is not correct to check `if(clientClosed)` in `send()`, this PR adds a new grpc stream level boolean to do the closed check in `send()`.
    
    * [BEAM-14157] Add unit test testing CommitWorkStream retries around stream closing
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] review comments
    
    * [BEAM-14157] fix test
    
    * [BEAM-14157] fix test
    
    Co-authored-by: Arun Pandian <pa...@google.com>
    
    * [BEAM-10582] Allow (and test) pyarrow 7 (#17229)
    
    * [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes being constructed. (#17240)
    
    * [BEAM-13519] Solve race issues when the server responds with an error before the GrpcStateClient finishes.
    
    The issue was that the InboundObserver can be invoked before outboundObserverFactory#outboundObserverFor returns meaning that
    the server is waiting for a response for cache.remove but cache.computeIfAbsent is being invoked at the same time.
    
    Another issue was that the outstandingRequests map could be updated with another request within GrpcStateClient during closeAndCleanup meaning that the CompleteableFuture would never be completed exceptionally.
    
    Passes 1000 times locally now without getting stuck or failing.
    
    * [BEAM-14256] update SpEL dependency to 5.3.18.RELEASE
    
    * [BEAM-14256] remove .RELEASE
    
    * [BEAM-13015] Disable retries for fnapi grpc channels which otherwise defaults on. (#17243)
    
    * [BEAM-13015] Disable retries for grpc channels which otherwise default to true.
    
    Since the channel is to the local runner process, retries are not expected to
    help. This simplifies the grpc stream stack to not involve a RetryStream object.
    
    * fixup comment
    
    * Update sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java
    
    * Update sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/channel/ManagedChannelFactory.java
    
    Co-authored-by: Lukasz Cwik <lc...@google.com>
    
    * [BEAM-9649] Add region option to Mongo Dataflow test.
    
    * Fix dependency issue causing failures
    
    * [BEAM-13952] Sickbay testAfterProcessingTimeContinuationTriggerUsingState (#17214)
    
    * BEAM-14235 parquetio module does not parse PEP-440 compliant Pyarrow version (#17275)
    
    * Update parquetio.py
    
    * Update CHANGES.md
    
    * Fix import order
    
    * [BEAM-14250] Fix coder registration for types defined in __main__.
    
    Until all runners are portable and we can get rid of all round trips
    between Pipeline and proto representatons, register types in __main__
    according to their string representations as pickling does not
    preserve identity.
    
    * Allow get_coder(None).
    
    Co-authored-by: Andy Ye <an...@gmail.com>
    
    * [Website] Contribution guide page indent bug fix (#17287)
    
    * Fix markdown indent issue in Development Setup section
    
    * update query
    
    * [BEAM-10976] Document go sdk bundle finalization (#17048)
    
    * [BEAM-13829] Expose status API from Go SDK Harness (#16957)
    
    * Avoid pr-bot state desync (#17299)
    
    * [BEAM-14259] Clean up staticcheck warnings in the exec package (#17285)
    
    * Minor: Prefer registered schema in SQL docs (#17298)
    
    * Prefer registered schema in SQL docs
    
    * address review comments
    
    * [Playground] add meta tags (#17207)
    
    * playground add meta tags
    
    * playground fix meta tags
    
    * fixes golint and deprecated issues in recent Go SDK import (#17304)
    
    * [BEAM-14262] Update plugins for Dockerized Jenkins.
    
    I copied the list from the cwiki and removed all of the ones that failed to install. https://cwiki.apache.org/confluence/display/INFRA/ci-beam.apache.org
    
    * Add ansicolor and ws-cleanup plugins.
    
    Without them, the seed job prints warnings:
    
    Warning: (CommonJobProperties.groovy, line 107) plugin 'ansicolor' needs to be installed
    Warning: (CommonJobProperties.groovy, line 113) plugin 'ws-cleanup' needs to be installed
    
    * [BEAM-14266] Replace deprecated ptypes package uses (#17302)
    
    * [BEAM-11936] Fix rawtypes warnings in SnowflakeIO (#17257)
    
    * [BEAM-10556] Fix rawtypes warnings in SnowflakeIO
    
    * fixup! [BEAM-10556] Fix rawtypes warnings in SnowflakeIO
    
    * Merge pull request #17262: [BEAM-14244] Use the supplied output timestamp for processing time timers rather than the input watermark
    
    * removed unused typing
    
    * added list typing
    
    * linted
    
    Co-authored-by: tvalentyn <tv...@users.noreply.github.com>
    Co-authored-by: Brian Hulette <hu...@gmail.com>
    Co-authored-by: kileys <ki...@google.com>
    Co-authored-by: Yichi Zhang <zy...@google.com>
    Co-authored-by: Kyle Weaver <kc...@google.com>
    Co-authored-by: johnjcasey <95...@users.noreply.github.com>
    Co-authored-by: Jack McCluskey <34...@users.noreply.github.com>
    Co-authored-by: Benjamin Gonzalez <be...@wizeline.com>
    Co-authored-by: Robert Burke <lo...@users.noreply.github.com>
    Co-authored-by: Danny McCormick <da...@google.com>
    Co-authored-by: Arun Pandian <ar...@gmail.com>
    Co-authored-by: Arun Pandian <pa...@google.com>
    Co-authored-by: Brian Hulette <bh...@google.com>
    Co-authored-by: Lukasz Cwik <lc...@google.com>
    Co-authored-by: johnjcasey <jo...@google.com>
    Co-authored-by: scwhittle <sc...@users.noreply.github.com>
    Co-authored-by: Arwin Tio <ar...@adroll.com>
    Co-authored-by: Robert Bradshaw <ro...@gmail.com>
    Co-authored-by: Andy Ye <an...@gmail.com>
    Co-authored-by: Yi Hu <ya...@google.com>
    Co-authored-by: Michael Li <bi...@google.com>
    Co-authored-by: Ritesh Ghorse <ri...@gmail.com>
    Co-authored-by: Aydar Farrakhov <st...@gmail.com>
    Co-authored-by: Kamil BreguĊ‚a <ka...@snowflake.com>
    Co-authored-by: Steven Niemitz <st...@gmail.com>
---
 sdks/python/apache_beam/ml/inference/base.py      | 263 ++++++++++++++++++++++
 sdks/python/apache_beam/ml/inference/base_test.py | 147 ++++++++++++
 2 files changed, 410 insertions(+)

diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py
new file mode 100644
index 00000000000..80490285edd
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -0,0 +1,263 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""An extensible run inference transform.
+
+Users of this module can extend the ModelLoader class for any MLframework. Then
+pass their extended ModelLoader object into RunInference to create a
+RunInference Beam transform for that framework.
+
+The transform will handle standard inference functionality like metric
+collection, sharing model between threads and batching elements.
+
+Note: This module is still actively being developed and users should have
+no expectation that these interfaces will not change.
+"""
+
+import logging
+import os
+import pickle
+import platform
+import sys
+import time
+from typing import Any
+from typing import Generic
+from typing import Iterable
+from typing import List
+from typing import TypeVar
+
+import apache_beam as beam
+from apache_beam.utils import shared
+
+try:
+  # pylint: disable=g-import-not-at-top
+  import resource
+except ImportError:
+  resource = None
+
+_MICROSECOND_TO_MILLISECOND = 1000
+_NANOSECOND_TO_MICROSECOND = 1000
+_SECOND_TO_MICROSECOND = 1_000_000
+
+T = TypeVar('T')
+
+
+class InferenceRunner():
+  """Implements running inferences for a framework."""
+  def run_inference(self, batch: List[Any], model: Any) -> Iterable[Any]:
+    """Runs inferences on a batch of examples and returns an Iterable of Predictions."""
+    raise NotImplementedError(type(self))
+
+  def get_num_bytes(self, batch: Any) -> int:
+    """Returns the number of bytes of data for a batch."""
+    return len(pickle.dumps(batch))
+
+  def get_metrics_namespace(self) -> str:
+    """Returns a namespace for metrics collected by the RunInference transform."""
+    return 'RunInference'
+
+
+class ModelLoader(Generic[T]):
+  """Has the ability to load an ML model."""
+  def load_model(self) -> T:
+    """Loads and initializes a model for processing."""
+    raise NotImplementedError(type(self))
+
+  def get_inference_runner(self) -> InferenceRunner:
+    """Returns an implementation of InferenceRunner for this model."""
+    raise NotImplementedError(type(self))
+
+
+class RunInference(beam.PTransform):
+  """An extensible transform for running inferences."""
+  def __init__(self, model_loader: ModelLoader, clock=None):
+    self._model_loader = model_loader
+    self._clock = clock
+
+  # TODO(BEAM-14208): Add batch_size back off in the case there
+  # are functional reasons large batch sizes cannot be handled.
+  def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
+    return (
+        pcoll
+        # TODO(BEAM-14044): Hook into the batching DoFn APIs.
+        | beam.BatchElements()
+        | beam.ParDo(_RunInferenceDoFn(self._model_loader, self._clock)))
+
+
+class _MetricsCollector:
+  """A metrics collector that tracks ML related performance and memory usage."""
+  def __init__(self, namespace: str):
+    # Metrics
+    self._inference_counter = beam.metrics.Metrics.counter(
+        namespace, 'num_inferences')
+    self._inference_request_batch_size = beam.metrics.Metrics.distribution(
+        namespace, 'inference_request_batch_size')
+    self._inference_request_batch_byte_size = (
+        beam.metrics.Metrics.distribution(
+            namespace, 'inference_request_batch_byte_size'))
+    # Batch inference latency in microseconds.
+    self._inference_batch_latency_micro_secs = (
+        beam.metrics.Metrics.distribution(
+            namespace, 'inference_batch_latency_micro_secs'))
+    self._model_byte_size = beam.metrics.Metrics.distribution(
+        namespace, 'model_byte_size')
+    # Model load latency in milliseconds.
+    self._load_model_latency_milli_secs = beam.metrics.Metrics.distribution(
+        namespace, 'load_model_latency_milli_secs')
+
+    # Metrics cache
+    self._load_model_latency_milli_secs_cache = None
+    self._model_byte_size_cache = None
+
+  def update_metrics_with_cache(self):
+    if self._load_model_latency_milli_secs_cache is not None:
+      self._load_model_latency_milli_secs.update(
+          self._load_model_latency_milli_secs_cache)
+      self._load_model_latency_milli_secs_cache = None
+    if self._model_byte_size_cache is not None:
+      self._model_byte_size.update(self._model_byte_size_cache)
+      self._model_byte_size_cache = None
+
+  def cache_load_model_metrics(self, load_model_latency_ms, model_byte_size):
+    self._load_model_latency_milli_secs_cache = load_model_latency_ms
+    self._model_byte_size_cache = model_byte_size
+
+  def update(
+      self,
+      examples_count: int,
+      examples_byte_size: int,
+      latency_micro_secs: int):
+    self._inference_batch_latency_micro_secs.update(latency_micro_secs)
+    self._inference_counter.inc(examples_count)
+    self._inference_request_batch_size.update(examples_count)
+    self._inference_request_batch_byte_size.update(examples_byte_size)
+
+
+class _RunInferenceDoFn(beam.DoFn):
+  """A DoFn implementation generic to frameworks."""
+  def __init__(self, model_loader: ModelLoader, clock=None):
+    self._model_loader = model_loader
+    self._inference_runner = model_loader.get_inference_runner()
+    self._shared_model_handle = shared.Shared()
+    self._metrics_collector = _MetricsCollector(
+        self._inference_runner.get_metrics_namespace())
+    self._clock = clock
+    if not clock:
+      self._clock = _ClockFactory.make_clock()
+    self._model = None
+
+  def _load_model(self):
+    def load():
+      """Function for constructing shared LoadedModel."""
+      memory_before = _get_current_process_memory_in_bytes()
+      start_time = self._clock.get_current_time_in_microseconds()
+      model = self._model_loader.load_model()
+      end_time = self._clock.get_current_time_in_microseconds()
+      memory_after = _get_current_process_memory_in_bytes()
+      load_model_latency_ms = ((end_time - start_time) /
+                               _MICROSECOND_TO_MILLISECOND)
+      model_byte_size = memory_after - memory_before
+      self._metrics_collector.cache_load_model_metrics(
+          load_model_latency_ms, model_byte_size)
+      return model
+
+    # TODO(BEAM-14207): Investigate releasing model.
+    return self._shared_model_handle.acquire(load)
+
+  def setup(self):
+    self._model = self._load_model()
+
+  def process(self, batch):
+    # Process supports both keyed data, and example only data.
+    # First keys and samples are separated (if there are keys)
+    has_keys = isinstance(batch[0], tuple)
+    if has_keys:
+      examples = [example for _, example in batch]
+      keys = [key for key, _ in batch]
+    else:
+      examples = batch
+      keys = None
+
+    start_time = self._clock.get_current_time_in_microseconds()
+    result_generator = self._inference_runner.run_inference(
+        examples, self._model)
+    predictions = list(result_generator)
+
+    inference_latency = self._clock.get_current_time_in_microseconds(
+    ) - start_time
+    num_bytes = self._inference_runner.get_num_bytes(examples)
+    num_elements = len(batch)
+    self._metrics_collector.update(num_elements, num_bytes, inference_latency)
+
+    # Keys are recombined with predictions in the RunInference PTransform.
+    if has_keys:
+      yield from zip(keys, predictions)
+    else:
+      yield from predictions
+
+  def finish_bundle(self):
+    # TODO(BEAM-13970): Figure out why there is a cache.
+    self._metrics_collector.update_metrics_with_cache()
+
+
+def _is_darwin() -> bool:
+  return sys.platform == 'darwin'
+
+
+def _get_current_process_memory_in_bytes():
+  """Returns memory usage in bytes."""
+
+  if resource is not None:
+    usage = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
+    if _is_darwin():
+      return usage
+    return usage * 1024
+  else:
+    logging.warning(
+        'Resource module is not available for current platform, '
+        'memory usage cannot be fetched.')
+  return 0
+
+
+def _is_windows() -> bool:
+  return platform.system() == 'Windows' or os.name == 'nt'
+
+
+def _is_cygwin() -> bool:
+  return platform.system().startswith('CYGWIN_NT')
+
+
+class _Clock(object):
+  def get_current_time_in_microseconds(self) -> int:
+    return int(time.time() * _SECOND_TO_MICROSECOND)
+
+
+class _FineGrainedClock(_Clock):
+  def get_current_time_in_microseconds(self) -> int:
+    return int(
+        time.clock_gettime_ns(time.CLOCK_REALTIME) /  # pytype: disable=module-attr
+        _NANOSECOND_TO_MICROSECOND)
+
+
+#TODO(BEAM-14255): Research simplifying the internal clock and just using time.
+class _ClockFactory(object):
+  @staticmethod
+  def make_clock() -> _Clock:
+    if (hasattr(time, 'clock_gettime_ns') and not _is_windows() and
+        not _is_cygwin()):
+      return _FineGrainedClock()
+    return _Clock()
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py
new file mode 100644
index 00000000000..ab7bb16383e
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -0,0 +1,147 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.ml.base."""
+
+import pickle
+import unittest
+from typing import Any
+from typing import Iterable
+
+import apache_beam as beam
+import apache_beam.ml.inference.base as base
+from apache_beam.metrics.metric import MetricsFilter
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+
+class FakeModel:
+  def predict(self, example: int) -> int:
+    return example + 1
+
+
+class FakeInferenceRunner(base.InferenceRunner):
+  def __init__(self, clock=None):
+    self._mock_clock = clock
+
+  def run_inference(self, batch: Any, model: Any) -> Iterable[Any]:
+    if self._mock_clock:
+      self._mock_clock.current_time += 3000
+    for example in batch:
+      yield model.predict(example)
+
+
+class FakeModelLoader(base.ModelLoader):
+  def __init__(self, clock=None):
+    self._mock_clock = clock
+
+  def load_model(self):
+    if self._mock_clock:
+      self._mock_clock.current_time += 50000
+    return FakeModel()
+
+  def get_inference_runner(self):
+    return FakeInferenceRunner(self._mock_clock)
+
+
+class MockClock(base._Clock):
+  def __init__(self):
+    self.current_time = 10000
+
+  def get_current_time_in_microseconds(self) -> int:
+    return self.current_time
+
+
+class ExtractInferences(beam.DoFn):
+  def process(self, prediction_result):
+    yield prediction_result.inference
+
+
+class RunInferenceBaseTest(unittest.TestCase):
+  def test_run_inference_impl_simple_examples(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(FakeModelLoader())
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_with_keyed_examples(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [(i, example + 1) for i, example in enumerate(examples)]
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      actual = pcoll | base.RunInference(FakeModelLoader())
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_counted_metrics(self):
+    pipeline = TestPipeline()
+    examples = [1, 5, 3, 10]
+    pcoll = pipeline | 'start' >> beam.Create(examples)
+    _ = pcoll | base.RunInference(FakeModelLoader())
+    run_result = pipeline.run()
+    run_result.wait_until_finish()
+
+    metric_results = (
+        run_result.metrics().query(MetricsFilter().with_name('num_inferences')))
+    num_inferences_counter = metric_results['counters'][0]
+    self.assertEqual(num_inferences_counter.committed, 4)
+
+    inference_request_batch_size = run_result.metrics().query(
+        MetricsFilter().with_name('inference_request_batch_size'))
+    self.assertTrue(inference_request_batch_size['distributions'])
+    self.assertEqual(
+        inference_request_batch_size['distributions'][0].result.sum, 4)
+    inference_request_batch_byte_size = run_result.metrics().query(
+        MetricsFilter().with_name('inference_request_batch_byte_size'))
+    self.assertTrue(inference_request_batch_byte_size['distributions'])
+    self.assertGreaterEqual(
+        inference_request_batch_byte_size['distributions'][0].result.sum,
+        len(pickle.dumps(examples)))
+    inference_request_batch_byte_size = run_result.metrics().query(
+        MetricsFilter().with_name('model_byte_size'))
+    self.assertTrue(inference_request_batch_byte_size['distributions'])
+
+  def test_timing_metrics(self):
+    pipeline = TestPipeline()
+    examples = [1, 5, 3, 10]
+    pcoll = pipeline | 'start' >> beam.Create(examples)
+    mock_clock = MockClock()
+    _ = pcoll | base.RunInference(
+        FakeModelLoader(clock=mock_clock), clock=mock_clock)
+    res = pipeline.run()
+    res.wait_until_finish()
+
+    metric_results = (
+        res.metrics().query(
+            MetricsFilter().with_name('inference_batch_latency_micro_secs')))
+    batch_latency = metric_results['distributions'][0]
+    self.assertEqual(batch_latency.result.count, 3)
+    self.assertEqual(batch_latency.result.mean, 3000)
+
+    metric_results = (
+        res.metrics().query(
+            MetricsFilter().with_name('load_model_latency_milli_secs')))
+    load_model_latency = metric_results['distributions'][0]
+    self.assertEqual(load_model_latency.result.count, 1)
+    self.assertEqual(load_model_latency.result.mean, 50)
+
+
+if __name__ == '__main__':
+  unittest.main()