diff --git a/crates/cli/src/subcommands/generate/csharp.rs b/crates/cli/src/subcommands/generate/csharp.rs index 4ae655345e..efaf2349ea 100644 --- a/crates/cli/src/subcommands/generate/csharp.rs +++ b/crates/cli/src/subcommands/generate/csharp.rs @@ -570,7 +570,7 @@ fn autogen_csharp_product_table_common( schema: Option, namespace: &str, ) -> String { - let mut output = CsharpAutogen::new(namespace, &["System.Collections.Generic"]); + let mut output = CsharpAutogen::new(namespace, &["System.Collections.Generic", "System.Linq"]); writeln!( output, @@ -634,15 +634,10 @@ fn autogen_csharp_product_table_common( continue; } let type_name = ty_fmt(ctx, &col.col_type, namespace); - let comparer = if format!("{type_name}") == "byte[]" { - ", new SpacetimeDB.ByteArrayComparer()" - } else { - "" - }; writeln!( - output, - "private static Dictionary<{type_name}, {name}> {field_name}_Index = new Dictionary<{type_name}, {name}>(16{comparer});" - ); + output, + "private static Dictionary<{type_name}, {name}> {field_name}_Index = new(16);" + ); } writeln!(output); // OnInsert method for updating indexes @@ -832,25 +827,30 @@ fn autogen_csharp_access_funcs_for_struct( ) -> bool { let primary_col_idx = schema.pk(); - writeln!( - output, - "public static System.Collections.Generic.IEnumerable<{struct_name_pascal_case}> Iter()" - ); + writeln!(output, "public static IEnumerable<{struct_name_pascal_case}> Iter()"); indented_block(output, |output| { writeln!( output, - "foreach(var entry in SpacetimeDBClient.clientDB.GetEntries(\"{table_name}\"))", + "return SpacetimeDBClient.clientDB.GetObjects(\"{table_name}\").Cast<{struct_name_pascal_case}>();", ); - indented_block(output, |output| { - // TODO: best way to handle this? - writeln!(output, "yield return ({struct_name_pascal_case})entry.Item2;"); - }); }); + writeln!(output); + + // Simple alias for Iter().Where(...) for API parity with C# server-side modules. + writeln!( + output, + "public static IEnumerable<{struct_name_pascal_case}> Query(Func<{struct_name_pascal_case}, bool> filter)" + ); + indented_block(output, |output| { + writeln!(output, "return Iter().Where(filter);",); + }); + writeln!(output); writeln!(output, "public static int Count()"); indented_block(output, |output| { writeln!(output, "return SpacetimeDBClient.clientDB.Count(\"{table_name}\");"); }); + writeln!(output); let constraints = schema.column_constraints(); for col in schema.columns() { @@ -863,188 +863,58 @@ fn autogen_csharp_access_funcs_for_struct( let field_type = &field.algebraic_type; let csharp_field_name_pascal = field_name.replace("r#", "").to_case(Case::Pascal); - let (field_type, csharp_field_type, is_option) = match field_type { + let csharp_field_type = match field_type { AlgebraicType::Product(product) => { if product.is_identity() { - ("Identity".into(), "SpacetimeDB.Identity".into(), false) + "SpacetimeDB.Identity" } else if product.is_address() { - ("Address".into(), "SpacetimeDB.Address".into(), false) - } else { - // TODO: We don't allow filtering on tuples right now, - // it's possible we may consider it for the future. - continue; - } - } - AlgebraicType::Sum(sum) => { - if let Some(Builtin(b)) = sum.as_option() { - match maybe_primitive(b) { - MaybePrimitive::Primitive(ty) => (format!("{b:?}"), format!("{ty}?"), true), - _ => { - continue; - } - } + "SpacetimeDB.Address" } else { continue; } } - AlgebraicType::Ref(_) => { - // TODO: We don't allow filtering on enums or tuples right now; - // it's possible we may consider it for the future. - continue; - } + AlgebraicType::Sum(_) | AlgebraicType::Ref(_) => continue, AlgebraicType::Builtin(b) => match maybe_primitive(b) { - MaybePrimitive::Primitive(ty) => (format!("{b:?}"), ty.into(), false), - MaybePrimitive::Array(ArrayType { elem_ty }) => { - if let Some(BuiltinType::U8) = elem_ty.as_builtin() { - // Do allow filtering for byte arrays - ("Bytes".into(), "byte[]".into(), false) - } else { - // TODO: We don't allow filtering based on an array type, but we might want other functionality here in the future. - continue; - } - } - MaybePrimitive::Map(_) => { - // TODO: It would be nice to be able to say, give me all entries where this vec contains this value, which we can do. - continue; - } + MaybePrimitive::Primitive(ty) => ty, + _ => continue, }, }; - let filter_return_type = fmt_fn(|f| { - if is_unique { - f.write_str(struct_name_pascal_case) - } else { - write!(f, "System.Collections.Generic.IEnumerable<{struct_name_pascal_case}>") - } - }); - - writeln!( - output, - "public static {filter_return_type} FilterBy{csharp_field_name_pascal}({csharp_field_type} value)" - ); - - indented_block(output, |output| { - if is_unique { + if is_unique { + writeln!( + output, + "public static {struct_name_pascal_case} FindBy{csharp_field_name_pascal}({csharp_field_type} value)" + ); + indented_block(output, |output| { writeln!( output, "{csharp_field_name_pascal}_Index.TryGetValue(value, out var r);" ); writeln!(output, "return r;"); + }); + writeln!(output); + } + + writeln!( + output, + "public static IEnumerable<{struct_name_pascal_case}> FilterBy{csharp_field_name_pascal}({csharp_field_type} value)" + ); + indented_block(output, |output| { + if is_unique { + writeln!(output, "return new[] {{ FindBy{csharp_field_name_pascal}(value) }};"); } else { - writeln!( - output, - "foreach(var entry in SpacetimeDBClient.clientDB.GetEntries(\"{table_name}\"))" - ); - indented_block(output, |output| { - writeln!(output, "var productValue = entry.Item1.AsProductValue();"); - if field_type == "Identity" { - writeln!( - output, - "var compareValue = Identity.From(productValue.elements[{col_i}].AsProductValue().elements[0].AsBytes());" - ); - } else if is_option { - writeln!( - output, - "var compareValue = ({csharp_field_type})(productValue.elements[{col_i}].AsSumValue().tag == 1 ? null : productValue.elements[{col_i}].AsSumValue().value.As{field_type}());" - ); - } else if field_type == "Address" { - writeln!( - output, - "var compareValue = (Address)Address.From(productValue.elements[{col_i}].AsProductValue().elements[0].AsBytes());" - ); - } else { - writeln!( - output, - "var compareValue = ({csharp_field_type})productValue.elements[{col_i}].As{field_type}();" - ); - } - if csharp_field_type == "byte[]" { - writeln!( - output, - "static bool ByteArrayCompare(byte[] a1, byte[] a2) - {{ - if (a1.Length != a2.Length) - return false; - - for (int i=0; i x.{csharp_field_name_pascal} == value);"); } }); writeln!(output); } - if let Some(primary_col_index) = primary_col_idx { + if let Some(primary_col_index) = schema.pk() { writeln!( output, - "public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2)" + "private static object GetPrimaryKeyValue(object row) => (({struct_name_pascal_case})row).{col_name_pascal_case};", + col_name_pascal_case = primary_col_index.col_name.replace("r#", "").to_case(Case::Pascal) ); - indented_block(output, |output| { - writeln!( - output, - "var primaryColumnValue1 = v1.AsProductValue().elements[{}];", - primary_col_index.col_pos - ); - writeln!( - output, - "var primaryColumnValue2 = v2.AsProductValue().elements[{}];", - primary_col_index.col_pos - ); - writeln!( - output, - "return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2);" - ); - }); - writeln!(output); - - writeln!( - output, - "public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v)" - ); - indented_block(output, |output| { - writeln!( - output, - "return v.AsProductValue().elements[{}];", - primary_col_index.col_pos - ); - }); - writeln!(output); - - writeln!( - output, - "public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t)" - ); - indented_block(output, |output| { - writeln!( - output, - "return t.product.elements[{}].algebraicType;", - primary_col_index.col_pos - ); - }); - } else { - writeln!( - output, - "public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2)" - ); - indented_block(output, |output| { - writeln!(output, "return false;"); - }); } primary_col_idx.is_some() diff --git a/crates/cli/tests/snapshots/codegen__codegen_csharp.snap b/crates/cli/tests/snapshots/codegen__codegen_csharp.snap index b60e0ac399..cc2d82e8b3 100644 --- a/crates/cli/tests/snapshots/codegen__codegen_csharp.snap +++ b/crates/cli/tests/snapshots/codegen__codegen_csharp.snap @@ -302,6 +302,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -313,8 +314,8 @@ namespace SpacetimeDB [Newtonsoft.Json.JsonProperty("other")] public uint Other; - private static Dictionary Id_Index = new Dictionary(16); - private static Dictionary Other_Index = new Dictionary(16); + private static Dictionary Id_Index = new(16); + private static Dictionary Other_Index = new(16); private static void InternalOnValueInserted(object insertedValue) { @@ -350,45 +351,44 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable Iter() + public static IEnumerable Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("PkMultiIdentity")) - { - yield return (PkMultiIdentity)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("PkMultiIdentity").Cast(); } + + public static IEnumerable Query(Func filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("PkMultiIdentity"); } - public static PkMultiIdentity FilterById(uint value) + + public static PkMultiIdentity FindById(uint value) { Id_Index.TryGetValue(value, out var r); return r; } - public static PkMultiIdentity FilterByOther(uint value) + public static IEnumerable FilterById(uint value) + { + return new[] { FindById(value) }; + } + + public static PkMultiIdentity FindByOther(uint value) { Other_Index.TryGetValue(value, out var r); return r; } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2) + public static IEnumerable FilterByOther(uint value) { - var primaryColumnValue1 = v1.AsProductValue().elements[0]; - var primaryColumnValue2 = v2.AsProductValue().elements[0]; - return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2); + return new[] { FindByOther(value) }; } - public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v) - { - return v.AsProductValue().elements[0]; - } - - public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t) - { - return t.product.elements[0].algebraicType; - } + private static object GetPrimaryKeyValue(object row) => ((PkMultiIdentity)row).Id; public delegate void InsertEventHandler(PkMultiIdentity insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void UpdateEventHandler(PkMultiIdentity oldValue, PkMultiIdentity newValue, SpacetimeDB.ReducerEvent dbEvent); @@ -426,6 +426,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -468,47 +469,31 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable Iter() + public static IEnumerable Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point")) - { - yield return (Point)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("Point").Cast(); } + + public static IEnumerable Query(Func filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("Point"); } - public static System.Collections.Generic.IEnumerable FilterByX(long value) + + public static IEnumerable FilterByX(long value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (long)productValue.elements[0].AsI64(); - if (compareValue == value) - { - yield return (Point)entry.Item2; - } - } + return Query(x => x.X == value); } - public static System.Collections.Generic.IEnumerable FilterByY(long value) + public static IEnumerable FilterByY(long value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("Point")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (long)productValue.elements[1].AsI64(); - if (compareValue == value) - { - yield return (Point)entry.Item2; - } - } + return Query(x => x.Y == value); } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2) - { - return false; - } public delegate void InsertEventHandler(Point insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void DeleteEventHandler(Point deletedValue, SpacetimeDB.ReducerEvent dbEvent); @@ -769,6 +754,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -815,60 +801,36 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable Iter() + public static IEnumerable Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA")) - { - yield return (TestA)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("TestA").Cast(); } + + public static IEnumerable Query(Func filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("TestA"); } - public static System.Collections.Generic.IEnumerable FilterByX(uint value) + + public static IEnumerable FilterByX(uint value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (uint)productValue.elements[0].AsU32(); - if (compareValue == value) - { - yield return (TestA)entry.Item2; - } - } + return Query(x => x.X == value); } - public static System.Collections.Generic.IEnumerable FilterByY(uint value) + public static IEnumerable FilterByY(uint value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (uint)productValue.elements[1].AsU32(); - if (compareValue == value) - { - yield return (TestA)entry.Item2; - } - } + return Query(x => x.Y == value); } - public static System.Collections.Generic.IEnumerable FilterByZ(string value) + public static IEnumerable FilterByZ(string value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestA")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (string)productValue.elements[2].AsString(); - if (compareValue == value) - { - yield return (TestA)entry.Item2; - } - } + return Query(x => x.Z == value); } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2) - { - return false; - } public delegate void InsertEventHandler(TestA insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void DeleteEventHandler(TestA deletedValue, SpacetimeDB.ReducerEvent dbEvent); @@ -899,6 +861,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -935,6 +898,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -980,21 +944,21 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable Iter() + public static IEnumerable Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestD")) - { - yield return (TestD)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("TestD").Cast(); } + + public static IEnumerable Query(Func filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("TestD"); } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2) - { - return false; - } + public delegate void InsertEventHandler(TestD insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void DeleteEventHandler(TestD deletedValue, SpacetimeDB.ReducerEvent dbEvent); @@ -1025,6 +989,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -1036,7 +1001,7 @@ namespace SpacetimeDB [Newtonsoft.Json.JsonProperty("name")] public string Name; - private static Dictionary Id_Index = new Dictionary(16); + private static Dictionary Id_Index = new(16); private static void InternalOnValueInserted(object insertedValue) { @@ -1070,52 +1035,38 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable Iter() + public static IEnumerable Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestE")) - { - yield return (TestE)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("TestE").Cast(); } + + public static IEnumerable Query(Func filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("TestE"); } - public static TestE FilterById(ulong value) + + public static TestE FindById(ulong value) { Id_Index.TryGetValue(value, out var r); return r; } - public static System.Collections.Generic.IEnumerable FilterByName(string value) + public static IEnumerable FilterById(ulong value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("TestE")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (string)productValue.elements[1].AsString(); - if (compareValue == value) - { - yield return (TestE)entry.Item2; - } - } + return new[] { FindById(value) }; } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue v1, SpacetimeDB.SATS.AlgebraicValue v2) + public static IEnumerable FilterByName(string value) { - var primaryColumnValue1 = v1.AsProductValue().elements[0]; - var primaryColumnValue2 = v2.AsProductValue().elements[0]; - return SpacetimeDB.SATS.AlgebraicValue.Compare(t.product.elements[0].algebraicType, primaryColumnValue1, primaryColumnValue2); + return Query(x => x.Name == value); } - public static SpacetimeDB.SATS.AlgebraicValue GetPrimaryKeyValue(SpacetimeDB.SATS.AlgebraicValue v) - { - return v.AsProductValue().elements[0]; - } - - public static SpacetimeDB.SATS.AlgebraicType GetPrimaryKeyType(SpacetimeDB.SATS.AlgebraicType t) - { - return t.product.elements[0].algebraicType; - } + private static object GetPrimaryKeyValue(object row) => ((TestE)row).Id; public delegate void InsertEventHandler(TestE insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void UpdateEventHandler(TestE oldValue, TestE newValue, SpacetimeDB.ReducerEvent dbEvent); @@ -1223,6 +1174,7 @@ namespace SpacetimeDB using System; using System.Collections.Generic; +using System.Linq; namespace SpacetimeDB { @@ -1261,34 +1213,26 @@ namespace SpacetimeDB }; } - public static System.Collections.Generic.IEnumerable<_Private> Iter() + public static IEnumerable<_Private> Iter() { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("_Private")) - { - yield return (_Private)entry.Item2; - } + return SpacetimeDBClient.clientDB.GetObjects("_Private").Cast<_Private>(); } + + public static IEnumerable<_Private> Query(Func<_Private, bool> filter) + { + return Iter().Where(filter); + } + public static int Count() { return SpacetimeDBClient.clientDB.Count("_Private"); } - public static System.Collections.Generic.IEnumerable<_Private> FilterByName(string value) + + public static IEnumerable<_Private> FilterByName(string value) { - foreach(var entry in SpacetimeDBClient.clientDB.GetEntries("_Private")) - { - var productValue = entry.Item1.AsProductValue(); - var compareValue = (string)productValue.elements[0].AsString(); - if (compareValue == value) - { - yield return (_Private)entry.Item2; - } - } + return Query(x => x.Name == value); } - public static bool ComparePrimaryKey(SpacetimeDB.SATS.AlgebraicType t, SpacetimeDB.SATS.AlgebraicValue _v1, SpacetimeDB.SATS.AlgebraicValue _v2) - { - return false; - } public delegate void InsertEventHandler(_Private insertedValue, SpacetimeDB.ReducerEvent dbEvent); public delegate void DeleteEventHandler(_Private deletedValue, SpacetimeDB.ReducerEvent dbEvent);