Skip to content

Commit

Permalink
Support multiple attributes needing label selection (#916)
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasbansal authored Oct 2, 2024
1 parent 5d266f2 commit eb145aa
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 34 deletions.
24 changes: 12 additions & 12 deletions src/autolabel/few_shot/fixed_example_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ def select_examples(
**kwargs,
) -> List[dict]:
"""Select which examples to use based on the input lengths."""
label_column = kwargs.get("label_column")
selected_labels = kwargs.get("selected_labels")
selected_labels_map = kwargs.get("selected_labels_map")

if not selected_labels:
return self.examples[: self.k]

if not label_column:
print("No label column provided, returning all examples")
if not selected_labels_map:
return self.examples[: self.k]

# get the examples where label matches the selected labels
valid_examples = [
example
for example in self.examples
if example.get(label_column) in selected_labels
]
valid_examples = []
for example in self.examples:
valid = True
for label_column, selected_labels in selected_labels_map.items():
if example.get(label_column) not in selected_labels:
valid = False
break
if valid:
valid_examples.append(example)

return valid_examples[: min(self.k, len(valid_examples))]

@classmethod
Expand Down
46 changes: 28 additions & 18 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
config: Union[AutolabelConfig, str, dict],
cache: Optional[bool] = True,
example_selector: Optional[BaseExampleSelector] = None,
label_selector: Optional[BaseLabelSelector] = None,
label_selector_map: Optional[Dict[str, BaseLabelSelector]] = {},
console_output: Optional[bool] = True,
generation_cache: Optional[BaseCache] = SQLAlchemyGenerationCache(),
transform_cache: Optional[BaseCache] = SQLAlchemyTransformCache(),
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(
)

self.example_selector = example_selector
self.label_selector = label_selector
self.label_selector_map = label_selector_map

if in_notebook():
import nest_asyncio
Expand Down Expand Up @@ -221,14 +221,22 @@ async def arun(
if (
self.config.label_selection()
and self.config.task_type() == TaskType.ATTRIBUTE_EXTRACTION
and not self.label_selector
and not self.label_selector_map
):
self.label_selector = LabelSelector(
config=self.config,
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(model=self.config.embedding_model_name()),
)
self.label_selector_map = {}
for attribute in self.config.attributes():
label_selection_count = attribute.get(
AutolabelConfig.LABEL_SELECTION_KEY
)
if label_selection_count:
label_selector = LabelSelector(
config=self.config,
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(),
DEFAULT_EMBEDDING_PROVIDER,
)(model=self.config.embedding_model_name()),
)
self.label_selector_map[attribute["name"]] = label_selector

current_index = 0
cost = 0.0
Expand All @@ -251,19 +259,21 @@ async def arun(
chunk = dataset.inputs[current_index]
examples = []

if self.label_selector:
if self.label_selector_map:
# Create toEmbed string using the example template from the config
example_template = self.config.example_template()
toEmbed = example_template.format_map(defaultdict(str, chunk))
selected_labels = self.label_selector.select_labels(toEmbed)
selected_labels_map = {
self.config.label_selection_attribute(): selected_labels
}
selected_labels_map = {}
for attribute in self.config.attributes():
attribute_name = attribute.get("name")
label_selector = self.label_selector_map.get(attribute_name)
if label_selector:
selected_labels = label_selector.select_labels(toEmbed)
selected_labels_map[attribute_name] = selected_labels
if self.example_selector:
examples = self.example_selector.select_examples(
safe_serialize_to_string(chunk),
selected_labels=selected_labels,
label_column=self.config.label_selection_attribute(),
selected_labels_map=selected_labels_map,
)
else:
if self.example_selector:
Expand All @@ -276,7 +286,7 @@ async def arun(
chunk,
examples,
selected_labels_map=selected_labels_map
if self.label_selector
if self.label_selector_map
else None,
max_input_tokens=self.llm.max_context_length,
get_num_tokens=self.llm.get_num_tokens,
Expand Down Expand Up @@ -310,7 +320,7 @@ async def arun(
chunk,
final_prompt,
selected_labels_map=selected_labels_map
if self.label_selector
if self.label_selector_map
else None,
)
annotation.input_tokens = input_tokens
Expand Down
8 changes: 4 additions & 4 deletions src/autolabel/task_chain/task_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
confidence_tokenizer: Optional[AutoTokenizer] = None,
confidence_endpoint: Optional[str] = None,
column_name_map: Optional[Dict[str, str]] = None,
label_selector: Optional[BaseLabelSelector] = None,
label_selector_map: Optional[BaseLabelSelector] = None,
):
self.task_chain_config = task_chain_config
self.cache = cache
Expand All @@ -125,7 +125,7 @@ def __init__(
self.confidence_tokenizer = confidence_tokenizer
self.confidence_endpoint = confidence_endpoint
self.column_name_map = column_name_map
self.label_selector = label_selector
self.label_selector_map = label_selector_map

# TODO: For now, we run each separate step of the task chain serially and aggregate at the end.
# We can optimize this with parallelization where possible/no dependencies.
Expand Down Expand Up @@ -155,7 +155,7 @@ async def run(self, dataset_df: pd.DataFrame):
confidence_tokenizer=self.confidence_tokenizer,
confidence_endpoint=self.confidence_endpoint,
console_output=False,
label_selector=self.label_selector,
label_selector_map=self.label_selector_map,
)
for transform_dict in autolabel_config.transforms():
transform = TransformFactory.from_dict(
Expand All @@ -174,7 +174,7 @@ async def run(self, dataset_df: pd.DataFrame):
confidence_tokenizer=self.confidence_tokenizer,
confidence_endpoint=self.confidence_endpoint,
console_output=False,
label_selector=self.label_selector,
label_selector_map=self.label_selector_map,
)
dataset = await agent.arun(
dataset,
Expand Down

0 comments on commit eb145aa

Please sign in to comment.