Skip to content

Commit

Permalink
Make example generation pipeline output Parquet files in addition to …
Browse files Browse the repository at this point in the history
…TFRecords.

This allows for more efficient searching of specific examples.

Also add a stand-alone tool to convert existing TFRecords to Parquet files.

PiperOrigin-RevId: 688804769
  • Loading branch information
jzxu authored and copybara-github committed Oct 23, 2024
1 parent 6819943 commit 5bfb8c7
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 24 deletions.
150 changes: 129 additions & 21 deletions src/skai/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import geopandas as gpd
import numpy as np
from openlocationcode import openlocationcode
import pyarrow
import shapely.geometry
import shapely.wkb
from skai import beam_utils
Expand Down Expand Up @@ -119,6 +120,7 @@ class ExamplesGenerationConfig:
cloud_region: If using Dataflow, the Cloud region to run under.
use_dataflow: If true, execute pipeline in Cloud Dataflow.
output_metadata_file: Output a CSV metadata file for all generated examples.
output_parquet: Output a Parquet file for all generated examples.
worker_service_account: If using Dataflow, the service account to run as.
min_dataflow_workers: If using Dataflow, the minimum number of workers to
instantiate.
Expand Down Expand Up @@ -173,14 +175,15 @@ class ExamplesGenerationConfig:
cloud_region: Optional[str] = None
use_dataflow: bool = False
output_metadata_file: bool = True
output_parquet: bool = False
worker_service_account: Optional[str] = None
min_dataflow_workers: int = 10
max_dataflow_workers: int = 20
example_patch_size: int = 64
large_patch_size: int = 256
resolution: float = 0.5
output_shards: int = 20
gdal_env: List[str] = dataclasses.field(default_factory=list)
gdal_env: list[str] = dataclasses.field(default_factory=list)
buildings_method: str = 'file' # file, open_street_map, open_buildings, none
buildings_file: Optional[str] = None
overpass_url: Optional[str] = 'https://lz4.overpass-api.de/api/interpreter'
Expand Down Expand Up @@ -878,7 +881,9 @@ def _generate_examples_pipeline(
max_workers: int,
wait_for_dataflow_job: bool,
cloud_detector_model_path: Optional[str],
output_metadata_file: bool) -> None:
output_metadata_file: bool,
output_parquet: bool,
) -> None:
"""Runs example generation pipeline.
Args:
Expand All @@ -903,8 +908,8 @@ def _generate_examples_pipeline(
wait_for_dataflow_job: If true, wait for dataflow job to complete before
returning.
cloud_detector_model_path: Path to tflite cloud detector model.
output_metadata_file: Enable true to generate a file of example metadata, or
disable to skip this step.
output_metadata_file: If true, write a CSV file containing example metadata.
output_parquet: If true, write out examples in Parquet format.
"""

temp_dir = os.path.join(output_dir, 'temp')
Expand All @@ -925,9 +930,15 @@ def _generate_examples_pipeline(
if buildings_labeled:
examples_output_prefix = (
os.path.join(output_dir, 'examples', 'labeled-large', 'labeled'))
parquet_prefix = os.path.join(
output_dir, 'examples', 'labeled-parquet', 'examples'
)
else:
examples_output_prefix = (
os.path.join(output_dir, 'examples', 'unlabeled-large', 'unlabeled'))
parquet_prefix = os.path.join(
output_dir, 'examples', 'unlabeled-parquet', 'examples'
)

pipeline = beam.Pipeline(options=pipeline_options)
examples = _generate_examples(
Expand All @@ -944,10 +955,13 @@ def _generate_examples_pipeline(
file_name_suffix='.tfrecord',
num_shards=num_output_shards))

if output_parquet:
_write_examples_to_parquet(examples, parquet_prefix)

if output_metadata_file:
rows = (
examples
| 'extract_metadata_rows' >> beam.Map(_get_example_metadata)
| 'extract_example_metadata' >> beam.Map(_get_example_metadata)
| 'remove_duplicates' >> beam.Distinct()
)
df = apache_beam.dataframe.convert.to_dataframe(rows)
Expand Down Expand Up @@ -1086,12 +1100,42 @@ def run_example_generation(
config.max_dataflow_workers,
wait_for_dataflow,
config.cloud_detector_model_path,
config.output_metadata_file
config.output_metadata_file,
config.output_parquet,
)


class ExampleType(typing.NamedTuple):
def _example_to_dict(
e: tf.train.Example, include_images: bool
) -> dict[str, Any]:
"""Extracts features from an Example into a dict."""
Metrics.counter('skai', 'examples_processed').inc()
longitude, latitude = utils.get_float_feature(e, 'coordinates')
features = {
'int64_id': utils.get_int64_feature(e, 'int64_id')[0],
'example_id': utils.get_bytes_feature(e, 'example_id')[0].decode(),
'encoded_coordinates': utils.get_bytes_feature(e, 'encoded_coordinates')[
0
].decode(),
'longitude': longitude,
'latitude': latitude,
'pre_image_id': utils.get_bytes_feature(e, 'pre_image_id')[0].decode(),
'post_image_id': utils.get_bytes_feature(e, 'post_image_id')[0].decode(),
'plus_code': utils.get_bytes_feature(e, 'plus_code')[0].decode(),
}
if include_images:
features['pre_image_png_large'] = utils.get_bytes_feature(
e, 'pre_image_png_large'
)[0]
features['post_image_png_large'] = utils.get_bytes_feature(
e, 'post_image_png_large'
)[0]
return features


class ExampleMetadata(typing.NamedTuple):
example_id: str
int64_id: int
encoded_coordinates: str
longitude: float
latitude: float
Expand All @@ -1100,18 +1144,82 @@ class ExampleType(typing.NamedTuple):
plus_code: str


@beam.typehints.with_output_types(ExampleType)
def _get_example_metadata(example: tf.train.Example) -> ExampleType:
return ExampleType(
example_id=utils.get_bytes_feature(example, 'example_id')[0].decode(),
encoded_coordinates=utils.get_bytes_feature(
example, 'encoded_coordinates'
)[0].decode(),
longitude=utils.get_float_feature(example, 'coordinates')[0],
latitude=utils.get_float_feature(example, 'coordinates')[1],
post_image_id=utils.get_bytes_feature(example, 'post_image_id')[
0
].decode(),
pre_image_id=utils.get_bytes_feature(example, 'pre_image_id')[0].decode(),
plus_code=utils.get_bytes_feature(example, 'plus_code')[0].decode(),
def _get_example_metadata(example: tf.train.Example) -> ExampleMetadata:
return ExampleMetadata(**_example_to_dict(example, False))


def _write_examples_to_parquet(
examples: beam.PCollection,
parquet_prefix: str,
) -> None:
"""Writes a PCollection of Examples into parquet files.
Args:
examples: PCollection of examples.
parquet_prefix: Path prefix for output Parquet files.
"""
schema = pyarrow.schema([
('int64_id', pyarrow.int64()),
('example_id', pyarrow.string()),
('encoded_coordinates', pyarrow.string()),
('longitude', pyarrow.float64()),
('latitude', pyarrow.float64()),
('pre_image_id', pyarrow.string()),
('post_image_id', pyarrow.string()),
('plus_code', pyarrow.string()),
('pre_image_png_large', pyarrow.binary()),
('post_image_png_large', pyarrow.binary()),
])
_ = (
examples
| 'extract_features_to_dict' >> beam.Map(_example_to_dict, True)
| 'write_parquet'
>> beam.io.parquetio.WriteToParquet(
parquet_prefix,
schema=schema,
codec='snappy',
file_name_suffix='.parquet',
row_group_buffer_size=1,
record_batch_size=100,
num_shards=100,
)
)


def convert_tfrecords_to_parquet(
tfrecords_pattern: str,
parquet_prefix: str,
project: str,
region: str,
service_account: str,
temp_dir: str,
) -> None:
"""Converts TFRecords to Parquet format.
Args:
tfrecords_pattern: Pattern matching input TFRecords.
parquet_prefix: Path prefix for output Parquet files.
project: GCP project.
region: GCP region.
service_account: Service account to run Dataflow job.
temp_dir: Beam temporary directory path.
"""
pipeline_options = beam_utils.get_pipeline_options(
True,
'convert-tfrecords-to-parquet',
project,
region,
temp_dir,
10,
100,
service_account,
machine_type=None,
accelerator=None,
accelerator_count=0,
)
pipeline = beam.Pipeline(options=pipeline_options)
examples = pipeline | 'read_tfrecords' >> beam.io.tfrecordio.ReadFromTFRecord(
tfrecords_pattern, coder=beam.coders.ProtoCoder(tf.train.Example)
)
_write_examples_to_parquet(examples, parquet_prefix)
_ = pipeline.run()
7 changes: 4 additions & 3 deletions src/skai/generate_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for generate_examples.py."""

import glob
import os
import pathlib
Expand Down Expand Up @@ -482,7 +480,9 @@ def testGenerateExamplesPipeline(self):
max_workers=0,
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
output_metadata_file=False)
output_metadata_file=False,
output_parquet=False,
)

tfrecords = os.listdir(
os.path.join(output_dir, 'examples', 'unlabeled-large')
Expand Down Expand Up @@ -522,6 +522,7 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
output_metadata_file=True,
output_parquet=False,
)

tfrecords = os.listdir(
Expand Down
39 changes: 39 additions & 0 deletions src/tools/tfrecords_to_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Converts TFRecords to Parquet files.
This allows for more efficient searching of specific examples.
"""

from typing import Sequence

from absl import app
from absl import flags

from skai import generate_examples

FLAGS = flags.FLAGS
flags.DEFINE_string('examples_pattern', None, 'Examples pattern', required=True)
flags.DEFINE_string('output_prefix', None, 'Output prefix.', required=True)
flags.DEFINE_string('cloud_project', None, 'GCP project name.')
flags.DEFINE_string('cloud_region', None, 'GCP region, e.g. us-central1.')
flags.DEFINE_string(
'worker_service_account', None,
'Service account that will launch Dataflow workers. If unset, workers will '
'run with the project\'s default Compute Engine service account.')
flags.DEFINE_string('temp_dir', None, 'Temporary directory for Beam.')


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')

generate_examples.convert_tfrecords_to_parquet(
FLAGS.examples_pattern,
FLAGS.output_prefix,
FLAGS.cloud_project,
FLAGS.cloud_region,
FLAGS.worker_service_account,
FLAGS.temp_dir,
)

if __name__ == '__main__':
app.run(main)

0 comments on commit 5bfb8c7

Please sign in to comment.