Skip to content

Commit

Permalink
fix+feat: cysql fixes - BED-5141, BED-5142, BED-5307, BED-5173, BED-5…
Browse files Browse the repository at this point in the history
…193, BED-5284, BED-5263 (#1112)
  • Loading branch information
zinic authored Feb 3, 2025
1 parent 22397aa commit 9847889
Show file tree
Hide file tree
Showing 36 changed files with 1,664 additions and 861 deletions.
41 changes: 30 additions & 11 deletions packages/go/cypher/frontend/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,44 +353,63 @@ func (s *SinglePartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingCl
type MultiPartQueryVisitor struct {
BaseVisitor

Query *cypher.MultiPartQuery
Query *cypher.MultiPartQuery
partIdx int
}

func NewMultiPartQueryVisitor() *MultiPartQueryVisitor {
return &MultiPartQueryVisitor{
Query: cypher.NewMultiPartQuery(),
Query: cypher.NewMultiPartQuery(),
partIdx: 0,
}
}

func (s *MultiPartQueryVisitor) EnterOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) {
// If the part index is equal to the length of parts then this signifies that a new query part
// is required. We do not advance the index here - this is done with the following `with`
// cypher AST component
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewReadingClauseVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) {
part := cypher.NewMultiPartQueryPart()
part.AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause)
s.Query.Parts = append(s.Query.Parts, part)
s.Query.CurrentPart().AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause)
}

func (s *MultiPartQueryVisitor) EnterOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) {
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewUpdatingClauseVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) {
// Make sure to mark that this multipart query part contains a mutation (non-read operation). This
// field is being set to make it easier for downstream consumers of the openCypher AST to identify
// if this query contains a mutation.
s.ctx.HasMutation = true
part := cypher.NewMultiPartQueryPart()
part.AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause)
s.Query.Parts = append(s.Query.Parts, part)

s.Query.CurrentPart().AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause)
}

func (s *MultiPartQueryVisitor) EnterOC_With(ctx *parser.OC_WithContext) {
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewWithVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_With(ctx *parser.OC_WithContext) {
part := cypher.NewMultiPartQueryPart()
part.With = s.ctx.Exit().(*WithVisitor).With
s.Query.Parts = append(s.Query.Parts, part)
s.Query.CurrentPart().With = s.ctx.Exit().(*WithVisitor).With

// Advance the part index so a new multipart query part gets allocated for the next reading
// or updating clause
s.partIdx += 1
}

func (s *MultiPartQueryVisitor) EnterOC_SinglePartQuery(ctx *parser.OC_SinglePartQueryContext) {
Expand Down
1 change: 1 addition & 0 deletions packages/go/cypher/models/cypher/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
ToIntegerFunction = "toint"
ListSizeFunction = "size"
CoalesceFunction = "coalesce"
CollectFunction = "collect"

// ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/)
ITTCYear = "year"
Expand Down
4 changes: 4 additions & 0 deletions packages/go/cypher/models/cypher/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ type MultiPartQuery struct {
SinglePartQuery *SinglePartQuery
}

func (s *MultiPartQuery) CurrentPart() *MultiPartQueryPart {
return s.Parts[len(s.Parts)-1]
}

func NewMultiPartQuery() *MultiPartQuery {
return &MultiPartQuery{}
}
Expand Down
5 changes: 4 additions & 1 deletion packages/go/cypher/models/pgsql/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
}

case pgsql.ExistsExpression:
exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Subquery, pgsql.FormattingLiteral("exists ("))
exprStack = append(exprStack, typedNextExpr.Subquery, pgsql.FormattingLiteral("exists "))

if typedNextExpr.Negated {
exprStack = append(exprStack, pgsql.FormattingLiteral("not "))
Expand All @@ -517,6 +517,9 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
}
}

case pgsql.Subquery:
exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Query, pgsql.FormattingLiteral("("))

default:
return fmt.Errorf("unable to format pgsql node type: %T", nextExpr)
}
Expand Down
10 changes: 10 additions & 0 deletions packages/go/cypher/models/pgsql/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ const (
FunctionEdgesToPath Identifier = "edges_to_path"
FunctionExtract Identifier = "extract"
)

func IsAggregateFunction(function Identifier) bool {
switch function {
case FunctionCount, FunctionArrayAggregate:
return true

default:
return false
}
}
4 changes: 4 additions & 0 deletions packages/go/cypher/models/pgsql/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func AsIdentifierSet(identifiers ...Identifier) *IdentifierSet {
return newSet
}

func (s *IdentifierSet) Clear() {
clear(s.identifiers)
}

func (s *IdentifierSet) Len() int {
return len(s.identifiers)
}
Expand Down
27 changes: 13 additions & 14 deletions packages/go/cypher/models/pgsql/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ type Subquery struct {
Query Query
}

func (s Subquery) NodeType() string {
return "subquery"
}

func (s Subquery) AsExpression() Expression {
return s
}

// not <expr>
type UnaryExpression struct {
Operator Expression
Expand Down Expand Up @@ -321,6 +329,10 @@ type Parenthetical struct {
Expression Expression
}

func (s Parenthetical) AsSelectItem() SelectItem {
return s
}

func (s Parenthetical) NodeType() string {
return "parenthetical"
}
Expand Down Expand Up @@ -1022,7 +1034,6 @@ func (s Select) NodeType() string {
// select 1
// union
// select 2;

type SetOperation struct {
Operator Operator
LOperand SetExpression
Expand All @@ -1047,7 +1058,7 @@ func (s SetOperation) NodeType() string {
//
// [not] exists(<query>)
type ExistsExpression struct {
Subquery Query
Subquery Subquery
Negated bool
}

Expand Down Expand Up @@ -1132,18 +1143,6 @@ func (s Query) NodeType() string {
return "query"
}

func BinaryExpressionJoinTyped(optional Expression, operator Operator, conjoined *BinaryExpression) *BinaryExpression {
if optional == nil {
return conjoined
}

return NewBinaryExpression(
conjoined,
operator,
optional,
)
}

func BinaryExpressionJoin(optional Expression, operator Operator, conjoined Expression) Expression {
if optional == nil {
return conjoined
Expand Down
1 change: 1 addition & 0 deletions packages/go/cypher/models/pgsql/pgtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ const (
TimestampWithoutTimeZone DataType = "timestamp without time zone"

Scope DataType = "scope"
InlineProjection DataType = "inline_projection"
ParameterIdentifier DataType = "parameter_identifier"
ExpansionPattern DataType = "expansion_pattern"
ExpansionPath DataType = "expansion_path"
Expand Down
Loading

0 comments on commit 9847889

Please sign in to comment.