Skip to content

Commit

Permalink
Added examples for tensorflow types in Datatypes and IO section
Browse files Browse the repository at this point in the history
Signed-off-by: sumana sree <sumanasree2705@gmail.com>
  • Loading branch information
sumana-2705 committed Oct 4, 2024
1 parent a1dde19 commit f6be5ad
Showing 1 changed file with 85 additions and 0 deletions.
85 changes: 85 additions & 0 deletions examples/data_types_and_io/data_types_and_io/tensorflow_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Tensorflow Model
import tensorflow as tf
from flytekit import task, workflow

@task
def train_model() -> tf.keras.Model:
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model

@task
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:
loss, accuracy = model.evaluate(x, y)
return accuracy

@workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:
model = train_model()
accuracy = evaluate_model(model=model, x=x, y=y)
return accuracy


# TensorFlow Record File
from flytekit.types.file import TFRecordFile
from flytekit import task, workflow

@task
def process_tfrecord(file: TFRecordFile) -> int:
dataset = tf.data.TFRecordDataset(file)
count = 0
for raw_record in dataset:
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
count += 1
return count

@workflow
def tfrecord_workflow(file: TFRecordFile) -> int:
return process_tfrecord(file=file)


# TensorFlow Records Directory
from flytekit.types.directory import TFRecordsDirectory
from flytekit import task, workflow
import os
import tensorflow as tf

@task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
files = [f.path for f in os.scandir(dir) if f.is_file()]
dataset = tf.data.TFRecordDataset(files)
count = 0
for raw_record in dataset:
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
count += 1
return count

@workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
return process_tfrecords_dir(dir=dir)

# TFRecordDatasetConfig
from flytekit.types.directory import TFRecordsDirectory
from flytekit import task, workflow
import os
import tensorflow as tf

@task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
files = [f.path for f in os.scandir(dir) if f.is_file()]
dataset = tf.data.TFRecordDataset(files)
count = 0
for raw_record in dataset:
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
count += 1
return count

@workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
return process_tfrecords_dir(dir=dir)

0 comments on commit f6be5ad

Please sign in to comment.