Update C# codegen to consistent filtering rules (#1277)

* Update C# codegen to consistent filtering rules

- Limit types as per proposal.
- Add `Query` client-side SDK helper for API parity with server-side modules (on client-side it's a simple wrapper around Iter + Where).
- Change return type of `FilterBy` to always be iterable, with new `FindBy` function for unique fields.
- Simplify the way primary keys are handled - must go with https://github.com/clockworklabs/spacetimedb-csharp-sdk/pull/93 for the client SDK counterpart.

* Add using for System.Linq

* Update snapshot
This commit is contained in:
Ingvar Stepanyan
2024-05-28 15:22:12 +01:00
committed by GitHub
parent a7c75751f2
commit 307dfeebd3
2 changed files with 137 additions and 323 deletions
+45 -175
View File
@@ -570,7 +570,7 @@ fn autogen_csharp_product_table_common(
schema: Option<TableSchema>,
namespace: &str,
) -> String {
let mut output = CsharpAutogen::new(namespace, &["System.Collections.Generic"]);
let mut output = CsharpAutogen::new(namespace, &["System.Collections.Generic", "System.Linq"]);
writeln!(
output,
@@ -634,15 +634,10 @@ fn autogen_csharp_product_table_common(
continue;
}
let type_name = ty_fmt(ctx, &col.col_type, namespace);
let comparer = if format!("{type_name}") == "byte[]" {
", new SpacetimeDB.ByteArrayComparer()"
} else {
""
};
writeln!(
output,
"private static Dictionary<{type_name}, {name}> {field_name}_Index = new Dictionary<{type_name}, {name}>(16{comparer});"
);
output,
"private static Dictionary<{type_name}, {name}> {field_name}_Index = new(16);"
);
}
writeln!(output);
// OnInsert method for updating indexes
@@ -832,25 +827,30 @@ fn autogen_csharp_access_funcs_for_struct(
) -> bool {
let primary_col_idx = schema.pk();
writeln!(
output,
"public static System.Collections.Generic.IEnumerable<{struct_name_pascal_case}> Iter()"
);
writeln!(output, "public static IEnumerable<{struct_name_pascal_case}> Iter()");
indented_block(output, |output| {
writeln!(
output,
"foreach(var entry in SpacetimeDBClient.clientDB.GetEntries(\"{table_name}\"))",
"return SpacetimeDBClient.clientDB.GetObjects(\"{table_name}\").Cast<{struct_name_pascal_case}>();",
);
indented_block(output, |output| {
// TODO: best way to handle this?
writeln!(output, "yield return ({struct_name_pascal_case})entry.Item2;");
});
});
writeln!(output);
// Simple alias for Iter().Where(...) for API parity with C# server-side modules.
writeln!(
output,
"public static IEnumerable<{struct_name_pascal_case}> Query(Func<{struct_name_pascal_case}, bool> filter)"
);
indented_block(output, |output| {
writeln!(output, "return Iter().Where(filter);",);
});
writeln!(output);
writeln!(output, "public static int Count()");
indented_block(output, |output| {
writeln!(output, "return SpacetimeDBClient.clientDB.Count(\"{table_name}\");");
});
writeln!(output);
let constraints = schema.column_constraints();
for col in schema.columns() {
@@ -863,188 +863,58 @@ fn autogen_csharp_access_funcs_for_struct(
let field_type = &field.algebraic_type;
let csharp_field_name_pascal = field_name.replace("r#", "").to_case(Case::Pascal);
let (field_type, csharp_field_type, is_option) = match field_type {
let csharp_field_type = match field_type {
AlgebraicType::Product(product) => {
if product.is_identity() {
("Identity".into(), "SpacetimeDB.Identity".into(), false)
"SpacetimeDB.Identity"
} else if product.is_address() {
("Address".into(), "SpacetimeDB.Address".into(), false)
} else {
// TODO: We don't allow filtering on tuples right now,
// it's possible we may consider it for the future.
continue;
}
}
AlgebraicType::Sum(sum) => {
if let Some(Builtin(b)) = sum.as_option() {
match maybe_primitive(b) {
MaybePrimitive::Primitive(ty) => (format!("{b:?}"), format!("{ty}?"), true),
_ => {
continue;
}
}
"SpacetimeDB.Address"
} else {
continue;
}
}
AlgebraicType::Ref(_) => {
// TODO: We don't allow filtering on enums or tuples right now;
// it's possible we may consider it for the future.
continue;
}
AlgebraicType::Sum(_) | AlgebraicType::Ref(_) => continue,
AlgebraicType::Builtin(b) => match maybe_primitive(b) {
MaybePrimitive::Primitive(ty) => (format!("{b:?}"), ty.into(), false),
MaybePrimitive::Array(ArrayType { elem_ty }) => {
if let Some(BuiltinType::U8) = elem_ty.as_builtin() {
// Do allow filtering for byte arrays
("Bytes".into(), "byte[]".into(), false)
} else {
// TODO: We don't allow filtering based on an array type, but we might want other functionality here in the future.
continue;
}
}
MaybePrimitive::Map(_) => {
// TODO: It would be nice to be able to say, give me all entries where this vec contains this value, which we can do.
continue;
}
MaybePrimitive::Primitive(ty) => ty,
_ => continue,
},
};
let filter_return_type = fmt_fn(|f| {
if is_unique {
f.write_str(struct_name_pascal_case)
} else {
write!(f, "System.Collections.Generic.IEnumerable<{struct_name_pascal_case}>")
}
});
writeln!(
output,
"public static {filter_return_type} FilterBy{csharp_field_name_pascal}({csharp_field_type} value)"
);
indented_block(output, |output| {
if is_unique {
if is_unique {
writeln!(
output,
"public static {struct_name_pascal_case} FindBy{csharp_field_name_pascal}({csharp_field_type} value)"
);
indented_block(output, |output| {
writeln!(
output,
"{csharp_field_name_pascal}_Index.TryGetValue(value, out var r);"
);
writeln!(output, "return r;");
});
writeln!(output);
}
writeln!(
output,
"public static IEnumerable<{struct_name_pascal_case}> FilterBy{csharp_field_name_pascal}({csharp_field_type} value)"
);
indented_block(output, |output| {
if is_unique {
writeln!(output, "return new[] {{ FindBy{csharp_field_name_pascal}(value) }};");
} else {
writeln!(
output,
"foreach(var entry in SpacetimeDBClient.clientDB.GetEntries(\"{table_name}\"))"
);
indented_block(output, |output| {
writeln!(output, "var productValue = entry.Item1.AsProductValue();");
if field_type == "Identity" {
writeln!(
output,
"var compareValue = Identity.From(productValue.elements[{col_i}].AsProductValue().elements[0].AsBytes());"
);
} else if is_option {
writeln!(
output,
"var compareValue = ({csharp_field_type})(productValue.elements[{col_i}].AsSumValue().tag == 1 ? null : productValue.elements[{col_i}].AsSumValue().value.As{field_type}());"
);
} else if field_type == "Address" {
writeln!(
output,
"var compareValue = (Address)Address.From(productValue.elements[{col_i}].AsProductValue().elements[0].AsBytes());"
);
} else {
writeln!(
output,
"var compareValue = ({csharp_field_type})productValue.elements[{col_i}].As{field_type}();"
);
}
if csharp_field_type == "byte[]" {
writeln!(
output,
"static bool ByteArrayCompare(byte[] a1, byte[] a2)
{{
if (a1.Length != a2.Length)
return false;
for (int i=0; i<a1.Length; i++)
if (a1[i]!=a2[i])
return false;
return true;
}}"
);
writeln!(output);
writeln!(output, "if (ByteArrayCompare(compareValue, value))");
indented_block(output, |output| {
writeln!(output, "yield return ({struct_name_pascal_case})entry.Item2;");
});
} else {
writeln!(output, "if (compareValue == value)");
indented_block(output, |output| {
writeln!(output, "yield return ({struct_name_pascal_case})entry.Item2;");
});
}
});
writeln!(output, "return Query(x => x.{csharp_field_name_pascal} == value);");
}
});
writeln!(output);
}
if let Some(primary_col_index) = primary_col_idx {
if let Some(primary_col_index) = schema.pk() {
writeln!(
output,
"public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2)"
"private static object GetPrimaryKeyValue(object row) => (({struct_name_pascal_case})row).{col_name_pascal_case};",
col_name_pascal_case = primary_col_index.col_name.replace("r#", "").to_case(Case::Pascal)
);
indented_block(output, |output| {
writeln!(
output,
"var primaryColumnValue1 = v1.AsProductValue().elements[{}];",
primary_col_index.col_pos
);
writeln!(
output,
"var primaryColumnValue2 = v2.AsProductValue().elements[{}];",
primary_col_index.col_pos
);
writeln!(
output,
"return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2);"
);
});
writeln!(output);
writeln!(
output,
"public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v)"
);
indented_block(output, |output| {
writeln!(
output,
"return v.AsProductValue().elements[{}];",
primary_col_index.col_pos
);
});
writeln!(output);
writeln!(
output,
"public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t)"
);
indented_block(output, |output| {
writeln!(
output,
"return t.product.elements[{}].algebraicType;",
primary_col_index.col_pos
);
});
} else {
writeln!(
output,
"public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)"
);
indented_block(output, |output| {
writeln!(output, "return false;");
});
}
primary_col_idx.is_some()
@@ -302,6 +302,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -313,8 +314,8 @@ namespace SpacetimeDB
[Newtonsoft.Json.JsonProperty("other")]
public uint Other;
private static Dictionary<uint, PkMultiIdentity> Id_Index = new Dictionary<uint, PkMultiIdentity>(16);
private static Dictionary<uint, PkMultiIdentity> Other_Index = new Dictionary<uint, PkMultiIdentity>(16);
private static Dictionary<uint, PkMultiIdentity> Id_Index = new(16);
private static Dictionary<uint, PkMultiIdentity> Other_Index = new(16);
private static void InternalOnValueInserted(object insertedValue)
{
@@ -350,45 +351,44 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<PkMultiIdentity> Iter()
public static IEnumerable<PkMultiIdentity> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("PkMultiIdentity"))
{
yield return (PkMultiIdentity)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("PkMultiIdentity").Cast<PkMultiIdentity>();
}
public static IEnumerable<PkMultiIdentity> Query(Func<PkMultiIdentity, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("PkMultiIdentity");
}
public static PkMultiIdentity FilterById(uint value)
public static PkMultiIdentity FindById(uint value)
{
Id_Index.TryGetValue(value, out var r);
return r;
}
public static PkMultiIdentity FilterByOther(uint value)
public static IEnumerable<PkMultiIdentity> FilterById(uint value)
{
return new[] { FindById(value) };
}
public static PkMultiIdentity FindByOther(uint value)
{
Other_Index.TryGetValue(value, out var r);
return r;
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2)
public static IEnumerable<PkMultiIdentity> FilterByOther(uint value)
{
var primaryColumnValue1 = v1.AsProductValue().elements[0];
var primaryColumnValue2 = v2.AsProductValue().elements[0];
return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2);
return new[] { FindByOther(value) };
}
public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v)
{
return v.AsProductValue().elements[0];
}
public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t)
{
return t.product.elements[0].algebraicType;
}
private static object GetPrimaryKeyValue(object row) => ((PkMultiIdentity)row).Id;
public delegate void InsertEventHandler(PkMultiIdentity insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void UpdateEventHandler(PkMultiIdentity oldValue, PkMultiIdentity newValue, SpacetimeDB.ReducerEvent dbEvent);
@@ -426,6 +426,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -468,47 +469,31 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<Point> Iter()
public static IEnumerable<Point> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point"))
{
yield return (Point)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("Point").Cast<Point>();
}
public static IEnumerable<Point> Query(Func<Point, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("Point");
}
public static System.Collections.Generic.IEnumerable<Point> FilterByX(long value)
public static IEnumerable<Point> FilterByX(long value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (long)productValue.elements[0].AsI64();
if (compareValue == value)
{
yield return (Point)entry.Item2;
}
}
return Query(x => x.X == value);
}
public static System.Collections.Generic.IEnumerable<Point> FilterByY(long value)
public static IEnumerable<Point> FilterByY(long value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (long)productValue.elements[1].AsI64();
if (compareValue == value)
{
yield return (Point)entry.Item2;
}
}
return Query(x => x.Y == value);
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)
{
return false;
}
public delegate void InsertEventHandler(Point insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void DeleteEventHandler(Point deletedValue, SpacetimeDB.ReducerEvent dbEvent);
@@ -769,6 +754,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -815,60 +801,36 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<TestA> Iter()
public static IEnumerable<TestA> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA"))
{
yield return (TestA)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("TestA").Cast<TestA>();
}
public static IEnumerable<TestA> Query(Func<TestA, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("TestA");
}
public static System.Collections.Generic.IEnumerable<TestA> FilterByX(uint value)
public static IEnumerable<TestA> FilterByX(uint value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (uint)productValue.elements[0].AsU32();
if (compareValue == value)
{
yield return (TestA)entry.Item2;
}
}
return Query(x => x.X == value);
}
public static System.Collections.Generic.IEnumerable<TestA> FilterByY(uint value)
public static IEnumerable<TestA> FilterByY(uint value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (uint)productValue.elements[1].AsU32();
if (compareValue == value)
{
yield return (TestA)entry.Item2;
}
}
return Query(x => x.Y == value);
}
public static System.Collections.Generic.IEnumerable<TestA> FilterByZ(string value)
public static IEnumerable<TestA> FilterByZ(string value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (string)productValue.elements[2].AsString();
if (compareValue == value)
{
yield return (TestA)entry.Item2;
}
}
return Query(x => x.Z == value);
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)
{
return false;
}
public delegate void InsertEventHandler(TestA insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void DeleteEventHandler(TestA deletedValue, SpacetimeDB.ReducerEvent dbEvent);
@@ -899,6 +861,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -935,6 +898,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -980,21 +944,21 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<TestD> Iter()
public static IEnumerable<TestD> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestD"))
{
yield return (TestD)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("TestD").Cast<TestD>();
}
public static IEnumerable<TestD> Query(Func<TestD, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("TestD");
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)
{
return false;
}
public delegate void InsertEventHandler(TestD insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void DeleteEventHandler(TestD deletedValue, SpacetimeDB.ReducerEvent dbEvent);
@@ -1025,6 +989,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -1036,7 +1001,7 @@ namespace SpacetimeDB
[Newtonsoft.Json.JsonProperty("name")]
public string Name;
private static Dictionary<ulong, TestE> Id_Index = new Dictionary<ulong, TestE>(16);
private static Dictionary<ulong, TestE> Id_Index = new(16);
private static void InternalOnValueInserted(object insertedValue)
{
@@ -1070,52 +1035,38 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<TestE> Iter()
public static IEnumerable<TestE> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestE"))
{
yield return (TestE)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("TestE").Cast<TestE>();
}
public static IEnumerable<TestE> Query(Func<TestE, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("TestE");
}
public static TestE FilterById(ulong value)
public static TestE FindById(ulong value)
{
Id_Index.TryGetValue(value, out var r);
return r;
}
public static System.Collections.Generic.IEnumerable<TestE> FilterByName(string value)
public static IEnumerable<TestE> FilterById(ulong value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestE"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (string)productValue.elements[1].AsString();
if (compareValue == value)
{
yield return (TestE)entry.Item2;
}
}
return new[] { FindById(value) };
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2)
public static IEnumerable<TestE> FilterByName(string value)
{
var primaryColumnValue1 = v1.AsProductValue().elements[0];
var primaryColumnValue2 = v2.AsProductValue().elements[0];
return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2);
return Query(x => x.Name == value);
}
public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v)
{
return v.AsProductValue().elements[0];
}
public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t)
{
return t.product.elements[0].algebraicType;
}
private static object GetPrimaryKeyValue(object row) => ((TestE)row).Id;
public delegate void InsertEventHandler(TestE insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void UpdateEventHandler(TestE oldValue, TestE newValue, SpacetimeDB.ReducerEvent dbEvent);
@@ -1223,6 +1174,7 @@ namespace SpacetimeDB
using System;
using System.Collections.Generic;
using System.Linq;
namespace SpacetimeDB
{
@@ -1261,34 +1213,26 @@ namespace SpacetimeDB
};
}
public static System.Collections.Generic.IEnumerable<_Private> Iter()
public static IEnumerable<_Private> Iter()
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("_Private"))
{
yield return (_Private)entry.Item2;
}
return SpacetimeDBClient.clientDB.GetObjects("_Private").Cast<_Private>();
}
public static IEnumerable<_Private> Query(Func<_Private, bool> filter)
{
return Iter().Where(filter);
}
public static int Count()
{
return SpacetimeDBClient.clientDB.Count("_Private");
}
public static System.Collections.Generic.IEnumerable<_Private> FilterByName(string value)
public static IEnumerable<_Private> FilterByName(string value)
{
foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("_Private"))
{
var productValue = entry.Item1.AsProductValue();
var compareValue = (string)productValue.elements[0].AsString();
if (compareValue == value)
{
yield return (_Private)entry.Item2;
}
}
return Query(x => x.Name == value);
}
public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)
{
return false;
}
public delegate void InsertEventHandler(_Private insertedValue, SpacetimeDB.ReducerEvent dbEvent);
public delegate void DeleteEventHandler(_Private deletedValue, SpacetimeDB.ReducerEvent dbEvent);