diff --git a/builder/buffer.go b/builder/buffer.go index e807a62..53138d3 100644 --- a/builder/buffer.go +++ b/builder/buffer.go @@ -120,6 +120,19 @@ func (b *Buffer) WriteEscape(value string) { b.WriteString(b.escape("", value)) } +func (b Buffer) escapeSchema(table string) string { + if b.AllowTableSchema && strings.IndexByte(table, '.') >= 0 { + parts := strings.Split(table, ".") + for i, part := range parts { + part = strings.TrimSpace(part) + parts[i] = b.Quoter.ID(part) + } + return strings.Join(parts, ".") + } else { + return b.Quoter.ID(strings.ReplaceAll(strings.TrimSpace(table), ".", "_")) + } +} + func (b Buffer) escape(table, value string) string { if table == "" && value == "*" { return value @@ -134,17 +147,9 @@ func (b Buffer) escape(table, value string) string { var escaped_table string if table != "" { if i := strings.Index(strings.ToLower(table), " as "); i > -1 { - return b.escape(table[:i], "") + " AS " + b.Quoter.ID(table[i+4:]) - } - if b.AllowTableSchema && strings.IndexByte(table, '.') >= 0 { - parts := strings.Split(table, ".") - for i, part := range parts { - part = strings.TrimSpace(part) - parts[i] = b.Quoter.ID(part) - } - escaped_table = strings.Join(parts, ".") + escaped_table = b.escapeSchema(table[:i]) + " AS " + b.Quoter.ID(strings.TrimSpace(table[i+4:])) } else { - escaped_table = b.Quoter.ID(strings.ReplaceAll(table, ".", "_")) + escaped_table = b.escapeSchema(table) } } diff --git a/builder/query_test.go b/builder/query_test.go index 3c33b18..2dfb45b 100644 --- a/builder/query_test.go +++ b/builder/query_test.go @@ -105,6 +105,10 @@ func TestQuery_Build(t *testing.T) { result: "SELECT `users`.* FROM `users` FOR UPDATE;", query: rel.From("users").Lock("FOR UPDATE"), }, + { + result: "SELECT `c`.`id`,`c`.`name` FROM `contacts` AS `c`;", + query: rel.Select("c.id", "c.name").From("contacts as c"), + }, } for _, test := range tests { @@ -185,6 +189,10 @@ func TestQuery_Build_ordinal(t *testing.T) { result: "SELECT \"users\".* FROM \"users\" FOR UPDATE;", query: rel.From("users").Lock("FOR UPDATE"), }, + { + result: "SELECT \"c\".\"id\",\"c\".\"name\" FROM \"contacts\" AS \"c\";", + query: rel.Select("c.id", "c.name").From("contacts as c"), + }, } for _, test := range tests {