-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* New module for the integration with instructor * Mode common functions related to structured outputs to it's own module * Draft instructor integration with openai * Add tests for openai integration * Add unit tests for the instructor integrations * Add tests for anthropic integration * Fix including anthropic wrapper * Update llms to deal with instructor * Update dependencies with instructor * Run tests with instructor only on python>=3.9 * Fix circular import with create_distiset * Define _prepare_structured_output as staticmethod * Remove rewritten variable * Remove dead code * Check on Enum.value instead of Enum class as it isn't pickleable * Add tests for utilities related to generation of BaseModel objects from json schema dicts * Add fix to deal with nested BaseModel objects * Fix call from instructor, this should be done on instructor end, but works for the moment * Add docstirngs and typing info * Add script to generate a sample dataset and visualize the result * Update the docstring of the structured output expected format * Add reference in the docs to structured outputs with instructor * Add reference to the dependency installation * Update typing info * Fix test with new mocked client for mistral * Update docs/sections/learn/advanced/structured_generation.md Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update docs/sections/learn/advanced/structured_generation.md Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/instructor.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/instructor.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/utils.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/instructor.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/utils.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/utils.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Update src/distilabel/steps/tasks/structured_outputs/utils.py Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> * Add changes from code review * Fix type hint per code review * Update docs/sections/learn/advanced/structured_generation.md Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> * Remove repeated line --------- Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com> Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
- Loading branch information
1 parent
7e9230b
commit 01b4292
Showing
28 changed files
with
1,472 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
from typing import Any, Dict, List, Union | ||
|
||
from graphviz import Digraph | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class Node(BaseModel): | ||
id: int | ||
label: str | ||
color: str | ||
|
||
|
||
class Edge(BaseModel): | ||
source: int | ||
target: int | ||
label: str | ||
color: str = "black" | ||
|
||
|
||
class KnowledgeGraph(BaseModel): | ||
nodes: List[Node] = Field(..., default_factory=list) | ||
edges: List[Edge] = Field(..., default_factory=list) | ||
|
||
|
||
def visualize_knowledge_graph(kg: KnowledgeGraph): | ||
dot = Digraph(comment="Knowledge Graph") | ||
|
||
# Add nodes | ||
for node in kg.nodes: | ||
dot.node(str(node.id), node.label, color=node.color) | ||
|
||
# Add edges | ||
for edge in kg.edges: | ||
dot.edge( | ||
str(edge.source), | ||
str(edge.target), | ||
label=edge.label, | ||
color=edge.color or "black", | ||
) | ||
|
||
# Render the graph | ||
dot.render("knowledge_graph.gv", view=True) | ||
|
||
|
||
def create_knowledge_graph(data: str) -> Union[KnowledgeGraph, None]: | ||
data: Dict[str, Any] = json.loads(data) | ||
|
||
nodes = [Node(**node) for node in data["nodes"]] | ||
edges = [] | ||
for edge in data["edges"]: | ||
if edge.get("color") is None: | ||
edge["color"] = "black" | ||
edges.append(Edge(**edge)) | ||
|
||
return KnowledgeGraph(nodes=nodes, edges=edges) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
args = sys.argv[1:] | ||
|
||
from datasets import load_dataset | ||
|
||
ds = load_dataset("distilabel-internal-testing/knowledge_graphs", split="train") | ||
graphs = [create_knowledge_graph(g) for g in ds["generation"]] | ||
visualize_knowledge_graph(graphs[int(args[0])]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List | ||
|
||
from distilabel.llms import MistralLLM | ||
from distilabel.pipeline import Pipeline | ||
from distilabel.steps import LoadDataFromDicts | ||
from distilabel.steps.tasks import TextGeneration | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class Node(BaseModel): | ||
id: int | ||
label: str | ||
color: str | ||
|
||
|
||
class Edge(BaseModel): | ||
source: int | ||
target: int | ||
label: str | ||
color: str = "black" | ||
|
||
|
||
class KnowledgeGraph(BaseModel): | ||
nodes: List[Node] = Field(..., default_factory=list) | ||
edges: List[Edge] = Field(..., default_factory=list) | ||
|
||
|
||
with Pipeline( | ||
name="Knowledge-Graphs", | ||
description=( | ||
"Generate knowledge graphs to answer questions, this type of dataset can be used to " | ||
"steer a model to answer questions with a knowledge graph." | ||
), | ||
) as pipeline: | ||
sample_questions = [ | ||
"Teach me about quantum mechanics", | ||
"Who is who in The Simpsons family?", | ||
"Tell me about the evolution of programming languages", | ||
] | ||
|
||
load_dataset = LoadDataFromDicts( | ||
name="load_instructions", | ||
data=[ | ||
{ | ||
"system_prompt": "You are a knowledge graph expert generator. Help me understand by describing everything as a detailed knowledge graph.", | ||
"instruction": f"{question}", | ||
} | ||
for question in sample_questions | ||
], | ||
) | ||
|
||
text_generation = TextGeneration( | ||
name="knowledge_graph_generation", | ||
llm=MistralLLM( | ||
model="open-mixtral-8x22b", structured_output={"schema": KnowledgeGraph} | ||
), | ||
input_batch_size=8, | ||
output_mappings={"model_name": "generation_model"}, | ||
) | ||
load_dataset >> text_generation | ||
|
||
|
||
if __name__ == "__main__": | ||
distiset = pipeline.run( | ||
parameters={ | ||
text_generation.name: { | ||
"llm": {"generation_kwargs": {"max_new_tokens": 2048}} | ||
} | ||
}, | ||
use_cache=False, | ||
) | ||
|
||
distiset.push_to_hub("distilabel-internal-testing/knowledge_graphs") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.