Registering a Custom Type¶
This notebook shows how to register a new type so that it can be used as the return annotation for @prompt
, @promptchain
, and @chatprompt
. This is done by creating a new FunctionSchema
which defines the parameters required to create the type, and how to parse/serialize these from/to the LLM.
See https://platform.openai.com/docs/guides/function-calling for more information on function calling, which enables this.
In [1]:
Copied!
# Create FunctionSchema for pd.DataFrame
import json
from collections.abc import Iterable
from typing import Any
import pandas as pd
from magentic.chat_model.function_schema import FunctionSchema, register_function_schema
@register_function_schema(pd.DataFrame)
class DataFrameFunctionSchema(FunctionSchema[pd.DataFrame]):
@property
def name(self) -> str:
"""The name of the function.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
"""
return "dataframe"
@property
def description(self) -> str | None:
return "A DataFrame object."
@property
def parameters(self) -> dict[str, Any]:
"""The parameters the functions accepts as a JSON Schema object."""
return {
"properties": {
"columns": {"type": "array", "items": {"type": "string"}},
"data": {
"type": "array",
"items": {"type": "array", "items": {"type": "string"}},
},
},
"required": ["index", "columns", "data"],
"type": "object",
}
def parse_args(self, chunks: Iterable[str]) -> pd.DataFrame:
"""Parse an iterable of string chunks into the function arguments."""
args = json.loads("".join(chunks))
return pd.DataFrame(**args)
def serialize_args(self, value: pd.DataFrame) -> dict:
"""Serialize an object into a JSON string of function arguments."""
return {
"columns": value.columns.tolist(),
"data": value.to_numpy().tolist(),
}
# Create FunctionSchema for pd.DataFrame
import json
from collections.abc import Iterable
from typing import Any
import pandas as pd
from magentic.chat_model.function_schema import FunctionSchema, register_function_schema
@register_function_schema(pd.DataFrame)
class DataFrameFunctionSchema(FunctionSchema[pd.DataFrame]):
@property
def name(self) -> str:
"""The name of the function.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64.
"""
return "dataframe"
@property
def description(self) -> str | None:
return "A DataFrame object."
@property
def parameters(self) -> dict[str, Any]:
"""The parameters the functions accepts as a JSON Schema object."""
return {
"properties": {
"columns": {"type": "array", "items": {"type": "string"}},
"data": {
"type": "array",
"items": {"type": "array", "items": {"type": "string"}},
},
},
"required": ["index", "columns", "data"],
"type": "object",
}
def parse_args(self, chunks: Iterable[str]) -> pd.DataFrame:
"""Parse an iterable of string chunks into the function arguments."""
args = json.loads("".join(chunks))
return pd.DataFrame(**args)
def serialize_args(self, value: pd.DataFrame) -> dict:
"""Serialize an object into a JSON string of function arguments."""
return {
"columns": value.columns.tolist(),
"data": value.to_numpy().tolist(),
}
In [2]:
Copied!
# Roundtrip test the new FunctionSchema
function_schema = DataFrameFunctionSchema(pd.DataFrame)
df_test = pd.DataFrame(
{
"A": [1, 2, 3],
"B": [4, 5, 6],
},
)
args = function_schema.serialize_args(df_test)
print(args)
obj = function_schema.parse_args(json.dumps(args))
obj
# Roundtrip test the new FunctionSchema
function_schema = DataFrameFunctionSchema(pd.DataFrame)
df_test = pd.DataFrame(
{
"A": [1, 2, 3],
"B": [4, 5, 6],
},
)
args = function_schema.serialize_args(df_test)
print(args)
obj = function_schema.parse_args(json.dumps(args))
obj
{'columns': ['A', 'B'], 'data': [[1, 4], [2, 5], [3, 6]]}
Out[2]:
A | B | |
---|---|---|
0 | 1 | 4 |
1 | 2 | 5 |
2 | 3 | 6 |
In [3]:
Copied!
# Use pd.DataFrame as the return type of a prompt function
import pandas as pd
from magentic import prompt
@prompt(
"Create a table listing the ingredients needed to cook {dish}."
"Include a column for the quantity of each ingredient."
"Also include a column with alergy information."
)
def list_ingredients(dish: str) -> pd.DataFrame: ...
list_ingredients("lasagna")
# Use pd.DataFrame as the return type of a prompt function
import pandas as pd
from magentic import prompt
@prompt(
"Create a table listing the ingredients needed to cook {dish}."
"Include a column for the quantity of each ingredient."
"Also include a column with alergy information."
)
def list_ingredients(dish: str) -> pd.DataFrame: ...
list_ingredients("lasagna")
Out[3]:
Ingredient | Quantity | Allergy Information | |
---|---|---|---|
0 | Lasagna noodles | 16 oz | Contains wheat, may contain egg and soy |
1 | Ground beef | 1 lb | Contains beef, may contain soy and gluten |
2 | Tomato sauce | 24 oz | Contains tomatoes, may contain soy and garlic |
3 | Mozzarella cheese | 16 oz | Contains milk, may contain soy |
4 | Ricotta cheese | 15 oz | Contains milk, may contain soy and eggs |
5 | Parmesan cheese | 1 cup | Contains milk, may contain soy and eggs |
6 | Garlic | 3 cloves | No known allergies |
7 | Onion | 1 | No known allergies |
8 | Olive oil | 2 tbsp | No known allergies |
9 | Salt | 1 tsp | No known allergies |
10 | Pepper | 1/2 tsp | No known allergies |
11 | Italian seasoning | 1 tsp | No known allergies |
12 | Sugar | 1 tsp | No known allergies |