diff --git a/docs/modules/components/pages/outputs/snowflake_streaming.adoc b/docs/modules/components/pages/outputs/snowflake_streaming.adoc new file mode 100644 index 0000000000..b4d8a83397 --- /dev/null +++ b/docs/modules/components/pages/outputs/snowflake_streaming.adoc @@ -0,0 +1,357 @@ += snowflake_streaming +:type: output +:status: experimental +:categories: ["Services"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Ingest data into Snowflake using Snowpipe Streaming. + +Introduced in version 4.39.0. + + +[tabs] +====== +Common:: ++ +-- + +```yml +# Common config fields, showing default values +output: + label: "" + snowflake_streaming: + account: AAAAAAA-AAAAAAA # No default (required) + user: "" # No default (required) + role: ACCOUNTADMIN # No default (required) + database: "" # No default (required) + schema: "" # No default (required) + table: "" # No default (required) + private_key: "" # No default (optional) + private_key_file: "" # No default (optional) + private_key_pass: "" # No default (optional) + mapping: "" # No default (optional) + batching: + count: 0 + byte_size: 0 + period: "" + check: "" + max_in_flight: 64 +``` + +-- +Advanced:: ++ +-- + +```yml +# All config fields, showing default values +output: + label: "" + snowflake_streaming: + account: AAAAAAA-AAAAAAA # No default (required) + user: "" # No default (required) + role: ACCOUNTADMIN # No default (required) + database: "" # No default (required) + schema: "" # No default (required) + table: "" # No default (required) + private_key: "" # No default (optional) + private_key_file: "" # No default (optional) + private_key_pass: "" # No default (optional) + mapping: "" # No default (optional) + batching: + count: 0 + byte_size: 0 + period: "" + check: "" + processors: [] # No default (optional) + max_in_flight: 64 + channel_prefix: "" # No default (optional) +``` + +-- +====== + +Ingest data into Snowflake using Snowpipe Streaming. + +[%header,format=dsv] +|=== +Snowflake column type:Allowed format in Redpanda Connect +CHAR, VARCHAR:string +BINARY:[]byte +NUMBER:any numeric type, string +FLOAT:any numeric type +BOOLEAN:bool,any numeric type,string parsable according to `strconv.ParseBool` +TIME,DATE,TIMESTAMP:unix or RFC 3339 with nanoseconds timestamps +VARIANT,ARRAY,OBJECT:any data type is converted into JSON +GEOGRAPHY,GEOMETRY: Not supported +|=== + +For TIMESTAMP, TIME and DATE columns, you can parse different string formats using a bloblang `mapping`. + +Authentication can be configured using a https://docs.snowflake.com/en/user-guide/key-pair-auth[RSA Key Pair^]. + +There are https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#limitations[limitations^] of what data types can be loaded into Snowflake using this method. + + +== Performance + +This output benefits from sending multiple messages in flight in parallel for improved performance. You can tune the max number of in flight messages (or message batches) with the field `max_in_flight`. + +This output benefits from sending messages as a batch for improved performance. Batches can be formed at both the input and output level. You can find out more xref:configuration:batching.adoc[in this doc]. + +It is recommended that each batches results in at least 16MiB of compressed output being written to Snowflake. +You can monitor the output batch size using the `snowflake_compressed_output_size_bytes` metric. + + +== Fields + +=== `account` + +Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. + However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], + the Account Identifier is formatted as `..` and this field needs to be + populated using the `` part. + + +*Type*: `string` + + +```yml +# Examples + +account: AAAAAAA-AAAAAAA +``` + +=== `user` + +The user to run the Snowpipe Stream as. See https://docs.snowflake.com/en/user-guide/admin-user-management[Snowflake Documentation^] on how to create a user. + + +*Type*: `string` + + +=== `role` + +The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles. + + +*Type*: `string` + + +```yml +# Examples + +role: ACCOUNTADMIN +``` + +=== `database` + +The Snowflake database to ingest data into. + + +*Type*: `string` + + +=== `schema` + +The Snowflake schema to ingest data into. + + +*Type*: `string` + + +=== `table` + +The Snowflake table to ingest data into. + + +*Type*: `string` + + +=== `private_key` + +The PEM encoded private RSA key to use for authenticating with Snowflake. Either this or `private_key_file` must be specified. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + + +=== `private_key_file` + +The file to load the private RSA key from. This should be a `.p8` PEM encoded file. Either this or `private_key` must be specified. + + +*Type*: `string` + + +=== `private_key_pass` + +The RSA key passphrase if the RSA key is encrypted. +[CAUTION] +==== +This field contains sensitive information that usually shouldn't be added to a config directly, read our xref:configuration:secrets.adoc[secrets page for more info]. +==== + + + +*Type*: `string` + + +=== `mapping` + +A bloblang mapping to execute on each message. + + +*Type*: `string` + + +=== `batching` + +Allows you to configure a xref:configuration:batching.adoc[batching policy]. + + +*Type*: `object` + + +```yml +# Examples + +batching: + byte_size: 5000 + count: 0 + period: 1s + +batching: + count: 10 + period: 1s + +batching: + check: this.contains("END BATCH") + count: 0 + period: 1m +``` + +=== `batching.count` + +A number of messages at which the batch should be flushed. If `0` disables count based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.byte_size` + +An amount of bytes at which the batch should be flushed. If `0` disables size based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.period` + +A period in which an incomplete batch should be flushed regardless of its size. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +period: 1s + +period: 1m + +period: 500ms +``` + +=== `batching.check` + +A xref:guides:bloblang/about.adoc[Bloblang query] that should return a boolean value indicating whether a message should end a batch. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +check: this.type == "end_of_transaction" +``` + +=== `batching.processors` + +A list of xref:components:processors/about.adoc[processors] to apply to a batch as it is flushed. This allows you to aggregate and archive the batch however you see fit. Please note that all resulting messages are flushed as a single batch, therefore splitting the batch into smaller batches using these processors is a no-op. + + +*Type*: `array` + + +```yml +# Examples + +processors: + - archive: + format: concatenate + +processors: + - archive: + format: lines + +processors: + - archive: + format: json_array +``` + +=== `max_in_flight` + +The maximum number of messages to have in flight at a given time. Increase this to improve throughput. + + +*Type*: `int` + +*Default*: `64` + +=== `channel_prefix` + +The prefix to use when creating a channel name. +Duplicate channel names will result in errors and prevent multiple instances of Redpanda Connect from writing at the same time. +By default this will create a channel name that is based on the table FQN so there will only be a single stream per table. + +At most `max_in_flight` channels will be opened. + +NOTE: There is a limit of 10,000 streams per table - if using more than 10k streams please reach out to Snowflake support. + + +*Type*: `string` + + + diff --git a/go.mod b/go.mod index 6d6659c240..7ed5c978b1 100644 --- a/go.mod +++ b/go.mod @@ -65,6 +65,7 @@ require ( github.com/go-sql-driver/mysql v1.8.1 github.com/gocql/gocql v1.6.0 github.com/gofrs/uuid v4.4.0+incompatible + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gosimple/slug v1.14.0 github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c @@ -133,11 +134,11 @@ require ( go.opentelemetry.io/otel/sdk v1.28.0 go.opentelemetry.io/otel/trace v1.28.0 go.uber.org/multierr v1.11.0 - golang.org/x/crypto v0.26.0 - golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa - golang.org/x/net v0.28.0 + golang.org/x/crypto v0.28.0 + golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c + golang.org/x/net v0.30.0 golang.org/x/sync v0.8.0 - golang.org/x/text v0.17.0 + golang.org/x/text v0.19.0 google.golang.org/api v0.188.0 google.golang.org/protobuf v1.34.2 modernc.org/sqlite v1.32.0 @@ -178,10 +179,10 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect - github.com/andybalholm/brotli v1.1.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect - github.com/apache/thrift v0.18.1 // indirect + github.com/apache/thrift v0.21.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/ardielle/ardielle-go v1.5.2 // indirect github.com/armon/go-metrics v0.3.4 // indirect @@ -200,7 +201,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.15 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/smithy-go v1.20.4 // indirect + github.com/aws/smithy-go v1.20.4 github.com/aymerick/douceur v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.4.0 // indirect @@ -247,7 +248,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect - github.com/goccy/go-json v0.10.2 // indirect + github.com/goccy/go-json v0.10.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect @@ -259,7 +260,7 @@ require ( github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/google/uuid v1.6.0 github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect github.com/gorilla/css v1.0.1 // indirect @@ -297,8 +298,8 @@ require ( github.com/jcmturner/rpc/v2 v2.0.3 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect - github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/kr/fs v0.1.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect @@ -346,7 +347,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/segmentio/asm v1.2.0 // indirect - github.com/segmentio/encoding v0.4.0 // indirect + github.com/segmentio/encoding v0.4.0 github.com/segmentio/ksuid v1.0.4 // indirect github.com/shirou/gopsutil/v3 v3.24.2 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect @@ -374,13 +375,13 @@ require ( go.opentelemetry.io/proto/otlp v1.3.1 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/mod v0.20.0 // indirect - golang.org/x/oauth2 v0.22.0 // indirect - golang.org/x/sys v0.24.0 // indirect - golang.org/x/term v0.23.0 // indirect + golang.org/x/mod v0.21.0 // indirect + golang.org/x/oauth2 v0.22.0 + golang.org/x/sys v0.26.0 // indirect + golang.org/x/term v0.25.0 // indirect golang.org/x/time v0.6.0 // indirect - golang.org/x/tools v0.24.0 // indirect - golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + golang.org/x/tools v0.26.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto v0.0.0-20240708141625-4ad9e859172b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect diff --git a/go.sum b/go.sum index b7316a5711..48e9ab178b 100644 --- a/go.sum +++ b/go.sum @@ -163,8 +163,8 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= -github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= @@ -175,8 +175,8 @@ github.com/apache/pulsar-client-go v0.13.1 h1:XAAKXjF99du7LP6qu/nBII1HC2nS483/vQ github.com/apache/pulsar-client-go v0.13.1/go.mod h1:0X5UCs+Cv5w6Ds38EZebUMfyVUFIh+URF2BeipEVhIU= github.com/apache/thrift v0.0.0-20181112125854-24918abba929/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.14.2/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= -github.com/apache/thrift v0.18.1 h1:lNhK/1nqjbwbiOPDBPFJVKxgDEGSepKuTh6OLiXW8kg= -github.com/apache/thrift v0.18.1/go.mod h1:rdQn/dCcDKEWjjylUeueum4vQEjG2v8v2PqriUnbr+I= +github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= +github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/ardielle/ardielle-go v1.5.2 h1:TilHTpHIQJ27R1Tl/iITBzMwiUGSlVfiVhwDNGM3Zj4= @@ -497,8 +497,8 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -776,10 +776,10 @@ github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0 github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= -github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= -github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -1187,6 +1187,8 @@ github.com/xitongsys/parquet-go-source v0.0.0-20211228015320-b4f792c43cd0 h1:ti/ github.com/xitongsys/parquet-go-source v0.0.0-20211228015320-b4f792c43cd0/go.mod h1:qLb2Itmdcp7KPa5KZKvhE9U1q5bYSOmgeOckF/H2rQA= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= @@ -1289,8 +1291,8 @@ golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1305,8 +1307,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= +golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -1337,8 +1339,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= -golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1376,8 +1378,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1458,8 +1460,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -1468,8 +1470,8 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= -golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1485,8 +1487,8 @@ golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1537,16 +1539,16 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= -golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= +golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= diff --git a/internal/impl/snowflake/output_snowflake_put.go b/internal/impl/snowflake/output_snowflake_put.go index 8a6587250f..d4c5716cd9 100644 --- a/internal/impl/snowflake/output_snowflake_put.go +++ b/internal/impl/snowflake/output_snowflake_put.go @@ -417,10 +417,17 @@ func init() { //------------------------------------------------------------------------------ +func wipeSlice(b []byte) { + for i := range b { + b[i] = '~' + } +} + // getPrivateKeyFromFile reads and parses the private key // Inspired from https://github.com/chanzuckerberg/terraform-provider-snowflake/blob/c07d5820bea7ac3d8a5037b0486c405fdf58420e/pkg/provider/provider.go#L367 func getPrivateKeyFromFile(f fs.FS, path, passphrase string) (*rsa.PrivateKey, error) { privateKeyBytes, err := service.ReadFile(f, path) + defer wipeSlice(privateKeyBytes) if err != nil { return nil, fmt.Errorf("failed to read private key %s: %s", path, err) } diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go new file mode 100644 index 0000000000..6787ea3592 --- /dev/null +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -0,0 +1,328 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + +package snowflake + +import ( + "context" + "crypto/rsa" + "fmt" + "strings" + "sync" + + "github.com/redpanda-data/benthos/v4/public/bloblang" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming" +) + +const ( + ssoFieldAccount = "account" + ssoFieldUser = "user" + ssoFieldRole = "role" + ssoFieldDB = "database" + ssoFieldSchema = "schema" + ssoFieldTable = "table" + ssoFieldKey = "private_key" + ssoFieldKeyFile = "private_key_file" + ssoFieldKeyPass = "private_key_pass" + ssoFieldBatching = "batching" + ssoFieldChannelPrefix = "channel_prefix" + ssoFieldMapping = "mapping" +) + +func snowflakeStreamingOutputConfig() *service.ConfigSpec { + return service.NewConfigSpec(). + Categories("Services"). + Version("4.39.0"). + Summary("Ingest data into Snowflake using Snowpipe Streaming."). + Description(` +Ingest data into Snowflake using Snowpipe Streaming. + +[%header,format=dsv] +|=== +Snowflake column type:Allowed format in Redpanda Connect +CHAR, VARCHAR:string +BINARY:[]byte +NUMBER:any numeric type, string +FLOAT:any numeric type +BOOLEAN:bool,any numeric type,string parsable according to `+"`strconv.ParseBool`"+` +TIME,DATE,TIMESTAMP:unix or RFC 3339 with nanoseconds timestamps +VARIANT,ARRAY,OBJECT:any data type is converted into JSON +GEOGRAPHY,GEOMETRY: Not supported +|=== + +For TIMESTAMP, TIME and DATE columns, you can parse different string formats using a bloblang `+"`"+ssoFieldMapping+"`"+`. + +Authentication can be configured using a https://docs.snowflake.com/en/user-guide/key-pair-auth[RSA Key Pair^]. + +There are https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#limitations[limitations^] of what data types can be loaded into Snowflake using this method. +`+service.OutputPerformanceDocs(true, true)+` + +It is recommended that each batches results in at least 16MiB of compressed output being written to Snowflake. +You can monitor the output batch size using the `+"`snowflake_compressed_output_size_bytes`"+` metric. +`). + Fields( + service.NewStringField(ssoFieldAccount). + Description(`Account name, which is the same as the https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#where-are-account-identifiers-used[Account Identifier^]. + However, when using an https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#using-an-account-locator-as-an-identifier[Account Locator^], + the Account Identifier is formatted as `+"`..`"+` and this field needs to be + populated using the `+"``"+` part. +`).Example("AAAAAAA-AAAAAAA"), + service.NewStringField(ssoFieldUser).Description("The user to run the Snowpipe Stream as. See https://docs.snowflake.com/en/user-guide/admin-user-management[Snowflake Documentation^] on how to create a user."), + service.NewStringField(ssoFieldRole).Description("The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles.").Example("ACCOUNTADMIN"), + service.NewStringField(ssoFieldDB).Description("The Snowflake database to ingest data into."), + service.NewStringField(ssoFieldSchema).Description("The Snowflake schema to ingest data into."), + service.NewStringField(ssoFieldTable).Description("The Snowflake table to ingest data into."), + service.NewStringField(ssoFieldKey).Description("The PEM encoded private RSA key to use for authenticating with Snowflake. Either this or `private_key_file` must be specified.").Optional().Secret(), + service.NewStringField(ssoFieldKeyFile).Description("The file to load the private RSA key from. This should be a `.p8` PEM encoded file. Either this or `private_key` must be specified.").Optional(), + service.NewStringField(ssoFieldKeyPass).Description("The RSA key passphrase if the RSA key is encrypted.").Optional().Secret(), + service.NewBloblangField(ssoFieldMapping).Description("A bloblang mapping to execute on each message.").Optional(), + service.NewBatchPolicyField(ssoFieldBatching), + service.NewOutputMaxInFlightField(), + service.NewStringField(ssoFieldChannelPrefix). + Description(`The prefix to use when creating a channel name. +Duplicate channel names will result in errors and prevent multiple instances of Redpanda Connect from writing at the same time. +By default this will create a channel name that is based on the table FQN so there will only be a single stream per table. + +At most `+"`max_in_flight`"+` channels will be opened. + +NOTE: There is a limit of 10,000 streams per table - if using more than 10k streams please reach out to Snowflake support.`). + Optional(). + Advanced(), + ).LintRule(`root = match { + this.exists("private_key") && this.exists("private_key_file") => [ "both ` + "`private_key`" + ` and ` + "`private_key_file`" + ` can't be set simultaneously" ], +}`) +} + +func init() { + err := service.RegisterBatchOutput( + "snowflake_streaming", + snowflakeStreamingOutputConfig(), + func(conf *service.ParsedConfig, mgr *service.Resources) ( + output service.BatchOutput, + batchPolicy service.BatchPolicy, + maxInFlight int, + err error, + ) { + if maxInFlight, err = conf.FieldMaxInFlight(); err != nil { + return + } + if batchPolicy, err = conf.FieldBatchPolicy(ssoFieldBatching); err != nil { + return + } + output, err = newSnowflakeStreamer(conf, mgr) + return + }) + if err != nil { + panic(err) + } +} + +func newSnowflakeStreamer( + conf *service.ParsedConfig, + mgr *service.Resources, +) (service.BatchOutput, error) { + keypass := "" + if conf.Contains(ssoFieldKeyPass) { + pass, err := conf.FieldString(ssoFieldKey) + if err != nil { + return nil, err + } + keypass = pass + } + var rsaKey *rsa.PrivateKey + if conf.Contains(ssoFieldKey) { + key, err := conf.FieldString(ssoFieldKey) + if err != nil { + return nil, err + } + rsaKey, err = getPrivateKey([]byte(key), keypass) + if err != nil { + return nil, err + } + } else if conf.Contains(ssoFieldKeyFile) { + keyFile, err := conf.FieldString(ssoFieldKeyFile) + if err != nil { + return nil, err + } + rsaKey, err = getPrivateKeyFromFile(mgr.FS(), keyFile, keypass) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("one of `%s` or `%s` is required", ssoFieldKey, ssoFieldKeyFile) + } + account, err := conf.FieldString(ssoFieldAccount) + if err != nil { + return nil, err + } + user, err := conf.FieldString(ssoFieldUser) + if err != nil { + return nil, err + } + role, err := conf.FieldString(ssoFieldRole) + if err != nil { + return nil, err + } + db, err := conf.FieldString(ssoFieldDB) + if err != nil { + return nil, err + } + schema, err := conf.FieldString(ssoFieldSchema) + if err != nil { + return nil, err + } + table, err := conf.FieldString(ssoFieldTable) + if err != nil { + return nil, err + } + var mapping *bloblang.Executor + if conf.Contains(ssoFieldMapping) { + mapping, err = conf.FieldBloblang(ssoFieldMapping) + if err != nil { + return nil, err + } + } + var channelPrefix string + if conf.Contains(ssoFieldChannelPrefix) { + channelPrefix, err = conf.FieldString(ssoFieldChannelPrefix) + if err != nil { + return nil, err + } + } else { + // There is a limit of 10k channels, so we can't dynamically create them. + // The only other good default is to create one and only allow a single + // stream to write to a single table. + channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table) + } + client, err := streaming.NewSnowflakeServiceClient( + context.Background(), + streaming.ClientOptions{ + Account: account, + User: user, + Role: role, + PrivateKey: rsaKey, + Logger: mgr.Logger(), + ConnectVersion: mgr.EngineVersion(), + Application: strings.TrimPrefix(channelPrefix, "Redpanda_Connect_"), + }) + if err != nil { + return nil, err + } + o := &snowflakeStreamerOutput{ + channelPrefix: channelPrefix, + client: client, + db: db, + schema: schema, + table: table, + mapping: mapping, + logger: mgr.Logger(), + buildTime: mgr.Metrics().NewTimer("snowflake_build_output_latency_ns"), + uploadTime: mgr.Metrics().NewTimer("snowflake_upload_latency_ns"), + compressedOutput: mgr.Metrics().NewCounter("snowflake_compressed_output_size_bytes"), + } + return o, nil +} + +type snowflakeStreamerOutput struct { + client *streaming.SnowflakeServiceClient + channelPool sync.Pool + channelCreationMu sync.Mutex + poolSize int + compressedOutput *service.MetricCounter + uploadTime *service.MetricTimer + buildTime *service.MetricTimer + + channelPrefix, db, schema, table string + mapping *bloblang.Executor + logger *service.Logger +} + +func (o *snowflakeStreamerOutput) openNewChannel(ctx context.Context) (*streaming.SnowflakeIngestionChannel, error) { + // Use a lock here instead of an atomic because this should not be called at steady state and it's better to limit + // creating extra channels when there is a limit of 10K. + o.channelCreationMu.Lock() + defer o.channelCreationMu.Unlock() + name := fmt.Sprintf("%s_%d", o.channelPrefix, o.poolSize) + client, err := o.openChannel(ctx, name, int16(o.poolSize)) + if err == nil { + o.poolSize++ + } + return client, err +} + +func (o *snowflakeStreamerOutput) openChannel(ctx context.Context, name string, id int16) (*streaming.SnowflakeIngestionChannel, error) { + o.logger.Debugf("opening snowflake streaming channel: %s", name) + return o.client.OpenChannel(ctx, streaming.ChannelOptions{ + ID: id, + Name: name, + DatabaseName: o.db, + SchemaName: o.schema, + TableName: o.table, + }) +} + +func (o *snowflakeStreamerOutput) Connect(ctx context.Context) error { + // Precreate a single channel so we know stuff works, otherwise we'll create them on demand. + c, err := o.openNewChannel(ctx) + if err != nil { + return fmt.Errorf("unable to open snowflake streaming channel: %w", err) + } + o.channelPool.Put(c) + return nil +} + +func (o *snowflakeStreamerOutput) WriteBatch(ctx context.Context, batch service.MessageBatch) error { + if o.mapping != nil { + mapped := make(service.MessageBatch, len(batch)) + exec := batch.BloblangExecutor(o.mapping) + for i := range batch { + msg, err := exec.Query(i) + if err != nil { + return fmt.Errorf("error executing %s: %w", ssoFieldMapping, err) + } + mapped[i] = msg + } + batch = mapped + } + var channel *streaming.SnowflakeIngestionChannel + if maybeChan := o.channelPool.Get(); maybeChan != nil { + channel = maybeChan.(*streaming.SnowflakeIngestionChannel) + } else { + var err error + if channel, err = o.openNewChannel(ctx); err != nil { + return fmt.Errorf("unable to open snowflake streaming channel: %w", err) + } + } + stats, err := channel.InsertRows(ctx, batch) + o.compressedOutput.Incr(int64(stats.CompressedOutputSize)) + o.uploadTime.Timing(stats.UploadTime.Nanoseconds()) + o.buildTime.Timing(stats.BuildTime.Nanoseconds()) + // If there is some kind of failure, try to reopen the channel + if err != nil { + reopened, reopenErr := o.openChannel(ctx, channel.Name, channel.ID) + if reopenErr == nil { + o.channelPool.Put(reopened) + } else { + o.logger.Warnf("unable to reopen channel %q after failure: %v", channel.Name, reopenErr) + // Keep around the same channel just in case so we don't keep creating new channels. + o.channelPool.Put(channel) + } + return err + } + polls, err := channel.WaitUntilCommitted(ctx) + if err == nil { + o.logger.Tracef("batch committed in snowflake after %d polls", polls) + } + o.channelPool.Put(channel) + return err +} + +func (o *snowflakeStreamerOutput) Close(ctx context.Context) error { + return o.client.Close() +} diff --git a/internal/impl/snowflake/streaming/.gitignore b/internal/impl/snowflake/streaming/.gitignore new file mode 100644 index 0000000000..4bed5da93f --- /dev/null +++ b/internal/impl/snowflake/streaming/.gitignore @@ -0,0 +1 @@ +*.parquet diff --git a/internal/impl/snowflake/streaming/README.md b/internal/impl/snowflake/streaming/README.md new file mode 100644 index 0000000000..f3631df5fd --- /dev/null +++ b/internal/impl/snowflake/streaming/README.md @@ -0,0 +1,15 @@ +# Snowflake Integration SDK for Redpanda Connect + + +### Testing + +To enable integration tests, you need to follow the instructions here to generate a public/private key for snowflake: https://docs.snowflake.com/en/user-guide/key-pair-auth + +Run the `openssl` commands from that guide in the `resources` directory to generate the correct keys for the integration test (the test requires the private key is unencrypted), then run the following: + +``` +SNOWFLAKE_USER=XXX \ + SNOWFLAKE_ACCOUNT=alskjd-asdaks \ + SNOWFLAKE_DB=xxx \ + go test -v . +``` diff --git a/internal/impl/snowflake/streaming/compat.go b/internal/impl/snowflake/streaming/compat.go new file mode 100644 index 0000000000..facbb9e735 --- /dev/null +++ b/internal/impl/snowflake/streaming/compat.go @@ -0,0 +1,175 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" +) + +var ( + pow10TableInt32 []int32 + pow10TableInt64 []int64 +) + +func init() { + { + pow10TableInt64 = make([]int64, 19) + n := int64(1) + pow10TableInt64[0] = n + for i := range pow10TableInt64[1:] { + n = 10 * n + pow10TableInt64[i+1] = n + } + } + { + pow10TableInt32 = make([]int32, 19) + n := int32(1) + pow10TableInt32[0] = n + for i := range pow10TableInt32[1:] { + n = 10 * n + pow10TableInt32[i+1] = n + } + } +} + +func deriveKey(encryptionKey, diversifier string) ([]byte, error) { + decodedKey, err := base64.StdEncoding.DecodeString(encryptionKey) + if err != nil { + return nil, err + } + hash := sha256.New() + hash.Write(decodedKey) + hash.Write([]byte(diversifier)) + return hash.Sum(nil)[:], nil +} + +// See Encyptor.encrypt in the Java SDK +func encrypt(buf []byte, encryptionKey string, diversifier string, iv int64) ([]byte, error) { + // Derive the key from the diversifier and the original encryptionKey from server + key, err := deriveKey(encryptionKey, diversifier) + if err != nil { + return nil, err + } + // Using our derived key and padded input, encrypt the thing. + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + // Create our cypher using the iv + ivBytes := make([]byte, aes.BlockSize) + binary.BigEndian.PutUint64(ivBytes[8:], uint64(iv)) + stream := cipher.NewCTR(block, ivBytes) + // Actually do the encryption in place + stream.XORKeyStream(buf, buf) + return buf, nil +} + +func padBuffer(buf []byte, alignmentSize int) []byte { + padding := alignmentSize - len(buf)%alignmentSize + return append(buf, make([]byte, padding)...) +} + +func md5Hash(b []byte) string { + s := md5.Sum(b) + return hex.EncodeToString(s[:]) +} + +// Generate the path for a blob when uploading to an internal snowflake table. +// +// Never change, this must exactly match the java SDK, don't think you can be fancy and change something. +func generateBlobPath(clientPrefix string, threadID, counter int) string { + now := time.Now().UTC() + year := now.Year() + month := int(now.Month()) + day := now.Day() + hour := now.Hour() + minute := now.Minute() + blobShortName := fmt.Sprintf("%s_%s_%d_%d.bdec", strconv.FormatInt(now.Unix(), 36), clientPrefix, threadID, counter) + return fmt.Sprintf("%d/%d/%d/%d/%d/%s", year, month, day, hour, minute, blobShortName) +} + +// truncateBytesAsHex truncates an array of bytes up to 32 bytes and optionally increment the last byte(s). +// More the one byte can be incremented in case it overflows. +// +// NOTE: This can mutate `bytes` +func truncateBytesAsHex(bytes []byte, truncateUp bool) string { + const maxLobLen int = 32 + if len(bytes) <= maxLobLen { + return hex.EncodeToString(bytes) + } + if truncateUp { + var i int + for i = maxLobLen - 1; i >= 0; i-- { + bytes[i]++ + if bytes[i] != 0 { + break + } + } + if i < 0 { + return "Z" + } + } + return hex.EncodeToString(bytes[:maxLobLen]) +} + +// normalizeColumnName normalizes the column to the same as Snowflake's +// internal representation. See LiteralQuoteUtils.unquoteColumnName in +// the Java SDK for reference, although that code is quite hard to read. +func normalizeColumnName(name string) string { + if strings.HasPrefix(name, `"`) && strings.HasSuffix(name, `"`) { + unquoted := name[1 : len(name)-1] + noDoubleQuotes := strings.ReplaceAll(unquoted, `""`, ``) + if !strings.ContainsRune(noDoubleQuotes, '"') { + return strings.ReplaceAll(unquoted, `""`, `"`) + } + if !strings.ContainsRune(unquoted, '"') { + return unquoted + } + // fallthrough + } + return strings.ToUpper(strings.ReplaceAll(name, `\ `, ` `)) +} + +// snowflakeTimestampInt computes the same result as the logic in TimestampWrapper +// in the Java SDK. It converts a timestamp to the integer representation that +// is used internally within Snowflake. +func snowflakeTimestampInt(t time.Time, scale int32, includeTZ bool) int128.Num { + epoch := int128.FromInt64(t.Unix()) + // this calculation is intentionally done at low resolution to truncate the nanoseconds + // according to our scale. + fraction := (int32(t.Nanosecond()) / pow10TableInt32[9-scale]) * pow10TableInt32[9-scale] + timeInNanos := int128.Add( + int128.Mul(epoch, int128.Pow10Table[9]), + int128.FromInt64(int64(fraction)), + ) + scaledTime := int128.Div(timeInNanos, int128.Pow10Table[9-scale]) + if includeTZ { + _, tzOffsetSec := t.Zone() + offsetMinutes := tzOffsetSec / 60 + offsetMinutes += 1440 + scaledTime = int128.Shl(scaledTime, 14) + const tzMask = (1 << 14) - 1 + scaledTime = int128.Add(scaledTime, int128.FromInt64(int64(offsetMinutes&tzMask))) + } + return scaledTime +} diff --git a/internal/impl/snowflake/streaming/compat_test.go b/internal/impl/snowflake/streaming/compat_test.go new file mode 100644 index 0000000000..e5310fda77 --- /dev/null +++ b/internal/impl/snowflake/streaming/compat_test.go @@ -0,0 +1,204 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "crypto/aes" + "encoding/base64" + "encoding/hex" + "slices" + "testing" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" + "github.com/stretchr/testify/require" +) + +func TestEncryption(t *testing.T) { + data := []byte("testEncryptionDecryption") + key := base64.StdEncoding.EncodeToString([]byte("encryption_key")) + diversifier := "2021/08/10/blob.bdec" + actual, err := encrypt(data, key, diversifier, 0) + require.NoError(t, err) + // this value was obtained from the Cryptor unit tests in the Java SDK + expected := []byte{133, 80, 92, 68, 33, 84, 54, 127, 139, 26, 89, 42, 80, 118, 6, 27, 56, 48, 149, 113, 118, 62, 50, 158} + require.Equal(t, expected, actual) +} + +func mustHexDecode(s string) []byte { + decoded, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return decoded +} + +func TestTruncateBytesAsHex(t *testing.T) { + // Test empty input + require.Equal(t, "", truncateBytesAsHex([]byte{}, false)) + require.Equal(t, "", truncateBytesAsHex([]byte{}, true)) + + // Test basic case + decoded := mustHexDecode("aa") + require.Equal(t, "aa", truncateBytesAsHex(decoded, false)) + require.Equal(t, "aa", truncateBytesAsHex(decoded, true)) + + // Test exactly 32 bytes + decoded = mustHexDecode("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", truncateBytesAsHex(decoded, false)) + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", truncateBytesAsHex(decoded, true)) + + decoded = mustHexDecode("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + require.Equal(t, "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", truncateBytesAsHex(decoded, false)) + require.Equal(t, "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", truncateBytesAsHex(decoded, true)) + + // Test 1 truncate up + decoded = mustHexDecode("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa") + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", truncateBytesAsHex(decoded, false)) + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab", truncateBytesAsHex(decoded, true)) + + // Test one overflow + decoded = mustHexDecode("aaaaaaaaaaaaaaaaaaaaaaaaaaaaafffffffffffffffffffffffffffffffaaffffffff") + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaafffffffffffffffffffffffffffffffaaff", truncateBytesAsHex(decoded, false)) + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaafffffffffffffffffffffffffffffffab00", truncateBytesAsHex(decoded, true)) + + // Test many overflow + decoded = mustHexDecode("aaaaaaaaaaaaaaaaaaaaaaaaaaaaafffffffffffffffffffffffffffffffffffffffffffffffffffff") + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaafffffffffffffffffffffffffffffffffff", truncateBytesAsHex(decoded, false)) + require.Equal(t, "aaaaaaaaaaaaaaaaaaaaaaaaaaaab00000000000000000000000000000000000", truncateBytesAsHex(decoded, true)) + + // Test infinity + decoded = mustHexDecode("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffcccccccccccc") + require.Equal(t, "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", truncateBytesAsHex(decoded, false)) + require.Equal(t, "Z", truncateBytesAsHex(decoded, true)) +} + +func mustBase64Decode(s string) []byte { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + panic(err) + } + return b +} + +// TestCompat takes each stage of transforms that are applied in the JavaSDK and ensures that this SDK is byte for byte the same. +func TestCompat(t *testing.T) { + unpadded := mustBase64Decode("UEFSMRUAFUwVPhWUpsKLARwVBBUAFQYVCAAAH4sIAAAAAAAA/2NiYGBgZmZABT7ofABnJDzZJgAAABUAFSgVQhXlo/S6CRwVBBUAFQYVCAAAH4sIAAAAAAAA/2NiYGBgZmYGkoWlFVAKAA+YiDUUAAAAFQAVDhU2FZ/44TAcFQQVABUGFQgAAB+LCAAAAAAAAP9jYmBgYGZmBgB3cpG6BwAAABkRAhkYEAAAAAAAAAAAAAAAAAAAAEwZGBAAAAAAAAAAAAAAAAAAAABMFQIZFgAZFgQZJgAEABkRAhkYA3F1eBkYA3F1eBUCGRYAGRYEGSYABAAZEQIZGAEBGRgBARUCGRYAGRYEGSYABAAZHBYIFWwWAAAAGRwWdBVwFgAAGRYMABkcFuQBFWIWAAAAFQIZTEgEYmRlYxUGABUOFSAVAhgBQSUKFQAVTBUCHFwVABVMAAAAFQwlAhgBQiUANQQcHAAAABUAJQIYAUNVBgAWBBkcGTwmCBwVDhk1BggAGRgBQRUEFgQWehZsJgg8GBAAAAAAAAAAAAAAAAAAAABMGBAAAAAAAAAAAAAAAAAAAABMFgAoEAAAAAAAAAAAAAAAAAAAAEwYEAAAAAAAAAAAAAAAAAAAAEwAGRwVABUAFQIAPCkWBBkmAAQAABaaBBUUFsYCFWwAJnQcFQwZNQYIABkYAUIVBBYEFlYWcCZ0PBgDcXV4GANxdXgWACgDcXV4GANxdXgAGRwVABUAFQIAPBYMGRYEGSYABAAAFq4EFRoWsgMVOAAm5AEcFQAZNQYIABkYAUMVBBYEFjoWYibkATwYAQEYAQEWACgBARgBAQAZHBUAFQAVAgA8KRYEGSYABAAAFsgEFRYW6gMVMAAWigIWBCYIFr4CFAAAGVwYATEYAzIsNQAYATIYAzksOAAYATMYAzEsMQAYBXNmVmVyGAMxLDEAGA1wcmltYXJ5RmlsZUlkGENzbDFpejVfOVFqUVVKRDJZeGhrQ0hOZFZmUVR0dDBoR1JPR2tiMzdJTlIzM3BoRU00c0NDXzMwMDFfMzRfMC5iZGVjABhKcGFycXVldC1tciB2ZXJzaW9uIDEuMTQuMSAoYnVpbGQgOTdlZGU5NjgzNzc0MDBkMWQ3OWUzMTk2NjM2YmEzZGUzOTIxOTZiYSkZPBwAABwAABwAAABFAgAAUEFSMQ==") + actualPadded := padBuffer(slices.Clone(unpadded), aes.BlockSize) + padded := mustBase64Decode("UEFSMRUAFUwVPhWUpsKLARwVBBUAFQYVCAAAH4sIAAAAAAAA/2NiYGBgZmZABT7ofABnJDzZJgAAABUAFSgVQhXlo/S6CRwVBBUAFQYVCAAAH4sIAAAAAAAA/2NiYGBgZmYGkoWlFVAKAA+YiDUUAAAAFQAVDhU2FZ/44TAcFQQVABUGFQgAAB+LCAAAAAAAAP9jYmBgYGZmBgB3cpG6BwAAABkRAhkYEAAAAAAAAAAAAAAAAAAAAEwZGBAAAAAAAAAAAAAAAAAAAABMFQIZFgAZFgQZJgAEABkRAhkYA3F1eBkYA3F1eBUCGRYAGRYEGSYABAAZEQIZGAEBGRgBARUCGRYAGRYEGSYABAAZHBYIFWwWAAAAGRwWdBVwFgAAGRYMABkcFuQBFWIWAAAAFQIZTEgEYmRlYxUGABUOFSAVAhgBQSUKFQAVTBUCHFwVABVMAAAAFQwlAhgBQiUANQQcHAAAABUAJQIYAUNVBgAWBBkcGTwmCBwVDhk1BggAGRgBQRUEFgQWehZsJgg8GBAAAAAAAAAAAAAAAAAAAABMGBAAAAAAAAAAAAAAAAAAAABMFgAoEAAAAAAAAAAAAAAAAAAAAEwYEAAAAAAAAAAAAAAAAAAAAEwAGRwVABUAFQIAPCkWBBkmAAQAABaaBBUUFsYCFWwAJnQcFQwZNQYIABkYAUIVBBYEFlYWcCZ0PBgDcXV4GANxdXgWACgDcXV4GANxdXgAGRwVABUAFQIAPBYMGRYEGSYABAAAFq4EFRoWsgMVOAAm5AEcFQAZNQYIABkYAUMVBBYEFjoWYibkATwYAQEYAQEWACgBARgBAQAZHBUAFQAVAgA8KRYEGSYABAAAFsgEFRYW6gMVMAAWigIWBCYIFr4CFAAAGVwYATEYAzIsNQAYATIYAzksOAAYATMYAzEsMQAYBXNmVmVyGAMxLDEAGA1wcmltYXJ5RmlsZUlkGENzbDFpejVfOVFqUVVKRDJZeGhrQ0hOZFZmUVR0dDBoR1JPR2tiMzdJTlIzM3BoRU00c0NDXzMwMDFfMzRfMC5iZGVjABhKcGFycXVldC1tciB2ZXJzaW9uIDEuMTQuMSAoYnVpbGQgOTdlZGU5NjgzNzc0MDBkMWQ3OWUzMTk2NjM2YmEzZGUzOTIxOTZiYSkZPBwAABwAABwAAABFAgAAUEFSMQAAAAA=") + require.Equal(t, padded, actualPadded) + encryptionKey := "i3aoKhzaBpbgJ7NtZHagllmUxTDJEbcEObJg+OMbZio=" + blobPath := "2024/10/8/14/1/sl1iz5_9QjQUJD2YxhkCHNdVfQTtt0hGROGkb37INR33phEM4sCC_3001_34_0.bdec" + actualEncrypted, err := encrypt(slices.Clone(padded), encryptionKey, blobPath, 0) + require.NoError(t, err) + encrypted := mustBase64Decode("ZBVRKvbk6yq2rtif+3FeYsuVP6bh0JSvaViL843qnI+Nqcvl74xBYaFQ0YKbxRTg2pBGW2VHDQOPk03Fbg7ENHJGJFbv0Dr7R1sMQyMyHXQdQMEknrpinkomPA04K5EnNlJTY21pDqL4xpTBdeZWzX0SPGvhwQnSCmMPvNWsdeTq5fnqtunNfJES9FwKvVU1DVGoOewOs/sR7j7/IjVkcK8YElO+pqAMbf8OqFsoeVpWcaroT5fxZiSMZQ6jBRoBSRAtkFi9WFwEW6eGq+iMu9CGccumSOb48wj4aa8EuyZRWYa5vDqnJYz76+ea91Akvp1+OKkoA7QTUY7iBi4emH8AdeRlG35F5O/JCbZ1sNUhEoJSTQfRID582lK1MRsVaxwamJw/2Ty3NG80S22dVV2ILhjl38GZjypJHihCFjkU8g9qkEvhuwNrEeK6xwWJ6DF+OtxE6PzVUdNgOWzwFxRMASayZWyAH/+1KCVCIbURS5lDbT/Mv+fEA6waKasgiynqAIw/1z2c39h+ThtxNKWVaZzENGOOjAWpaKTSxQ8UiaiSG7WBtFtAmYJlQ5mAJO+i133Xipv86mVJv8OudRoIzYM8pZMVIP/Y7RD3kCkP3IzGS9QDQOhC8aXomHcEaXK+Z9iCewe9T+atdUX18OSuEr9owcI0Eu7gvWnpRK5fWVRqi3i+uz/HdmKF0qcmEDTzuMs+PvUl84J9kJjR1Savr4UKmZlp3u/i+nXTx0zgrV/NtdX4eXJMeaCaP2AJfKQzY1UCSFZS/5mSzsRzk/R3SiFLee7caWq7HsAQEAdpMz2pvylSxS0YCxL5KivGk/sKAMjaDRvQpblO5zcKH+mFaTgehpVr4oqaIwdMVw5Q7aRrjol97zMNu95kdCk8m2vyFvZKLzk+WWVxK645fJYUE2v/B8M3H3phVDJqn4//gGsQG/xLdwBWFpI1W9GZq4F3qvAxeB3XldKV1IsgH+ygBkxAAvlexba3Qb+rWnE9B+KjX+r8u8qI1WIDObF71NQ0m/bDgCz1KhIyUaYUu7O++U4vUK/e2TD2nX5+m3m3DAxHQousdiodh1C5dr249v0GTcbnKlCNLOMRCLdB222Xd2pQPI5M7p0Dj+yNrecD6FlIeLavEJF3QvE6urwmO8nMaJJ3WmX+euCO1Yia1m5gFBVnaSGSI1RmqxAiSUQ=") + require.Equal(t, encrypted, actualEncrypted) + fileMD5Hash := "c211779e08513408f0a8b28a17c230b0" + require.Equal(t, md5Hash(actualEncrypted), fileMD5Hash) + chunkMD5Hash := "1ca9f885bedc25ded3abf3df045543be" + require.Equal(t, md5Hash(actualEncrypted[:len(unpadded)]), chunkMD5Hash) +} + +func TestColumnNormalization(t *testing.T) { + require.Equal(t, "", normalizeColumnName("")) + require.Equal(t, "FOO", normalizeColumnName("foo")) + require.Equal(t, `bar`, normalizeColumnName(`"bar"`)) + require.Equal(t, "'BAR'", normalizeColumnName(`'bar'`)) + require.Equal(t, "BAR", normalizeColumnName(`bar`)) + require.Equal(t, `C1`, normalizeColumnName(`"C1"`)) + require.Equal(t, `how are you`, normalizeColumnName(`"how are you"`)) + require.Equal(t, `HOW ARE YOU`, normalizeColumnName(`how are you`)) + require.Equal(t, `how\ are\ you`, normalizeColumnName(`"how\ are\ you"`)) + require.Equal(t, `HOW ARE YOU`, normalizeColumnName(`how\ are\ you`)) + require.Equal(t, `"FOO`, normalizeColumnName(`"foo`)) + require.Equal(t, `FOO"`, normalizeColumnName(`foo"`)) + require.Equal(t, `FOO" BAR "BAZ`, normalizeColumnName(`foo" bar "baz`)) + require.Equal(t, `"FOO \"BAZ"`, normalizeColumnName(`"foo \"baz"`)) + require.Equal(t, `"FOO \"BAZ"`, normalizeColumnName(`"foo \"baz"`)) + require.Equal(t, `foo" bar "baz`, normalizeColumnName(`"foo"" bar ""baz"`)) +} + +func TestSnowflakeTimestamp(t *testing.T) { + type TestCase struct { + timestamp string + value int128.Num + scale int32 + keepTZ bool + tz bool + } + cases := [...]TestCase{ + { + timestamp: "2021-01-01 01:00:00.123", + value: int128.FromInt64(1609462800123000000), + scale: 9, + }, + { + timestamp: "1971-01-01 00:00:00.001", + value: int128.Mul(int128.FromInt64(31536000001), int128.FromInt64(1000000)), + scale: 9, + }, + { + timestamp: "1971-01-01 00:00:00.000", + value: int128.Mul(int128.FromInt64(31536000000), int128.FromInt64(1000000)), + scale: 9, + }, + { + timestamp: "2021-01-01 01:00:00.123", + value: int128.FromInt64(1609462800123000000), + scale: 9, + }, + { + timestamp: "2021-01-01 01:00:00.123", + value: int128.FromInt64(16094628001230), + scale: 4, + }, + { + timestamp: "2021-01-01 01:00:00.123+01:00", + value: int128.FromInt64(263693795348153820), + scale: 4, + keepTZ: true, + tz: true, + }, + { + timestamp: "2021-01-01 01:00:00.123+01:00", + value: int128.MustParse("26369379534815232001500"), + scale: 9, + keepTZ: true, + tz: true, + }, + { + timestamp: "2024-01-01 12:00:00.000-08:00", + value: int128.MustParse("1704139200000000000"), + scale: 9, + keepTZ: true, + tz: false, + }, + { + timestamp: "2024-01-01 12:00:00.000-08:00", + value: int128.MustParse("27920616652800000000960"), + scale: 9, + keepTZ: true, + tz: true, + }, + } + for _, c := range cases { + c := c + t.Run("", func(t *testing.T) { + layout := "2006-01-02 15:04:05.000" + if c.keepTZ { + layout = "2006-01-02 15:04:05.000-07:00" + } + parsed, err := time.Parse(layout, c.timestamp) + require.NoError(t, err) + require.Equal(t, c.value, snowflakeTimestampInt(parsed, c.scale, c.tz)) + }) + } +} diff --git a/internal/impl/snowflake/streaming/int128/decimal.go b/internal/impl/snowflake/streaming/int128/decimal.go new file mode 100644 index 0000000000..48cfdd317d --- /dev/null +++ b/internal/impl/snowflake/streaming/int128/decimal.go @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functionality in this file was derived (with modifications) from +// arrow-go and it's decimal128 package. We currently don't use that +// package directly due to bugs in the implementation, but hopefully +// we can upstream some fixes from that and then remove this package. + +package int128 + +import ( + "errors" + "fmt" + "math" + "math/big" +) + +// FitsInPrecision returns true or false if the value currently held by +// n would fit within precision (0 < prec <= 38) without losing any data. +func (i Num) FitsInPrecision(prec int32) bool { + if prec == 0 { + // Precision 0 is valid in snowflake, even if it seems useless + return i == Num{} + } + // The abs call does nothing for this value, so we need to handle it properly + if i == MinInt128 { + return false + } + return Less(i.Abs(), Pow10Table[prec]) +} + +func scalePositiveFloat64(v float64, prec, scale int32) (float64, error) { + var pscale float64 + if scale >= -38 && scale <= 38 { + pscale = float64PowersOfTen[scale+38] + } else { + pscale = math.Pow10(int(scale)) + } + + v *= pscale + v = math.RoundToEven(v) + maxabs := float64PowersOfTen[prec+38] + if v <= -maxabs || v >= maxabs { + return 0, fmt.Errorf("cannot convert %f to Int128(precision=%d, scale=%d): overflow", v, prec, scale) + } + return v, nil +} + +func fromPositiveFloat64(v float64, prec, scale int32) (Num, error) { + v, err := scalePositiveFloat64(v, prec, scale) + if err != nil { + return Num{}, err + } + + hi := math.Floor(math.Ldexp(v, -64)) + low := v - math.Ldexp(hi, 64) + return Num{hi: int64(hi), lo: uint64(low)}, nil +} + +// this has to exist despite sharing some code with fromPositiveFloat64 +// because if we don't do the casts back to float32 in between each +// step, we end up with a significantly different answer! +// Aren't floating point values so much fun? +// +// example value to use: +// +// v := float32(1.8446746e+15) +// +// You'll end up with a different values if you do: +// +// FromFloat64(float64(v), 20, 4) +// +// vs +// +// FromFloat32(v, 20, 4) +// +// because float64(v) == 1844674629206016 rather than 1844674600000000 +func fromPositiveFloat32(v float32, prec, scale int32) (Num, error) { + val, err := scalePositiveFloat64(float64(v), prec, scale) + if err != nil { + return Num{}, err + } + + hi := float32(math.Floor(math.Ldexp(float64(float32(val)), -64))) + low := float32(val) - float32(math.Ldexp(float64(hi), 64)) + return Num{hi: int64(hi), lo: uint64(low)}, nil +} + +// FromFloat32 returns a new Int128 constructed from the given float32 +// value using the provided precision and scale. Will return an error if the +// value cannot be accurately represented with the desired precision and scale. +func FromFloat32(v float32, prec, scale int32) (Num, error) { + if v < 0 { + dec, err := fromPositiveFloat32(-v, prec, scale) + if err != nil { + return dec, err + } + return Neg(dec), nil + } + return fromPositiveFloat32(v, prec, scale) +} + +// FromFloat64 returns a new Int128 constructed from the given float64 +// value using the provided precision and scale. Will return an error if the +// value cannot be accurately represented with the desired precision and scale. +func FromFloat64(v float64, prec, scale int32) (Num, error) { + if v < 0 { + dec, err := fromPositiveFloat64(-v, prec, scale) + if err != nil { + return dec, err + } + return Neg(dec), nil + } + return fromPositiveFloat64(v, prec, scale) +} + +var ( + pt5 = big.NewFloat(0.5) +) + +// FromString converts a string into an Int128 as long as it fits within the given precision and scale. +func FromString(v string, prec, scale int32) (n Num, err error) { + n, err = fromStringFast(v, prec, scale) + if err != nil { + n, err = fromStringSlow(v, prec, scale) + } + return +} + +var errFallbackNeeded = errors.New("fallback to slowpath needed") + +// A parsing fast path +func fromStringFast(s string, prec, scale int32) (n Num, err error) { + sLen := int32(len(s)) + // Even though there could be decimal points or negative/positive signs + // we need to limit the length of the string to prevent overflow. + // + // Using numbers this large is probably rare anyways. + if sLen == 0 || sLen > 38 { + err = errFallbackNeeded + return + } + s0 := s + if s[0] == '-' || s[0] == '+' { + s = s[1:] + if len(s) == 0 { + err = errFallbackNeeded + return + } + } + + // The value between '.' - '0' + // we can't write that expression because + // go is strict about overflow in constants + const dotMinusZero = 254 + for i, ch := range []byte(s) { + ch -= '0' + if ch > 9 { + if ch == dotMinusZero { + s = s[i+1:] + goto fraction + } + return n, errFallbackNeeded + } + n = Add(Mul(n, ten), FromUint64(uint64(ch))) + } +finish: + if s0[0] == '-' { + n = Neg(n) + } + // Rescale validates the the new number fits within the precision + n, err = Rescale(n, prec, scale) + return +fraction: + for i, ch := range []byte(s) { + ch -= '0' + if ch > 9 { + return n, errFallbackNeeded + } + if scale == 0 { + // Round! + if ch >= 5 { + n = Add(n, one) + } + // We need to validate the rest of the number is valid + // ie is not scientific notation + for _, ch := range []byte(s[i+1:]) { + ch -= '0' + if ch > 9 { + return n, errFallbackNeeded + } + } + break + } + n = Add(Mul(n, ten), FromUint64(uint64(ch))) + scale-- + } + goto finish +} + +func fromStringSlow(v string, prec, scale int32) (n Num, err error) { + var out *big.Float + out, _, err = big.ParseFloat(v, 10, 128, big.ToNearestAway) + if err != nil { + return + } + + var ok bool + if scale < 0 { + var tmp big.Int + val, _ := out.Int(&tmp) + n, ok = bigInt(val) + if !ok { + err = fmt.Errorf("value out of range: %s", v) + return + } + n = Div(n, Pow10Table[-scale]) + } else { + p := (&big.Float{}).SetPrec(128).SetInt(Pow10Table[scale].bigInt()) + out = out.Mul(out, p) + var tmp big.Int + val, _ := out.Int(&tmp) + // Round by subtracting the whole number so we only have the + // fractional bit left, then compare it to 0.5, then adjust + // the whole number according to IEEE RoundTiesToAway rounding + // mode, which is to round away from zero if the fractional + // part is |>=0.5|. + p = p.SetInt(val) + out = out.Sub(out, p) + if out.Signbit() { + if out.Cmp(pt5) <= 0 { + val = val.Sub(val, big.NewInt(1)) + } + } else { + if out.Cmp(pt5) >= 0 { + val = val.Add(val, big.NewInt(1)) + } + } + n, ok = bigInt(val) + if !ok { + err = fmt.Errorf("value out of range: %s", v) + return + } + } + + if !n.FitsInPrecision(prec) { + err = fmt.Errorf("val %s doesn't fit in precision %d", n.String(), prec) + } + return +} + +// ToFloat32 returns a float32 value representative of this Int128, +// but with the given scale. +func (i Num) ToFloat32(scale int32) float32 { + return float32(i.ToFloat64(scale)) +} + +func float64Positive(n Num, scale int32) float64 { + const twoTo64 float64 = 1.8446744073709552e+19 + x := float64(n.hi) * twoTo64 + x += float64(n.lo) + if scale >= -38 && scale <= 38 { + return x * float64PowersOfTen[-scale+38] + } + + return x * math.Pow10(-int(scale)) +} + +// ToFloat64 returns a float64 value representative of this Int128, +// but with the given scale. +func (i Num) ToFloat64(scale int32) float64 { + if i.hi < 0 { + return -float64Positive(Neg(i), scale) + } + return float64Positive(i, scale) +} + +// Rescale returns a new number such that it is scaled to |scale| (the current +// scale is assumed to be zero). It also validates that the scaled value fits +// within the specified precision. +func Rescale(n Num, precision, scale int32) (out Num, err error) { + if !n.FitsInPrecision(precision - scale) { + err = fmt.Errorf("value (%s) out of range (precision=%d,scale=%d)", n.String(), precision, scale) + return + } + if scale == 0 { + out = n + return + } + out = Mul(n, Pow10Table[scale]) + return +} + +var ( + float64PowersOfTen = [...]float64{ + 1e-38, 1e-37, 1e-36, 1e-35, 1e-34, 1e-33, 1e-32, 1e-31, 1e-30, 1e-29, + 1e-28, 1e-27, 1e-26, 1e-25, 1e-24, 1e-23, 1e-22, 1e-21, 1e-20, 1e-19, + 1e-18, 1e-17, 1e-16, 1e-15, 1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, + 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, + 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, + 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, 1e20, 1e21, + 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, 1e30, 1e31, + 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, + } +) diff --git a/internal/impl/snowflake/streaming/int128/decimal_test.go b/internal/impl/snowflake/streaming/int128/decimal_test.go new file mode 100644 index 0000000000..2db939ef06 --- /dev/null +++ b/internal/impl/snowflake/streaming/int128/decimal_test.go @@ -0,0 +1,471 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functionality in this file was derived (with modifications) from +// arrow-go and it's decimal128 package. We currently don't use that +// package directly due to bugs in the implementation, but hopefully +// we can upstream some fixes from that and then remove this package. + +package int128 + +import ( + "fmt" + "math" + "math/big" + "math/rand/v2" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func ulps64(actual, expected float64) int64 { + ulp := math.Nextafter(actual, math.Inf(1)) - actual + return int64(math.Abs((expected - actual) / ulp)) +} + +func ulps32(actual, expected float32) int64 { + ulp := math.Nextafter32(actual, float32(math.Inf(1))) - actual + return int64(math.Abs(float64((expected - actual) / ulp))) +} + +func assertFloat32Approx(t *testing.T, x, y float32) bool { + t.Helper() + const maxulps int64 = 4 + ulps := ulps32(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func assertFloat64Approx(t *testing.T, x, y float64) bool { + t.Helper() + const maxulps int64 = 4 + ulps := ulps64(x, y) + return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d ulps)", x, y, ulps) +} + +func TestDecimalToReal(t *testing.T) { + tests := []struct { + decimalVal string + scale int32 + exp float64 + }{ + {"0", 0, 0}, + {"0", 10, 0.0}, + {"0", -10, 0.0}, + {"1", 0, 1.0}, + {"12345", 0, 12345.0}, + {"12345", 1, 1234.5}, + // 2**62 + {"4611686018427387904", 0, math.Pow(2, 62)}, + // 2**63 + 2**62 + {"13835058055282163712", 0, math.Pow(2, 63) + math.Pow(2, 62)}, + // 2**64 + 2**62 + {"23058430092136939520", 0, math.Pow(2, 64) + math.Pow(2, 62)}, + // 10**38 - 2**103 + {"99999989858795198174164788026374356992", 0, math.Pow10(38) - math.Pow(2, 103)}, + } + + t.Run("float32", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float32, scale int32) { + bi, _ := (&big.Int{}).SetString(str, 10) + dec, ok := bigInt(bi) + assert.True(t, ok) + assert.Equalf(t, v, dec.ToFloat32(scale), "Decimal Val: %s, Scale: %d, Val: %s", str, scale, dec.String()) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, float32(tt.exp), tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, float32(-tt.exp), tt.scale) + } + }) + } + + t.Run("precision", func(t *testing.T) { + // 2**63 + 2**40 (exactly representable in a float's 24 bits of precision) + checkDecimalToFloat(t, "9223373136366403584", float32(9.223373e+18), 0) + checkDecimalToFloat(t, "-9223373136366403584", float32(-9.223373e+18), 0) + // 2**64 + 2**41 exactly representable in a float + checkDecimalToFloat(t, "18446746272732807168", float32(1.8446746e+19), 0) + checkDecimalToFloat(t, "-18446746272732807168", float32(-1.8446746e+19), 0) + }) + + t.Run("large values", func(t *testing.T) { + checkApproxDecimalToFloat := func(str string, v float32, scale int32) { + bi, _ := (&big.Int{}).SetString(str, 10) + dec, ok := bigInt(bi) + assert.True(t, ok) + assertFloat32Approx(t, v, dec.ToFloat32(scale)) + } + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-38); scale <= 38; scale++ { + checkApproxDecimalToFloat("1", float32(math.Pow10(-int(scale))), scale) + checkApproxDecimalToFloat("123", float32(123)*float32(math.Pow10(-int(scale))), scale) + } + }) + }) + + t.Run("float64", func(t *testing.T) { + checkDecimalToFloat := func(t *testing.T, str string, v float64, scale int32) { + bi, _ := (&big.Int{}).SetString(str, 10) + dec, ok := bigInt(bi) + assert.True(t, ok) + assert.Equalf(t, v, dec.ToFloat64(scale), "Decimal Val: %s, Scale: %d", str, scale) + } + for _, tt := range tests { + t.Run(tt.decimalVal, func(t *testing.T) { + checkDecimalToFloat(t, tt.decimalVal, tt.exp, tt.scale) + if tt.decimalVal != "0" { + checkDecimalToFloat(t, "-"+tt.decimalVal, -tt.exp, tt.scale) + } + }) + } + + t.Run("precision", func(t *testing.T) { + // 2**63 + 2**11 (exactly representable in float64's 53 bits of precision) + checkDecimalToFloat(t, "9223373136366403584", float64(9.223373136366404e+18), 0) + checkDecimalToFloat(t, "-9223373136366403584", float64(-9.223373136366404e+18), 0) + + // 2**64 - 2**11 (exactly representable in a float64) + checkDecimalToFloat(t, "18446746272732807168", float64(1.8446746272732807e+19), 0) + checkDecimalToFloat(t, "-18446746272732807168", float64(-1.8446746272732807e+19), 0) + + // 2**64 + 2**11 (exactly representable in a float64) + checkDecimalToFloat(t, "18446744073709555712", float64(1.8446744073709556e+19), 0) + checkDecimalToFloat(t, "-18446744073709555712", float64(-1.8446744073709556e+19), 0) + + // Almost 10**38 (minus 2**73) + checkDecimalToFloat(t, "99999999999999978859343891977453174784", 9.999999999999998e+37, 0) + checkDecimalToFloat(t, "-99999999999999978859343891977453174784", -9.999999999999998e+37, 0) + checkDecimalToFloat(t, "99999999999999978859343891977453174784", 9.999999999999998e+27, 10) + checkDecimalToFloat(t, "-99999999999999978859343891977453174784", -9.999999999999998e+27, 10) + checkDecimalToFloat(t, "99999999999999978859343891977453174784", 9.999999999999998e+47, -10) + checkDecimalToFloat(t, "-99999999999999978859343891977453174784", -9.999999999999998e+47, -10) + }) + + t.Run("large values", func(t *testing.T) { + checkApproxDecimalToFloat := func(str string, v float64, scale int32) { + bi, _ := (&big.Int{}).SetString(str, 10) + dec, ok := bigInt(bi) + assert.True(t, ok) + assertFloat64Approx(t, v, dec.ToFloat64(scale)) + } + // exact comparisons would succeed on most platforms, but not all power-of-ten + // factors are exactly representable in binary floating point, so we'll use + // approx and ensure that the values are within 4 ULP (unit of least precision) + for scale := int32(-308); scale <= 306; scale++ { + checkApproxDecimalToFloat("1", math.Pow10(-int(scale)), scale) + checkApproxDecimalToFloat("123", float64(123)*math.Pow10(-int(scale)), scale) + } + }) + }) +} + +func TestDecimalFromFloat(t *testing.T) { + tests := []struct { + val float64 + precision, scale int32 + expected string + }{ + {0, 1, 0, "0"}, + {-0, 1, 0, "0"}, + {0, 19, 4, "0.0000"}, + {math.Copysign(0.0, -1), 19, 4, "0.0000"}, + {123, 7, 4, "123.0000"}, + {-123, 7, 4, "-123.0000"}, + {456.78, 7, 4, "456.7800"}, + {-456.78, 7, 4, "-456.7800"}, + {456.784, 5, 2, "456.78"}, + {-456.784, 5, 2, "-456.78"}, + {456.786, 5, 2, "456.79"}, + {-456.786, 5, 2, "-456.79"}, + {999.99, 5, 2, "999.99"}, + {-999.99, 5, 2, "-999.99"}, + {123, 19, 0, "123"}, + {-123, 19, 0, "-123"}, + {123.4, 19, 0, "123"}, + {-123.4, 19, 0, "-123"}, + {123.6, 19, 0, "124"}, + {-123.6, 19, 0, "-124"}, + // 2**62 + {4.611686018427387904e+18, 19, 0, "4611686018427387904"}, + {-4.611686018427387904e+18, 19, 0, "-4611686018427387904"}, + // 2**63 + {9.223372036854775808e+18, 19, 0, "9223372036854775808"}, + {-9.223372036854775808e+18, 19, 0, "-9223372036854775808"}, + // 2**64 + {1.8446744073709551616e+19, 20, 0, "18446744073709551616"}, + {-1.8446744073709551616e+19, 20, 0, "-18446744073709551616"}, + } + + t.Run("float64", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := FromFloat64(tt.val, tt.precision, tt.scale) + assert.NoError(t, err) + + assert.Equal(t, tt.expected, big.NewFloat(n.ToFloat64(tt.scale)).Text('f', int(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float64 range + for scale := int32(-308); scale <= 308; scale++ { + val := math.Pow10(int(scale)) + n, err := FromFloat64(val, 1, -scale) + assert.NoError(t, err) + assert.Equal(t, "1", n.bigInt().String()) + } + + for scale := int32(-307); scale <= 306; scale++ { + val := 123 * math.Pow10(int(scale)) + n, err := FromFloat64(val, 2, -scale-1) + assert.NoError(t, err) + assert.Equal(t, "12", n.bigInt().String()) + n, err = FromFloat64(val, 3, -scale) + assert.NoError(t, err) + assert.Equal(t, "123", n.bigInt().String()) + n, err = FromFloat64(val, 4, -scale+1) + assert.NoError(t, err) + assert.Equal(t, "1230", n.bigInt().String()) + } + }) + }) + + t.Run("float32", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + n, err := FromFloat32(float32(tt.val), tt.precision, tt.scale) + assert.NoError(t, err) + + assert.Equal(t, tt.expected, big.NewFloat(float64(n.ToFloat32(tt.scale))).Text('f', int(tt.scale))) + }) + } + + t.Run("large values", func(t *testing.T) { + // test entire float32 range + for scale := int32(-38); scale <= 38; scale++ { + val := float32(math.Pow10(int(scale))) + n, err := FromFloat32(val, 1, -scale) + assert.NoError(t, err) + assert.Equal(t, "1", n.bigInt().String()) + } + + for scale := int32(-37); scale <= 36; scale++ { + val := 123 * float32(math.Pow10(int(scale))) + n, err := FromFloat32(val, 2, -scale-1) + assert.NoError(t, err) + assert.Equal(t, "12", n.bigInt().String()) + n, err = FromFloat32(val, 3, -scale) + assert.NoError(t, err) + assert.Equal(t, "123", n.bigInt().String()) + n, err = FromFloat32(val, 4, -scale+1) + assert.NoError(t, err) + assert.Equal(t, "1230", n.bigInt().String()) + } + }) + }) +} + +func TestFromString(t *testing.T) { + tests := []struct { + s string + expected int64 + expectedScale int32 + }{ + {"12.3", 123, 1}, + {"0.00123", 123, 5}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"1.23e+3", 1230, 0}, + {"-1.23E+3", -1230, 0}, + {"1.23e+5", 123000, 0}, + {"1.2345E+7", 12345000, 0}, + {"1.23e-8", 123, 10}, + {"-1.23E-8", -123, 10}, + {"1.23E+3", 1230, 0}, + {"-1.23e+3", -1230, 0}, + {"1.23e+5", 123000, 0}, + {"1.2345e+7", 12345000, 0}, + {"0000000", 0, 0}, + {"000.0000", 0, 4}, + {".00000", 0, 5}, + {"1e1", 10, 0}, + {"+234.567", 234567, 3}, + {"1e-37", 1, 37}, + {"2112.33", 211233, 2}, + {"-2112.33", -211233, 2}, + {"12E2", 12, -2}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%s_%d", tt.s, tt.expectedScale), func(t *testing.T) { + n, err := FromString(tt.s, 37, tt.expectedScale) + assert.NoError(t, err) + + ex := FromInt64(tt.expected) + assert.Equal(t, ex, n, "got: %s, want: %d", n.String(), tt.expected) + }) + } +} + +func TestFromStringFast(t *testing.T) { + tests := []string{ + "0", + "0924535.11610", + "480754368.9554427", + "1", + "11", + "11.1", + "12345.12345", + "999999999999999999999999999999999999.9", + } + + for _, str := range tests { + str := str + digitCount, leadingDigits := computeDecimalParameters(str) + t.Run(str, func(t *testing.T) { + cases := 0 + for prec := int32(38); prec >= digitCount; prec-- { + maxScale := prec - leadingDigits + for scale := maxScale; scale >= 0; scale-- { + actual, actualErr := fromStringFast(str, prec, scale) + assert.NoError(t, actualErr) + expected, expectedErr := fromStringSlow(str, prec, scale) + assert.NoError(t, expectedErr) + assert.Equal( + t, + expected, + actual, + "NUMBER(%d, %d): want: %s, got: %s", + prec, scale, + expected.String(), + actual.String(), + ) + cases++ + } + } + }) + } + // Try to stress some edge cases where we could overflow but result in something + // valid after + t.Run("OverflowEdgeCase", func(t *testing.T) { + v, err := fromStringFast(strings.Repeat("9", 40), 38, 0) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast(strings.Repeat("9", 40), 38, 37) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast(strings.Repeat("9", 40), 38, 38) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("9"+strings.Repeat("0", 39), 38, 0) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("9"+strings.Repeat("0", 39), 38, 37) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("9"+strings.Repeat("0", 39), 38, 38) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("76063353390654101946871725586039877751.7", 38, 1) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("99999999999999999999999999999999999999.9", 38, 1) + assert.Error(t, err, "got: %v", v) + v, err = fromStringFast("999999999999999999999999999999999999.9", 38, 3) + assert.Error(t, err, "got: %v", v) + for i := 1; i <= 38; i++ { + v, err = fromStringFast(strings.Repeat("9", 38), 38, int32(i)) + assert.Error(t, err, "got: %v", v) + } + }) +} + +func TestFromStringFastVsSlowRandomized(t *testing.T) { + for i := 0; i < 1000; i++ { + precision := rand.N(36) + 2 + scale := rand.N(precision - 1) + str := "" + for j := 0; j < precision; j++ { + str += strconv.Itoa(rand.N(10)) + } + if scale > 0 { + str += "." + for j := 0; j < scale; j++ { + str += strconv.Itoa(rand.N(10)) + } + } + fastN, fastErr := fromStringFast(str, int32(precision), int32(scale)) + if fastErr == errFallbackNeeded { + continue + } + slowN, slowErr := fromStringSlow(str, int32(precision), int32(scale)) + require.Equal(t, slowErr == nil, fastErr == nil, "%s (scale=%d,precision=%d): slowErr=%v, fastErr=%v", str, scale, precision, slowErr, fastErr) + if slowErr == nil && fastErr == nil { + require.Equal(t, fastN, slowN, "%s (scale=%d,precision=%d): %s vs %s", str, scale, precision, fastN, slowN) + } + } +} + +func BenchmarkParsing(b *testing.B) { + tests := []string{ + "1", + "11", + "11.1", + "12345.12345", + "99999999999999999999999999999999999999", + "-9999999999999999999999999999999999999", + "1234567890.1234567890", + } + for _, test := range tests { + test := test + digitCount, leadingDigits := computeDecimalParameters(test) + scale := digitCount - leadingDigits + b.Run("fast_"+test, func(b *testing.B) { + b.SetBytes(int64(len(test))) + for i := 0; i < b.N; i++ { + _, err := fromStringFast(test, digitCount, scale) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("slow_"+test, func(b *testing.B) { + b.SetBytes(int64(len(test))) + for i := 0; i < b.N; i++ { + _, err := fromStringSlow(test, digitCount, scale) + if err != nil { + b.Fatal(err) + } + } + }) + } +} + +func computeDecimalParameters(str string) (digitCount int32, leadingDigits int32) { + foundFraction := false + for _, r := range str { + if r == '.' { + foundFraction = true + continue + } + if r != '-' { + digitCount++ + if !foundFraction { + leadingDigits++ + } + } + } + return +} diff --git a/internal/impl/snowflake/streaming/int128/division.go b/internal/impl/snowflake/streaming/int128/division.go new file mode 100644 index 0000000000..95b37a7dd5 --- /dev/null +++ b/internal/impl/snowflake/streaming/int128/division.go @@ -0,0 +1,78 @@ +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The algorithm here is ported from absl so we attribute changes in this file +// under the same license, even though it's golang. + +package int128 + +// Div computes a / b +// +// Division by zero panics +func Div(dividend, divisor Num) Num { + // algorithm is ported from absl::int128 + if divisor == (Num{}) { + panic("int128 division by zero") + } + negateQuotient := (dividend.hi < 0) != (divisor.hi < 0) + if dividend.IsNegative() { + dividend = Neg(dividend) + } + if divisor.IsNegative() { + divisor = Neg(divisor) + } + if divisor == dividend { + return FromInt64(1) + } + if uGt(divisor, dividend) { + return Num{} + } + denominator := divisor + var quotient Num + shift := fls128(dividend) - fls128(denominator) + denominator = Shl(denominator, uint(shift)) + // Uses shift-subtract algorithm to divide dividend by denominator. The + // remainder will be left in dividend. + for i := 0; i <= shift; i++ { + quotient = Shl(quotient, 1) + if uGt(dividend, denominator) { + dividend = Sub(dividend, denominator) + quotient = Or(quotient, FromInt64(1)) + } + denominator = uShr(denominator, 1) + } + if negateQuotient { + quotient = Neg(quotient) + } + return quotient +} + +// uShr is unsigned shift right (no sign extending) +func uShr(v Num, amt uint) Num { + n := amt - 64 + m := 64 - amt + return Num{ + hi: int64(uint64(v.hi) >> amt), + lo: v.lo>>amt | uint64(v.hi)>>n | uint64(v.hi)<= b.lo + } else { + return uint64(a.hi) >= uint64(b.hi) + } +} diff --git a/internal/impl/snowflake/streaming/int128/int128.go b/internal/impl/snowflake/streaming/int128/int128.go new file mode 100644 index 0000000000..3eedb8fd97 --- /dev/null +++ b/internal/impl/snowflake/streaming/int128/int128.go @@ -0,0 +1,373 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +// package int128 contains an implmentation of int128 that is more +// efficent (no allocations) compared to math/big.Int +// +// Several Snowflake data types are under the hood int128 (date/time), +// so we can use this type and not hurt performance. +package int128 + +import ( + "encoding/binary" + "fmt" + "math" + "math/big" + "math/bits" +) + +// Common constant values for int128 +var ( + MaxInt128 = FromBigEndian([]byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + MinInt128 = FromBigEndian([]byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + MaxInt64 = FromInt64(math.MaxInt64) + MinInt64 = FromInt64(math.MinInt64) + MaxInt32 = FromInt64(math.MaxInt32) + MinInt32 = FromInt64(math.MinInt32) + MaxInt16 = FromInt64(math.MaxInt16) + MinInt16 = FromInt64(math.MinInt16) + MaxInt8 = FromInt64(math.MaxInt8) + MinInt8 = FromInt64(math.MinInt8) + one = FromUint64(1) + ten = FromUint64(10) + + // For Snowflake, we need to do some quick multiplication to scale numbers + // to make that fast we precompute some powers of 10 in a lookup table. + Pow10Table = [...]Num{ + FromUint64(1e00), + FromUint64(1e01), + FromUint64(1e02), + FromUint64(1e03), + FromUint64(1e04), + FromUint64(1e05), + FromUint64(1e06), + FromUint64(1e07), + FromUint64(1e08), + FromUint64(1e09), + FromUint64(1e10), + FromUint64(1e11), + FromUint64(1e12), + FromUint64(1e13), + FromUint64(1e14), + FromUint64(1e15), + FromUint64(1e16), + FromUint64(1e17), + FromUint64(1e18), + FromUint64(1e19), + New(5, 7766279631452241920), + New(54, 3875820019684212736), + New(542, 1864712049423024128), + New(5421, 200376420520689664), + New(54210, 2003764205206896640), + New(542101, 1590897978359414784), + New(5421010, 15908979783594147840), + New(54210108, 11515845246265065472), + New(542101086, 4477988020393345024), + New(5421010862, 7886392056514347008), + New(54210108624, 5076944270305263616), + New(542101086242, 13875954555633532928), + New(5421010862427, 9632337040368467968), + New(54210108624275, 4089650035136921600), + New(542101086242752, 4003012203950112768), + New(5421010862427522, 3136633892082024448), + New(54210108624275221, 12919594847110692864), + New(542101086242752217, 68739955140067328), + New(5421010862427522170, 687399551400673280), + } +) + +// Num is a *signed* int128 type that is more efficent than big.Int +// +// Default value is 0 +type Num struct { + hi int64 + lo uint64 +} + +// New constructs an Int128 from two 64 bit integers. +func New(hi int64, lo uint64) Num { + return Num{ + hi: hi, + lo: lo, + } +} + +// FromInt64 casts an signed int64 to uint128 +func FromInt64(v int64) Num { + hi := int64(0) + // sign extend + if v < 0 { + hi = ^hi + } + return Num{ + hi: hi, + lo: uint64(v), + } +} + +// FromUint64 casts an unsigned int64 to uint128 +func FromUint64(v uint64) Num { + return Num{ + hi: 0, + lo: v, + } +} + +// Add computes a + b +func Add(a, b Num) Num { + lo, carry := bits.Add64(a.lo, b.lo, 0) + hi, _ := bits.Add64(uint64(a.hi), uint64(b.hi), carry) + return Num{int64(hi), lo} +} + +// Sub computes a - b +func Sub(a, b Num) Num { + lo, carry := bits.Sub64(a.lo, b.lo, 0) + hi, _ := bits.Sub64(uint64(a.hi), uint64(b.hi), carry) + return Num{int64(hi), lo} +} + +// Mul computes a * b +func Mul(a, b Num) Num { + hi, lo := bits.Mul64(a.lo, b.lo) + hi += (uint64(a.hi) * b.lo) + (a.lo * uint64(b.hi)) + return Num{hi: int64(hi), lo: lo} +} + +func fls128(n Num) int { + if n.hi != 0 { + return 127 - bits.LeadingZeros64(uint64(n.hi)) + } + return 64 - bits.LeadingZeros64(n.lo) +} + +// Neg computes -v +func Neg(n Num) Num { + n.lo = ^n.lo + 1 + n.hi = ^n.hi + if n.lo == 0 { + n.hi += 1 + } + return n +} + +// Abs computes v < 0 ? -v : v +func (i Num) Abs() Num { + if i.IsNegative() { + return Neg(i) + } + return i +} + +// IsNegative returns true if `i` is negative +func (i Num) IsNegative() bool { + return i.hi < 0 +} + +// Shl returns a << i +func Shl(v Num, amt uint) Num { + n := amt - 64 + m := 64 - amt + return Num{ + hi: v.hi<>m), + lo: v.lo << amt, + } +} + +// Or returns a | i +func Or(a Num, b Num) Num { + return Num{ + hi: a.hi | b.hi, + lo: a.lo | b.lo, + } +} + +// Less returns a < b +func Less(a, b Num) bool { + if a.hi == b.hi { + return a.lo < b.lo + } else { + return a.hi < b.hi + } +} + +// Greater returns a > b +func Greater(a, b Num) bool { + if a.hi == b.hi { + return a.lo > b.lo + } else { + return a.hi > b.hi + } +} + +// FromBigEndian converts bi endian bytes to Int128 +func FromBigEndian(b []byte) Num { + hi := int64(binary.BigEndian.Uint64(b[0:8])) + lo := binary.BigEndian.Uint64(b[8:16]) + return Num{ + hi: hi, + lo: lo, + } +} + +// ToBigEndian converts an Int128 into big endian bytes +func (i Num) ToBigEndian() []byte { + b := make([]byte, 16) + binary.BigEndian.PutUint64(b[0:8], uint64(i.hi)) + binary.BigEndian.PutUint64(b[8:16], i.lo) + return b +} + +// AppendBigEndian converts an Int128 into big endian bytes +func (i Num) AppendBigEndian(b []byte) []byte { + b = binary.BigEndian.AppendUint64(b[0:8], uint64(i.hi)) + return binary.BigEndian.AppendUint64(b[8:16], i.lo) +} + +// ToInt64 casts an Int128 to a int64 by truncating the bytes. +func (i Num) ToInt64() int64 { + return int64(i.lo) +} + +// ToInt32 casts an Int128 to a int32 by truncating the bytes. +func (i Num) ToInt32() int32 { + return int32(i.lo) +} + +// ToInt16 casts an Int128 to a int16 by truncating the bytes. +func (i Num) ToInt16() int16 { + return int16(i.lo) +} + +// ToInt8 casts an Int128 to a int8 by truncating the bytes. +func (i Num) ToInt8() int8 { + return int8(i.lo) +} + +// Min computes min(a, b) +func Min(a, b Num) Num { + if Less(a, b) { + return a + } else { + return b + } +} + +// Max computes min(a, b) +func Max(a, b Num) Num { + if Greater(a, b) { + return a + } else { + return b + } +} + +// MustParse converted a base 10 formatted string into an Int128 +// and panics otherwise +// +// Only use for testing. +func MustParse(str string) Num { + n, ok := Parse(str) + if !ok { + panic(fmt.Sprintf("unable to parse %q into Int128", str)) + } + return n +} + +// Parse converted a base 10 formatted string into an Int128 +// +// Not fast, but simple +func Parse(str string) (n Num, ok bool) { + var bi *big.Int + bi, ok = big.NewInt(0).SetString(str, 10) + if !ok { + return + } + return bigInt(bi) +} + +// String returns the number as base 10 formatted string. +// +// This is not fast but it isn't on a hot path. +func (i Num) String() string { + return string(i.bigInt().Append(nil, 10)) +} + +// MarshalJSON implements JSON serialization of +// an int128 like BigInteger in the Snowflake +// Java SDK with Jackson. +// +// This is not fast but it isn't on a hot path. +func (i Num) MarshalJSON() ([]byte, error) { + return i.bigInt().Append(nil, 10), nil +} + +func (i Num) bigInt() *big.Int { + hi := big.NewInt(i.hi) // Preserves sign + hi = hi.Lsh(hi, 64) + lo := &big.Int{} + lo.SetUint64(i.lo) + return hi.Or(hi, lo) +} + +var ( + maxBigInt128 = MaxInt128.bigInt() + minBigInt128 = MinInt128.bigInt() +) + +func bigInt(bi *big.Int) (n Num, ok bool) { + // One cannot check BitLen here because that misses that MinInt128 + // requires 128 bits along with other out of range values. Instead + // the better check is to explicitly compare our allowed bounds + ok = bi.Cmp(minBigInt128) >= 0 && bi.Cmp(maxBigInt128) <= 0 + if !ok { + return + } + b := bi.Bits() + if len(b) == 0 { + return + } + n.lo = uint64(b[0]) + if len(b) > 1 { + n.hi = int64(b[1]) + } + if bi.Sign() < 0 { + n = Neg(n) + } + return +} + +// ByteWidth returns the maximum number of bytes needed to store v +func ByteWidth(v Num) int { + if v.IsNegative() { + switch { + case !Less(v, MinInt8): + return 1 + case !Less(v, MinInt16): + return 2 + case !Less(v, MinInt32): + return 4 + case !Less(v, MinInt64): + return 8 + } + return 16 + } + switch { + case !Greater(v, MaxInt8): + return 1 + case !Greater(v, MaxInt16): + return 2 + case !Greater(v, MaxInt32): + return 4 + case !Greater(v, MaxInt64): + return 8 + } + return 16 +} diff --git a/internal/impl/snowflake/streaming/int128/int128_test.go b/internal/impl/snowflake/streaming/int128/int128_test.go new file mode 100644 index 0000000000..765b6ce731 --- /dev/null +++ b/internal/impl/snowflake/streaming/int128/int128_test.go @@ -0,0 +1,390 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package int128 + +import ( + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdd(t *testing.T) { + require.Equal(t, MinInt128, Add(MaxInt128, FromInt64(1))) + require.Equal(t, MaxInt128, Add(MinInt128, FromInt64(-1))) + require.Equal(t, FromInt64(2), Add(FromInt64(1), FromInt64(1))) + require.Equal( + t, + FromBigEndian([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE}), + Add(FromUint64(math.MaxUint64), FromUint64(math.MaxUint64)), + ) + require.Equal( + t, + FromBigEndian([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + Add(FromInt64(math.MaxInt64), FromInt64(1)), + ) + require.Equal( + t, + FromBigEndian([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + Add(FromUint64(math.MaxUint64), FromInt64(1)), + ) +} + +func TestSub(t *testing.T) { + require.Equal(t, MaxInt128, Sub(MinInt128, FromInt64(1))) + require.Equal(t, MinInt128, Sub(MaxInt128, FromInt64(-1))) + require.Equal( + t, + FromBigEndian([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}), + Sub(FromInt64(0), FromInt64(math.MaxInt64)), + ) + require.Equal( + t, + FromBigEndian([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}), + Sub(FromInt64(0), FromUint64(math.MaxUint64)), + ) +} + +func SlowMul(a Num, b Num) Num { + delta := FromInt64(-1) + deltaFn := Add + if Less(b, FromInt64(0)) { + delta = FromInt64(1) + deltaFn = Sub + } + r := FromInt64(0) + for i := b; i != FromInt64(0); i = Add(i, delta) { + r = deltaFn(r, a) + } + return r +} + +func TestMul(t *testing.T) { + tc := [][2]Num{ + {FromInt64(10), FromInt64(10)}, + {FromInt64(1), FromInt64(10)}, + {FromInt64(0), FromInt64(10)}, + {FromInt64(0), FromInt64(0)}, + {FromInt64(math.MaxInt64), FromInt64(0)}, + {FromInt64(math.MaxInt64), FromInt64(1)}, + {FromInt64(math.MaxInt64), FromInt64(2)}, + {FromInt64(math.MaxInt64), FromInt64(3)}, + {FromInt64(math.MaxInt64), FromInt64(4)}, + {FromInt64(math.MaxInt64), FromInt64(10)}, + {FromUint64(math.MaxUint64), FromInt64(10)}, + {FromUint64(math.MaxUint64), FromInt64(2)}, + {FromUint64(math.MaxUint64), FromInt64(100)}, + {MaxInt128, FromInt64(100)}, + {MaxInt128, FromInt64(10)}, + {MinInt128, FromInt64(10)}, + {MinInt128, FromInt64(-1)}, + {MaxInt128, FromInt64(-1)}, + {FromInt64(-1), FromInt64(-1)}, + } + for _, c := range tc { + a, b := c[0], c[1] + expected := SlowMul(a, b) + actual := Mul(a, b) + require.Equal( + t, + expected, + actual, + "%s x %s, got: %s, want: %s", + a.String(), + b.String(), + actual.String(), + expected.String(), + ) + actual = Mul(b, a) + require.Equal( + t, + expected, + actual, + "%s x %s, got: %s, want: %s", + b.String(), + a.String(), + actual.String(), + expected.String(), + ) + } +} + +func TestShl(t *testing.T) { + for i := uint(0); i < 64; i++ { + require.Equal(t, Num{lo: 1 << i}, Shl(FromInt64(1), i)) + require.Equal(t, Num{hi: 1 << i}, Shl(FromInt64(1), i+64)) + require.Equal(t, Num{hi: ^0, lo: uint64(int64(-1) << i)}, Shl(FromInt64(-1), i)) + require.Equal(t, Num{hi: -1 << i}, Shl(FromInt64(-1), i+64)) + } + require.Equal(t, Num{}, Shl(FromInt64(1), 128)) + require.Equal(t, Num{}, Shl(FromInt64(-1), 128)) +} + +func TestUshr(t *testing.T) { + for i := uint(0); i < 64; i++ { + require.Equal(t, Num{hi: int64(uint64(1<<63) >> i)}, uShr(MinInt128, i), i) + require.Equal(t, Num{lo: (1 << 63) >> i}, uShr(MinInt128, i+64), i) + } + require.Equal(t, Num{}, uShr(MinInt128, 128)) + require.Equal(t, Num{}, uShr(FromInt64(-1), 128)) +} + +func TestNeg(t *testing.T) { + require.Equal(t, FromInt64(-1), Neg(FromInt64(1))) + require.Equal(t, FromInt64(1), Neg(FromInt64(-1))) + require.Equal(t, Sub(FromInt64(0), MaxInt64), Neg(MaxInt64)) + require.Equal(t, Add(MinInt128, FromInt64(1)), Neg(MaxInt128)) + require.Equal(t, MinInt128, Neg(MinInt128)) +} + +func TestDiv(t *testing.T) { + type TestCase struct { + dividend, divisor, quotient Num + } + cases := []TestCase{ + {FromInt64(100), FromInt64(10), FromInt64(10)}, + {FromInt64(64), FromInt64(8), FromInt64(8)}, + {FromInt64(10), FromInt64(3), FromInt64(3)}, + {FromInt64(99), FromInt64(25), FromInt64(3)}, + { + FromInt64(0x15f2a64138), + FromInt64(0x67da05), + FromInt64(0x15f2a64138 / 0x67da05), + }, + { + FromInt64(0x5e56d194af43045f), + FromInt64(0xcf1543fb99), + FromInt64(0x5e56d194af43045f / 0xcf1543fb99), + }, + { + FromInt64(0x15e61ed052036a), + FromInt64(-0xc8e6), + FromInt64(0x15e61ed052036a / -0xc8e6), + }, + { + FromInt64(0x88125a341e85), + FromInt64(-0xd23fb77683), + FromInt64(0x88125a341e85 / -0xd23fb77683), + }, + { + FromInt64(-0xc06e20), + FromInt64(0x5a), + FromInt64(-0xc06e20 / 0x5a), + }, + { + FromInt64(-0x4f100219aea3e85d), + FromInt64(0xdcc56cb4efe993), + FromInt64(-0x4f100219aea3e85d / 0xdcc56cb4efe993), + }, + { + FromInt64(-0x168d629105), + FromInt64(-0xa7), + FromInt64(-0x168d629105 / -0xa7), + }, + { + FromInt64(-0x7b44e92f03ab2375), + FromInt64(-0x6516), + FromInt64(-0x7b44e92f03ab2375 / -0x6516), + }, + { + Num{0x6ada48d489007966, 0x3c9c5c98150d5d69}, + Num{0x8bc308fb, 0x8cb9cc9a3b803344}, + FromInt64(0xc3b87e08), + }, + { + Num{0xd6946511b5b, 0x4886c5c96546bf5f}, + Neg(Num{0x263b, 0xfd516279efcfe2dc}), + FromInt64(-0x59cbabf0), + }, + { + Neg(Num{0x33db734f9e8d1399, 0x8447ac92482bca4d}), + FromInt64(0x37495078240), + Neg(Num{0xf01f1, 0xbc0368bf9a77eae8}), + }, + { + Neg(Num{0x13f837b409a07e7d, 0x7fc8e248a7d73560}), + FromInt64(-0x1b9f), + Num{0xb9157556d724, 0xb14f635714d7563e}, + }, + } + for _, c := range cases { + c := c + t.Run("", func(t *testing.T) { + require.Equal( + t, + c.quotient, + Div(c.dividend, c.divisor), + "%s / %s = %s", + c.dividend, + c.divisor, + c.quotient, + ) + }) + } +} + +func TestPow10(t *testing.T) { + expected := FromInt64(1) + for _, v := range Pow10Table { + require.Equal(t, expected, v) + expected = Mul(expected, FromInt64(10)) + } +} + +func TestCompare(t *testing.T) { + tc := [][2]Num{ + {FromInt64(0), FromInt64(1)}, + {FromInt64(-1), FromInt64(0)}, + {MinInt128, FromInt64(0)}, + {MinInt128, FromInt64(-1)}, + {MinInt128, FromInt64(math.MinInt64)}, + {MinInt128, FromUint64(math.MaxUint64)}, + {MinInt128, MaxInt128}, + {FromInt64(0), MaxInt128}, + {FromInt64(-1), MaxInt128}, + {FromInt64(math.MinInt64), MaxInt128}, + {FromInt64(math.MaxInt64), MaxInt128}, + {FromUint64(math.MaxUint64), MaxInt128}, + } + for _, vals := range tc { + a, b := vals[0], vals[1] + require.True(t, Less(a, b)) + require.False(t, Less(b, a)) + require.True(t, Greater(b, a)) + require.False(t, Greater(a, b)) + require.NotEqual(t, a, b) + require.Equal(t, a, a) + require.Equal(t, b, b) + } + require.Equal(t, FromInt64(0), FromInt64(0)) + require.NotEqual(t, FromInt64(1), FromInt64(0)) + require.Equal(t, Shl(FromInt64(1), 64), Add(FromUint64(math.MaxUint64), FromInt64(1))) +} + +func TestParse(t *testing.T) { + for _, expected := range [...]Num{ + MinInt128, + MaxInt128, + FromInt64(0), + FromInt64(-1), + FromInt64(1), + MinInt8, + MaxInt8, + MinInt16, + MaxInt16, + MinInt32, + MaxInt32, + MinInt64, + MaxInt64, + Add(MaxInt64, FromUint64(1)), + } { + actual, ok := Parse(expected.String()) + require.True(t, ok, "%s", expected) + require.Equal(t, expected, actual) + } + // One less than min + _, ok := Parse("-170141183460469231731687303715884105729") + require.False(t, ok) + // One more than max + _, ok = Parse("170141183460469231731687303715884105728") + require.False(t, ok) +} + +func TestString(t *testing.T) { + require.Equal(t, "-170141183460469231731687303715884105728", MinInt128.String()) + require.Equal(t, "170141183460469231731687303715884105727", MaxInt128.String()) +} + +func TestByteWidth(t *testing.T) { + tests := [][2]int64{ + {0, 1}, + {1, 1}, + {-1, 1}, + {-16, 1}, + {16, 1}, + {math.MaxInt8 - 1, 1}, + {math.MaxInt8, 1}, + {math.MaxInt8 + 1, 2}, + {math.MinInt8 - 1, 2}, + {math.MinInt8, 1}, + {math.MinInt8 + 1, 1}, + {math.MaxInt16 - 1, 2}, + {math.MaxInt16, 2}, + {math.MaxInt16 + 1, 4}, + {math.MinInt16 - 1, 4}, + {math.MinInt16, 2}, + {math.MinInt16 + 1, 2}, + {math.MaxInt32 - 1, 4}, + {math.MaxInt32, 4}, + {math.MaxInt32 + 1, 8}, + {math.MinInt32 - 1, 8}, + {math.MinInt32, 4}, + {math.MinInt32 + 1, 4}, + {math.MaxInt64 - 1, 8}, + {math.MaxInt64, 8}, + // {math.MaxInt64 + 1, 8}, + // {math.MinInt64 - 1, 8}, + {math.MinInt64, 8}, + {math.MinInt64 + 1, 8}, + } + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("byteWidth(%d)", tc[0]), func(t *testing.T) { + require.Equal(t, int(tc[1]), ByteWidth(FromInt64(tc[0]))) + }) + } + require.Equal(t, 16, ByteWidth(Sub(MinInt64, FromInt64(1)))) + require.Equal(t, 16, ByteWidth(MinInt128)) + require.Equal(t, 16, ByteWidth(Add(MaxInt64, FromInt64(1)))) + require.Equal(t, 16, ByteWidth(MaxInt128)) +} + +func TestIncreaseScaleBy(t *testing.T) { + type TestCase struct { + n Num + scale int32 + overflow bool + } + tests := []TestCase{ + {MinInt64, 1, false}, + {MaxInt64, 1, false}, + {MaxInt64, 2, false}, + {MinInt64, 2, false}, + {MaxInt128, 1, true}, + {MinInt128, 1, true}, + {MinInt128, 0, true}, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + v, err := Rescale(tc.n, 38, tc.scale) + if tc.overflow { + require.Error(t, err, "got: %v, err: %v", v) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestFitsInPrec(t *testing.T) { + // Examples from snowflake documentation + snowflakeNumberMax := "+99999999999999999999999999999999999999" + snowflakeNumberMin := "-99999999999999999999999999999999999999" + require.True(t, MustParse(snowflakeNumberMax).FitsInPrecision(38), snowflakeNumberMax) + require.True(t, MustParse(snowflakeNumberMin).FitsInPrecision(38), snowflakeNumberMin) + require.True(t, MustParse("80068800064664092541968040996862354605").FitsInPrecision(38), "80068800064664092541968040996862354605") + snowflakeNumberTiny := "1.2e-36" + n, err := FromString(snowflakeNumberTiny, 38, 37) + require.NoError(t, err) + require.True(t, n.FitsInPrecision(38), snowflakeNumberTiny) +} diff --git a/internal/impl/snowflake/streaming/integration_test.go b/internal/impl/snowflake/streaming/integration_test.go new file mode 100644 index 0000000000..d2c7cde20d --- /dev/null +++ b/internal/impl/snowflake/streaming/integration_test.go @@ -0,0 +1,511 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + +package streaming_test + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math" + "os" + "strconv" + "strings" + "testing" + "time" + + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func msg(s string) *service.Message { + return service.NewMessage([]byte(s)) +} + +func structuredMsg(v any) *service.Message { + msg := service.NewMessage(nil) + msg.SetStructured(v) + return msg +} + +func envOr(name, dflt string) string { + val := os.Getenv(name) + if val != "" { + return val + } + return dflt +} + +func setup(t *testing.T) (*streaming.SnowflakeRestClient, *streaming.SnowflakeServiceClient) { + t.Helper() + ctx := context.Background() + privateKeyFile, err := os.ReadFile("./resources/rsa_key.p8") + if errors.Is(err, os.ErrNotExist) { + t.Skip("no RSA private key, skipping snowflake test") + } + require.NoError(t, err) + block, _ := pem.Decode(privateKeyFile) + require.NoError(t, err) + parseResult, err := x509.ParsePKCS8PrivateKey(block.Bytes) + require.NoError(t, err) + clientOptions := streaming.ClientOptions{ + Account: envOr("SNOWFLAKE_ACCOUNT", "WQKFXQQ-WI77362"), + User: envOr("SNOWFLAKE_USER", "ROCKWOODREDPANDA"), + Role: "ACCOUNTADMIN", + PrivateKey: parseResult.(*rsa.PrivateKey), + ConnectVersion: "", + Application: "development", + } + restClient, err := streaming.NewRestClient( + clientOptions.Account, + clientOptions.User, + clientOptions.ConnectVersion, + "Redpanda_Connect_"+clientOptions.Application, + clientOptions.PrivateKey, + clientOptions.Logger, + ) + require.NoError(t, err) + t.Cleanup(restClient.Close) + streamClient, err := streaming.NewSnowflakeServiceClient(ctx, clientOptions) + require.NoError(t, err) + t.Cleanup(func() { _ = streamClient.Close() }) + return restClient, streamClient +} + +func TestAllSnowflakeDatatypes(t *testing.T) { + ctx := context.Background() + restClient, streamClient := setup(t) + channelOpts := streaming.ChannelOptions{ + Name: t.Name(), + DatabaseName: envOr("SNOWFLAKE_DB", "BABY_DATABASE"), + SchemaName: "PUBLIC", + TableName: "TEST_TABLE_KITCHEN_SINK", + } + _, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(` + DROP TABLE IF EXISTS %s; + CREATE TABLE %s ( + A STRING, + B BOOLEAN, + C VARIANT, + D ARRAY, + E OBJECT, + F REAL, + G NUMBER, + H TIME, + I DATE, + J TIMESTAMP_LTZ, + K TIMESTAMP_NTZ, + L TIMESTAMP_TZ + );`, channelOpts.TableName, channelOpts.TableName), + Parameters: map[string]string{ + "MULTI_STATEMENT_COUNT": "0", + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + err = streamClient.DropChannel(ctx, channelOpts) + if err != nil { + t.Log("unable to cleanup stream in SNOW:", err) + } + }) + channel, err := streamClient.OpenChannel(ctx, channelOpts) + require.NoError(t, err) + _, err = channel.InsertRows(ctx, service.MessageBatch{ + msg(`{ + "A": "bar", + "B": true, + "C": {"foo": "bar"}, + "D": [[42], null, {"A":"B"}], + "E": {"foo":"bar"}, + "F": 3.14, + "G": -1, + "H": "2024-01-01T13:02:06Z", + "I": "2007-11-03T00:00:00Z", + "J": "2024-01-01T12:00:00.000Z", + "K": "2024-01-01T12:00:00.000-08:00", + "L": "2024-01-01T12:00:00.000-08:00" + }`), + msg(`{ + "A": "baz", + "B": "false", + "C": {"a":"b"}, + "D": [1, 2, 3], + "E": {"foo":"baz"}, + "F": 42.12345, + "G": 9, + "H": "2024-01-02T13:02:06.123456789Z", + "I": "2019-03-04T00:00:00.12345Z", + "J": "1970-01-02T12:00:00.000Z", + "K": "2024-02-01T12:00:00.000-08:00", + "L": "2024-01-01T12:00:01.000-08:00" + }`), + msg(`{ + "A": "foo", + "B": null, + "C": [1, 2, 3], + "D": ["a", 9, "z"], + "E": {"baz":"qux"}, + "F": -0.0, + "G": 42, + "H": 1728680106, + "I": 1728680106, + "J": "2024-01-03T12:00:00.000-08:00", + "K": "2024-01-01T13:00:00.000-08:00", + "L": "2024-01-01T12:30:00.000-08:00" + }`), + }) + require.NoError(t, err) + time.Sleep(time.Second) + // Always order by A so we get consistent ordering for our test + resp, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(`SELECT * FROM %s ORDER BY A;`, channelOpts.TableName), + Parameters: map[string]string{ + "TIMESTAMP_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM", + "DATE_OUTPUT_FORMAT": "YYYY-MM-DD", + "TIME_OUTPUT_FORMAT": "HH24:MI:SS", + }, + }) + assert.Equal(t, "00000", resp.SQLState) + expected := [][]string{ + { + `bar`, + `true`, + `{"foo":"bar"}`, + `[[42], null, {"A":"B"}]`, + `{"foo": "bar"}`, + `3.14`, + `-1`, + `13:02:06`, + `2007-11-03`, + `2024-01-01 04:00:00.000 -0800`, + `2024-01-01 20:00:00.000`, + `2024-01-01 12:00:00.000 -0800`, + }, + { + `baz`, + `false`, + `{"a":"b"}`, + `[1, 2, 3]`, + `{"foo":"baz"}`, + `42.12345`, + `9`, + `13:02:06`, + `2019-03-04`, + `1970-01-02 04:00:00.000 -0800`, + `2024-02-01 20:00:00.000`, + `2024-01-01 12:00:01.000 -0800`, + }, + { + `foo`, + ``, + `[1, 2, 3]`, + `["a", 9, "z"]`, + `{"baz":"qux"}`, + `-0.0`, + `42`, + `20:55:06`, + `2024-10-11`, + `2024-01-03 12:00:00.000 -0800`, + `2024-01-01 21:00:00.000`, + `2024-01-01 12:30:00.000 -0800`, + }, + } + assert.Equal(t, parseSnowflakeData(expected), parseSnowflakeData(resp.Data)) + require.EventuallyWithT(t, func(collect *assert.CollectT) { + // Make sure stats are written correctly by doing a query that only needs to read from epInfo + resp, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(`SELECT + MAX(A), MAX(B), MAX(C), + MAX(F), + MAX(G), MAX(H), MAX(I), + MAX(J), MAX(K), MAX(L) + FROM %s`, channelOpts.TableName), + Parameters: map[string]string{ + "TIMESTAMP_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM", + "DATE_OUTPUT_FORMAT": "YYYY-MM-DD", + "TIME_OUTPUT_FORMAT": "HH24:MI:SS", + }, + }) + if !assert.NoError(collect, err) { + t.Logf("failed to scan table: %s", err) + return + } + assert.Equal(collect, "00000", resp.SQLState) + expected := [][]string{ + { + `foo`, + `true`, + `[1, 2, 3]`, + `42.12345`, + `42`, + `20:55:06`, + `2024-10-11`, + `2024-01-03 12:00:00.000 -0800`, + `2024-02-01 20:00:00.000`, + `2024-01-01 12:30:00.000 -0800`, + }, + } + assert.Equal(collect, parseSnowflakeData(expected), parseSnowflakeData(resp.Data)) + }, 3*time.Second, time.Second) +} + +func TestIntegerCompat(t *testing.T) { + ctx := context.Background() + restClient, streamClient := setup(t) + channelOpts := streaming.ChannelOptions{ + Name: t.Name(), + DatabaseName: envOr("SNOWFLAKE_DB", "BABY_DATABASE"), + SchemaName: "PUBLIC", + TableName: "TEST_INT_TABLE", + } + _, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(` + DROP TABLE IF EXISTS %s; + CREATE TABLE IF NOT EXISTS %s ( + A NUMBER, + B NUMBER(38, 8), + C NUMBER(18, 0), + D NUMBER(28, 8) + );`, channelOpts.TableName, channelOpts.TableName), + Parameters: map[string]string{ + "MULTI_STATEMENT_COUNT": "0", + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + err = streamClient.DropChannel(ctx, channelOpts) + if err != nil { + t.Log("unable to cleanup stream in SNOW:", err) + } + }) + channel, err := streamClient.OpenChannel(ctx, channelOpts) + require.NoError(t, err) + _, err = channel.InsertRows(ctx, service.MessageBatch{ + structuredMsg(map[string]any{ + "a": math.MinInt64, + "b": math.MinInt8, + "c": math.MaxInt32, + "d": math.MinInt8, + }), + structuredMsg(map[string]any{ + "a": 0, + "b": "0.12345678", + "c": 0, + }), + structuredMsg(map[string]any{ + "a": math.MaxInt64, + "b": math.MaxInt8, + "c": math.MaxInt16, + "d": "1234.12345678", + }), + }) + require.NoError(t, err) + require.EventuallyWithT(t, func(collect *assert.CollectT) { + // Always order by A so we get consistent ordering for our test + resp, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(`SELECT * FROM %s ORDER BY A;`, channelOpts.TableName), + }) + if !assert.NoError(collect, err) { + t.Logf("failed to scan table: %s", err) + return + } + assert.Equal(collect, "00000", resp.SQLState) + itoa := strconv.Itoa + assert.Equal(collect, parseSnowflakeData([][]string{ + {itoa(math.MinInt64), itoa(math.MinInt8), itoa(math.MaxInt32), itoa(math.MinInt8)}, + {"0", "0.12345678", "0", ""}, + {itoa(math.MaxInt64), itoa(math.MaxInt8), itoa(math.MaxInt16), "1234.12345678"}, + }), parseSnowflakeData(resp.Data)) + }, 3*time.Second, time.Second) +} + +func TestTimestampCompat(t *testing.T) { + ctx := context.Background() + restClient, streamClient := setup(t) + channelOpts := streaming.ChannelOptions{ + Name: t.Name(), + DatabaseName: envOr("SNOWFLAKE_DB", "BABY_DATABASE"), + SchemaName: "PUBLIC", + TableName: "TEST_TIMESTAMP_TABLE", + } + var columnDefs []string + var columnNames []string + for _, tsType := range []string{"_NTZ", "_TZ", "_LTZ"} { + for precision := range make([]int, 10) { + name := fmt.Sprintf("TS%s_%d", tsType, precision) + columnNames = append(columnNames, name) + columnDefs = append(columnDefs, name+fmt.Sprintf(" TIMESTAMP%s(%d)", tsType, precision)) + } + } + _, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(` + DROP TABLE IF EXISTS %s; + CREATE TABLE IF NOT EXISTS %s ( + %s + );`, channelOpts.TableName, channelOpts.TableName, strings.Join(columnDefs, ", ")), + Parameters: map[string]string{ + "MULTI_STATEMENT_COUNT": "0", + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + err = streamClient.DropChannel(ctx, channelOpts) + if err != nil { + t.Log("unable to cleanup stream in SNOW:", err) + } + }) + channel, err := streamClient.OpenChannel(ctx, channelOpts) + require.NoError(t, err) + timestamps1 := map[string]any{} + timestamps2 := map[string]any{} + easternTz, err := time.LoadLocation("America/New_York") + require.NoError(t, err) + for _, col := range columnNames { + timestamps1[col] = time.Date( + 2024, 1, 01, + 12, 30, 05, + int(time.Nanosecond+time.Microsecond+time.Millisecond), + time.UTC, + ) + timestamps2[col] = time.Date( + 2024, 1, 01, + 20, 45, 55, + int(time.Nanosecond+time.Microsecond+time.Millisecond), + easternTz, + ) + } + _, err = channel.InsertRows(ctx, service.MessageBatch{ + structuredMsg(timestamps1), + structuredMsg(timestamps2), + msg(`{}`), // all nulls + }) + require.NoError(t, err) + expectedRows := [][]string{ + { + "2024-01-01 12:30:05.000", + "2024-01-01 12:30:05.000", + "2024-01-01 12:30:05.000", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05.001", + "2024-01-01 12:30:05. Z", + "2024-01-01 12:30:05.0 Z", + "2024-01-01 12:30:05.00 Z", + "2024-01-01 12:30:05.001 Z", + "2024-01-01 12:30:05.0010 Z", + "2024-01-01 12:30:05.00100 Z", + "2024-01-01 12:30:05.001001 Z", + "2024-01-01 12:30:05.0010010 Z", + "2024-01-01 12:30:05.00100100 Z", + "2024-01-01 12:30:05.001001001 Z", + "2024-01-01 04:30:05. -0800", + "2024-01-01 04:30:05.0 -0800", + "2024-01-01 04:30:05.00 -0800", + "2024-01-01 04:30:05.001 -0800", + "2024-01-01 04:30:05.0010 -0800", + "2024-01-01 04:30:05.00100 -0800", + "2024-01-01 04:30:05.001001 -0800", + "2024-01-01 04:30:05.0010010 -0800", + "2024-01-01 04:30:05.00100100 -0800", + "2024-01-01 04:30:05.001001001 -0800", + }, + { + "2024-01-02 01:45:55.000", + "2024-01-02 01:45:55.000", + "2024-01-02 01:45:55.000", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-02 01:45:55.001", + "2024-01-01 20:45:55. -0500", + "2024-01-01 20:45:55.0 -0500", + "2024-01-01 20:45:55.00 -0500", + "2024-01-01 20:45:55.001 -0500", + "2024-01-01 20:45:55.0010 -0500", + "2024-01-01 20:45:55.00100 -0500", + "2024-01-01 20:45:55.001001 -0500", + "2024-01-01 20:45:55.0010010 -0500", + "2024-01-01 20:45:55.00100100 -0500", + "2024-01-01 20:45:55.001001001 -0500", + "2024-01-01 17:45:55. -0800", + "2024-01-01 17:45:55.0 -0800", + "2024-01-01 17:45:55.00 -0800", + "2024-01-01 17:45:55.001 -0800", + "2024-01-01 17:45:55.0010 -0800", + "2024-01-01 17:45:55.00100 -0800", + "2024-01-01 17:45:55.001001 -0800", + "2024-01-01 17:45:55.0010010 -0800", + "2024-01-01 17:45:55.00100100 -0800", + "2024-01-01 17:45:55.001001001 -0800", + }, + make([]string, 30), + } + require.EventuallyWithT(t, func(collect *assert.CollectT) { + resp, err := restClient.RunSQL(ctx, streaming.RunSQLRequest{ + Database: channelOpts.DatabaseName, + Schema: channelOpts.SchemaName, + Statement: fmt.Sprintf(`SELECT * FROM %s ORDER BY TS_NTZ_9;`, channelOpts.TableName), + Parameters: map[string]string{ + "TIMESTAMP_OUTPUT_FORMAT": "YYYY-MM-DD HH24:MI:SS.FF TZHTZM", + }, + }) + if !assert.NoError(t, err) { + t.Logf("failed to scan table: %s", err) + return + } + assert.Equal(t, "00000", resp.SQLState) + assert.Equal(t, parseSnowflakeData(expectedRows), parseSnowflakeData(resp.Data)) + }, 3*time.Second, time.Second) +} + +// parseSnowflakeData returns "json-ish" data that can be JSON or could be just a raw string. +// We want to parse for the JSON rows have whitespace, so this gives us a more semantic comparison. +func parseSnowflakeData(rawData [][]string) [][]any { + var parsedData [][]any + for _, rawRow := range rawData { + var parsedRow []any + for _, rawCol := range rawRow { + var parsedCol any + if rawCol != `` { + err := json.Unmarshal([]byte(rawCol), &parsedCol) + if err != nil { + parsedCol = rawCol + } + } + parsedRow = append(parsedRow, parsedCol) + } + parsedData = append(parsedData, parsedRow) + } + return parsedData +} diff --git a/internal/impl/snowflake/streaming/parquet.go b/internal/impl/snowflake/streaming/parquet.go new file mode 100644 index 0000000000..7b5db3868d --- /dev/null +++ b/internal/impl/snowflake/streaming/parquet.go @@ -0,0 +1,152 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/parquet-go/parquet-go" + "github.com/parquet-go/parquet-go/format" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/segmentio/encoding/thrift" +) + +func messageToRow(msg *service.Message) (map[string]any, error) { + v, err := msg.AsStructured() + if err != nil { + return nil, fmt.Errorf("error extracting object from message: %w", err) + } + row, ok := v.(map[string]any) + if !ok { + return nil, fmt.Errorf("expected object, got: %T", v) + } + mapped := make(map[string]any, len(row)) + for k, v := range row { + mapped[normalizeColumnName(k)] = v + } + return mapped, nil +} + +// TODO: If the memory pressure is too great from writing all +// records buffered as a single row group, then consider +// return some kind of iterator of chunks of rows that we can +// then feed into the actual parquet construction process. +// +// If a single parquet file is too much, we can consider having multiple +// parquet files in a single bdec file. +func constructRowGroup( + batch service.MessageBatch, + schema *parquet.Schema, + transformers map[string]*dataTransformer, +) ([]parquet.Row, error) { + // We write all of our data in a columnar fashion, but need to pivot that data so that we can feed it into + // out parquet library (which sadly will redo the pivot - maybe we need a lower level abstraction...). + // So create a massive matrix that we will write stuff in columnar form, but then we don't need to move any + // data to create rows of the data via an in-place transpose operation. + // + // TODO: Consider caching/pooling this matrix as I expect many are similarily sized. + matrix := make([]parquet.Value, len(batch)*len(schema.Fields())) + rowWidth := len(schema.Fields()) + for idx, field := range schema.Fields() { + // The column index is consistent between two schemas that are the same because the schema fields are always + // in sorted order. + columnIndex := idx + t := transformers[field.Name()] + t.buf.Prepare(matrix, columnIndex, rowWidth) + t.stats.Reset() + } + // First we need to shred our record into columns, snowflake's data model + // is thankfully a flat list of columns, so no dremel style record shredding + // is needed + for _, msg := range batch { + row, err := messageToRow(msg) + if err != nil { + return nil, err + } + // We **must** write a null, so iterate over the schema not the record, + // which might be sparse + for name, t := range transformers { + v := row[name] + err = t.converter.ValidateAndConvert(t.stats, v, t.buf) + if err != nil { + return nil, fmt.Errorf("invalid data for column %s: %w", name, err) + } + } + } + // Now all our values have been written to each buffer - here is where we do our matrix + // transpose mentioned above + rows := make([]parquet.Row, len(batch)) + for i := range rows { + rowStart := i * rowWidth + rows[i] = matrix[rowStart : rowStart+rowWidth] + } + return rows, nil +} + +type parquetFileData struct { + schema *parquet.Schema + rows []parquet.Row + metadata map[string]string +} + +func writeParquetFile(writer io.Writer, rpcnVersion string, data parquetFileData) (err error) { + pw := parquet.NewGenericWriter[map[string]any]( + writer, + data.schema, + parquet.CreatedBy("RedpandaConnect", rpcnVersion, "unknown"), + // Recommended by the Snowflake team to enable data page stats + parquet.DataPageStatistics(true), + parquet.Compression(&parquet.Zstd), + ) + for k, v := range data.metadata { + pw.SetKeyValueMetadata(k, v) + } + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("encoding panic: %v", r) + } + }() + _, err = pw.WriteRows(data.rows) + if err != nil { + return + } + err = pw.Close() + return +} + +func readParquetMetadata(parquetFile []byte) (metadata format.FileMetaData, err error) { + if len(parquetFile) < 8 { + return format.FileMetaData{}, fmt.Errorf("too small of parquet file: %d", len(parquetFile)) + } + trailingBytes := parquetFile[len(parquetFile)-8:] + if string(trailingBytes[4:]) != "PAR1" { + return metadata, fmt.Errorf("missing magic bytes, got: %q", trailingBytes[4:]) + } + footerSize := int(binary.LittleEndian.Uint32(trailingBytes)) + if len(parquetFile) < footerSize+8 { + return metadata, fmt.Errorf("too small of parquet file: %d, footer size: %d", len(parquetFile), footerSize) + } + footerBytes := parquetFile[len(parquetFile)-(footerSize+8) : len(parquetFile)-8] + if err := thrift.Unmarshal(new(thrift.CompactProtocol), footerBytes, &metadata); err != nil { + return metadata, fmt.Errorf("unable to extract parquet metadata: %w", err) + } + return +} + +func totalUncompressedSize(metadata format.FileMetaData) int32 { + var size int64 + for _, rowGroup := range metadata.RowGroups { + size += rowGroup.TotalByteSize + } + return int32(size) +} diff --git a/internal/impl/snowflake/streaming/parquet_test.go b/internal/impl/snowflake/streaming/parquet_test.go new file mode 100644 index 0000000000..5449a7772a --- /dev/null +++ b/internal/impl/snowflake/streaming/parquet_test.go @@ -0,0 +1,101 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "bytes" + "io" + "testing" + + "github.com/aws/smithy-go/ptr" + "github.com/parquet-go/parquet-go" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/stretchr/testify/require" +) + +func msg(s string) *service.Message { + return service.NewMessage([]byte(s)) +} + +func TestWriteParquet(t *testing.T) { + b := bytes.NewBuffer(nil) + batch := service.MessageBatch{ + msg(`{"a":2}`), + msg(`{"a":12353}`), + } + inputDataSchema := parquet.Group{ + "A": parquet.Decimal(0, 18, parquet.Int32Type), + } + transformers := map[string]*dataTransformer{ + "A": { + converter: numberConverter{ + nullable: true, + scale: 0, + precision: 38, + }, + stats: &statsBuffer{columnID: 1}, + column: &columnMetadata{ + Name: "A", + Ordinal: 1, + Type: "NUMBER(18,0)", + LogicalType: "fixed", + PhysicalType: "SB8", + Precision: ptr.Int32(18), + Scale: ptr.Int32(0), + Nullable: true, + }, + buf: &int32Buffer{}, + }, + } + schema := parquet.NewSchema("bdec", inputDataSchema) + rows, err := constructRowGroup( + batch, + schema, + transformers, + ) + require.NoError(t, err) + err = writeParquetFile(b, "latest", parquetFileData{ + schema, rows, nil, + }) + require.NoError(t, err) + actual, err := readGeneric( + bytes.NewReader(b.Bytes()), + int64(b.Len()), + parquet.NewSchema("bdec", inputDataSchema), + ) + require.NoError(t, err) + require.Equal(t, []map[string]any{ + {"A": int32(2)}, + {"A": int32(12353)}, + }, actual) +} + +func readGeneric(r io.ReaderAt, size int64, schema *parquet.Schema) (rows []map[string]any, err error) { + config, err := parquet.NewReaderConfig(schema) + if err != nil { + return nil, err + } + file, err := parquet.OpenFile(r, size) + if err != nil { + return nil, err + } + reader := parquet.NewGenericReader[map[string]any](file, config) + rows = make([]map[string]any, file.NumRows()) + for i := range rows { + rows[i] = map[string]any{} + } + n, err := reader.Read(rows) + if err == io.EOF { + err = nil + } + reader.Close() + return rows[:n], err +} diff --git a/internal/impl/snowflake/streaming/resources/.gitignore b/internal/impl/snowflake/streaming/resources/.gitignore new file mode 100644 index 0000000000..72e8ffc0db --- /dev/null +++ b/internal/impl/snowflake/streaming/resources/.gitignore @@ -0,0 +1 @@ +* diff --git a/internal/impl/snowflake/streaming/rest.go b/internal/impl/snowflake/streaming/rest.go new file mode 100644 index 0000000000..16c9a91cf0 --- /dev/null +++ b/internal/impl/snowflake/streaming/rest.go @@ -0,0 +1,487 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/golang-jwt/jwt" + "github.com/google/uuid" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" + "github.com/redpanda-data/connect/v4/internal/periodic" + "github.com/redpanda-data/connect/v4/internal/typed" +) + +const ( + responseSuccess = 0 + responseErrRetryRequest = 10 + responseErrQueueFull = 7 +) + +type ( + clientConfigureRequest struct { + Role string `json:"role"` + FileName string `json:"file_name,omitempty"` + } + fileLocationInfo struct { + // The stage type + LocationType string + // The container or bucket + Location string + // The path of the target file + Path string + // The credentials required for the stage + Creds map[string]string + // AWS/S3/GCS Region (s3/GCS only) + Region string + // The Azure Storage endpoint (Azure only) + EndPoint string + // The Azure Storage Account (Azure only) + StorageAccount string + // GCS gives us back a presigned URL instead of a cred (obsolete) + PresignedURL string + // Whether to encrypt/decrypt files on the stage + IsClientSideEncrypted bool + // Whether to use s3 regional URL (AWS only) + UseS3RegionalURL bool + // A unique ID for volume assigned by server + VolumeHash string + } + clientConfigureResponse struct { + Prefix string `json:"prefix"` + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + StageLocation fileLocationInfo `json:"stage_location"` + DeploymentID int64 `json:"deployment_id"` + } + channelStatusRequest struct { + Table string `json:"table"` + Database string `json:"database"` + Schema string `json:"schema"` + Name string `json:"channel_name"` + ClientSequencer *int64 `json:"client_sequencer,omitempty"` + } + batchChannelStatusRequest struct { + Role string `json:"role"` + Channels []channelStatusRequest `json:"channels"` + } + channelStatusResponse struct { + StatusCode int64 `json:"status_code"` + PersistedOffsetToken string `json:"persisted_offset_token"` + PersistedClientSequencer int64 `json:"persisted_client_sequencer"` + PersistedRowSequencer int64 `json:"persisted_row_sequencer"` + } + batchChannelStatusResponse struct { + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + Channels []channelStatusResponse `json:"channels"` + } + openChannelRequest struct { + RequestID string `json:"request_id"` + Role string `json:"role"` + Channel string `json:"channel"` + Table string `json:"table"` + Database string `json:"database"` + Schema string `json:"schema"` + WriteMode string `json:"write_mode"` + IsIceberg bool `json:"is_iceberg,omitempty"` + OffsetToken string `json:"offset_token,omitempty"` + } + columnMetadata struct { + Name string `json:"name"` + Type string `json:"type"` + LogicalType string `json:"logical_type"` + PhysicalType string `json:"physical_type"` + Precision *int32 `json:"precision"` + Scale *int32 `json:"scale"` + ByteLength *int32 `json:"byte_length"` + Length *int32 `json:"length"` + Nullable bool `json:"nullable"` + Collation *string `json:"collation"` + // The JSON serialization of Iceberg data type of the column, + // see https://iceberg.apache.org/spec/#appendix-c-json-serialization for more details. + SourceIcebergDataType *string `json:"source_iceberg_data_type"` + // The column ordinal is an internal id of the column used by server scanner for the column identification. + Ordinal int32 `json:"ordinal"` + } + openChannelResponse struct { + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + Database string `json:"database"` + Schema string `json:"schema"` + Table string `json:"table"` + Channel string `json:"channel"` + ClientSequencer int64 `json:"client_sequencer"` + RowSequencer int64 `json:"row_sequencer"` + TableColumns []columnMetadata `json:"table_columns"` + EncryptionKey string `json:"encryption_key"` + EncryptionKeyID int64 `json:"encryption_key_id"` + IcebergLocationInfo fileLocationInfo `json:"iceberg_location"` + } + dropChannelRequest struct { + RequestID string `json:"request_id"` + Role string `json:"role"` + Channel string `json:"channel"` + Table string `json:"table"` + Database string `json:"database"` + Schema string `json:"schema"` + IsIceberg bool `json:"is_iceberg"` + // Optionally specify at a specific version + ClientSequencer *int64 `json:"client_sequencer,omitempty"` + } + dropChannelResponse struct { + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + Database string `json:"database"` + Schema string `json:"schema"` + Table string `json:"table"` + Channel string `json:"channel"` + } + fileColumnProperties struct { + ColumnOrdinal int32 `json:"columnId"` + FieldID *int32 `json:"field_id,omitempty"` + // current hex-encoded max value, truncated down to 32 bytes + MinStrValue *string `json:"minStrValue"` + // current hex-encoded max value, truncated up to 32 bytes + MaxStrValue *string `json:"maxStrValue"` + MinIntValue int128.Num `json:"minIntValue"` + MaxIntValue int128.Num `json:"maxIntValue"` + MinRealValue float64 `json:"minRealValue"` + MaxRealValue float64 `json:"maxRealValue"` + NullCount int64 `json:"nullCount"` + // Currently not tracked + DistinctValues int64 `json:"distinctValues"` + MaxLength int64 `json:"maxLength"` + // collated columns do not support ingestion + // they are always null + Collation *string `json:"collation"` + MinStrNonCollated *string `json:"minStrNonCollated"` + MaxStrNonCollated *string `json:"maxStrNonCollated"` + } + epInfo struct { + Rows int64 `json:"rows"` + Columns map[string]fileColumnProperties `json:"columns"` + } + channelMetadata struct { + Channel string `json:"channel_name"` + ClientSequencer int64 `json:"client_sequencer"` + RowSequencer int64 `json:"row_sequencer"` + StartOffsetToken *string `json:"start_offset_token"` + EndOffsetToken *string `json:"end_offset_token"` + // In the JavaSDK this is always just the end offset version + OffsetToken *string `json:"offset_token"` + } + chunkMetadata struct { + Database string `json:"database"` + Schema string `json:"schema"` + Table string `json:"table"` + ChunkStartOffset int64 `json:"chunk_start_offset"` + ChunkLength int32 `json:"chunk_length"` + ChunkLengthUncompressed int32 `json:"chunk_length_uncompressed"` + Channels []channelMetadata `json:"channels"` + ChunkMD5 string `json:"chunk_md5"` + EPS *epInfo `json:"eps,omitempty"` + EncryptionKeyID int64 `json:"encryption_key_id,omitempty"` + FirstInsertTimeInMillis int64 `json:"first_insert_time_in_ms"` + LastInsertTimeInMillis int64 `json:"last_insert_time_in_ms"` + } + blobStats struct { + FlushStartMs int64 `json:"flush_start_ms"` + BuildDurationMs int64 `json:"build_duration_ms"` + UploadDurationMs int64 `json:"upload_duration_ms"` + } + blobMetadata struct { + Path string `json:"path"` + MD5 string `json:"md5"` + Chunks []chunkMetadata `json:"chunks"` + // Currently always 3 + BDECVersion int8 `json:"bdec_version"` + SpansMixedTables bool `json:"spans_mixed_tables"` + BlobStats blobStats `json:"blob_stats"` + } + registerBlobRequest struct { + RequestID string `json:"request_id"` + Role string `json:"role"` + Blobs []blobMetadata `json:"blobs"` + IsIceberg bool `json:"is_iceberg"` + } + channelRegisterStatus struct { + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + Channel string `json:"channel"` + ClientSequencer int64 `json:"client_sequencer"` + } + chunkRegisterStatus struct { + Channels []channelRegisterStatus `json:"channels"` + Database string `json:"database"` + Schema string `json:"schema"` + Table string `json:"table"` + } + blobRegisterStatus struct { + Chunks []chunkRegisterStatus `json:"chunks"` + } + registerBlobResponse struct { + StatusCode int64 `json:"status_code"` + Message string `json:"message"` + Blobs []blobRegisterStatus `json:"blobs"` + } + // RunSQLRequest is the way to run a SQL statement + RunSQLRequest struct { + Statement string `json:"statement"` + Timeout int64 `json:"timeout"` + Database string `json:"database,omitempty"` + Schema string `json:"schema,omitempty"` + Warehouse string `json:"warehouse,omitempty"` + Role string `json:"role,omitempty"` + // https://docs.snowflake.com/en/sql-reference/parameters + Parameters map[string]string `json:"parameters,omitempty"` + } + // RowType holds metadata for a row + RowType struct { + Name string `json:"name"` + Type string `json:"type"` + Length int64 `json:"length"` + Precision int64 `json:"precision"` + Scale int64 `json:"scale"` + Nullable bool `json:"nullable"` + } + // ResultSetMetadata holds metadata for the result set + ResultSetMetadata struct { + NumRows int64 `json:"numRows"` + Format string `json:"format"` + RowType []RowType `json:"rowType"` + } + // RunSQLResponse is the completed SQL query response + RunSQLResponse struct { + ResultSetMetadata ResultSetMetadata `json:"resultSetMetaData"` + Data [][]string `json:"data"` + Code string `json:"code"` + StatementStatusURL string `json:"statementStatusURL"` + SQLState string `json:"sqlState"` + StatementHandle string `json:"statementHandle"` + Message string `json:"message"` + CreatedOn int64 `json:"createdOn"` + } +) + +// SnowflakeRestClient allows you to make REST API calls against Snowflake APIs. +type SnowflakeRestClient struct { + account string + user string + app string + privateKey *rsa.PrivateKey + client *http.Client + userAgent string + logger *service.Logger + + authRefreshLoop *periodic.Periodic + cachedJWT *typed.AtomicValue[string] +} + +// NewRestClient creates a new REST client for the given parameters. +func NewRestClient(account, user, version, app string, privateKey *rsa.PrivateKey, logger *service.Logger) (c *SnowflakeRestClient, err error) { + version = strings.TrimLeft(version, "v") + // Drop any -rc suffix, Snowflake doesn't like it + splits := strings.SplitN(version, "-", 2) + if len(splits) > 1 { + version = splits[0] + } + if version == "" { + // We can't use a major version <2 so just use 99 as the unknown version + // this should only show up in development, not released binaries + version = "99.0.0" + } + userAgent := fmt.Sprintf("RedpandaConnect/%v", version) + debugf(logger, "making snowflake HTTP requests using User-Agent: %s", userAgent) + c = &SnowflakeRestClient{ + account: account, + user: user, + client: http.DefaultClient, + privateKey: privateKey, + userAgent: userAgent, + logger: logger, + app: url.QueryEscape(app), + cachedJWT: typed.NewAtomicValue(""), + authRefreshLoop: periodic.New( + time.Hour-(2*time.Minute), + func() { + jwt, err := c.computeJWT() + // We've already done this once, and there is no external component here + // so this should never fail, but log just in case... + if err != nil { + logger.Errorf("unable to mint JWT for snowflake output: %s", err) + return + } + c.cachedJWT.Store(jwt) + }, + ), + } + jwt, err := c.computeJWT() + if err != nil { + return nil, err + } + c.cachedJWT.Store(jwt) + c.authRefreshLoop.Start() + return c, nil +} + +// Close stops the auth refresh loop for a REST client. +func (c *SnowflakeRestClient) Close() { + c.authRefreshLoop.Stop() +} + +func (c *SnowflakeRestClient) computeJWT() (string, error) { + pubBytes, err := x509.MarshalPKIXPublicKey(c.privateKey.Public()) + if err != nil { + return "", err + } + hash := sha256.Sum256(pubBytes) + accountName := strings.ToUpper(c.account) + userName := strings.ToUpper(c.user) + issueAtTime := time.Now().UTC() + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{ + "iss": fmt.Sprintf("%s.%s.%s", accountName, userName, "SHA256:"+base64.StdEncoding.EncodeToString(hash[:])), + "sub": fmt.Sprintf("%s.%s", accountName, userName), + "iat": issueAtTime.Unix(), + "exp": issueAtTime.Add(time.Hour).Unix(), + }) + return token.SignedString(c.privateKey) +} + +// RunSQL executes a series of SQL statements. It's expected that these statements execute in less than 45 seconds so +// we don't have to handle async requests. +func (c *SnowflakeRestClient) RunSQL(ctx context.Context, req RunSQLRequest) (resp RunSQLResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/api/v2/statements?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +// configureClient configures a client for Snowpipe Streaming. +func (c *SnowflakeRestClient) configureClient(ctx context.Context, req clientConfigureRequest) (resp clientConfigureResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/v1/streaming/client/configure?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +// channelStatus returns the status of a given channel +func (c *SnowflakeRestClient) channelStatus(ctx context.Context, req batchChannelStatusRequest) (resp batchChannelStatusResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/v1/streaming/channels/status?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +// openChannel opens a channel for writing +func (c *SnowflakeRestClient) openChannel(ctx context.Context, req openChannelRequest) (resp openChannelResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/v1/streaming/channels/open?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +// dropChannel drops a channel when it's no longer in use. +func (c *SnowflakeRestClient) dropChannel(ctx context.Context, req dropChannelRequest) (resp dropChannelResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/v1/streaming/channels/drop?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +// registerBlob registers a blob in object storage to be ingested into Snowflake. +func (c *SnowflakeRestClient) registerBlob(ctx context.Context, req registerBlobRequest) (resp registerBlobResponse, err error) { + requestID := uuid.NewString() + err = c.doPost(ctx, fmt.Sprintf("https://%s.snowflakecomputing.com/v1/streaming/channels/write/blobs?application=%s&requestId=%s", c.account, c.app, requestID), req, &resp) + return +} + +func debugf(l *service.Logger, msg string, args ...any) { + if debug { + fmt.Printf("%s\n", fmt.Sprintf(msg, args...)) + } + l.Tracef(msg, args...) +} + +func (c *SnowflakeRestClient) doPost(ctx context.Context, url string, req any, resp any) error { + marshaller := json.Marshal + if debug { + marshaller = func(v any) ([]byte, error) { + return json.MarshalIndent(v, "", " ") + } + } + reqBody, err := marshaller(req) + if err != nil { + return err + } + respBody, err := backoff.RetryNotifyWithData(func() ([]byte, error) { + debugf(c.logger, "making request to %s with body %s", url, reqBody) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody)) + if errors.Is(err, context.Canceled) { + return nil, backoff.Permanent(err) + } else if err != nil { + return nil, fmt.Errorf("unable to make http request: %w", err) + } + httpReq.Header.Add("Content-Type", "application/json") + httpReq.Header.Add("Accept", "application/json") + httpReq.Header.Add("User-Agent", c.userAgent) + jwt := c.cachedJWT.Load() + httpReq.Header.Add("Authorization", "Bearer "+jwt) + httpReq.Header.Add("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT") + r, err := c.client.Do(httpReq) + if errors.Is(err, context.Canceled) { + return nil, backoff.Permanent(err) + } else if err != nil { + return nil, fmt.Errorf("unable to perform http request: %w", err) + } + respBody, err := io.ReadAll(r.Body) + _ = r.Body.Close() + if errors.Is(err, context.Canceled) { + return nil, backoff.Permanent(err) + } else if err != nil { + return nil, fmt.Errorf("unable to read http response: %w", err) + } + if r.StatusCode != 200 { + return nil, fmt.Errorf("non successful status code (%d): %s", r.StatusCode, respBody) + } + debugf(c.logger, "got response to %s with body %s", url, respBody) + return respBody, nil + }, + backoff.WithContext( + backoff.WithMaxRetries( + backoff.NewConstantBackOff(100*time.Millisecond), + 3, + ), + ctx, + ), + func(err error, _ time.Duration) { + debugf(c.logger, "failed request at %s: %s", url, err) + }, + ) + if err != nil { + return err + } + err = json.Unmarshal(respBody, resp) + if err != nil { + return fmt.Errorf("invalid response: %w, full response: %s", err, respBody[:min(128, len(respBody))]) + } + return err +} diff --git a/internal/impl/snowflake/streaming/schema.go b/internal/impl/snowflake/streaming/schema.go new file mode 100644 index 0000000000..b4590c8adb --- /dev/null +++ b/internal/impl/snowflake/streaming/schema.go @@ -0,0 +1,343 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/dustin/go-humanize" + "github.com/parquet-go/parquet-go" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" +) + +type dataTransformer struct { + converter dataConverter + stats *statsBuffer + column *columnMetadata + buf typedBuffer +} + +func convertFixedType(column columnMetadata) (parquet.Node, dataConverter, typedBuffer, error) { + var scale int32 + var precision int32 + if column.Scale != nil { + scale = *column.Scale + } + if column.Precision != nil { + precision = *column.Precision + } + isDecimal := column.Scale != nil && column.Precision != nil + if (column.Scale != nil && *column.Scale != 0) || strings.ToUpper(column.PhysicalType) == "SB16" { + c := numberConverter{nullable: column.Nullable, scale: scale, precision: precision} + b := &typedBufferImpl{} + t := parquet.FixedLenByteArrayType(16) + if isDecimal { + return parquet.Decimal(int(scale), int(precision), t), c, b, nil + } + return parquet.Leaf(t), c, b, nil + } + var ptype parquet.Type + var defaultPrecision int32 + var buffer typedBuffer + switch strings.ToUpper(column.PhysicalType) { + case "SB1": + ptype = parquet.Int32Type + defaultPrecision = maxPrecisionForByteWidth(1) + buffer = &int32Buffer{} + case "SB2": + ptype = parquet.Int32Type + defaultPrecision = maxPrecisionForByteWidth(2) + buffer = &int32Buffer{} + case "SB4": + ptype = parquet.Int32Type + defaultPrecision = maxPrecisionForByteWidth(4) + buffer = &int32Buffer{} + case "SB8": + ptype = parquet.Int64Type + defaultPrecision = maxPrecisionForByteWidth(8) + buffer = &int64Buffer{} + default: + return nil, nil, nil, fmt.Errorf("unsupported physical column type: %s", column.PhysicalType) + } + validationPrecision := precision + if column.Precision == nil { + validationPrecision = defaultPrecision + } + c := numberConverter{nullable: column.Nullable, scale: scale, precision: validationPrecision} + if isDecimal { + return parquet.Decimal(int(scale), int(precision), ptype), c, buffer, nil + } + return parquet.Leaf(ptype), c, buffer, nil +} + +// maxJSONSize is the size that any kind of semi-structured data can be, which is 16MiB minus a small overhead +const maxJSONSize = 16*humanize.MiByte - 64 + +// See ParquetTypeGenerator +func constructParquetSchema(columns []columnMetadata) (*parquet.Schema, map[string]*dataTransformer, map[string]string, error) { + groupNode := parquet.Group{} + transformers := map[string]*dataTransformer{} + // Don't write the sfVer key as it allows us to not have to narrow the numeric types in parquet. + typeMetadata := map[string]string{ /*"sfVer": "1,1"*/ } + var err error + for _, column := range columns { + id := int(column.Ordinal) + var n parquet.Node + var converter dataConverter + var buffer typedBuffer + logicalType := strings.ToLower(column.LogicalType) + switch logicalType { + case "fixed": + n, converter, buffer, err = convertFixedType(column) + if err != nil { + return nil, nil, nil, err + } + case "array": + typeMetadata[fmt.Sprintf("%d:obj_enc", id)] = "1" + n = parquet.String() + converter = jsonArrayConverter{jsonConverter{column.Nullable, maxJSONSize}} + buffer = &typedBufferImpl{} + case "object": + typeMetadata[fmt.Sprintf("%d:obj_enc", id)] = "1" + n = parquet.String() + converter = jsonObjectConverter{jsonConverter{column.Nullable, maxJSONSize}} + buffer = &typedBufferImpl{} + case "variant": + typeMetadata[fmt.Sprintf("%d:obj_enc", id)] = "1" + n = parquet.String() + converter = jsonConverter{column.Nullable, maxJSONSize} + buffer = &typedBufferImpl{} + case "any", "text", "char": + n = parquet.String() + byteLength := 16 * humanize.MiByte + if column.ByteLength != nil { + byteLength = int(*column.ByteLength) + } + byteLength = min(byteLength, 16*humanize.MiByte) + converter = binaryConverter{nullable: column.Nullable, maxLength: byteLength, utf8: true} + buffer = &typedBufferImpl{} + case "binary": + n = parquet.Leaf(parquet.ByteArrayType) + // Why binary data defaults to 8MiB instead of the 16MiB for strings... ¯\_(ツ)_/¯ + byteLength := 8 * humanize.MiByte + if column.ByteLength != nil { + byteLength = int(*column.ByteLength) + } + byteLength = min(byteLength, 16*humanize.MiByte) + converter = binaryConverter{nullable: column.Nullable, maxLength: byteLength} + buffer = &typedBufferImpl{} + case "boolean": + n = parquet.Leaf(parquet.BooleanType) + converter = boolConverter{column.Nullable} + buffer = &typedBufferImpl{} + case "real": + n = parquet.Leaf(parquet.DoubleType) + converter = doubleConverter{column.Nullable} + buffer = &typedBufferImpl{} + case "timestamp_tz", "timestamp_ltz", "timestamp_ntz": + var scale, precision int32 + var pt parquet.Type + if column.PhysicalType == "SB8" { + pt = parquet.Int64Type + precision = maxPrecisionForByteWidth(8) + buffer = &int64Buffer{} + } else { + pt = parquet.FixedLenByteArrayType(16) + precision = maxPrecisionForByteWidth(16) + buffer = &typedBufferImpl{} + } + if column.Scale != nil { + scale = *column.Scale + } + // The server always returns 0 precision for timestamp columns, + // the Java SDK also seems to not validate precision of timestamps + // so ignore it and use the default precision for the column type + n = parquet.Decimal(int(scale), int(precision), pt) + converter = timestampConverter{ + nullable: column.Nullable, + scale: scale, + precision: precision, + includeTZ: logicalType == "timestamp_tz", + trimTZ: logicalType == "timestamp_ntz", + defaultTZ: time.UTC, + } + case "time": + t := parquet.Int32Type + precision := 9 + buffer = &int32Buffer{} + if column.PhysicalType == "SB8" { + t = parquet.Int64Type + precision = 18 + buffer = &int64Buffer{} + } + scale := int32(9) + if column.Scale != nil { + scale = *column.Scale + } + n = parquet.Decimal(int(scale), precision, t) + converter = timeConverter{column.Nullable, scale} + case "date": + n = parquet.Leaf(parquet.Int32Type) + converter = dateConverter{column.Nullable} + buffer = &int32Buffer{} + default: + return nil, nil, nil, fmt.Errorf("unsupported logical column type: %s", column.LogicalType) + } + if column.Nullable { + n = parquet.Optional(n) + } + n = parquet.FieldID(n, id) + // Use plain encoding for now as there seems to be compatibility issues with the default settings + // we might be able to tune this more. + n = parquet.Encoded(n, &parquet.Plain) + typeMetadata[strconv.Itoa(id)] = fmt.Sprintf( + "%d,%d", + logicalTypeOrdinal(column.LogicalType), + physicalTypeOrdinal(column.PhysicalType), + ) + name := normalizeColumnName(column.Name) + groupNode[name] = n + transformers[name] = &dataTransformer{ + converter: converter, + stats: &statsBuffer{columnID: id}, + column: &column, + buf: buffer, + } + } + return parquet.NewSchema("bdec", groupNode), transformers, typeMetadata, nil +} + +type statsBuffer struct { + columnID int + minIntVal, maxIntVal int128.Num + minRealVal, maxRealVal float64 + minStrVal, maxStrVal []byte + maxStrLen int + nullCount int64 + first bool +} + +func (s *statsBuffer) Reset() { + s.first = true + s.minIntVal = int128.FromInt64(0) + s.maxIntVal = int128.FromInt64(0) + s.minRealVal = 0 + s.maxRealVal = 0 + s.minStrVal = nil + s.maxStrVal = nil + s.maxStrLen = 0 + s.nullCount = 0 +} + +func computeColumnEpInfo(stats map[string]*dataTransformer) map[string]fileColumnProperties { + info := map[string]fileColumnProperties{} + for _, transformer := range stats { + stat := transformer.stats + var minStrVal *string = nil + if stat.minStrVal != nil { + s := truncateBytesAsHex(stat.minStrVal, false) + minStrVal = &s + } + var maxStrVal *string = nil + if stat.maxStrVal != nil { + s := truncateBytesAsHex(stat.maxStrVal, true) + maxStrVal = &s + } + info[transformer.column.Name] = fileColumnProperties{ + ColumnOrdinal: int32(stat.columnID), + NullCount: stat.nullCount, + MinStrValue: minStrVal, + MaxStrValue: maxStrVal, + MaxLength: int64(stat.maxStrLen), + MinIntValue: stat.minIntVal, + MaxIntValue: stat.maxIntVal, + MinRealValue: stat.minRealVal, + MaxRealValue: stat.maxRealVal, + DistinctValues: -1, + } + } + return info +} + +func physicalTypeOrdinal(str string) int { + switch strings.ToUpper(str) { + case "ROWINDEX": + return 9 + case "DOUBLE": + return 7 + case "SB1": + return 1 + case "SB2": + return 2 + case "SB4": + return 3 + case "SB8": + return 4 + case "SB16": + return 5 + case "LOB": + return 8 + case "ROW": + return 10 + } + return -1 +} + +func logicalTypeOrdinal(str string) int { + switch strings.ToUpper(str) { + case "BOOLEAN": + return 1 + case "NULL": + return 15 + case "REAL": + return 8 + case "FIXED": + return 2 + case "TEXT": + return 9 + case "BINARY": + return 10 + case "DATE": + return 7 + case "TIME": + return 6 + case "TIMESTAMP_LTZ": + return 3 + case "TIMESTAMP_NTZ": + return 4 + case "TIMESTAMP_TZ": + return 5 + case "ARRAY": + return 13 + case "OBJECT": + return 12 + case "VARIANT": + return 11 + } + return -1 +} + +func maxPrecisionForByteWidth(byteWidth int) int32 { + switch byteWidth { + case 1: + return 3 + case 2: + return 5 + case 4: + return 9 + case 8: + return 18 + } + return 38 +} diff --git a/internal/impl/snowflake/streaming/streaming.go b/internal/impl/snowflake/streaming/streaming.go new file mode 100644 index 0000000000..96c9cb3ed2 --- /dev/null +++ b/internal/impl/snowflake/streaming/streaming.go @@ -0,0 +1,457 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/md5" + "crypto/rsa" + "encoding/hex" + "fmt" + "math/rand/v2" + "os" + "path" + "sync/atomic" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/parquet-go/parquet-go" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/periodic" + "github.com/redpanda-data/connect/v4/internal/typed" +) + +const debug = false + +// ClientOptions is the options to create a Snowflake Snowpipe API Client +type ClientOptions struct { + // Account name + Account string + // username + User string + // Snowflake Role (i.e. ACCOUNTADMIN) + Role string + // Private key for the user + PrivateKey *rsa.PrivateKey + // Logger for... logging? + Logger *service.Logger + ConnectVersion string + Application string +} + +type stageUploaderResult struct { + uploader uploader + err error +} + +// SnowflakeServiceClient is a port from Java :) +type SnowflakeServiceClient struct { + client *SnowflakeRestClient + clientPrefix string + deploymentID int64 + options ClientOptions + requestIDCounter *atomic.Int64 + + uploader *typed.AtomicValue[stageUploaderResult] + uploadRefreshLoop *periodic.Periodic +} + +// NewSnowflakeServiceClient creates a new API client for the Snowpipe Streaming API +func NewSnowflakeServiceClient(ctx context.Context, opts ClientOptions) (*SnowflakeServiceClient, error) { + client, err := NewRestClient( + opts.Account, + opts.User, + opts.ConnectVersion, + "Redpanda_Connect_"+opts.Application, + opts.PrivateKey, + opts.Logger, + ) + if err != nil { + return nil, err + } + resp, err := client.configureClient(ctx, clientConfigureRequest{Role: opts.Role}) + if err != nil { + return nil, err + } + if resp.StatusCode != responseSuccess { + return nil, fmt.Errorf("unable to initialize client - status: %d, message: %s", resp.StatusCode, resp.Message) + } + uploader, err := newUploader(resp.StageLocation) + if err != nil { + return nil, fmt.Errorf("unable to initialize stage uploader: %w", err) + } + uploaderAtomic := typed.NewAtomicValue(stageUploaderResult{ + uploader: uploader, + }) + ssc := &SnowflakeServiceClient{ + client: client, + clientPrefix: fmt.Sprintf("%s_%d", resp.Prefix, resp.DeploymentID), + deploymentID: resp.DeploymentID, + options: opts, + + uploader: uploaderAtomic, + // Tokens expire every hour, so refresh a bit before that + uploadRefreshLoop: periodic.NewWithContext(time.Hour-(2*time.Minute), func(ctx context.Context) { + resp, err := client.configureClient(ctx, clientConfigureRequest{Role: opts.Role}) + if err != nil { + uploaderAtomic.Store(stageUploaderResult{err: err}) + return + } + // TODO: Do the other checks here that the Java SDK does (deploymentID, etc) + uploader, err := newUploader(resp.StageLocation) + uploaderAtomic.Store(stageUploaderResult{uploader: uploader, err: err}) + }), + requestIDCounter: &atomic.Int64{}, + } + ssc.uploadRefreshLoop.Start() + return ssc, nil +} + +// Close closes the client and future requests have undefined behavior. +func (c *SnowflakeServiceClient) Close() error { + c.uploadRefreshLoop.Stop() + c.client.Close() + return nil +} + +func (c *SnowflakeServiceClient) nextRequestID() string { + rid := c.requestIDCounter.Add(1) + return fmt.Sprintf("%s_%d", c.clientPrefix, rid) +} + +// ChannelOptions the parameters to opening a channel using SnowflakeServiceClient +type ChannelOptions struct { + // ID of this channel, should be unique per channel + ID int16 + // Name is the name of the channel + Name string + // DatabaseName is the name of the database + DatabaseName string + // SchemaName is the name of the schema + SchemaName string + // TableName is the name of the table + TableName string +} + +type encryptionInfo struct { + encryptionKeyID int64 + encryptionKey string +} + +// OpenChannel creates a new or reuses a channel to load data into a Snowflake table. +func (c *SnowflakeServiceClient) OpenChannel(ctx context.Context, opts ChannelOptions) (*SnowflakeIngestionChannel, error) { + resp, err := c.client.openChannel(ctx, openChannelRequest{ + RequestID: c.nextRequestID(), + Role: c.options.Role, + Channel: opts.Name, + Database: opts.DatabaseName, + Schema: opts.SchemaName, + Table: opts.TableName, + WriteMode: "CLOUD_STORAGE", + }) + if err != nil { + return nil, err + } + if resp.StatusCode != responseSuccess { + return nil, fmt.Errorf("unable to open channel %s - status: %d, message: %s", opts.Name, resp.StatusCode, resp.Message) + } + schema, transformers, typeMetadata, err := constructParquetSchema(resp.TableColumns) + if err != nil { + return nil, err + } + ch := &SnowflakeIngestionChannel{ + ChannelOptions: opts, + clientPrefix: c.clientPrefix, + schema: schema, + version: c.options.ConnectVersion, + client: c.client, + role: c.options.Role, + uploader: c.uploader, + encryptionInfo: &encryptionInfo{ + encryptionKeyID: resp.EncryptionKeyID, + encryptionKey: resp.EncryptionKey, + }, + clientSequencer: resp.ClientSequencer, + rowSequencer: resp.RowSequencer, + transformers: transformers, + fileMetadata: typeMetadata, + buffer: bytes.NewBuffer(nil), + requestIDCounter: c.requestIDCounter, + } + return ch, nil +} + +// OffsetToken is the persisted client offset of a stream. This can be used to implement exactly-once +// processing. +type OffsetToken string + +// ChannelStatus returns the offset token for a channel or an error +func (c *SnowflakeServiceClient) ChannelStatus(ctx context.Context, opts ChannelOptions) (OffsetToken, error) { + resp, err := c.client.channelStatus(ctx, batchChannelStatusRequest{ + Role: c.options.Role, + Channels: []channelStatusRequest{ + { + Name: opts.Name, + Table: opts.TableName, + Database: opts.DatabaseName, + Schema: opts.SchemaName, + }, + }, + }) + if err != nil { + return "", err + } + if resp.StatusCode != responseSuccess { + return "", fmt.Errorf("unable to status channel %s - status: %d, message: %s", opts.Name, resp.StatusCode, resp.Message) + } + if len(resp.Channels) != 1 { + return "", fmt.Errorf("failed to fetch channel %s, got %d channels in response", opts.Name, len(resp.Channels)) + } + channel := resp.Channels[0] + if channel.StatusCode != responseSuccess { + return "", fmt.Errorf("unable to status channel %s - status: %d", opts.Name, resp.StatusCode) + } + return OffsetToken(channel.PersistedOffsetToken), nil +} + +// DropChannel drops it like it's hot 🔥 +func (c *SnowflakeServiceClient) DropChannel(ctx context.Context, opts ChannelOptions) error { + resp, err := c.client.dropChannel(ctx, dropChannelRequest{ + RequestID: c.nextRequestID(), + Role: c.options.Role, + Channel: opts.Name, + Table: opts.TableName, + Database: opts.DatabaseName, + Schema: opts.SchemaName, + }) + if err != nil { + return err + } + if resp.StatusCode != responseSuccess { + return fmt.Errorf("unable to drop channel %s - status: %d, message: %s", opts.Name, resp.StatusCode, resp.Message) + } + return nil +} + +// SnowflakeIngestionChannel is a write connection to a single table in Snowflake +type SnowflakeIngestionChannel struct { + ChannelOptions + role string + clientPrefix string + version string + schema *parquet.Schema + client *SnowflakeRestClient + uploader *typed.AtomicValue[stageUploaderResult] + encryptionInfo *encryptionInfo + clientSequencer int64 + rowSequencer int64 + transformers map[string]*dataTransformer + fileMetadata map[string]string + buffer *bytes.Buffer + // This is shared among the various open channels to get some uniqueness + // when naming bdec files + requestIDCounter *atomic.Int64 +} + +func (c *SnowflakeIngestionChannel) nextRequestID() string { + rid := c.requestIDCounter.Add(1) + return fmt.Sprintf("%s_%d", c.clientPrefix, rid) +} + +// InsertStats holds some basic statistics about the InsertRows operation +type InsertStats struct { + BuildTime time.Duration + UploadTime time.Duration + CompressedOutputSize int +} + +// InsertRows creates a parquet file using the schema from the data, +// then writes that file into the Snowflake table +func (c *SnowflakeIngestionChannel) InsertRows(ctx context.Context, batch service.MessageBatch) (InsertStats, error) { + stats := InsertStats{} + startTime := time.Now() + rows, err := constructRowGroup(batch, c.schema, c.transformers) + if err != nil { + return stats, err + } + // Prevent multiple channels from having the same bdec file (it must be unique) + // so add the ID of the channel in the upper 16 bits and then get 48 bits of + // randomness outside that. + fakeThreadID := (int(c.ID) << 48) | rand.N(1<<48) + blobPath := generateBlobPath(c.clientPrefix, fakeThreadID, int(c.requestIDCounter.Add(1))) + // This is extra metadata that is required for functionality in snowflake. + c.fileMetadata["primaryFileId"] = path.Base(blobPath) + c.buffer.Reset() + err = writeParquetFile(c.buffer, c.version, parquetFileData{ + schema: c.schema, + rows: rows, + metadata: c.fileMetadata, + }) + if err != nil { + return stats, err + } + unencrypted := c.buffer.Bytes() + metadata, err := readParquetMetadata(unencrypted) + if err != nil { + return stats, fmt.Errorf("unable to parse parquet metadata: %w", err) + } + if debug { + _ = os.WriteFile("latest_test.parquet", unencrypted, 0o644) + } + unencryptedLen := len(unencrypted) + unencrypted = padBuffer(unencrypted, aes.BlockSize) + encrypted, err := encrypt(unencrypted, c.encryptionInfo.encryptionKey, blobPath, 0) + if err != nil { + return stats, err + } + uploadStartTime := time.Now() + fileMD5Hash := md5.Sum(encrypted) + uploaderResult := c.uploader.Load() + if uploaderResult.err != nil { + return stats, fmt.Errorf("failed to acquire stage uploader: %w", uploaderResult.err) + } + uploader := uploaderResult.uploader + err = backoff.Retry(func() error { + return uploader.upload(ctx, blobPath, encrypted, fileMD5Hash[:]) + }, backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Second), 3)) + if err != nil { + return stats, err + } + + uploadFinishTime := time.Now() + columnEpInfo := computeColumnEpInfo(c.transformers) + resp, err := c.client.registerBlob(ctx, registerBlobRequest{ + RequestID: c.nextRequestID(), + Role: c.role, + Blobs: []blobMetadata{ + { + Path: blobPath, + MD5: hex.EncodeToString(fileMD5Hash[:]), + BDECVersion: 3, + BlobStats: blobStats{ + FlushStartMs: startTime.UnixMilli(), + BuildDurationMs: uploadStartTime.UnixMilli() - startTime.UnixMilli(), + UploadDurationMs: uploadFinishTime.UnixMilli() - uploadStartTime.UnixMilli(), + }, + Chunks: []chunkMetadata{ + { + Database: c.DatabaseName, + Schema: c.SchemaName, + Table: c.TableName, + ChunkStartOffset: 0, + ChunkLength: int32(unencryptedLen), + ChunkLengthUncompressed: totalUncompressedSize(metadata), + ChunkMD5: md5Hash(encrypted[:unencryptedLen]), + EncryptionKeyID: c.encryptionInfo.encryptionKeyID, + FirstInsertTimeInMillis: startTime.UnixMilli(), + LastInsertTimeInMillis: startTime.UnixMilli(), + EPS: &epInfo{ + Rows: metadata.NumRows, + Columns: columnEpInfo, + }, + Channels: []channelMetadata{ + { + Channel: c.Name, + ClientSequencer: c.clientSequencer, + RowSequencer: c.rowSequencer + 1, + StartOffsetToken: nil, + EndOffsetToken: nil, + OffsetToken: nil, + }, + }, + }, + }, + }, + }, + }) + if err != nil { + return stats, err + } + if len(resp.Blobs) != 1 { + return stats, fmt.Errorf("unexpected number of response blobs: %d", len(resp.Blobs)) + } + status := resp.Blobs[0] + if len(status.Chunks) != 1 { + return stats, fmt.Errorf("unexpected number of response blob chunks: %d", len(status.Chunks)) + } + chunk := status.Chunks[0] + if len(chunk.Channels) != 1 { + return stats, fmt.Errorf("unexpected number of channels for blob chunk: %d", len(chunk.Channels)) + } + channel := chunk.Channels[0] + if channel.StatusCode != responseSuccess { + msg := channel.Message + if msg == "" { + msg = "(no message)" + } + return stats, fmt.Errorf("error response injesting data (%d): %s", channel.StatusCode, msg) + } + c.rowSequencer++ + c.clientSequencer = channel.ClientSequencer + stats.CompressedOutputSize = unencryptedLen + stats.BuildTime = uploadStartTime.Sub(startTime) + stats.UploadTime = uploadFinishTime.Sub(uploadStartTime) + return stats, err +} + +// WaitUntilCommitted waits until all the data in the channel has been committed +// along with how many polls it took to get that. +func (c *SnowflakeIngestionChannel) WaitUntilCommitted(ctx context.Context) (int, error) { + var polls int + err := backoff.Retry(func() error { + polls++ + resp, err := c.client.channelStatus(ctx, batchChannelStatusRequest{ + Role: c.role, + Channels: []channelStatusRequest{ + { + Table: c.TableName, + Database: c.DatabaseName, + Schema: c.SchemaName, + Name: c.Name, + ClientSequencer: &c.clientSequencer, + }, + }, + }) + if err != nil { + return err + } + if resp.StatusCode != responseSuccess { + msg := resp.Message + if msg == "" { + msg = "(no message)" + } + return fmt.Errorf("error fetching channel status (%d): %s", resp.StatusCode, msg) + } + if len(resp.Channels) != 1 { + return fmt.Errorf("unexpected number of channels for status request: %d", len(resp.Channels)) + } + status := resp.Channels[0] + if status.PersistedClientSequencer != c.clientSequencer { + return backoff.Permanent(fmt.Errorf("unexpected number of channels for status request: %d", len(resp.Channels))) + } + if status.PersistedRowSequencer < c.rowSequencer { + return fmt.Errorf("row sequencer not yet committed: %d < %d", status.PersistedRowSequencer, c.rowSequencer) + } + return nil + }, backoff.WithContext( + // 1, 10, 100, 1000, 1000, ... + backoff.NewExponentialBackOff( + backoff.WithInitialInterval(time.Millisecond), + backoff.WithMultiplier(10), + backoff.WithMaxInterval(time.Second), + backoff.WithMaxElapsedTime(10*time.Minute), + ), + ctx, + )) + return polls, err +} diff --git a/internal/impl/snowflake/streaming/streaming_test.go b/internal/impl/snowflake/streaming/streaming_test.go new file mode 100644 index 0000000000..13ab4eac17 --- /dev/null +++ b/internal/impl/snowflake/streaming/streaming_test.go @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDebugModeDisabled(t *testing.T) { + // So I can't forget to disable this! + require.False(t, debug) +} diff --git a/internal/impl/snowflake/streaming/uploader.go b/internal/impl/snowflake/streaming/uploader.go new file mode 100644 index 0000000000..b0092be53a --- /dev/null +++ b/internal/impl/snowflake/streaming/uploader.go @@ -0,0 +1,180 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "net/url" + "path/filepath" + "strings" + + gcs "cloud.google.com/go/storage" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "golang.org/x/oauth2" + gcsopt "google.golang.org/api/option" +) + +type uploader interface { + upload(ctx context.Context, path string, encrypted, md5Hash []byte) error +} + +func newUploader(fileLocationInfo fileLocationInfo) (uploader, error) { + switch fileLocationInfo.LocationType { + case "S3": + creds := fileLocationInfo.Creds + awsKeyID := creds["AWS_KEY_ID"] + awsSecretKey := creds["AWS_SECRET_KEY"] + awsToken := creds["AWS_TOKEN"] + // TODO: Handle regional URLs + if fileLocationInfo.UseS3RegionalURL { + return nil, errors.New("S3 Regional URLs are not supported") + } + // TODO: Handle EndPoint, the Java SDK says this is only for Azure, but + // that doesn't seem to be the case from reading the Java JDBC driver, + // the Golang driver says this is used for FIPS in GovCloud. + if fileLocationInfo.EndPoint != "" { + return nil, errors.New("custom S3 endpoint is not supported") + } + client := s3.New(s3.Options{ + Region: fileLocationInfo.Region, + Credentials: credentials.NewStaticCredentialsProvider( + awsKeyID, + awsSecretKey, + awsToken, + ), + }) + bucket, pathPrefix, err := splitBucketAndPath(fileLocationInfo.Location) + if err != nil { + return nil, err + } + uploader := manager.NewUploader(client) + return &s3Uploader{ + client: uploader, + bucket: bucket, + pathPrefix: pathPrefix, + }, nil + case "GCS": + accessToken := fileLocationInfo.Creds["GCS_ACCESS_TOKEN"] + // Even though the GCS uploader takes a context, it's not used because we configure + // static access token credentials. The context is only used for service account + // auth via the instance metadata server. + client, err := gcs.NewClient(context.Background(), gcsopt.WithTokenSource( + oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: accessToken, + TokenType: "Bearer", + }), + )) + if err != nil { + return nil, err + } + bucket, prefix, err := splitBucketAndPath(fileLocationInfo.Location) + if err != nil { + return nil, err + } + return &gcsUploader{ + bucket: client.Bucket(bucket), + pathPrefix: prefix, + }, err + case "AZURE": + sasToken := fileLocationInfo.Creds["AZURE_SAS_TOKEN"] + urlString := fmt.Sprintf("https://%s.%s/%s", fileLocationInfo.StorageAccount, fileLocationInfo.EndPoint, sasToken) + u, err := url.Parse(urlString) + if err != nil { + return nil, fmt.Errorf("invalid azure blob storage url: %w", err) + } + client, err := azblob.NewClientWithNoCredential(u.String(), nil) + if err != nil { + return nil, fmt.Errorf("unable to create azure blob storage client: %w", err) + } + container, prefix, err := splitBucketAndPath(fileLocationInfo.Location) + if err != nil { + return nil, err + } + return &azureUploader{ + client: client, + container: container, + pathPrefix: prefix, + }, nil + } + return nil, fmt.Errorf("unsupported location type: %s", fileLocationInfo.LocationType) +} + +type azureUploader struct { + client *azblob.Client + container, pathPrefix string +} + +func (u *azureUploader) upload(ctx context.Context, path string, encrypted, md5Hash []byte) error { + // We upload in multiple parts, so we have to validate ourselves post upload 😒 + resp, err := u.client.UploadBuffer(ctx, u.container, filepath.Join(u.pathPrefix, path), encrypted, nil) + if err != nil { + return err + } + if !bytes.Equal(resp.ContentMD5, md5Hash) { + return fmt.Errorf("invalid md5 hash got: %s want: %s", hex.EncodeToString(resp.ContentMD5), md5Hash) + } + return nil +} + +type s3Uploader struct { + client *manager.Uploader + bucket, pathPrefix string +} + +func (u *s3Uploader) upload(ctx context.Context, path string, encrypted, md5Hash []byte) error { + input := &s3.PutObjectInput{ + Bucket: &u.bucket, + Key: aws.String(filepath.Join(u.pathPrefix, path)), + Body: bytes.NewReader(encrypted), + ContentMD5: aws.String(base64.StdEncoding.EncodeToString(md5Hash)), + } + _, err := u.client.Upload(ctx, input) + return err +} + +type gcsUploader struct { + bucket *gcs.BucketHandle + pathPrefix string +} + +func (u *gcsUploader) upload(ctx context.Context, path string, encrypted, md5Hash []byte) error { + object := u.bucket.Object(filepath.Join(u.pathPrefix, path)) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ow := object.NewWriter(ctx) + ow.MD5 = md5Hash + for len(encrypted) > 0 { + n, err := ow.Write(encrypted) + if err != nil { + _ = ow.Close() + return err + } + encrypted = encrypted[n:] + } + return ow.Close() +} + +func splitBucketAndPath(stageLocation string) (string, string, error) { + bucketAndPath := strings.SplitN(stageLocation, "/", 2) + if len(bucketAndPath) != 2 { + return "", "", fmt.Errorf("unexpected stage location: %s", stageLocation) + } + return bucketAndPath[0], bucketAndPath[1], nil +} diff --git a/internal/impl/snowflake/streaming/userdata_converter.go b/internal/impl/snowflake/streaming/userdata_converter.go new file mode 100644 index 0000000000..b351a81113 --- /dev/null +++ b/internal/impl/snowflake/streaming/userdata_converter.go @@ -0,0 +1,482 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "time" + "unicode/utf8" + + "github.com/Jeffail/gabs/v2" + "github.com/parquet-go/parquet-go" + "github.com/redpanda-data/benthos/v4/public/bloblang" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" +) + +// typedBuffer is the buffer that holds columnar data before we write to the parquet file +type typedBuffer interface { + WriteNull() + WriteInt128(int128.Num) + WriteBool(bool) + WriteFloat64(float64) + WriteBytes([]byte) // should never be nil + + // Prepare for writing values to the following matrix. + // Must be called before writing + // The matrix size must be pre-allocated to be the size of + // the data that will be written - this buffer will not modify + // the size of the data. + Prepare(matrix []parquet.Value, columnIndex, rowWidth int) + Reset() +} + +type typedBufferImpl struct { + matrix []parquet.Value + columnIndex int + rowWidth int + currentRow int +} + +func (b *typedBufferImpl) WriteValue(v parquet.Value) { + b.matrix[(b.currentRow*b.rowWidth)+b.columnIndex] = v + b.currentRow++ +} +func (b *typedBufferImpl) WriteNull() { + b.WriteValue(parquet.NullValue()) +} +func (b *typedBufferImpl) WriteInt128(v int128.Num) { + b.WriteValue(parquet.FixedLenByteArrayValue(v.ToBigEndian()).Level(0, 1, b.columnIndex)) +} +func (b *typedBufferImpl) WriteBool(v bool) { + b.WriteValue(parquet.BooleanValue(v).Level(0, 1, b.columnIndex)) +} +func (b *typedBufferImpl) WriteFloat64(v float64) { + b.WriteValue(parquet.DoubleValue(v).Level(0, 1, b.columnIndex)) +} +func (b *typedBufferImpl) WriteBytes(v []byte) { + b.WriteValue(parquet.ByteArrayValue(v).Level(0, 1, b.columnIndex)) +} +func (b *typedBufferImpl) Prepare(matrix []parquet.Value, columnIndex, rowWidth int) { + b.currentRow = 0 + b.matrix = matrix + b.columnIndex = columnIndex + b.rowWidth = rowWidth +} +func (b *typedBufferImpl) Reset() { + b.Prepare(nil, 0, 0) +} + +type int64Buffer struct { + typedBufferImpl +} + +func (b *int64Buffer) WriteInt128(v int128.Num) { + b.WriteValue(parquet.Int64Value(v.ToInt64()).Level(0, 1, b.columnIndex)) +} + +type int32Buffer struct { + typedBufferImpl +} + +func (b *int32Buffer) WriteInt128(v int128.Num) { + b.WriteValue(parquet.Int32Value(int32(v.ToInt64())).Level(0, 1, b.columnIndex)) +} + +type dataConverter interface { + ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error +} + +var errNullValue = errors.New("unexpected null value") + +type boolConverter struct { + nullable bool +} + +func (c boolConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + v, err := bloblang.ValueAsBool(val) + if err != nil { + return err + } + i := int128.FromUint64(0) + if v { + i = int128.FromUint64(1) + } + if stats.first { + stats.minIntVal = i + stats.maxIntVal = i + stats.first = false + } else { + stats.minIntVal = int128.Min(stats.minIntVal, i) + stats.maxIntVal = int128.Max(stats.maxIntVal, i) + } + buf.WriteBool(v) + return nil +} + +type numberConverter struct { + nullable bool + scale int32 + precision int32 +} + +func (c numberConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + var v int128.Num + var err error + switch t := val.(type) { + case int: + v = int128.FromInt64(int64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case int8: + v = int128.FromInt64(int64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case int16: + v = int128.FromInt64(int64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case int32: + v = int128.FromInt64(int64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case int64: + v = int128.FromInt64(t) + v, err = int128.Rescale(v, c.precision, c.scale) + case uint: + v = int128.FromUint64(uint64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case uint8: + v = int128.FromUint64(uint64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case uint16: + v = int128.FromUint64(uint64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case uint32: + v = int128.FromUint64(uint64(t)) + v, err = int128.Rescale(v, c.precision, c.scale) + case uint64: + v = int128.FromUint64(t) + v, err = int128.Rescale(v, c.precision, c.scale) + case float32: + v, err = int128.FromFloat32(t, c.precision, c.scale) + case float64: + v, err = int128.FromFloat64(t, c.precision, c.scale) + case string: + v, err = int128.FromString(t, c.precision, c.scale) + case json.Number: + v, err = int128.FromString(t.String(), c.precision, c.scale) + default: + // fallback to the good error message that bloblang provides + var i int64 + i, err = bloblang.ValueAsInt64(val) + if err != nil { + return err + } + v = int128.FromInt64(i) + v, err = int128.Rescale(v, c.precision, c.scale) + } + if err != nil { + return err + } + if stats.first { + stats.minIntVal = v + stats.maxIntVal = v + stats.first = false + } else { + stats.minIntVal = int128.Min(stats.minIntVal, v) + stats.maxIntVal = int128.Max(stats.maxIntVal, v) + } + buf.WriteInt128(v) + return nil +} + +type doubleConverter struct { + nullable bool +} + +func (c doubleConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + v, err := bloblang.ValueAsFloat64(val) + if err != nil { + return err + } + if stats.first { + stats.minRealVal = v + stats.maxRealVal = v + stats.first = false + } else { + stats.minRealVal = min(stats.minRealVal, v) + stats.maxRealVal = max(stats.maxRealVal, v) + } + buf.WriteFloat64(v) + return nil +} + +type binaryConverter struct { + nullable bool + maxLength int + utf8 bool +} + +func (c binaryConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + v, err := bloblang.ValueAsBytes(val) + if err != nil { + return err + } + if len(v) > c.maxLength { + return fmt.Errorf("value too long, length: %d, max: %d", len(v), c.maxLength) + } + if c.utf8 && !utf8.Valid(v) { + return errors.New("invalid UTF8") + } + if stats.first { + stats.minStrVal = v + stats.maxStrVal = v + stats.maxStrLen = len(v) + stats.first = false + } else { + if bytes.Compare(v, stats.minStrVal) < 0 { + stats.minStrVal = v + } + if bytes.Compare(v, stats.maxStrVal) > 0 { + stats.maxStrVal = v + } + stats.maxStrLen = max(stats.maxStrLen, len(v)) + } + buf.WriteBytes(v) + return nil +} + +type jsonConverter struct { + nullable bool + maxLength int +} + +func (c jsonConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + v := gabs.Wrap(val).Bytes() + if len(v) > c.maxLength { + return fmt.Errorf("value too long, length: %d, max: %d", len(v), c.maxLength) + } + if stats.first { + stats.minStrVal = v + stats.maxStrVal = v + stats.maxStrLen = len(v) + stats.first = false + } else { + if bytes.Compare(v, stats.minStrVal) < 0 { + stats.minStrVal = v + } + if bytes.Compare(v, stats.maxStrVal) > 0 { + stats.maxStrVal = v + } + stats.maxStrLen = max(stats.maxStrLen, len(v)) + } + buf.WriteBytes(v) + return nil +} + +type jsonArrayConverter struct { + jsonConverter +} + +func (c jsonArrayConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val != nil { + if _, ok := val.([]any); !ok { + return errors.New("not a JSON array") + } + } + return c.jsonConverter.ValidateAndConvert(stats, val, buf) +} + +type jsonObjectConverter struct { + jsonConverter +} + +func (c jsonObjectConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val != nil { + if _, ok := val.(map[string]any); !ok { + return errors.New("not a JSON object") + } + } + return c.jsonConverter.ValidateAndConvert(stats, val, buf) +} + +type timestampConverter struct { + nullable bool + scale, precision int32 + includeTZ bool + trimTZ bool + defaultTZ *time.Location +} + +func (c timestampConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + var s string + var t time.Time + var err error + switch v := val.(type) { + case []byte: + s = string(v) + case string: + s = v + default: + t, err = bloblang.ValueAsTimestamp(val) + if err != nil { + return err + } + } + if s != "" { + location := c.defaultTZ + t, err = time.ParseInLocation(time.RFC3339Nano, s, location) + if err != nil { + return fmt.Errorf("unable to parse timestamp value from %q", s) + } + } + if c.trimTZ { + t = t.UTC() + } + v := snowflakeTimestampInt(t, c.scale, c.includeTZ) + if !v.FitsInPrecision(c.precision) { + return fmt.Errorf( + "unable to fit timestamp (%s -> %s) within required precision: %v", + t.Format(time.RFC3339Nano), + v.String(), + c.precision, + ) + } + if stats.first { + stats.minIntVal = v + stats.maxIntVal = v + stats.first = false + } else { + stats.minIntVal = int128.Min(stats.minIntVal, v) + stats.maxIntVal = int128.Max(stats.maxIntVal, v) + } + buf.WriteInt128(v) + return nil +} + +type timeConverter struct { + nullable bool + scale int32 +} + +func (c timeConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + t, err := bloblang.ValueAsTimestamp(val) + if err != nil { + return err + } + // 24 hours in nanoseconds fits within uint64, so we can't overflow + nanos := t.Hour()*int(time.Hour.Nanoseconds()) + + t.Minute()*int(time.Minute.Nanoseconds()) + + t.Second()*int(time.Second.Nanoseconds()) + + t.Nanosecond() + v := int128.FromInt64(int64(nanos) / pow10TableInt64[9-c.scale]) + if stats.first { + stats.minIntVal = v + stats.maxIntVal = v + stats.first = false + } else { + stats.minIntVal = int128.Min(stats.minIntVal, v) + stats.maxIntVal = int128.Max(stats.maxIntVal, v) + } + // TODO(perf): consider switching to int64 buffers so more stuff can fit in cache + buf.WriteInt128(v) + return nil +} + +type dateConverter struct { + nullable bool +} + +func (c dateConverter) ValidateAndConvert(stats *statsBuffer, val any, buf typedBuffer) error { + if val == nil { + if !c.nullable { + return errNullValue + } + stats.nullCount++ + buf.WriteNull() + return nil + } + t, err := bloblang.ValueAsTimestamp(val) + if err != nil { + return err + } + t = t.UTC() + if t.Year() < -9999 || t.Year() > 9999 { + return fmt.Errorf("DATE columns out of range, year: %d", t.Year()) + } + v := int128.FromInt64(t.Unix() / int64(24*60*60)) + if stats.first { + stats.minIntVal = v + stats.maxIntVal = v + stats.first = false + } else { + stats.minIntVal = int128.Min(stats.minIntVal, v) + stats.maxIntVal = int128.Max(stats.maxIntVal, v) + } + // TODO(perf): consider switching to int64 buffers so more stuff can fit in cache + buf.WriteInt128(v) + return nil +} diff --git a/internal/impl/snowflake/streaming/userdata_converter_test.go b/internal/impl/snowflake/streaming/userdata_converter_test.go new file mode 100644 index 0000000000..5a4ea6a6cb --- /dev/null +++ b/internal/impl/snowflake/streaming/userdata_converter_test.go @@ -0,0 +1,508 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package streaming + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/parquet-go/parquet-go" + "github.com/redpanda-data/connect/v4/internal/impl/snowflake/streaming/int128" + "github.com/stretchr/testify/require" +) + +type validateTestCase struct { + name string + input any + output any + err bool + scale int32 + precision int32 +} + +func TestTimeConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "2020-01-01T13:02:00.0Z", + output: 46920, + scale: 0, + }, + { + input: "2020-01-01T13:02:06.0Z", + output: 46926, + scale: 0, + }, + { + input: "2020-01-01T13:02:06Z", + output: 469260, + scale: 1, + }, + { + input: "2020-01-01T13:02:06Z", + output: 46926000000000, + scale: 9, + }, + { + input: "2020-01-01T13:02:06.1234Z", + output: 46926, + scale: 0, + }, + { + input: "2020-01-01T13:02:06.1234Z", + output: 469261, + scale: 1, + }, + { + input: "2020-01-01T13:02:06.1234Z", + output: 46926123400000, + scale: 9, + }, + { + input: "2020-01-01T13:02:06.123456789Z", + output: 46926, + scale: 0, + }, + { + input: "2020-01-01T13:02:06.123456789Z", + output: 469261, + scale: 1, + }, + { + input: "2020-01-01T13:02:06.123456789Z", + output: 46926123456789, + scale: 9, + }, + { + input: 46926, + output: 46926, + scale: 0, + }, + { + input: 1728680106, + output: 75306000000000, + scale: 9, + }, + { + input: "2023-01-19T14:23:55.878137", + scale: 9, + err: true, + }, + { + input: nil, + output: nil, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &timeConverter{nullable: true, scale: tc.scale} + runTestcase(t, c, tc) + }) + } +} + +func TestNumberConverter(t *testing.T) { + tests := []validateTestCase{ + { + name: "Number(2, 0)", + input: 12, + output: 12, + precision: 2, + }, + { + name: "Number(4, 0)", + input: 1234, + output: 1234, + precision: 4, + }, + { + name: "Number(9, 0)", + input: 123456789, + output: 123456789, + precision: 9, + }, + { + name: "Number(18, 0)", + input: 123456789987654321, + output: 123456789987654321, + precision: 18, + }, + { + name: "Number(38, 0)", + input: json.Number("91234567899876543219876543211234567891"), + output: int128.MustParse("91234567899876543219876543211234567891"), + precision: 38, + }, + { + name: "Number(38, 37)", + input: json.Number("9.1234567899876543219876543211234567891"), + output: int128.MustParse("91234567899876543219876543211234567891"), + precision: 38, + scale: 37, + }, + { + name: "Number(38, 28)", + input: json.Number("9123456789.9876543219876543211234567891"), + output: int128.MustParse("91234567899876543219876543211234567891"), + precision: 38, + scale: 28, + }, + { + name: "Number(19, 0) Error", + input: json.Number("91234567899876543219876543211234567891"), + err: true, + precision: 19, // too small + }, + { + name: "Number(19, 4)", + input: json.Number("123.4321"), + output: 1234321, + scale: 4, + precision: 19, + }, + { + name: "Number(19, 10)", + input: json.Number("123.4321"), + output: 1234321000000, + scale: 10, + precision: 19, + }, + { + name: "Number(26, 4)", + input: 123456789987654321, + output: int128.MustParse("1234567899876543210000"), + scale: 4, + precision: 26, + }, + { + name: "Number(19, 4) Error", + input: 123456789987654321, + err: true, + scale: 4, + precision: 19, + }, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + c := &numberConverter{ + nullable: true, + scale: tc.scale, + precision: tc.precision, + } + runTestcase(t, c, tc) + }) + } +} + +func TestRealConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: 12345.54321, + output: 12345.54321, + }, + { + input: 3.415, + output: 3.415, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &doubleConverter{nullable: true} + runTestcase(t, c, tc) + }) + } +} + +func TestBoolConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: true, + output: true, + }, + { + input: false, + output: false, + }, + { + input: nil, + output: nil, + }, + { + input: "false", + output: false, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &boolConverter{nullable: true} + runTestcase(t, c, tc) + }) + } +} + +func TestBinaryConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: []byte("1234abcd"), + output: []byte("1234abcd"), + }, + { + input: []byte(strings.Repeat("a", 57)), + err: true, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &binaryConverter{nullable: true, maxLength: 56} + runTestcase(t, c, tc) + }) + } +} + +func TestStringConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "1234abcd", + output: []byte("1234abcd"), + }, + { + input: strings.Repeat("a", 57), + err: true, + }, + { + input: "a\xc5z", + err: true, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &binaryConverter{nullable: true, maxLength: 56, utf8: true} + runTestcase(t, c, tc) + }) + } +} + +func TestTimestampNTZConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "2013-04-28T20:57:00.0Z", + output: 1367182620, + scale: 0, + precision: 18, + }, + { + input: "2013-04-28T20:57:01.000Z", + output: 1367182621000, + scale: 3, + precision: 18, + }, + { + input: "2013-04-28T20:57:01.000Z", + output: 1367182621, + scale: 0, + precision: 18, + }, + { + input: "2013-04-28T20:57:01.000+01:00", + output: 1367179021000, + scale: 3, + precision: 18, + }, + { + input: "2022-09-18T22:05:07.123456789Z", + output: 1663538707123456789, + scale: 9, + precision: 38, + }, + { + input: "2022-09-18T22:05:07.123456789+01:00", + output: 1663535107123456789, + scale: 9, + precision: 38, + }, + { + input: "2013-04-28T20:57:01.000Z", + output: 1367182621000, + scale: 3, + precision: 18, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + loc, err := time.LoadLocation("America/New_York") + require.NoError(t, err) + c := ×tampConverter{ + nullable: true, + scale: tc.scale, + precision: tc.precision, + includeTZ: false, + trimTZ: true, + defaultTZ: loc, + } + runTestcase(t, c, tc) + }) + } +} + +func TestTimestampTZConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "2013-04-28T20:57:01.000Z", + output: 22399920062465440, + scale: 3, + precision: 18, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + loc, err := time.LoadLocation("America/New_York") + require.NoError(t, err) + c := ×tampConverter{ + nullable: true, + scale: tc.scale, + precision: tc.precision, + includeTZ: true, + trimTZ: false, + defaultTZ: loc, + } + runTestcase(t, c, tc) + }) + } +} + +func TestTimestampLTZConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "2013-04-28T20:57:00Z", + output: 1367182620, + scale: 0, + precision: 18, + }, + { + input: "2013-04-28T20:57:00Z", + output: 136718262000, + scale: 2, + precision: 18, + }, + { + input: "2013-04-28T20:57:00Z", + err: true, + scale: 0, + precision: 9, // Mor precision needed + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + loc, err := time.LoadLocation("America/New_York") + require.NoError(t, err) + c := ×tampConverter{ + nullable: true, + scale: tc.scale, + precision: tc.precision, + includeTZ: false, + trimTZ: false, + defaultTZ: loc, + } + runTestcase(t, c, tc) + }) + } +} + +func TestDateConverter(t *testing.T) { + tests := []validateTestCase{ + { + input: "1970-01-10T00:00:00Z", + output: 9, + }, + { + input: 1674478926, + output: 19380, + }, + { + input: "1967-06-23T00:00:00Z", + output: -923, + }, + { + input: "2020-07-21T00:00:00Z", + output: 18464, + }, + { + input: time.Time{}.AddDate(10_000, 0, 0), + err: true, + }, + { + input: time.Time{}.AddDate(-10_001, 0, 0), + err: true, + }, + } + for _, tc := range tests { + tc := tc + t.Run("", func(t *testing.T) { + c := &dateConverter{nullable: true} + runTestcase(t, c, tc) + }) + } +} + +type testTypedBuffer struct { + output any +} + +func (b *testTypedBuffer) WriteNull() { + b.output = nil +} +func (b *testTypedBuffer) WriteInt128(v int128.Num) { + switch { + case int128.Less(v, int128.MinInt64): + b.output = v + case int128.Greater(v, int128.MaxInt64): + b.output = v + default: + b.output = int(v.ToInt64()) + } +} + +func (b *testTypedBuffer) WriteBool(v bool) { + b.output = v +} +func (b *testTypedBuffer) WriteFloat64(v float64) { + b.output = v +} +func (b *testTypedBuffer) WriteBytes(v []byte) { + b.output = v +} +func (b *testTypedBuffer) Prepare([]parquet.Value, int, int) { + b.output = nil +} +func (b *testTypedBuffer) Reset() {} + +func runTestcase(t *testing.T, dc dataConverter, tc validateTestCase) { + t.Helper() + s := statsBuffer{} + b := testTypedBuffer{} + err := dc.ValidateAndConvert(&s, tc.input, &b) + if tc.err { + require.Errorf(t, err, "instead got: %#v", b.output) + } else { + require.NoError(t, err) + require.Equal(t, tc.output, b.output) + } +} diff --git a/internal/periodic/periodic.go b/internal/periodic/periodic.go new file mode 100644 index 0000000000..1f7ad90ab3 --- /dev/null +++ b/internal/periodic/periodic.go @@ -0,0 +1,93 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package periodic + +import ( + "context" + "time" +) + +// Periodic holds a background goroutine that can do periodic work. +// +// The work here cannot communicate errors directly, so it must +// communicate with channels or swallow errors. +// +// NOTE: It's expected that Start and Stop are called on the same +// goroutine or be externally synchronized as to not race. +type Periodic struct { + duration time.Duration + work func(context.Context) + + cancel context.CancelFunc + done chan any +} + +// New creates new background work that runs every `duration` and performs `work`. +func New(duration time.Duration, work func()) *Periodic { + return &Periodic{ + duration: duration, + work: func(context.Context) { work() }, + } +} + +// NewWithContext creates new background work that runs every `duration` and performs `work`. +// +// Work is passed a context that is cancelled when the overall periodic is cancelled. +func NewWithContext(duration time.Duration, work func(context.Context)) *Periodic { + return &Periodic{ + duration: duration, + work: work, + } +} + +// Start starts the `Periodic` work. +// +// It does not do work immedately, only after the time has passed. +func (p *Periodic) Start() { + if p.cancel != nil { + return + } + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan any) + go runBackgroundLoop(ctx, p.duration, done, p.work) + p.cancel = cancel + p.done = done +} + +func runBackgroundLoop(ctx context.Context, d time.Duration, done chan any, work func(context.Context)) { + refreshTimer := time.NewTicker(d) + defer func() { + refreshTimer.Stop() + close(done) + }() + for ctx.Err() == nil { + select { + case <-refreshTimer.C: + work(ctx) + case <-ctx.Done(): + return + } + } +} + +// Stop stops the periodic work and waits for the background goroutine to exit. +func (p *Periodic) Stop() { + if p.cancel == nil { + return + } + p.cancel() + <-p.done + p.done = nil +} diff --git a/internal/periodic/periodic_test.go b/internal/periodic/periodic_test.go new file mode 100644 index 0000000000..b3a6fb2511 --- /dev/null +++ b/internal/periodic/periodic_test.go @@ -0,0 +1,62 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package periodic + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCancellation(t *testing.T) { + counter := atomic.Int32{} + p := New(time.Hour, func() { + counter.Add(1) + }) + p.Start() + require.Equal(t, int32(0), counter.Load()) + p.Stop() + require.Equal(t, int32(0), counter.Load()) +} + +func TestWorks(t *testing.T) { + counter := atomic.Int32{} + p := New(time.Millisecond, func() { + counter.Add(1) + }) + p.Start() + require.Eventually(t, func() bool { return counter.Load() > 5 }, time.Second, time.Millisecond) + p.Stop() + snapshot := counter.Load() + time.Sleep(time.Millisecond * 250) + require.Equal(t, snapshot, counter.Load()) +} + +func TestWorksWithContext(t *testing.T) { + active := atomic.Bool{} + p := NewWithContext(time.Millisecond, func(ctx context.Context) { + active.Store(true) + // Block until context is cancelled + <-ctx.Done() + active.Store(false) + }) + p.Start() + require.Eventually(t, func() bool { return active.Load() }, 10*time.Millisecond, time.Millisecond) + p.Stop() + require.False(t, active.Load()) +} diff --git a/internal/plugins/info.csv b/internal/plugins/info.csv index 33f6289696..025df1fe28 100644 --- a/internal/plugins/info.csv +++ b/internal/plugins/info.csv @@ -218,6 +218,7 @@ sftp ,output ,sftp ,3.39.0 ,certif skip_bom ,scanner ,skip_bom ,0.0.0 ,certified ,n ,y ,y sleep ,processor ,sleep ,0.0.0 ,certified ,n ,y ,y snowflake_put ,output ,Snowflake ,4.0.0 ,enterprise ,n ,y ,y +snowflake_streaming ,output ,Snowflake Streaming ,4.39.0 ,enterprise ,n ,y ,y socket ,input ,Socket ,0.0.0 ,certified ,n ,n ,n socket ,output ,Socket ,0.0.0 ,certified ,n ,n ,n socket_server ,input ,socket_server ,0.0.0 ,certified ,n ,n ,n diff --git a/internal/typed/atomic_value.go b/internal/typed/atomic_value.go new file mode 100644 index 0000000000..2dea69cca6 --- /dev/null +++ b/internal/typed/atomic_value.go @@ -0,0 +1,56 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typed + +import "sync/atomic" + +// AtomicValue is a small type safe generic wrapper over atomic.Value +// +// Must not be copied (use NewAtomicValue). +// +// Who doesn't like generics? +type AtomicValue[T any] struct { + noCopy + val atomic.Value +} + +// NewAtomicValue creates a new AtomicValue holding `v`. +func NewAtomicValue[T any](v T) *AtomicValue[T] { + a := &AtomicValue[T]{} + a.Store(v) + return a +} + +// Load returns the value set by the latest store. +func (a *AtomicValue[T]) Load() T { + // This dereference is safe because we only create these with values + return *a.val.Load().(*T) +} + +// Store sets the value of the atomic to `v`. +func (a *AtomicValue[T]) Store(v T) { + a.val.Store(&v) +} + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) UnLock() {}