You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "jaxpr (via GitHub)" <gi...@apache.org> on 2023/04/03 12:21:33 UTC

[GitHub] [beam] jaxpr commented on a diff in pull request #25905: Added windowin example

jaxpr commented on code in PR #25905:
URL: https://github.com/apache/beam/pull/25905#discussion_r1155886350


##########
sdks/python/apache_beam/examples/inference/milk_quality_prediction_windowing.py:
##########
@@ -0,0 +1,150 @@
+import argparse
+import logging
+import time
+from typing import NamedTuple
+
+import pandas
+import pandas as pd
+import xgboost
+
+from sklearn.model_selection import train_test_split
+
+import apache_beam as beam
+from apache_beam.ml.inference import RunInference
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.xgboost_inference import XGBoostModelHandlerPandas
+from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.testing.test_stream import TestStream
+from apache_beam import window
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--input_data',
+      dest='input_data',
+      required=True,
+      help='Path to the csv containing the dataset.')
+  parser.add_argument(
+      '--model_state',
+      dest='model_state',
+      required=True,
+      help='Path to the state of the XGBoost model loaded for Inference.')
+  return parser.parse_known_args(argv)
+
+
+def train_model(
+    samples: pandas.DataFrame,
+    labels: pandas.DataFrame,
+    model_state_output_path: str):
+  """Function to train the XGBoost model.
+    Args:
+      samples: Dataframe contiaing the training data
+      labels: Dataframe containing the labels for the training data
+      model_state_output_path: Path to store the trained model
+  """
+  xgb = xgboost.XGBClassifier(max_depth=3)
+  xgb.fit(samples, labels)
+  xgb.save_model(model_state_output_path)
+  return xgb
+
+
+class MilkQualityAggregation(NamedTuple):
+  bad_quality_measurements: int
+  medium_quality_measurements: int
+  high_quality_measurements: int
+
+
+class AggregateMilkQualityResults(beam.CombineFn):
+  """Simple aggregation to keep track of the number of samples with good, bad and medium quality milk."""
+  def create_accumulator(self):
+    return MilkQualityAggregation(0, 0, 0)
+
+  def add_input(
+      self, accumulator: MilkQualityAggregation, element: PredictionResult):
+    quality = element.inference[0]
+    if quality == 0:
+      return MilkQualityAggregation(
+          accumulator.bad_quality_measurements + 1,
+          accumulator.medium_quality_measurements,
+          accumulator.high_quality_measurements)
+    elif quality == 1:
+      return MilkQualityAggregation(
+          accumulator.bad_quality_measurements,
+          accumulator.medium_quality_measurements + 1,
+          accumulator.high_quality_measurements)
+    else:
+      return MilkQualityAggregation(
+          accumulator.bad_quality_measurements,
+          accumulator.medium_quality_measurements,
+          accumulator.high_quality_measurements + 1)
+
+  def merge_accumulators(self, accumulators: MilkQualityAggregation):
+    return MilkQualityAggregation(
+        sum(
+            aggregation.bad_quality_measurements
+            for aggregation in accumulators),
+        sum(
+            aggregation.medium_quality_measurements
+            for aggregation in accumulators),
+        sum(
+            aggregation.high_quality_measurements
+            for aggregation in accumulators),
+    )
+
+  def extract_output(self, accumulator: MilkQualityAggregation):
+    return accumulator
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+    Args:
+      argv: Command line arguments defined for this example.
+      save_main_session: Used for internal testing.
+      test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  df = pd.read_csv(known_args.input_data)
+  df['Grade'].replace(['low', 'medium', 'high'], [0, 1, 2], inplace=True)
+  x = df.drop(columns=['Grade'])
+  y = df['Grade']
+  x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.60, random_state=99)

Review Comment:
   I updated the code a bit so it should be more clear how the data is used. We use the training set to train the model and the test set as input data for the streaming pipeline.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org