Skip to content

Commit

Permalink
Fix lint warning and import error in data_types_and_io tf example (#1762
Browse files Browse the repository at this point in the history
)

* Fix lint warning and import error in data_types_and_io tf example

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>

* Remove use of is_container in tensorflow_type.py example

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>

* Fix lint warning

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>

---------

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
eapolinario and eapolinario authored Oct 22, 2024
1 parent 90184f9 commit 9aadec2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 47 deletions.
98 changes: 52 additions & 46 deletions examples/data_types_and_io/data_types_and_io/tensorflow_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Import necessary libraries and modules

from flytekit import task, workflow
from flytekit import ImageSpec, task, workflow
from flytekit.types.directory import TFRecordsDirectory
from flytekit.types.file import TFRecordFile

Expand All @@ -9,48 +9,54 @@
registry="ghcr.io/flyteorg",
)

if custom_image.is_container():
import tensorflow as tf

# TensorFlow Model
@task
def train_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")]
)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model

@task
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:
loss, accuracy = model.evaluate(x, y)
return accuracy

@workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:
model = train_model()
return evaluate_model(model=model, x=x, y=y)

# TFRecord Files
@task
def process_tfrecord(file: TFRecordFile) -> int:
count = 0
for record in tf.data.TFRecordDataset(file):
count += 1
return count

@workflow
def tfrecord_workflow(file: TFRecordFile) -> int:
return process_tfrecord(file=file)

# TFRecord Directories
@task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
count = 0
for record in tf.data.TFRecordDataset(dir.path):
count += 1
return count

@workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
return process_tfrecords_dir(dir=dir)
import tensorflow as tf


# TensorFlow Model
@task
def train_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")]
)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model


@task
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:
loss, accuracy = model.evaluate(x, y)
return accuracy


@workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:
model = train_model()
return evaluate_model(model=model, x=x, y=y)


# TFRecord Files
@task
def process_tfrecord(file: TFRecordFile) -> int:
count = 0
for record in tf.data.TFRecordDataset(file):
count += 1
return count


@workflow
def tfrecord_workflow(file: TFRecordFile) -> int:
return process_tfrecord(file=file)


# TFRecord Directories
@task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
count = 0
for record in tf.data.TFRecordDataset(dir.path):
count += 1
return count


@workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
return process_tfrecords_dir(dir=dir)
1 change: 1 addition & 0 deletions examples/data_types_and_io/requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pandas
torch
tabulate
tensorflow
pyarrow
1 change: 0 additions & 1 deletion examples/kfmpi_plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,3 @@ If your MPI workflow hangs or times out, it may be caused by an incorrect workfl

1. Verify Registration Method:
When using a custom image, refer to the Flyte documentation on [Registering workflows](https://docs.flyte.org/en/latest/user_guide/flyte_fundamentals/registering_workflows.html#registration-patterns) to ensure you're following the correct registration method.

0 comments on commit 9aadec2

Please sign in to comment.