Skip to content

Instantly share code, notes, and snippets.

@GleasonK
Created September 17, 2024 21:23
Show Gist options
  • Save GleasonK/cf68c91196bc6beb4017d1a2e53ef8bf to your computer and use it in GitHub Desktop.
Save GleasonK/cf68c91196bc6beb4017d1a2e53ef8bf to your computer and use it in GitHub Desktop.
// This is a dump from PyTorch/XLA of the following LLAMA2 model file:
// https://github.com/pytorch/xla/blob/master/test/stablehlo/llama_model2.py
//
module @IrToHlo.7509 attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func @main(%arg0: tensor<32000x4096xf32>, %arg1: tensor<4096xf32>, %arg2: tensor<4096x11008xf32>, %arg3: tensor<11008x4096xf32>, %arg4: tensor<4096xf32>, %arg5: tensor<4096x4096xf32>, %arg6: tensor<4096x4096xf32>, %arg7: tensor<4096xf32>, %arg8: tensor<4096x11008xf32>, %arg9: tensor<11008x4096xf32>, %arg10: tensor<4096xf32>, %arg11: tensor<4096x4096xf32>, %arg12: tensor<4096x4096xf32>, %arg13: tensor<4096xf32>, %arg14: tensor<4096x11008xf32>, %arg15: tensor<11008x4096xf32>, %arg16: tensor<4096xf32>, %arg17: tensor<4096x4096xf32>, %arg18: tensor<4096x4096xf32>, %arg19: tensor<4096xf32>, %arg20: tensor<4096x11008xf32>, %arg21: tensor<11008x4096xf32>, %arg22: tensor<4096xf32>, %arg23: tensor<4096x4096xf32>, %arg24: tensor<4096x4096xf32>, %arg25: tensor<4096xf32>, %arg26: tensor<4096x11008xf32>, %arg27: tensor<11008x4096xf32>, %arg28: tensor<4096xf32>, %arg29: tensor<4096x4096xf32>, %arg30: tensor<4096x4096xf32>, %arg31: tensor<4096xf32>, %arg32: tensor<4096x11008xf32>, %arg33: tensor<11008x4096xf32>, %arg34: tensor<4096xf32>, %arg35: tensor<4096x4096xf32>, %arg36: tensor<4096x4096xf32>, %arg37: tensor<4096xf32>, %arg38: tensor<4096x11008xf32>, %arg39: tensor<11008x4096xf32>, %arg40: tensor<4096xf32>, %arg41: tensor<4096x4096xf32>, %arg42: tensor<4096x4096xf32>, %arg43: tensor<4096xf32>, %arg44: tensor<4096x11008xf32>, %arg45: tensor<11008x4096xf32>, %arg46: tensor<4096xf32>, %arg47: tensor<4096x4096xf32>, %arg48: tensor<4096x4096xf32>, %arg49: tensor<4096xf32>, %arg50: tensor<4096x11008xf32>, %arg51: tensor<11008x4096xf32>, %arg52: tensor<4096xf32>, %arg53: tensor<4096x4096xf32>, %arg54: tensor<4096x4096xf32>, %arg55: tensor<4096xf32>, %arg56: tensor<4096x11008xf32>, %arg57: tensor<11008x4096xf32>, %arg58: tensor<4096xf32>, %arg59: tensor<4096x4096xf32>, %arg60: tensor<4096x4096xf32>, %arg61: tensor<4096xf32>, %arg62: tensor<4096x11008xf32>, %arg63: tensor<11008x4096xf32>, %arg64: tensor<4096xf32>, %arg65: tensor<4096x4096xf32>, %arg66: tensor<4096x4096xf32>, %arg67: tensor<4096xf32>, %arg68: tensor<4096x11008xf32>, %arg69: tensor<11008x4096xf32>, %arg70: tensor<4096xf32>, %arg71: tensor<4096x4096xf32>, %arg72: tensor<4096x4096xf32>, %arg73: tensor<4096xf32>, %arg74: tensor<4096x11008xf32>, %arg75: tensor<11008x4096xf32>, %arg76: tensor<4096xf32>, %arg77: tensor<4096x4096xf32>, %arg78: tensor<4096x4096xf32>, %arg79: tensor<4096xf32>, %arg80: tensor<4096x11008xf32>, %arg81: tensor<11008x4096xf32>, %arg82: tensor<4096xf32>, %arg83: tensor<4096x4096xf32>, %arg84: tensor<4096x4096xf32>, %arg85: tensor<4096xf32>, %arg86: tensor<4096x11008xf32>, %arg87: tensor<11008x4096xf32>, %arg88: tensor<4096xf32>, %arg89: tensor<4096x4096xf32>, %arg90: tensor<4096x4096xf32>, %arg91: tensor<4096xf32>, %arg92: tensor<4096x11008xf32>, %arg93: tensor<11008x4096xf32>, %arg94: tensor<4096xf32>, %arg95: tensor<4096x4096xf32>, %arg96: tensor<4096x4096xf32>, %arg97: tensor<4096xf32>, %arg98: tensor<4096x11008xf32>, %arg99: tensor<11008x4096xf32>, %arg100: tensor<4096xf32>, %arg101: tensor<4096x4096xf32>, %arg102: tensor<4096x4096xf32>, %arg103: tensor<4096xf32>, %arg104: tensor<4096x11008xf32>, %arg105: tensor<11008x4096xf32>, %arg106: tensor<4096xf32>, %arg107: tensor<4096x4096xf32>, %arg108: tensor<4096x4096xf32>, %arg109: tensor<4096xf32>, %arg110: tensor<4096x11008xf32>, %arg111: tensor<11008x4096xf32>, %arg112: tensor<4096xf32>, %arg113: tensor<4096x4096xf32>, %arg114: tensor<4096x4096xf32>, %arg115: tensor<4096xf32>, %arg116: tensor<4096x11008xf32>, %arg117: tensor<11008x4096xf32>, %arg118: tensor<4096xf32>, %arg119: tensor<4096x4096xf32>, %arg120: tensor<4096x4096xf32>, %arg121: tensor<4096xf32>, %arg122: tensor<4096x11008xf32>, %arg123: tensor<11008x4096xf32>, %arg124: tensor<4096xf32>, %arg125: tensor<4096x4096xf32>, %arg126: tensor<4096x4096xf32>, %arg127: tensor<4096xf32>, %arg128: tensor<4096x11008xf32>, %arg129: tensor<11008x4096xf32>, %arg130: tensor<4096xf32>, %arg131: tensor<4096x4096xf32>, %arg132: tensor<4096x4096xf32>, %arg133: tensor<4096xf32>, %arg134: tensor<4096x11008xf32>, %arg135: tensor<11008x4096xf32>, %arg136: tensor<4096xf32>, %arg137: tensor<4096x4096xf32>, %arg138: tensor<4096x4096xf32>, %arg139: tensor<4096xf32>, %arg140: tensor<4096x11008xf32>, %arg141: tensor<11008x4096xf32>, %arg142: tensor<4096xf32>, %arg143: tensor<4096x4096xf32>, %arg144: tensor<4096x4096xf32>, %arg145: tensor<4096xf32>, %arg146: tensor<4096x11008xf32>, %arg147: tensor<11008x4096xf32>, %arg148: tensor<4096xf32>, %arg149: tensor<4096x4096xf32>, %arg150: tensor<4096x4096xf32>, %arg151: tensor<4096xf32>, %arg152: tensor<4096x11008xf32>, %arg153: tensor<11008x4096xf32>, %arg154: tensor<4096xf32>, %arg155: tensor<4096x4096xf32>, %arg156: tensor<4096x4096xf32>, %arg157: tensor<4096xf32>, %arg158: tensor<4096x11008xf32>, %arg159: tensor<11008x4096xf32>, %arg160: tensor<4096xf32>, %arg161: tensor<4096x4096xf32>, %arg162: tensor<4096x4096xf32>, %arg163: tensor<4096xf32>, %arg164: tensor<4096x11008xf32>, %arg165: tensor<11008x4096xf32>, %arg166: tensor<4096xf32>, %arg167: tensor<4096x4096xf32>, %arg168: tensor<4096x4096xf32>, %arg169: tensor<4096xf32>, %arg170: tensor<4096x11008xf32>, %arg171: tensor<11008x4096xf32>, %arg172: tensor<4096xf32>, %arg173: tensor<4096x4096xf32>, %arg174: tensor<4096x4096xf32>, %arg175: tensor<4096xf32>, %arg176: tensor<4096x11008xf32>, %arg177: tensor<11008x4096xf32>, %arg178: tensor<4096xf32>, %arg179: tensor<4096x4096xf32>, %arg180: tensor<4096x4096xf32>, %arg181: tensor<4096xf32>, %arg182: tensor<4096x11008xf32>, %arg183: tensor<11008x4096xf32>, %arg184: tensor<4096xf32>, %arg185: tensor<4096x4096xf32>, %arg186: tensor<4096x4096xf32>, %arg187: tensor<4096xf32>, %arg188: tensor<4096x11008xf32>, %arg189: tensor<11008x4096xf32>, %arg190: tensor<4096xf32>, %arg191: tensor<4096x4096xf32>, %arg192: tensor<4096x4096xf32>, %arg193: tensor<4096xf32>, %arg194: tensor<8x100xi64>, %arg195: tensor<32000x4096xf32>, %arg196: tensor<100xi64>, %arg197: tensor<8x1024x32x128xf32>, %arg198: tensor<1x1x1024x1024xf32>, %arg199: tensor<2048x64xcomplex<f32>>, %arg200: tensor<4096x4096xf32>, %arg201: tensor<8x1024x32x128xf32>, %arg202: tensor<4096x4096xf32>, %arg203: tensor<11008x4096xf32>, %arg204: tensor<8x1024x32x128xf32>, %arg205: tensor<4096x4096xf32>, %arg206: tensor<8x1024x32x128xf32>, %arg207: tensor<4096x4096xf32>, %arg208: tensor<11008x4096xf32>, %arg209: tensor<8x1024x32x128xf32>, %arg210: tensor<4096x4096xf32>, %arg211: tensor<8x1024x32x128xf32>, %arg212: tensor<4096x4096xf32>, %arg213: tensor<11008x4096xf32>, %arg214: tensor<8x1024x32x128xf32>, %arg215: tensor<4096x4096xf32>, %arg216: tensor<8x1024x32x128xf32>, %arg217: tensor<4096x4096xf32>, %arg218: tensor<11008x4096xf32>, %arg219: tensor<8x1024x32x128xf32>, %arg220: tensor<4096x4096xf32>, %arg221: tensor<8x1024x32x128xf32>, %arg222: tensor<4096x4096xf32>, %arg223: tensor<11008x4096xf32>, %arg224: tensor<8x1024x32x128xf32>, %arg225: tensor<4096x4096xf32>, %arg226: tensor<8x1024x32x128xf32>, %arg227: tensor<4096x4096xf32>, %arg228: tensor<11008x4096xf32>, %arg229: tensor<8x1024x32x128xf32>, %arg230: tensor<4096x4096xf32>, %arg231: tensor<8x1024x32x128xf32>, %arg232: tensor<4096x4096xf32>, %arg233: tensor<11008x4096xf32>, %arg234: tensor<8x1024x32x128xf32>, %arg235: tensor<4096x4096xf32>, %arg236: tensor<8x1024x32x128xf32>, %arg237: tensor<4096x4096xf32>, %arg238: tensor<11008x4096xf32>, %arg239: tensor<8x1024x32x128xf32>, %arg240: tensor<4096x4096xf32>, %arg241: tensor<8x1024x32x128xf32>, %arg242: tensor<4096x4096xf32>, %arg243: tensor<11008x4096xf32>, %arg244: tensor<8x1024x32x128xf32>, %arg245: tensor<4096x4096xf32>, %arg246: tensor<8x1024x32x128xf32>, %arg247: tensor<4096x4096xf32>, %arg248: tensor<11008x4096xf32>, %arg249: tensor<8x1024x32x128xf32>, %arg250: tensor<4096x4096xf32>, %arg251: tensor<8x1024x32x128xf32>, %arg252: tensor<4096x4096xf32>, %arg253: tensor<11008x4096xf32>, %arg254: tensor<8x1024x32x128xf32>, %arg255: tensor<4096x4096xf32>, %arg256: tensor<8x1024x32x128xf32>, %arg257: tensor<4096x4096xf32>, %arg258: tensor<11008x4096xf32>, %arg259: tensor<8x1024x32x128xf32>, %arg260: tensor<4096x4096xf32>, %arg261: tensor<8x1024x32x128xf32>, %arg262: tensor<4096x4096xf32>, %arg263: tensor<11008x4096xf32>, %arg264: tensor<8x1024x32x128xf32>, %arg265: tensor<4096x4096xf32>, %arg266: tensor<8x1024x32x128xf32>, %arg267: tensor<4096x4096xf32>, %arg268: tensor<11008x4096xf32>, %arg269: tensor<8x1024x32x128xf32>, %arg270: tensor<4096x4096xf32>, %arg271: tensor<8x1024x32x128xf32>, %arg272: tensor<4096x4096xf32>, %arg273: tensor<11008x4096xf32>, %arg274: tensor<8x1024x32x128xf32>, %arg275: tensor<4096x4096xf32>, %arg276: tensor<8x1024x32x128xf32>, %arg277: tensor<4096x4096xf32>, %arg278: tensor<11008x4096xf32>, %arg279: tensor<8x1024x32x128xf32>, %arg280: tensor<4096x4096xf32>, %arg281: tensor<8x1024x32x128xf32>, %arg282: tensor<4096x4096xf32>, %arg283: tensor<11008x4096xf32>, %arg284: tensor<8x1024x32x128xf32>, %arg285: tensor<4096x4096xf32>, %arg286: tensor<8x1024x32x128xf32>, %arg287: tensor<4096x4096xf32>, %arg288: tensor<11008x4096xf32>, %arg289: tensor<8x1024x32x128xf32>, %arg290: tensor<4096x4096xf32>, %arg291: tensor<8x1024x32x128xf32>, %arg292: tensor<4096x4096xf32>, %arg293: tensor<11008x4096xf32>, %arg294: tensor<8x1024x32x128xf32>, %arg295: tensor<4096x4096xf32>, %arg296: tensor<8x1024x32x128xf32>, %arg297: tensor<4096x4096xf32>, %arg298: tensor<11008x4096xf32>, %arg299: tensor<8x1024x32x128xf32>, %arg300: tensor<4096x4096xf32>, %arg301: tensor<8x1024x32x128xf32>, %arg302: tensor<4096x4096xf32>, %arg303: tensor<11008x4096xf32>, %arg304: tensor<8x1024x32x128xf32>, %arg305: tensor<4096x4096xf32>, %arg306: tensor<8x1024x32x128xf32>, %arg307: tensor<4096x4096xf32>, %arg308: tensor<11008x4096xf32>, %arg309: tensor<8x1024x32x128xf32>, %arg310: tensor<4096x4096xf32>, %arg311: tensor<8x1024x32x128xf32>, %arg312: tensor<4096x4096xf32>, %arg313: tensor<11008x4096xf32>, %arg314: tensor<8x1024x32x128xf32>, %arg315: tensor<4096x4096xf32>, %arg316: tensor<8x1024x32x128xf32>, %arg317: tensor<4096x4096xf32>, %arg318: tensor<11008x4096xf32>, %arg319: tensor<8x1024x32x128xf32>, %arg320: tensor<4096x4096xf32>, %arg321: tensor<8x1024x32x128xf32>, %arg322: tensor<4096x4096xf32>, %arg323: tensor<11008x4096xf32>, %arg324: tensor<8x1024x32x128xf32>, %arg325: tensor<4096x4096xf32>, %arg326: tensor<8x1024x32x128xf32>, %arg327: tensor<4096x4096xf32>, %arg328: tensor<11008x4096xf32>, %arg329: tensor<8x1024x32x128xf32>, %arg330: tensor<4096x4096xf32>, %arg331: tensor<8x1024x32x128xf32>, %arg332: tensor<4096x4096xf32>, %arg333: tensor<11008x4096xf32>, %arg334: tensor<8x1024x32x128xf32>, %arg335: tensor<4096x4096xf32>, %arg336: tensor<8x1024x32x128xf32>, %arg337: tensor<4096x4096xf32>, %arg338: tensor<11008x4096xf32>, %arg339: tensor<8x1024x32x128xf32>, %arg340: tensor<4096x4096xf32>, %arg341: tensor<8x1024x32x128xf32>, %arg342: tensor<4096x4096xf32>, %arg343: tensor<11008x4096xf32>, %arg344: tensor<8x1024x32x128xf32>, %arg345: tensor<4096x4096xf32>, %arg346: tensor<8x1024x32x128xf32>, %arg347: tensor<4096x4096xf32>, %arg348: tensor<11008x4096xf32>, %arg349: tensor<8x1024x32x128xf32>, %arg350: tensor<4096x4096xf32>, %arg351: tensor<8x1024x32x128xf32>, %arg352: tensor<4096x4096xf32>, %arg353: tensor<11008x4096xf32>, %arg354: tensor<8x1024x32x128xf32>, %arg355: tensor<4096x4096xf32>, %arg356: tensor<8x1024x32x128xf32>, %arg357: tensor<4096x4096xf32>, %arg358: tensor<11008x4096xf32>) -> tensor<8x100x32000xf32> {
%cst = stablehlo.constant dense<11.3137083> : tensor<8x32x100x1024xf32>
%c = stablehlo.constant dense<1024> : tensor<100xi64>
%c_0 = stablehlo.constant dense<0> : tensor<100xi64>
%cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<8x100x1xf32>
%cst_2 = stablehlo.constant dense<2.44140625E-4> : tensor<8x100xf32>
%cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<8x100x4096xf32>
%cst_4 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.reshape %arg194 : (tensor<8x100xi64>) -> tensor<800xi64>
%1 = stablehlo.convert %0 : (tensor<800xi64>) -> tensor<800xui32>
%2 = "stablehlo.gather"(%arg195, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4096>}> : (tensor<32000x4096xf32>, tensor<800xui32>) -> tensor<800x4096xf32>
%3 = stablehlo.reshape %2 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%4 = stablehlo.power %3, %cst_3 : tensor<8x100x4096xf32>
%5 = stablehlo.reduce(%4 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%6 = stablehlo.multiply %5, %cst_2 : tensor<8x100xf32>
%7 = stablehlo.reshape %6 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%8 = stablehlo.add %7, %cst_1 : tensor<8x100x1xf32>
%9 = stablehlo.rsqrt %8 : tensor<8x100x1xf32>
%10 = stablehlo.reshape %9 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%11 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%12 = stablehlo.multiply %3, %11 : tensor<8x100x4096xf32>
%13 = stablehlo.broadcast_in_dim %arg193, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%14 = stablehlo.multiply %12, %13 : tensor<8x100x4096xf32>
%15 = stablehlo.reshape %14 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%16 = stablehlo.transpose %arg202, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%17 = stablehlo.dot_general %15, %16, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%18 = stablehlo.reshape %17 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%19 = stablehlo.transpose %18, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%20 = stablehlo.reshape %19 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%21 = stablehlo.slice %20 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%22 = stablehlo.reshape %21 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%23 = stablehlo.slice %20 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%24 = stablehlo.reshape %23 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%25 = stablehlo.complex %22, %24 : tensor<256x100x64xcomplex<f32>>
%26 = stablehlo.convert %arg196 : (tensor<100xi64>) -> tensor<100xui32>
%27 = "stablehlo.gather"(%arg199, %26) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 64>}> : (tensor<2048x64xcomplex<f32>>, tensor<100xui32>) -> tensor<100x64xcomplex<f32>>
%28 = stablehlo.broadcast_in_dim %27, dims = [1, 2] : (tensor<100x64xcomplex<f32>>) -> tensor<256x100x64xcomplex<f32>>
%29 = stablehlo.multiply %25, %28 : tensor<256x100x64xcomplex<f32>>
%30 = stablehlo.real %29 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%31 = stablehlo.reshape %30 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%32 = stablehlo.imag %29 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%33 = stablehlo.reshape %32 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%34 = stablehlo.concatenate %31, %33, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%35 = stablehlo.reshape %34 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%36 = stablehlo.compare LT, %arg196, %c_0 : (tensor<100xi64>, tensor<100xi64>) -> tensor<100xi1>
%37 = stablehlo.add %arg196, %c : tensor<100xi64>
%38 = stablehlo.select %36, %37, %arg196 : tensor<100xi1>, tensor<100xi64>
%39 = stablehlo.reshape %38 : (tensor<100xi64>) -> tensor<100x1xi64>
%40 = stablehlo.transpose %arg200, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%41 = stablehlo.dot_general %15, %40, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%42 = stablehlo.reshape %41 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%43 = stablehlo.transpose %42, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%44 = stablehlo.reshape %43 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%45 = stablehlo.slice %44 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%46 = stablehlo.reshape %45 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%47 = stablehlo.slice %44 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%48 = stablehlo.reshape %47 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%49 = stablehlo.complex %46, %48 : tensor<256x100x64xcomplex<f32>>
%50 = stablehlo.multiply %49, %28 : tensor<256x100x64xcomplex<f32>>
%51 = stablehlo.real %50 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%52 = stablehlo.reshape %51 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%53 = stablehlo.imag %50 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%54 = stablehlo.reshape %53 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%55 = stablehlo.concatenate %52, %54, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%56 = stablehlo.reshape %55 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%57 = stablehlo.transpose %56, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%58 = "stablehlo.scatter"(%arg201, %39, %57) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%59 = stablehlo.transpose %58, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%60 = stablehlo.reshape %59 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%61 = stablehlo.dot_general %35, %60, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%62 = stablehlo.reshape %61 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%63 = stablehlo.divide %62, %cst : tensor<8x32x100x1024xf32>
%64 = "stablehlo.gather"(%arg198, %26) <{dimension_numbers = #stablehlo.gather<offset_dims = [0, 1, 3], collapsed_slice_dims = [2], start_index_map = [2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1, 1, 1024>}> : (tensor<1x1x1024x1024xf32>, tensor<100xui32>) -> tensor<1x1x100x1024xf32>
%65 = stablehlo.reshape %64 : (tensor<1x1x100x1024xf32>) -> tensor<100x1024xf32>
%66 = stablehlo.broadcast_in_dim %65, dims = [2, 3] : (tensor<100x1024xf32>) -> tensor<8x32x100x1024xf32>
%67 = stablehlo.add %63, %66 : tensor<8x32x100x1024xf32>
%68 = stablehlo.reduce(%67 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%69 = stablehlo.broadcast_in_dim %68, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%70 = stablehlo.subtract %67, %69 : tensor<8x32x100x1024xf32>
%71 = stablehlo.exponential %70 : tensor<8x32x100x1024xf32>
%72 = stablehlo.reduce(%71 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%73 = stablehlo.broadcast_in_dim %72, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%74 = stablehlo.divide %71, %73 : tensor<8x32x100x1024xf32>
%75 = stablehlo.reshape %74 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%76 = stablehlo.transpose %arg192, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%77 = stablehlo.dot_general %15, %76, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%78 = stablehlo.reshape %77 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%79 = "stablehlo.scatter"(%arg197, %39, %78) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%80 = stablehlo.transpose %79, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%81 = stablehlo.reshape %80 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%82 = stablehlo.dot_general %75, %81, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%83 = stablehlo.reshape %82 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%84 = stablehlo.transpose %83, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%85 = stablehlo.reshape %84 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%86 = stablehlo.transpose %arg191, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%87 = stablehlo.dot_general %85, %86, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%88 = stablehlo.reshape %87 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%89 = stablehlo.add %3, %88 : tensor<8x100x4096xf32>
%90 = stablehlo.power %89, %cst_3 : tensor<8x100x4096xf32>
%91 = stablehlo.reduce(%90 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%92 = stablehlo.multiply %91, %cst_2 : tensor<8x100xf32>
%93 = stablehlo.reshape %92 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%94 = stablehlo.add %93, %cst_1 : tensor<8x100x1xf32>
%95 = stablehlo.rsqrt %94 : tensor<8x100x1xf32>
%96 = stablehlo.reshape %95 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%97 = stablehlo.broadcast_in_dim %96, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%98 = stablehlo.multiply %89, %97 : tensor<8x100x4096xf32>
%99 = stablehlo.broadcast_in_dim %arg190, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%100 = stablehlo.multiply %98, %99 : tensor<8x100x4096xf32>
%101 = stablehlo.reshape %100 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%102 = stablehlo.transpose %arg203, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%103 = stablehlo.dot_general %101, %102, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%104 = stablehlo.reshape %103 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%105 = stablehlo.logistic %104 : tensor<8x100x11008xf32>
%106 = stablehlo.multiply %104, %105 : tensor<8x100x11008xf32>
%107 = stablehlo.transpose %arg189, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%108 = stablehlo.dot_general %101, %107, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%109 = stablehlo.reshape %108 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%110 = stablehlo.multiply %106, %109 : tensor<8x100x11008xf32>
%111 = stablehlo.reshape %110 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%112 = stablehlo.transpose %arg188, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%113 = stablehlo.dot_general %111, %112, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%114 = stablehlo.reshape %113 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%115 = stablehlo.add %89, %114 : tensor<8x100x4096xf32>
%116 = stablehlo.power %115, %cst_3 : tensor<8x100x4096xf32>
%117 = stablehlo.reduce(%116 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%118 = stablehlo.multiply %117, %cst_2 : tensor<8x100xf32>
%119 = stablehlo.reshape %118 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%120 = stablehlo.add %119, %cst_1 : tensor<8x100x1xf32>
%121 = stablehlo.rsqrt %120 : tensor<8x100x1xf32>
%122 = stablehlo.reshape %121 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%123 = stablehlo.broadcast_in_dim %122, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%124 = stablehlo.multiply %115, %123 : tensor<8x100x4096xf32>
%125 = stablehlo.broadcast_in_dim %arg187, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%126 = stablehlo.multiply %124, %125 : tensor<8x100x4096xf32>
%127 = stablehlo.reshape %126 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%128 = stablehlo.transpose %arg207, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%129 = stablehlo.dot_general %127, %128, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%130 = stablehlo.reshape %129 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%131 = stablehlo.transpose %130, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%132 = stablehlo.reshape %131 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%133 = stablehlo.slice %132 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%134 = stablehlo.reshape %133 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%135 = stablehlo.slice %132 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%136 = stablehlo.reshape %135 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%137 = stablehlo.complex %134, %136 : tensor<256x100x64xcomplex<f32>>
%138 = stablehlo.multiply %137, %28 : tensor<256x100x64xcomplex<f32>>
%139 = stablehlo.real %138 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%140 = stablehlo.reshape %139 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%141 = stablehlo.imag %138 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%142 = stablehlo.reshape %141 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%143 = stablehlo.concatenate %140, %142, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%144 = stablehlo.reshape %143 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%145 = stablehlo.transpose %arg205, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%146 = stablehlo.dot_general %127, %145, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%147 = stablehlo.reshape %146 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%148 = stablehlo.transpose %147, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%149 = stablehlo.reshape %148 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%150 = stablehlo.slice %149 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%151 = stablehlo.reshape %150 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%152 = stablehlo.slice %149 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%153 = stablehlo.reshape %152 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%154 = stablehlo.complex %151, %153 : tensor<256x100x64xcomplex<f32>>
%155 = stablehlo.multiply %154, %28 : tensor<256x100x64xcomplex<f32>>
%156 = stablehlo.real %155 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%157 = stablehlo.reshape %156 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%158 = stablehlo.imag %155 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%159 = stablehlo.reshape %158 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%160 = stablehlo.concatenate %157, %159, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%161 = stablehlo.reshape %160 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%162 = stablehlo.transpose %161, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%163 = "stablehlo.scatter"(%arg206, %39, %162) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%164 = stablehlo.transpose %163, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%165 = stablehlo.reshape %164 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%166 = stablehlo.dot_general %144, %165, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%167 = stablehlo.reshape %166 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%168 = stablehlo.divide %167, %cst : tensor<8x32x100x1024xf32>
%169 = stablehlo.add %168, %66 : tensor<8x32x100x1024xf32>
%170 = stablehlo.reduce(%169 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%171 = stablehlo.broadcast_in_dim %170, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%172 = stablehlo.subtract %169, %171 : tensor<8x32x100x1024xf32>
%173 = stablehlo.exponential %172 : tensor<8x32x100x1024xf32>
%174 = stablehlo.reduce(%173 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%175 = stablehlo.broadcast_in_dim %174, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%176 = stablehlo.divide %173, %175 : tensor<8x32x100x1024xf32>
%177 = stablehlo.reshape %176 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%178 = stablehlo.transpose %arg186, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%179 = stablehlo.dot_general %127, %178, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%180 = stablehlo.reshape %179 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%181 = "stablehlo.scatter"(%arg204, %39, %180) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%182 = stablehlo.transpose %181, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%183 = stablehlo.reshape %182 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%184 = stablehlo.dot_general %177, %183, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%185 = stablehlo.reshape %184 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%186 = stablehlo.transpose %185, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%187 = stablehlo.reshape %186 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%188 = stablehlo.transpose %arg185, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%189 = stablehlo.dot_general %187, %188, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%190 = stablehlo.reshape %189 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%191 = stablehlo.add %115, %190 : tensor<8x100x4096xf32>
%192 = stablehlo.power %191, %cst_3 : tensor<8x100x4096xf32>
%193 = stablehlo.reduce(%192 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%194 = stablehlo.multiply %193, %cst_2 : tensor<8x100xf32>
%195 = stablehlo.reshape %194 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%196 = stablehlo.add %195, %cst_1 : tensor<8x100x1xf32>
%197 = stablehlo.rsqrt %196 : tensor<8x100x1xf32>
%198 = stablehlo.reshape %197 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%199 = stablehlo.broadcast_in_dim %198, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%200 = stablehlo.multiply %191, %199 : tensor<8x100x4096xf32>
%201 = stablehlo.broadcast_in_dim %arg184, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%202 = stablehlo.multiply %200, %201 : tensor<8x100x4096xf32>
%203 = stablehlo.reshape %202 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%204 = stablehlo.transpose %arg208, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%205 = stablehlo.dot_general %203, %204, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%206 = stablehlo.reshape %205 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%207 = stablehlo.logistic %206 : tensor<8x100x11008xf32>
%208 = stablehlo.multiply %206, %207 : tensor<8x100x11008xf32>
%209 = stablehlo.transpose %arg183, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%210 = stablehlo.dot_general %203, %209, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%211 = stablehlo.reshape %210 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%212 = stablehlo.multiply %208, %211 : tensor<8x100x11008xf32>
%213 = stablehlo.reshape %212 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%214 = stablehlo.transpose %arg182, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%215 = stablehlo.dot_general %213, %214, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%216 = stablehlo.reshape %215 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%217 = stablehlo.add %191, %216 : tensor<8x100x4096xf32>
%218 = stablehlo.power %217, %cst_3 : tensor<8x100x4096xf32>
%219 = stablehlo.reduce(%218 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%220 = stablehlo.multiply %219, %cst_2 : tensor<8x100xf32>
%221 = stablehlo.reshape %220 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%222 = stablehlo.add %221, %cst_1 : tensor<8x100x1xf32>
%223 = stablehlo.rsqrt %222 : tensor<8x100x1xf32>
%224 = stablehlo.reshape %223 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%225 = stablehlo.broadcast_in_dim %224, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%226 = stablehlo.multiply %217, %225 : tensor<8x100x4096xf32>
%227 = stablehlo.broadcast_in_dim %arg181, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%228 = stablehlo.multiply %226, %227 : tensor<8x100x4096xf32>
%229 = stablehlo.reshape %228 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%230 = stablehlo.transpose %arg212, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%231 = stablehlo.dot_general %229, %230, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%232 = stablehlo.reshape %231 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%233 = stablehlo.transpose %232, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%234 = stablehlo.reshape %233 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%235 = stablehlo.slice %234 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%236 = stablehlo.reshape %235 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%237 = stablehlo.slice %234 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%238 = stablehlo.reshape %237 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%239 = stablehlo.complex %236, %238 : tensor<256x100x64xcomplex<f32>>
%240 = stablehlo.multiply %239, %28 : tensor<256x100x64xcomplex<f32>>
%241 = stablehlo.real %240 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%242 = stablehlo.reshape %241 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%243 = stablehlo.imag %240 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%244 = stablehlo.reshape %243 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%245 = stablehlo.concatenate %242, %244, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%246 = stablehlo.reshape %245 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%247 = stablehlo.transpose %arg210, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%248 = stablehlo.dot_general %229, %247, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%249 = stablehlo.reshape %248 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%250 = stablehlo.transpose %249, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%251 = stablehlo.reshape %250 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%252 = stablehlo.slice %251 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%253 = stablehlo.reshape %252 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%254 = stablehlo.slice %251 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%255 = stablehlo.reshape %254 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%256 = stablehlo.complex %253, %255 : tensor<256x100x64xcomplex<f32>>
%257 = stablehlo.multiply %256, %28 : tensor<256x100x64xcomplex<f32>>
%258 = stablehlo.real %257 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%259 = stablehlo.reshape %258 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%260 = stablehlo.imag %257 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%261 = stablehlo.reshape %260 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%262 = stablehlo.concatenate %259, %261, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%263 = stablehlo.reshape %262 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%264 = stablehlo.transpose %263, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%265 = "stablehlo.scatter"(%arg211, %39, %264) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%266 = stablehlo.transpose %265, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%267 = stablehlo.reshape %266 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%268 = stablehlo.dot_general %246, %267, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%269 = stablehlo.reshape %268 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%270 = stablehlo.divide %269, %cst : tensor<8x32x100x1024xf32>
%271 = stablehlo.add %270, %66 : tensor<8x32x100x1024xf32>
%272 = stablehlo.reduce(%271 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%273 = stablehlo.broadcast_in_dim %272, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%274 = stablehlo.subtract %271, %273 : tensor<8x32x100x1024xf32>
%275 = stablehlo.exponential %274 : tensor<8x32x100x1024xf32>
%276 = stablehlo.reduce(%275 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%277 = stablehlo.broadcast_in_dim %276, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%278 = stablehlo.divide %275, %277 : tensor<8x32x100x1024xf32>
%279 = stablehlo.reshape %278 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%280 = stablehlo.transpose %arg180, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%281 = stablehlo.dot_general %229, %280, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%282 = stablehlo.reshape %281 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%283 = "stablehlo.scatter"(%arg209, %39, %282) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%284 = stablehlo.transpose %283, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%285 = stablehlo.reshape %284 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%286 = stablehlo.dot_general %279, %285, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%287 = stablehlo.reshape %286 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%288 = stablehlo.transpose %287, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%289 = stablehlo.reshape %288 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%290 = stablehlo.transpose %arg179, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%291 = stablehlo.dot_general %289, %290, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%292 = stablehlo.reshape %291 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%293 = stablehlo.add %217, %292 : tensor<8x100x4096xf32>
%294 = stablehlo.power %293, %cst_3 : tensor<8x100x4096xf32>
%295 = stablehlo.reduce(%294 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%296 = stablehlo.multiply %295, %cst_2 : tensor<8x100xf32>
%297 = stablehlo.reshape %296 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%298 = stablehlo.add %297, %cst_1 : tensor<8x100x1xf32>
%299 = stablehlo.rsqrt %298 : tensor<8x100x1xf32>
%300 = stablehlo.reshape %299 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%301 = stablehlo.broadcast_in_dim %300, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%302 = stablehlo.multiply %293, %301 : tensor<8x100x4096xf32>
%303 = stablehlo.broadcast_in_dim %arg178, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%304 = stablehlo.multiply %302, %303 : tensor<8x100x4096xf32>
%305 = stablehlo.reshape %304 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%306 = stablehlo.transpose %arg213, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%307 = stablehlo.dot_general %305, %306, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%308 = stablehlo.reshape %307 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%309 = stablehlo.logistic %308 : tensor<8x100x11008xf32>
%310 = stablehlo.multiply %308, %309 : tensor<8x100x11008xf32>
%311 = stablehlo.transpose %arg177, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%312 = stablehlo.dot_general %305, %311, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%313 = stablehlo.reshape %312 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%314 = stablehlo.multiply %310, %313 : tensor<8x100x11008xf32>
%315 = stablehlo.reshape %314 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%316 = stablehlo.transpose %arg176, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%317 = stablehlo.dot_general %315, %316, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%318 = stablehlo.reshape %317 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%319 = stablehlo.add %293, %318 : tensor<8x100x4096xf32>
%320 = stablehlo.power %319, %cst_3 : tensor<8x100x4096xf32>
%321 = stablehlo.reduce(%320 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%322 = stablehlo.multiply %321, %cst_2 : tensor<8x100xf32>
%323 = stablehlo.reshape %322 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%324 = stablehlo.add %323, %cst_1 : tensor<8x100x1xf32>
%325 = stablehlo.rsqrt %324 : tensor<8x100x1xf32>
%326 = stablehlo.reshape %325 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%327 = stablehlo.broadcast_in_dim %326, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%328 = stablehlo.multiply %319, %327 : tensor<8x100x4096xf32>
%329 = stablehlo.broadcast_in_dim %arg175, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%330 = stablehlo.multiply %328, %329 : tensor<8x100x4096xf32>
%331 = stablehlo.reshape %330 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%332 = stablehlo.transpose %arg217, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%333 = stablehlo.dot_general %331, %332, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%334 = stablehlo.reshape %333 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%335 = stablehlo.transpose %334, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%336 = stablehlo.reshape %335 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%337 = stablehlo.slice %336 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%338 = stablehlo.reshape %337 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%339 = stablehlo.slice %336 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%340 = stablehlo.reshape %339 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%341 = stablehlo.complex %338, %340 : tensor<256x100x64xcomplex<f32>>
%342 = stablehlo.multiply %341, %28 : tensor<256x100x64xcomplex<f32>>
%343 = stablehlo.real %342 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%344 = stablehlo.reshape %343 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%345 = stablehlo.imag %342 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%346 = stablehlo.reshape %345 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%347 = stablehlo.concatenate %344, %346, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%348 = stablehlo.reshape %347 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%349 = stablehlo.transpose %arg215, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%350 = stablehlo.dot_general %331, %349, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%351 = stablehlo.reshape %350 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%352 = stablehlo.transpose %351, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%353 = stablehlo.reshape %352 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%354 = stablehlo.slice %353 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%355 = stablehlo.reshape %354 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%356 = stablehlo.slice %353 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%357 = stablehlo.reshape %356 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%358 = stablehlo.complex %355, %357 : tensor<256x100x64xcomplex<f32>>
%359 = stablehlo.multiply %358, %28 : tensor<256x100x64xcomplex<f32>>
%360 = stablehlo.real %359 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%361 = stablehlo.reshape %360 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%362 = stablehlo.imag %359 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%363 = stablehlo.reshape %362 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%364 = stablehlo.concatenate %361, %363, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%365 = stablehlo.reshape %364 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%366 = stablehlo.transpose %365, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%367 = "stablehlo.scatter"(%arg216, %39, %366) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%368 = stablehlo.transpose %367, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%369 = stablehlo.reshape %368 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%370 = stablehlo.dot_general %348, %369, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%371 = stablehlo.reshape %370 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%372 = stablehlo.divide %371, %cst : tensor<8x32x100x1024xf32>
%373 = stablehlo.add %372, %66 : tensor<8x32x100x1024xf32>
%374 = stablehlo.reduce(%373 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%375 = stablehlo.broadcast_in_dim %374, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%376 = stablehlo.subtract %373, %375 : tensor<8x32x100x1024xf32>
%377 = stablehlo.exponential %376 : tensor<8x32x100x1024xf32>
%378 = stablehlo.reduce(%377 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%379 = stablehlo.broadcast_in_dim %378, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%380 = stablehlo.divide %377, %379 : tensor<8x32x100x1024xf32>
%381 = stablehlo.reshape %380 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%382 = stablehlo.transpose %arg174, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%383 = stablehlo.dot_general %331, %382, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%384 = stablehlo.reshape %383 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%385 = "stablehlo.scatter"(%arg214, %39, %384) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%386 = stablehlo.transpose %385, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%387 = stablehlo.reshape %386 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%388 = stablehlo.dot_general %381, %387, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%389 = stablehlo.reshape %388 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%390 = stablehlo.transpose %389, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%391 = stablehlo.reshape %390 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%392 = stablehlo.transpose %arg173, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%393 = stablehlo.dot_general %391, %392, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%394 = stablehlo.reshape %393 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%395 = stablehlo.add %319, %394 : tensor<8x100x4096xf32>
%396 = stablehlo.power %395, %cst_3 : tensor<8x100x4096xf32>
%397 = stablehlo.reduce(%396 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%398 = stablehlo.multiply %397, %cst_2 : tensor<8x100xf32>
%399 = stablehlo.reshape %398 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%400 = stablehlo.add %399, %cst_1 : tensor<8x100x1xf32>
%401 = stablehlo.rsqrt %400 : tensor<8x100x1xf32>
%402 = stablehlo.reshape %401 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%403 = stablehlo.broadcast_in_dim %402, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%404 = stablehlo.multiply %395, %403 : tensor<8x100x4096xf32>
%405 = stablehlo.broadcast_in_dim %arg172, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%406 = stablehlo.multiply %404, %405 : tensor<8x100x4096xf32>
%407 = stablehlo.reshape %406 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%408 = stablehlo.transpose %arg218, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%409 = stablehlo.dot_general %407, %408, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%410 = stablehlo.reshape %409 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%411 = stablehlo.logistic %410 : tensor<8x100x11008xf32>
%412 = stablehlo.multiply %410, %411 : tensor<8x100x11008xf32>
%413 = stablehlo.transpose %arg171, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%414 = stablehlo.dot_general %407, %413, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%415 = stablehlo.reshape %414 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%416 = stablehlo.multiply %412, %415 : tensor<8x100x11008xf32>
%417 = stablehlo.reshape %416 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%418 = stablehlo.transpose %arg170, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%419 = stablehlo.dot_general %417, %418, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%420 = stablehlo.reshape %419 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%421 = stablehlo.add %395, %420 : tensor<8x100x4096xf32>
%422 = stablehlo.power %421, %cst_3 : tensor<8x100x4096xf32>
%423 = stablehlo.reduce(%422 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%424 = stablehlo.multiply %423, %cst_2 : tensor<8x100xf32>
%425 = stablehlo.reshape %424 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%426 = stablehlo.add %425, %cst_1 : tensor<8x100x1xf32>
%427 = stablehlo.rsqrt %426 : tensor<8x100x1xf32>
%428 = stablehlo.reshape %427 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%429 = stablehlo.broadcast_in_dim %428, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%430 = stablehlo.multiply %421, %429 : tensor<8x100x4096xf32>
%431 = stablehlo.broadcast_in_dim %arg169, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%432 = stablehlo.multiply %430, %431 : tensor<8x100x4096xf32>
%433 = stablehlo.reshape %432 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%434 = stablehlo.transpose %arg222, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%435 = stablehlo.dot_general %433, %434, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%436 = stablehlo.reshape %435 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%437 = stablehlo.transpose %436, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%438 = stablehlo.reshape %437 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%439 = stablehlo.slice %438 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%440 = stablehlo.reshape %439 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%441 = stablehlo.slice %438 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%442 = stablehlo.reshape %441 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%443 = stablehlo.complex %440, %442 : tensor<256x100x64xcomplex<f32>>
%444 = stablehlo.multiply %443, %28 : tensor<256x100x64xcomplex<f32>>
%445 = stablehlo.real %444 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%446 = stablehlo.reshape %445 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%447 = stablehlo.imag %444 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%448 = stablehlo.reshape %447 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%449 = stablehlo.concatenate %446, %448, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%450 = stablehlo.reshape %449 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%451 = stablehlo.transpose %arg220, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%452 = stablehlo.dot_general %433, %451, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%453 = stablehlo.reshape %452 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%454 = stablehlo.transpose %453, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%455 = stablehlo.reshape %454 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%456 = stablehlo.slice %455 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%457 = stablehlo.reshape %456 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%458 = stablehlo.slice %455 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%459 = stablehlo.reshape %458 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%460 = stablehlo.complex %457, %459 : tensor<256x100x64xcomplex<f32>>
%461 = stablehlo.multiply %460, %28 : tensor<256x100x64xcomplex<f32>>
%462 = stablehlo.real %461 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%463 = stablehlo.reshape %462 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%464 = stablehlo.imag %461 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%465 = stablehlo.reshape %464 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%466 = stablehlo.concatenate %463, %465, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%467 = stablehlo.reshape %466 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%468 = stablehlo.transpose %467, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%469 = "stablehlo.scatter"(%arg221, %39, %468) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%470 = stablehlo.transpose %469, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%471 = stablehlo.reshape %470 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%472 = stablehlo.dot_general %450, %471, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%473 = stablehlo.reshape %472 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%474 = stablehlo.divide %473, %cst : tensor<8x32x100x1024xf32>
%475 = stablehlo.add %474, %66 : tensor<8x32x100x1024xf32>
%476 = stablehlo.reduce(%475 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%477 = stablehlo.broadcast_in_dim %476, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%478 = stablehlo.subtract %475, %477 : tensor<8x32x100x1024xf32>
%479 = stablehlo.exponential %478 : tensor<8x32x100x1024xf32>
%480 = stablehlo.reduce(%479 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%481 = stablehlo.broadcast_in_dim %480, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%482 = stablehlo.divide %479, %481 : tensor<8x32x100x1024xf32>
%483 = stablehlo.reshape %482 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%484 = stablehlo.transpose %arg168, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%485 = stablehlo.dot_general %433, %484, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%486 = stablehlo.reshape %485 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%487 = "stablehlo.scatter"(%arg219, %39, %486) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%488 = stablehlo.transpose %487, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%489 = stablehlo.reshape %488 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%490 = stablehlo.dot_general %483, %489, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%491 = stablehlo.reshape %490 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%492 = stablehlo.transpose %491, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%493 = stablehlo.reshape %492 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%494 = stablehlo.transpose %arg167, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%495 = stablehlo.dot_general %493, %494, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%496 = stablehlo.reshape %495 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%497 = stablehlo.add %421, %496 : tensor<8x100x4096xf32>
%498 = stablehlo.power %497, %cst_3 : tensor<8x100x4096xf32>
%499 = stablehlo.reduce(%498 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%500 = stablehlo.multiply %499, %cst_2 : tensor<8x100xf32>
%501 = stablehlo.reshape %500 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%502 = stablehlo.add %501, %cst_1 : tensor<8x100x1xf32>
%503 = stablehlo.rsqrt %502 : tensor<8x100x1xf32>
%504 = stablehlo.reshape %503 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%505 = stablehlo.broadcast_in_dim %504, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%506 = stablehlo.multiply %497, %505 : tensor<8x100x4096xf32>
%507 = stablehlo.broadcast_in_dim %arg166, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%508 = stablehlo.multiply %506, %507 : tensor<8x100x4096xf32>
%509 = stablehlo.reshape %508 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%510 = stablehlo.transpose %arg223, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%511 = stablehlo.dot_general %509, %510, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%512 = stablehlo.reshape %511 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%513 = stablehlo.logistic %512 : tensor<8x100x11008xf32>
%514 = stablehlo.multiply %512, %513 : tensor<8x100x11008xf32>
%515 = stablehlo.transpose %arg165, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%516 = stablehlo.dot_general %509, %515, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%517 = stablehlo.reshape %516 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%518 = stablehlo.multiply %514, %517 : tensor<8x100x11008xf32>
%519 = stablehlo.reshape %518 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%520 = stablehlo.transpose %arg164, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%521 = stablehlo.dot_general %519, %520, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%522 = stablehlo.reshape %521 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%523 = stablehlo.add %497, %522 : tensor<8x100x4096xf32>
%524 = stablehlo.power %523, %cst_3 : tensor<8x100x4096xf32>
%525 = stablehlo.reduce(%524 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%526 = stablehlo.multiply %525, %cst_2 : tensor<8x100xf32>
%527 = stablehlo.reshape %526 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%528 = stablehlo.add %527, %cst_1 : tensor<8x100x1xf32>
%529 = stablehlo.rsqrt %528 : tensor<8x100x1xf32>
%530 = stablehlo.reshape %529 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%531 = stablehlo.broadcast_in_dim %530, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%532 = stablehlo.multiply %523, %531 : tensor<8x100x4096xf32>
%533 = stablehlo.broadcast_in_dim %arg163, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%534 = stablehlo.multiply %532, %533 : tensor<8x100x4096xf32>
%535 = stablehlo.reshape %534 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%536 = stablehlo.transpose %arg227, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%537 = stablehlo.dot_general %535, %536, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%538 = stablehlo.reshape %537 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%539 = stablehlo.transpose %538, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%540 = stablehlo.reshape %539 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%541 = stablehlo.slice %540 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%542 = stablehlo.reshape %541 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%543 = stablehlo.slice %540 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%544 = stablehlo.reshape %543 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%545 = stablehlo.complex %542, %544 : tensor<256x100x64xcomplex<f32>>
%546 = stablehlo.multiply %545, %28 : tensor<256x100x64xcomplex<f32>>
%547 = stablehlo.real %546 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%548 = stablehlo.reshape %547 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%549 = stablehlo.imag %546 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%550 = stablehlo.reshape %549 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%551 = stablehlo.concatenate %548, %550, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%552 = stablehlo.reshape %551 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%553 = stablehlo.transpose %arg225, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%554 = stablehlo.dot_general %535, %553, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%555 = stablehlo.reshape %554 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%556 = stablehlo.transpose %555, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%557 = stablehlo.reshape %556 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%558 = stablehlo.slice %557 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%559 = stablehlo.reshape %558 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%560 = stablehlo.slice %557 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%561 = stablehlo.reshape %560 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%562 = stablehlo.complex %559, %561 : tensor<256x100x64xcomplex<f32>>
%563 = stablehlo.multiply %562, %28 : tensor<256x100x64xcomplex<f32>>
%564 = stablehlo.real %563 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%565 = stablehlo.reshape %564 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%566 = stablehlo.imag %563 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%567 = stablehlo.reshape %566 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%568 = stablehlo.concatenate %565, %567, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%569 = stablehlo.reshape %568 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%570 = stablehlo.transpose %569, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%571 = "stablehlo.scatter"(%arg226, %39, %570) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%572 = stablehlo.transpose %571, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%573 = stablehlo.reshape %572 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%574 = stablehlo.dot_general %552, %573, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%575 = stablehlo.reshape %574 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%576 = stablehlo.divide %575, %cst : tensor<8x32x100x1024xf32>
%577 = stablehlo.add %576, %66 : tensor<8x32x100x1024xf32>
%578 = stablehlo.reduce(%577 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%579 = stablehlo.broadcast_in_dim %578, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%580 = stablehlo.subtract %577, %579 : tensor<8x32x100x1024xf32>
%581 = stablehlo.exponential %580 : tensor<8x32x100x1024xf32>
%582 = stablehlo.reduce(%581 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%583 = stablehlo.broadcast_in_dim %582, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%584 = stablehlo.divide %581, %583 : tensor<8x32x100x1024xf32>
%585 = stablehlo.reshape %584 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%586 = stablehlo.transpose %arg162, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%587 = stablehlo.dot_general %535, %586, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%588 = stablehlo.reshape %587 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%589 = "stablehlo.scatter"(%arg224, %39, %588) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%590 = stablehlo.transpose %589, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%591 = stablehlo.reshape %590 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%592 = stablehlo.dot_general %585, %591, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%593 = stablehlo.reshape %592 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%594 = stablehlo.transpose %593, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%595 = stablehlo.reshape %594 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%596 = stablehlo.transpose %arg161, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%597 = stablehlo.dot_general %595, %596, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%598 = stablehlo.reshape %597 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%599 = stablehlo.add %523, %598 : tensor<8x100x4096xf32>
%600 = stablehlo.power %599, %cst_3 : tensor<8x100x4096xf32>
%601 = stablehlo.reduce(%600 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%602 = stablehlo.multiply %601, %cst_2 : tensor<8x100xf32>
%603 = stablehlo.reshape %602 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%604 = stablehlo.add %603, %cst_1 : tensor<8x100x1xf32>
%605 = stablehlo.rsqrt %604 : tensor<8x100x1xf32>
%606 = stablehlo.reshape %605 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%607 = stablehlo.broadcast_in_dim %606, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%608 = stablehlo.multiply %599, %607 : tensor<8x100x4096xf32>
%609 = stablehlo.broadcast_in_dim %arg160, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%610 = stablehlo.multiply %608, %609 : tensor<8x100x4096xf32>
%611 = stablehlo.reshape %610 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%612 = stablehlo.transpose %arg228, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%613 = stablehlo.dot_general %611, %612, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%614 = stablehlo.reshape %613 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%615 = stablehlo.logistic %614 : tensor<8x100x11008xf32>
%616 = stablehlo.multiply %614, %615 : tensor<8x100x11008xf32>
%617 = stablehlo.transpose %arg159, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%618 = stablehlo.dot_general %611, %617, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%619 = stablehlo.reshape %618 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%620 = stablehlo.multiply %616, %619 : tensor<8x100x11008xf32>
%621 = stablehlo.reshape %620 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%622 = stablehlo.transpose %arg158, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%623 = stablehlo.dot_general %621, %622, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%624 = stablehlo.reshape %623 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%625 = stablehlo.add %599, %624 : tensor<8x100x4096xf32>
%626 = stablehlo.power %625, %cst_3 : tensor<8x100x4096xf32>
%627 = stablehlo.reduce(%626 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%628 = stablehlo.multiply %627, %cst_2 : tensor<8x100xf32>
%629 = stablehlo.reshape %628 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%630 = stablehlo.add %629, %cst_1 : tensor<8x100x1xf32>
%631 = stablehlo.rsqrt %630 : tensor<8x100x1xf32>
%632 = stablehlo.reshape %631 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%633 = stablehlo.broadcast_in_dim %632, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%634 = stablehlo.multiply %625, %633 : tensor<8x100x4096xf32>
%635 = stablehlo.broadcast_in_dim %arg157, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%636 = stablehlo.multiply %634, %635 : tensor<8x100x4096xf32>
%637 = stablehlo.reshape %636 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%638 = stablehlo.transpose %arg232, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%639 = stablehlo.dot_general %637, %638, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%640 = stablehlo.reshape %639 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%641 = stablehlo.transpose %640, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%642 = stablehlo.reshape %641 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%643 = stablehlo.slice %642 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%644 = stablehlo.reshape %643 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%645 = stablehlo.slice %642 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%646 = stablehlo.reshape %645 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%647 = stablehlo.complex %644, %646 : tensor<256x100x64xcomplex<f32>>
%648 = stablehlo.multiply %647, %28 : tensor<256x100x64xcomplex<f32>>
%649 = stablehlo.real %648 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%650 = stablehlo.reshape %649 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%651 = stablehlo.imag %648 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%652 = stablehlo.reshape %651 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%653 = stablehlo.concatenate %650, %652, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%654 = stablehlo.reshape %653 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%655 = stablehlo.transpose %arg230, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%656 = stablehlo.dot_general %637, %655, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%657 = stablehlo.reshape %656 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%658 = stablehlo.transpose %657, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%659 = stablehlo.reshape %658 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%660 = stablehlo.slice %659 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%661 = stablehlo.reshape %660 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%662 = stablehlo.slice %659 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%663 = stablehlo.reshape %662 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%664 = stablehlo.complex %661, %663 : tensor<256x100x64xcomplex<f32>>
%665 = stablehlo.multiply %664, %28 : tensor<256x100x64xcomplex<f32>>
%666 = stablehlo.real %665 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%667 = stablehlo.reshape %666 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%668 = stablehlo.imag %665 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%669 = stablehlo.reshape %668 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%670 = stablehlo.concatenate %667, %669, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%671 = stablehlo.reshape %670 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%672 = stablehlo.transpose %671, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%673 = "stablehlo.scatter"(%arg231, %39, %672) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%674 = stablehlo.transpose %673, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%675 = stablehlo.reshape %674 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%676 = stablehlo.dot_general %654, %675, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%677 = stablehlo.reshape %676 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%678 = stablehlo.divide %677, %cst : tensor<8x32x100x1024xf32>
%679 = stablehlo.add %678, %66 : tensor<8x32x100x1024xf32>
%680 = stablehlo.reduce(%679 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%681 = stablehlo.broadcast_in_dim %680, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%682 = stablehlo.subtract %679, %681 : tensor<8x32x100x1024xf32>
%683 = stablehlo.exponential %682 : tensor<8x32x100x1024xf32>
%684 = stablehlo.reduce(%683 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%685 = stablehlo.broadcast_in_dim %684, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%686 = stablehlo.divide %683, %685 : tensor<8x32x100x1024xf32>
%687 = stablehlo.reshape %686 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%688 = stablehlo.transpose %arg156, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%689 = stablehlo.dot_general %637, %688, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%690 = stablehlo.reshape %689 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%691 = "stablehlo.scatter"(%arg229, %39, %690) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%692 = stablehlo.transpose %691, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%693 = stablehlo.reshape %692 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%694 = stablehlo.dot_general %687, %693, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%695 = stablehlo.reshape %694 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%696 = stablehlo.transpose %695, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%697 = stablehlo.reshape %696 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%698 = stablehlo.transpose %arg155, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%699 = stablehlo.dot_general %697, %698, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%700 = stablehlo.reshape %699 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%701 = stablehlo.add %625, %700 : tensor<8x100x4096xf32>
%702 = stablehlo.power %701, %cst_3 : tensor<8x100x4096xf32>
%703 = stablehlo.reduce(%702 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%704 = stablehlo.multiply %703, %cst_2 : tensor<8x100xf32>
%705 = stablehlo.reshape %704 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%706 = stablehlo.add %705, %cst_1 : tensor<8x100x1xf32>
%707 = stablehlo.rsqrt %706 : tensor<8x100x1xf32>
%708 = stablehlo.reshape %707 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%709 = stablehlo.broadcast_in_dim %708, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%710 = stablehlo.multiply %701, %709 : tensor<8x100x4096xf32>
%711 = stablehlo.broadcast_in_dim %arg154, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%712 = stablehlo.multiply %710, %711 : tensor<8x100x4096xf32>
%713 = stablehlo.reshape %712 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%714 = stablehlo.transpose %arg233, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%715 = stablehlo.dot_general %713, %714, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%716 = stablehlo.reshape %715 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%717 = stablehlo.logistic %716 : tensor<8x100x11008xf32>
%718 = stablehlo.multiply %716, %717 : tensor<8x100x11008xf32>
%719 = stablehlo.transpose %arg153, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%720 = stablehlo.dot_general %713, %719, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%721 = stablehlo.reshape %720 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%722 = stablehlo.multiply %718, %721 : tensor<8x100x11008xf32>
%723 = stablehlo.reshape %722 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%724 = stablehlo.transpose %arg152, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%725 = stablehlo.dot_general %723, %724, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%726 = stablehlo.reshape %725 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%727 = stablehlo.add %701, %726 : tensor<8x100x4096xf32>
%728 = stablehlo.power %727, %cst_3 : tensor<8x100x4096xf32>
%729 = stablehlo.reduce(%728 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%730 = stablehlo.multiply %729, %cst_2 : tensor<8x100xf32>
%731 = stablehlo.reshape %730 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%732 = stablehlo.add %731, %cst_1 : tensor<8x100x1xf32>
%733 = stablehlo.rsqrt %732 : tensor<8x100x1xf32>
%734 = stablehlo.reshape %733 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%735 = stablehlo.broadcast_in_dim %734, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%736 = stablehlo.multiply %727, %735 : tensor<8x100x4096xf32>
%737 = stablehlo.broadcast_in_dim %arg151, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%738 = stablehlo.multiply %736, %737 : tensor<8x100x4096xf32>
%739 = stablehlo.reshape %738 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%740 = stablehlo.transpose %arg237, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%741 = stablehlo.dot_general %739, %740, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%742 = stablehlo.reshape %741 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%743 = stablehlo.transpose %742, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%744 = stablehlo.reshape %743 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%745 = stablehlo.slice %744 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%746 = stablehlo.reshape %745 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%747 = stablehlo.slice %744 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%748 = stablehlo.reshape %747 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%749 = stablehlo.complex %746, %748 : tensor<256x100x64xcomplex<f32>>
%750 = stablehlo.multiply %749, %28 : tensor<256x100x64xcomplex<f32>>
%751 = stablehlo.real %750 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%752 = stablehlo.reshape %751 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%753 = stablehlo.imag %750 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%754 = stablehlo.reshape %753 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%755 = stablehlo.concatenate %752, %754, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%756 = stablehlo.reshape %755 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%757 = stablehlo.transpose %arg235, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%758 = stablehlo.dot_general %739, %757, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%759 = stablehlo.reshape %758 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%760 = stablehlo.transpose %759, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%761 = stablehlo.reshape %760 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%762 = stablehlo.slice %761 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%763 = stablehlo.reshape %762 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%764 = stablehlo.slice %761 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%765 = stablehlo.reshape %764 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%766 = stablehlo.complex %763, %765 : tensor<256x100x64xcomplex<f32>>
%767 = stablehlo.multiply %766, %28 : tensor<256x100x64xcomplex<f32>>
%768 = stablehlo.real %767 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%769 = stablehlo.reshape %768 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%770 = stablehlo.imag %767 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%771 = stablehlo.reshape %770 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%772 = stablehlo.concatenate %769, %771, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%773 = stablehlo.reshape %772 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%774 = stablehlo.transpose %773, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%775 = "stablehlo.scatter"(%arg236, %39, %774) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%776 = stablehlo.transpose %775, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%777 = stablehlo.reshape %776 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%778 = stablehlo.dot_general %756, %777, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%779 = stablehlo.reshape %778 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%780 = stablehlo.divide %779, %cst : tensor<8x32x100x1024xf32>
%781 = stablehlo.add %780, %66 : tensor<8x32x100x1024xf32>
%782 = stablehlo.reduce(%781 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%783 = stablehlo.broadcast_in_dim %782, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%784 = stablehlo.subtract %781, %783 : tensor<8x32x100x1024xf32>
%785 = stablehlo.exponential %784 : tensor<8x32x100x1024xf32>
%786 = stablehlo.reduce(%785 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%787 = stablehlo.broadcast_in_dim %786, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%788 = stablehlo.divide %785, %787 : tensor<8x32x100x1024xf32>
%789 = stablehlo.reshape %788 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%790 = stablehlo.transpose %arg150, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%791 = stablehlo.dot_general %739, %790, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%792 = stablehlo.reshape %791 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%793 = "stablehlo.scatter"(%arg234, %39, %792) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%794 = stablehlo.transpose %793, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%795 = stablehlo.reshape %794 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%796 = stablehlo.dot_general %789, %795, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%797 = stablehlo.reshape %796 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%798 = stablehlo.transpose %797, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%799 = stablehlo.reshape %798 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%800 = stablehlo.transpose %arg149, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%801 = stablehlo.dot_general %799, %800, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%802 = stablehlo.reshape %801 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%803 = stablehlo.add %727, %802 : tensor<8x100x4096xf32>
%804 = stablehlo.power %803, %cst_3 : tensor<8x100x4096xf32>
%805 = stablehlo.reduce(%804 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%806 = stablehlo.multiply %805, %cst_2 : tensor<8x100xf32>
%807 = stablehlo.reshape %806 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%808 = stablehlo.add %807, %cst_1 : tensor<8x100x1xf32>
%809 = stablehlo.rsqrt %808 : tensor<8x100x1xf32>
%810 = stablehlo.reshape %809 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%811 = stablehlo.broadcast_in_dim %810, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%812 = stablehlo.multiply %803, %811 : tensor<8x100x4096xf32>
%813 = stablehlo.broadcast_in_dim %arg148, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%814 = stablehlo.multiply %812, %813 : tensor<8x100x4096xf32>
%815 = stablehlo.reshape %814 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%816 = stablehlo.transpose %arg238, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%817 = stablehlo.dot_general %815, %816, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%818 = stablehlo.reshape %817 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%819 = stablehlo.logistic %818 : tensor<8x100x11008xf32>
%820 = stablehlo.multiply %818, %819 : tensor<8x100x11008xf32>
%821 = stablehlo.transpose %arg147, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%822 = stablehlo.dot_general %815, %821, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%823 = stablehlo.reshape %822 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%824 = stablehlo.multiply %820, %823 : tensor<8x100x11008xf32>
%825 = stablehlo.reshape %824 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%826 = stablehlo.transpose %arg146, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%827 = stablehlo.dot_general %825, %826, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%828 = stablehlo.reshape %827 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%829 = stablehlo.add %803, %828 : tensor<8x100x4096xf32>
%830 = stablehlo.power %829, %cst_3 : tensor<8x100x4096xf32>
%831 = stablehlo.reduce(%830 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%832 = stablehlo.multiply %831, %cst_2 : tensor<8x100xf32>
%833 = stablehlo.reshape %832 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%834 = stablehlo.add %833, %cst_1 : tensor<8x100x1xf32>
%835 = stablehlo.rsqrt %834 : tensor<8x100x1xf32>
%836 = stablehlo.reshape %835 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%837 = stablehlo.broadcast_in_dim %836, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%838 = stablehlo.multiply %829, %837 : tensor<8x100x4096xf32>
%839 = stablehlo.broadcast_in_dim %arg145, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%840 = stablehlo.multiply %838, %839 : tensor<8x100x4096xf32>
%841 = stablehlo.reshape %840 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%842 = stablehlo.transpose %arg242, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%843 = stablehlo.dot_general %841, %842, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%844 = stablehlo.reshape %843 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%845 = stablehlo.transpose %844, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%846 = stablehlo.reshape %845 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%847 = stablehlo.slice %846 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%848 = stablehlo.reshape %847 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%849 = stablehlo.slice %846 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%850 = stablehlo.reshape %849 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%851 = stablehlo.complex %848, %850 : tensor<256x100x64xcomplex<f32>>
%852 = stablehlo.multiply %851, %28 : tensor<256x100x64xcomplex<f32>>
%853 = stablehlo.real %852 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%854 = stablehlo.reshape %853 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%855 = stablehlo.imag %852 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%856 = stablehlo.reshape %855 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%857 = stablehlo.concatenate %854, %856, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%858 = stablehlo.reshape %857 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%859 = stablehlo.transpose %arg240, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%860 = stablehlo.dot_general %841, %859, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%861 = stablehlo.reshape %860 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%862 = stablehlo.transpose %861, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%863 = stablehlo.reshape %862 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%864 = stablehlo.slice %863 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%865 = stablehlo.reshape %864 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%866 = stablehlo.slice %863 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%867 = stablehlo.reshape %866 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%868 = stablehlo.complex %865, %867 : tensor<256x100x64xcomplex<f32>>
%869 = stablehlo.multiply %868, %28 : tensor<256x100x64xcomplex<f32>>
%870 = stablehlo.real %869 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%871 = stablehlo.reshape %870 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%872 = stablehlo.imag %869 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%873 = stablehlo.reshape %872 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%874 = stablehlo.concatenate %871, %873, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%875 = stablehlo.reshape %874 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%876 = stablehlo.transpose %875, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%877 = "stablehlo.scatter"(%arg241, %39, %876) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%878 = stablehlo.transpose %877, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%879 = stablehlo.reshape %878 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%880 = stablehlo.dot_general %858, %879, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%881 = stablehlo.reshape %880 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%882 = stablehlo.divide %881, %cst : tensor<8x32x100x1024xf32>
%883 = stablehlo.add %882, %66 : tensor<8x32x100x1024xf32>
%884 = stablehlo.reduce(%883 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%885 = stablehlo.broadcast_in_dim %884, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%886 = stablehlo.subtract %883, %885 : tensor<8x32x100x1024xf32>
%887 = stablehlo.exponential %886 : tensor<8x32x100x1024xf32>
%888 = stablehlo.reduce(%887 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%889 = stablehlo.broadcast_in_dim %888, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%890 = stablehlo.divide %887, %889 : tensor<8x32x100x1024xf32>
%891 = stablehlo.reshape %890 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%892 = stablehlo.transpose %arg144, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%893 = stablehlo.dot_general %841, %892, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%894 = stablehlo.reshape %893 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%895 = "stablehlo.scatter"(%arg239, %39, %894) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%896 = stablehlo.transpose %895, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%897 = stablehlo.reshape %896 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%898 = stablehlo.dot_general %891, %897, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%899 = stablehlo.reshape %898 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%900 = stablehlo.transpose %899, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%901 = stablehlo.reshape %900 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%902 = stablehlo.transpose %arg143, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%903 = stablehlo.dot_general %901, %902, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%904 = stablehlo.reshape %903 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%905 = stablehlo.add %829, %904 : tensor<8x100x4096xf32>
%906 = stablehlo.power %905, %cst_3 : tensor<8x100x4096xf32>
%907 = stablehlo.reduce(%906 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%908 = stablehlo.multiply %907, %cst_2 : tensor<8x100xf32>
%909 = stablehlo.reshape %908 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%910 = stablehlo.add %909, %cst_1 : tensor<8x100x1xf32>
%911 = stablehlo.rsqrt %910 : tensor<8x100x1xf32>
%912 = stablehlo.reshape %911 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%913 = stablehlo.broadcast_in_dim %912, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%914 = stablehlo.multiply %905, %913 : tensor<8x100x4096xf32>
%915 = stablehlo.broadcast_in_dim %arg142, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%916 = stablehlo.multiply %914, %915 : tensor<8x100x4096xf32>
%917 = stablehlo.reshape %916 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%918 = stablehlo.transpose %arg243, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%919 = stablehlo.dot_general %917, %918, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%920 = stablehlo.reshape %919 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%921 = stablehlo.logistic %920 : tensor<8x100x11008xf32>
%922 = stablehlo.multiply %920, %921 : tensor<8x100x11008xf32>
%923 = stablehlo.transpose %arg141, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%924 = stablehlo.dot_general %917, %923, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%925 = stablehlo.reshape %924 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%926 = stablehlo.multiply %922, %925 : tensor<8x100x11008xf32>
%927 = stablehlo.reshape %926 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%928 = stablehlo.transpose %arg140, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%929 = stablehlo.dot_general %927, %928, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%930 = stablehlo.reshape %929 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%931 = stablehlo.add %905, %930 : tensor<8x100x4096xf32>
%932 = stablehlo.power %931, %cst_3 : tensor<8x100x4096xf32>
%933 = stablehlo.reduce(%932 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%934 = stablehlo.multiply %933, %cst_2 : tensor<8x100xf32>
%935 = stablehlo.reshape %934 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%936 = stablehlo.add %935, %cst_1 : tensor<8x100x1xf32>
%937 = stablehlo.rsqrt %936 : tensor<8x100x1xf32>
%938 = stablehlo.reshape %937 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%939 = stablehlo.broadcast_in_dim %938, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%940 = stablehlo.multiply %931, %939 : tensor<8x100x4096xf32>
%941 = stablehlo.broadcast_in_dim %arg139, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%942 = stablehlo.multiply %940, %941 : tensor<8x100x4096xf32>
%943 = stablehlo.reshape %942 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%944 = stablehlo.transpose %arg247, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%945 = stablehlo.dot_general %943, %944, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%946 = stablehlo.reshape %945 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%947 = stablehlo.transpose %946, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%948 = stablehlo.reshape %947 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%949 = stablehlo.slice %948 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%950 = stablehlo.reshape %949 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%951 = stablehlo.slice %948 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%952 = stablehlo.reshape %951 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%953 = stablehlo.complex %950, %952 : tensor<256x100x64xcomplex<f32>>
%954 = stablehlo.multiply %953, %28 : tensor<256x100x64xcomplex<f32>>
%955 = stablehlo.real %954 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%956 = stablehlo.reshape %955 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%957 = stablehlo.imag %954 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%958 = stablehlo.reshape %957 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%959 = stablehlo.concatenate %956, %958, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%960 = stablehlo.reshape %959 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%961 = stablehlo.transpose %arg245, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%962 = stablehlo.dot_general %943, %961, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%963 = stablehlo.reshape %962 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%964 = stablehlo.transpose %963, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%965 = stablehlo.reshape %964 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%966 = stablehlo.slice %965 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%967 = stablehlo.reshape %966 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%968 = stablehlo.slice %965 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%969 = stablehlo.reshape %968 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%970 = stablehlo.complex %967, %969 : tensor<256x100x64xcomplex<f32>>
%971 = stablehlo.multiply %970, %28 : tensor<256x100x64xcomplex<f32>>
%972 = stablehlo.real %971 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%973 = stablehlo.reshape %972 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%974 = stablehlo.imag %971 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%975 = stablehlo.reshape %974 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%976 = stablehlo.concatenate %973, %975, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%977 = stablehlo.reshape %976 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%978 = stablehlo.transpose %977, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%979 = "stablehlo.scatter"(%arg246, %39, %978) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%980 = stablehlo.transpose %979, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%981 = stablehlo.reshape %980 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%982 = stablehlo.dot_general %960, %981, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%983 = stablehlo.reshape %982 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%984 = stablehlo.divide %983, %cst : tensor<8x32x100x1024xf32>
%985 = stablehlo.add %984, %66 : tensor<8x32x100x1024xf32>
%986 = stablehlo.reduce(%985 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%987 = stablehlo.broadcast_in_dim %986, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%988 = stablehlo.subtract %985, %987 : tensor<8x32x100x1024xf32>
%989 = stablehlo.exponential %988 : tensor<8x32x100x1024xf32>
%990 = stablehlo.reduce(%989 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%991 = stablehlo.broadcast_in_dim %990, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%992 = stablehlo.divide %989, %991 : tensor<8x32x100x1024xf32>
%993 = stablehlo.reshape %992 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%994 = stablehlo.transpose %arg138, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%995 = stablehlo.dot_general %943, %994, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%996 = stablehlo.reshape %995 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%997 = "stablehlo.scatter"(%arg244, %39, %996) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%998 = stablehlo.transpose %997, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%999 = stablehlo.reshape %998 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1000 = stablehlo.dot_general %993, %999, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1001 = stablehlo.reshape %1000 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1002 = stablehlo.transpose %1001, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1003 = stablehlo.reshape %1002 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1004 = stablehlo.transpose %arg137, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1005 = stablehlo.dot_general %1003, %1004, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1006 = stablehlo.reshape %1005 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1007 = stablehlo.add %931, %1006 : tensor<8x100x4096xf32>
%1008 = stablehlo.power %1007, %cst_3 : tensor<8x100x4096xf32>
%1009 = stablehlo.reduce(%1008 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1010 = stablehlo.multiply %1009, %cst_2 : tensor<8x100xf32>
%1011 = stablehlo.reshape %1010 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1012 = stablehlo.add %1011, %cst_1 : tensor<8x100x1xf32>
%1013 = stablehlo.rsqrt %1012 : tensor<8x100x1xf32>
%1014 = stablehlo.reshape %1013 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1015 = stablehlo.broadcast_in_dim %1014, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1016 = stablehlo.multiply %1007, %1015 : tensor<8x100x4096xf32>
%1017 = stablehlo.broadcast_in_dim %arg136, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1018 = stablehlo.multiply %1016, %1017 : tensor<8x100x4096xf32>
%1019 = stablehlo.reshape %1018 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1020 = stablehlo.transpose %arg248, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1021 = stablehlo.dot_general %1019, %1020, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1022 = stablehlo.reshape %1021 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1023 = stablehlo.logistic %1022 : tensor<8x100x11008xf32>
%1024 = stablehlo.multiply %1022, %1023 : tensor<8x100x11008xf32>
%1025 = stablehlo.transpose %arg135, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1026 = stablehlo.dot_general %1019, %1025, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1027 = stablehlo.reshape %1026 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1028 = stablehlo.multiply %1024, %1027 : tensor<8x100x11008xf32>
%1029 = stablehlo.reshape %1028 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1030 = stablehlo.transpose %arg134, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1031 = stablehlo.dot_general %1029, %1030, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1032 = stablehlo.reshape %1031 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1033 = stablehlo.add %1007, %1032 : tensor<8x100x4096xf32>
%1034 = stablehlo.power %1033, %cst_3 : tensor<8x100x4096xf32>
%1035 = stablehlo.reduce(%1034 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1036 = stablehlo.multiply %1035, %cst_2 : tensor<8x100xf32>
%1037 = stablehlo.reshape %1036 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1038 = stablehlo.add %1037, %cst_1 : tensor<8x100x1xf32>
%1039 = stablehlo.rsqrt %1038 : tensor<8x100x1xf32>
%1040 = stablehlo.reshape %1039 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1041 = stablehlo.broadcast_in_dim %1040, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1042 = stablehlo.multiply %1033, %1041 : tensor<8x100x4096xf32>
%1043 = stablehlo.broadcast_in_dim %arg133, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1044 = stablehlo.multiply %1042, %1043 : tensor<8x100x4096xf32>
%1045 = stablehlo.reshape %1044 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1046 = stablehlo.transpose %arg252, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1047 = stablehlo.dot_general %1045, %1046, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1048 = stablehlo.reshape %1047 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1049 = stablehlo.transpose %1048, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1050 = stablehlo.reshape %1049 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1051 = stablehlo.slice %1050 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1052 = stablehlo.reshape %1051 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1053 = stablehlo.slice %1050 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1054 = stablehlo.reshape %1053 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1055 = stablehlo.complex %1052, %1054 : tensor<256x100x64xcomplex<f32>>
%1056 = stablehlo.multiply %1055, %28 : tensor<256x100x64xcomplex<f32>>
%1057 = stablehlo.real %1056 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1058 = stablehlo.reshape %1057 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1059 = stablehlo.imag %1056 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1060 = stablehlo.reshape %1059 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1061 = stablehlo.concatenate %1058, %1060, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1062 = stablehlo.reshape %1061 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1063 = stablehlo.transpose %arg250, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1064 = stablehlo.dot_general %1045, %1063, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1065 = stablehlo.reshape %1064 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1066 = stablehlo.transpose %1065, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1067 = stablehlo.reshape %1066 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1068 = stablehlo.slice %1067 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1069 = stablehlo.reshape %1068 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1070 = stablehlo.slice %1067 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1071 = stablehlo.reshape %1070 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1072 = stablehlo.complex %1069, %1071 : tensor<256x100x64xcomplex<f32>>
%1073 = stablehlo.multiply %1072, %28 : tensor<256x100x64xcomplex<f32>>
%1074 = stablehlo.real %1073 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1075 = stablehlo.reshape %1074 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1076 = stablehlo.imag %1073 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1077 = stablehlo.reshape %1076 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1078 = stablehlo.concatenate %1075, %1077, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1079 = stablehlo.reshape %1078 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1080 = stablehlo.transpose %1079, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1081 = "stablehlo.scatter"(%arg251, %39, %1080) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1082 = stablehlo.transpose %1081, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1083 = stablehlo.reshape %1082 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1084 = stablehlo.dot_general %1062, %1083, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1085 = stablehlo.reshape %1084 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1086 = stablehlo.divide %1085, %cst : tensor<8x32x100x1024xf32>
%1087 = stablehlo.add %1086, %66 : tensor<8x32x100x1024xf32>
%1088 = stablehlo.reduce(%1087 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1089 = stablehlo.broadcast_in_dim %1088, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1090 = stablehlo.subtract %1087, %1089 : tensor<8x32x100x1024xf32>
%1091 = stablehlo.exponential %1090 : tensor<8x32x100x1024xf32>
%1092 = stablehlo.reduce(%1091 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1093 = stablehlo.broadcast_in_dim %1092, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1094 = stablehlo.divide %1091, %1093 : tensor<8x32x100x1024xf32>
%1095 = stablehlo.reshape %1094 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1096 = stablehlo.transpose %arg132, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1097 = stablehlo.dot_general %1045, %1096, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1098 = stablehlo.reshape %1097 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1099 = "stablehlo.scatter"(%arg249, %39, %1098) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1100 = stablehlo.transpose %1099, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1101 = stablehlo.reshape %1100 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1102 = stablehlo.dot_general %1095, %1101, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1103 = stablehlo.reshape %1102 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1104 = stablehlo.transpose %1103, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1105 = stablehlo.reshape %1104 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1106 = stablehlo.transpose %arg131, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1107 = stablehlo.dot_general %1105, %1106, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1108 = stablehlo.reshape %1107 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1109 = stablehlo.add %1033, %1108 : tensor<8x100x4096xf32>
%1110 = stablehlo.power %1109, %cst_3 : tensor<8x100x4096xf32>
%1111 = stablehlo.reduce(%1110 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1112 = stablehlo.multiply %1111, %cst_2 : tensor<8x100xf32>
%1113 = stablehlo.reshape %1112 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1114 = stablehlo.add %1113, %cst_1 : tensor<8x100x1xf32>
%1115 = stablehlo.rsqrt %1114 : tensor<8x100x1xf32>
%1116 = stablehlo.reshape %1115 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1117 = stablehlo.broadcast_in_dim %1116, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1118 = stablehlo.multiply %1109, %1117 : tensor<8x100x4096xf32>
%1119 = stablehlo.broadcast_in_dim %arg130, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1120 = stablehlo.multiply %1118, %1119 : tensor<8x100x4096xf32>
%1121 = stablehlo.reshape %1120 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1122 = stablehlo.transpose %arg253, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1123 = stablehlo.dot_general %1121, %1122, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1124 = stablehlo.reshape %1123 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1125 = stablehlo.logistic %1124 : tensor<8x100x11008xf32>
%1126 = stablehlo.multiply %1124, %1125 : tensor<8x100x11008xf32>
%1127 = stablehlo.transpose %arg129, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1128 = stablehlo.dot_general %1121, %1127, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1129 = stablehlo.reshape %1128 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1130 = stablehlo.multiply %1126, %1129 : tensor<8x100x11008xf32>
%1131 = stablehlo.reshape %1130 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1132 = stablehlo.transpose %arg128, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1133 = stablehlo.dot_general %1131, %1132, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1134 = stablehlo.reshape %1133 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1135 = stablehlo.add %1109, %1134 : tensor<8x100x4096xf32>
%1136 = stablehlo.power %1135, %cst_3 : tensor<8x100x4096xf32>
%1137 = stablehlo.reduce(%1136 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1138 = stablehlo.multiply %1137, %cst_2 : tensor<8x100xf32>
%1139 = stablehlo.reshape %1138 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1140 = stablehlo.add %1139, %cst_1 : tensor<8x100x1xf32>
%1141 = stablehlo.rsqrt %1140 : tensor<8x100x1xf32>
%1142 = stablehlo.reshape %1141 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1143 = stablehlo.broadcast_in_dim %1142, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1144 = stablehlo.multiply %1135, %1143 : tensor<8x100x4096xf32>
%1145 = stablehlo.broadcast_in_dim %arg127, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1146 = stablehlo.multiply %1144, %1145 : tensor<8x100x4096xf32>
%1147 = stablehlo.reshape %1146 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1148 = stablehlo.transpose %arg257, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1149 = stablehlo.dot_general %1147, %1148, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1150 = stablehlo.reshape %1149 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1151 = stablehlo.transpose %1150, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1152 = stablehlo.reshape %1151 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1153 = stablehlo.slice %1152 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1154 = stablehlo.reshape %1153 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1155 = stablehlo.slice %1152 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1156 = stablehlo.reshape %1155 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1157 = stablehlo.complex %1154, %1156 : tensor<256x100x64xcomplex<f32>>
%1158 = stablehlo.multiply %1157, %28 : tensor<256x100x64xcomplex<f32>>
%1159 = stablehlo.real %1158 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1160 = stablehlo.reshape %1159 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1161 = stablehlo.imag %1158 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1162 = stablehlo.reshape %1161 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1163 = stablehlo.concatenate %1160, %1162, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1164 = stablehlo.reshape %1163 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1165 = stablehlo.transpose %arg255, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1166 = stablehlo.dot_general %1147, %1165, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1167 = stablehlo.reshape %1166 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1168 = stablehlo.transpose %1167, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1169 = stablehlo.reshape %1168 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1170 = stablehlo.slice %1169 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1171 = stablehlo.reshape %1170 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1172 = stablehlo.slice %1169 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1173 = stablehlo.reshape %1172 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1174 = stablehlo.complex %1171, %1173 : tensor<256x100x64xcomplex<f32>>
%1175 = stablehlo.multiply %1174, %28 : tensor<256x100x64xcomplex<f32>>
%1176 = stablehlo.real %1175 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1177 = stablehlo.reshape %1176 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1178 = stablehlo.imag %1175 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1179 = stablehlo.reshape %1178 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1180 = stablehlo.concatenate %1177, %1179, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1181 = stablehlo.reshape %1180 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1182 = stablehlo.transpose %1181, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1183 = "stablehlo.scatter"(%arg256, %39, %1182) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1184 = stablehlo.transpose %1183, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1185 = stablehlo.reshape %1184 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1186 = stablehlo.dot_general %1164, %1185, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1187 = stablehlo.reshape %1186 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1188 = stablehlo.divide %1187, %cst : tensor<8x32x100x1024xf32>
%1189 = stablehlo.add %1188, %66 : tensor<8x32x100x1024xf32>
%1190 = stablehlo.reduce(%1189 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1191 = stablehlo.broadcast_in_dim %1190, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1192 = stablehlo.subtract %1189, %1191 : tensor<8x32x100x1024xf32>
%1193 = stablehlo.exponential %1192 : tensor<8x32x100x1024xf32>
%1194 = stablehlo.reduce(%1193 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1195 = stablehlo.broadcast_in_dim %1194, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1196 = stablehlo.divide %1193, %1195 : tensor<8x32x100x1024xf32>
%1197 = stablehlo.reshape %1196 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1198 = stablehlo.transpose %arg126, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1199 = stablehlo.dot_general %1147, %1198, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1200 = stablehlo.reshape %1199 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1201 = "stablehlo.scatter"(%arg254, %39, %1200) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1202 = stablehlo.transpose %1201, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1203 = stablehlo.reshape %1202 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1204 = stablehlo.dot_general %1197, %1203, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1205 = stablehlo.reshape %1204 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1206 = stablehlo.transpose %1205, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1207 = stablehlo.reshape %1206 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1208 = stablehlo.transpose %arg125, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1209 = stablehlo.dot_general %1207, %1208, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1210 = stablehlo.reshape %1209 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1211 = stablehlo.add %1135, %1210 : tensor<8x100x4096xf32>
%1212 = stablehlo.power %1211, %cst_3 : tensor<8x100x4096xf32>
%1213 = stablehlo.reduce(%1212 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1214 = stablehlo.multiply %1213, %cst_2 : tensor<8x100xf32>
%1215 = stablehlo.reshape %1214 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1216 = stablehlo.add %1215, %cst_1 : tensor<8x100x1xf32>
%1217 = stablehlo.rsqrt %1216 : tensor<8x100x1xf32>
%1218 = stablehlo.reshape %1217 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1219 = stablehlo.broadcast_in_dim %1218, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1220 = stablehlo.multiply %1211, %1219 : tensor<8x100x4096xf32>
%1221 = stablehlo.broadcast_in_dim %arg124, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1222 = stablehlo.multiply %1220, %1221 : tensor<8x100x4096xf32>
%1223 = stablehlo.reshape %1222 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1224 = stablehlo.transpose %arg258, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1225 = stablehlo.dot_general %1223, %1224, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1226 = stablehlo.reshape %1225 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1227 = stablehlo.logistic %1226 : tensor<8x100x11008xf32>
%1228 = stablehlo.multiply %1226, %1227 : tensor<8x100x11008xf32>
%1229 = stablehlo.transpose %arg123, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1230 = stablehlo.dot_general %1223, %1229, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1231 = stablehlo.reshape %1230 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1232 = stablehlo.multiply %1228, %1231 : tensor<8x100x11008xf32>
%1233 = stablehlo.reshape %1232 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1234 = stablehlo.transpose %arg122, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1235 = stablehlo.dot_general %1233, %1234, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1236 = stablehlo.reshape %1235 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1237 = stablehlo.add %1211, %1236 : tensor<8x100x4096xf32>
%1238 = stablehlo.power %1237, %cst_3 : tensor<8x100x4096xf32>
%1239 = stablehlo.reduce(%1238 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1240 = stablehlo.multiply %1239, %cst_2 : tensor<8x100xf32>
%1241 = stablehlo.reshape %1240 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1242 = stablehlo.add %1241, %cst_1 : tensor<8x100x1xf32>
%1243 = stablehlo.rsqrt %1242 : tensor<8x100x1xf32>
%1244 = stablehlo.reshape %1243 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1245 = stablehlo.broadcast_in_dim %1244, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1246 = stablehlo.multiply %1237, %1245 : tensor<8x100x4096xf32>
%1247 = stablehlo.broadcast_in_dim %arg121, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1248 = stablehlo.multiply %1246, %1247 : tensor<8x100x4096xf32>
%1249 = stablehlo.reshape %1248 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1250 = stablehlo.transpose %arg262, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1251 = stablehlo.dot_general %1249, %1250, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1252 = stablehlo.reshape %1251 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1253 = stablehlo.transpose %1252, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1254 = stablehlo.reshape %1253 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1255 = stablehlo.slice %1254 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1256 = stablehlo.reshape %1255 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1257 = stablehlo.slice %1254 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1258 = stablehlo.reshape %1257 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1259 = stablehlo.complex %1256, %1258 : tensor<256x100x64xcomplex<f32>>
%1260 = stablehlo.multiply %1259, %28 : tensor<256x100x64xcomplex<f32>>
%1261 = stablehlo.real %1260 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1262 = stablehlo.reshape %1261 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1263 = stablehlo.imag %1260 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1264 = stablehlo.reshape %1263 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1265 = stablehlo.concatenate %1262, %1264, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1266 = stablehlo.reshape %1265 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1267 = stablehlo.transpose %arg260, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1268 = stablehlo.dot_general %1249, %1267, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1269 = stablehlo.reshape %1268 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1270 = stablehlo.transpose %1269, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1271 = stablehlo.reshape %1270 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1272 = stablehlo.slice %1271 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1273 = stablehlo.reshape %1272 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1274 = stablehlo.slice %1271 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1275 = stablehlo.reshape %1274 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1276 = stablehlo.complex %1273, %1275 : tensor<256x100x64xcomplex<f32>>
%1277 = stablehlo.multiply %1276, %28 : tensor<256x100x64xcomplex<f32>>
%1278 = stablehlo.real %1277 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1279 = stablehlo.reshape %1278 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1280 = stablehlo.imag %1277 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1281 = stablehlo.reshape %1280 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1282 = stablehlo.concatenate %1279, %1281, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1283 = stablehlo.reshape %1282 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1284 = stablehlo.transpose %1283, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1285 = "stablehlo.scatter"(%arg261, %39, %1284) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1286 = stablehlo.transpose %1285, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1287 = stablehlo.reshape %1286 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1288 = stablehlo.dot_general %1266, %1287, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1289 = stablehlo.reshape %1288 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1290 = stablehlo.divide %1289, %cst : tensor<8x32x100x1024xf32>
%1291 = stablehlo.add %1290, %66 : tensor<8x32x100x1024xf32>
%1292 = stablehlo.reduce(%1291 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1293 = stablehlo.broadcast_in_dim %1292, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1294 = stablehlo.subtract %1291, %1293 : tensor<8x32x100x1024xf32>
%1295 = stablehlo.exponential %1294 : tensor<8x32x100x1024xf32>
%1296 = stablehlo.reduce(%1295 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1297 = stablehlo.broadcast_in_dim %1296, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1298 = stablehlo.divide %1295, %1297 : tensor<8x32x100x1024xf32>
%1299 = stablehlo.reshape %1298 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1300 = stablehlo.transpose %arg120, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1301 = stablehlo.dot_general %1249, %1300, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1302 = stablehlo.reshape %1301 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1303 = "stablehlo.scatter"(%arg259, %39, %1302) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1304 = stablehlo.transpose %1303, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1305 = stablehlo.reshape %1304 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1306 = stablehlo.dot_general %1299, %1305, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1307 = stablehlo.reshape %1306 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1308 = stablehlo.transpose %1307, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1309 = stablehlo.reshape %1308 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1310 = stablehlo.transpose %arg119, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1311 = stablehlo.dot_general %1309, %1310, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1312 = stablehlo.reshape %1311 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1313 = stablehlo.add %1237, %1312 : tensor<8x100x4096xf32>
%1314 = stablehlo.power %1313, %cst_3 : tensor<8x100x4096xf32>
%1315 = stablehlo.reduce(%1314 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1316 = stablehlo.multiply %1315, %cst_2 : tensor<8x100xf32>
%1317 = stablehlo.reshape %1316 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1318 = stablehlo.add %1317, %cst_1 : tensor<8x100x1xf32>
%1319 = stablehlo.rsqrt %1318 : tensor<8x100x1xf32>
%1320 = stablehlo.reshape %1319 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1321 = stablehlo.broadcast_in_dim %1320, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1322 = stablehlo.multiply %1313, %1321 : tensor<8x100x4096xf32>
%1323 = stablehlo.broadcast_in_dim %arg118, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1324 = stablehlo.multiply %1322, %1323 : tensor<8x100x4096xf32>
%1325 = stablehlo.reshape %1324 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1326 = stablehlo.transpose %arg263, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1327 = stablehlo.dot_general %1325, %1326, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1328 = stablehlo.reshape %1327 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1329 = stablehlo.logistic %1328 : tensor<8x100x11008xf32>
%1330 = stablehlo.multiply %1328, %1329 : tensor<8x100x11008xf32>
%1331 = stablehlo.transpose %arg117, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1332 = stablehlo.dot_general %1325, %1331, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1333 = stablehlo.reshape %1332 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1334 = stablehlo.multiply %1330, %1333 : tensor<8x100x11008xf32>
%1335 = stablehlo.reshape %1334 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1336 = stablehlo.transpose %arg116, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1337 = stablehlo.dot_general %1335, %1336, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1338 = stablehlo.reshape %1337 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1339 = stablehlo.add %1313, %1338 : tensor<8x100x4096xf32>
%1340 = stablehlo.power %1339, %cst_3 : tensor<8x100x4096xf32>
%1341 = stablehlo.reduce(%1340 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1342 = stablehlo.multiply %1341, %cst_2 : tensor<8x100xf32>
%1343 = stablehlo.reshape %1342 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1344 = stablehlo.add %1343, %cst_1 : tensor<8x100x1xf32>
%1345 = stablehlo.rsqrt %1344 : tensor<8x100x1xf32>
%1346 = stablehlo.reshape %1345 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1347 = stablehlo.broadcast_in_dim %1346, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1348 = stablehlo.multiply %1339, %1347 : tensor<8x100x4096xf32>
%1349 = stablehlo.broadcast_in_dim %arg115, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1350 = stablehlo.multiply %1348, %1349 : tensor<8x100x4096xf32>
%1351 = stablehlo.reshape %1350 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1352 = stablehlo.transpose %arg267, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1353 = stablehlo.dot_general %1351, %1352, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1354 = stablehlo.reshape %1353 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1355 = stablehlo.transpose %1354, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1356 = stablehlo.reshape %1355 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1357 = stablehlo.slice %1356 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1358 = stablehlo.reshape %1357 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1359 = stablehlo.slice %1356 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1360 = stablehlo.reshape %1359 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1361 = stablehlo.complex %1358, %1360 : tensor<256x100x64xcomplex<f32>>
%1362 = stablehlo.multiply %1361, %28 : tensor<256x100x64xcomplex<f32>>
%1363 = stablehlo.real %1362 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1364 = stablehlo.reshape %1363 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1365 = stablehlo.imag %1362 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1366 = stablehlo.reshape %1365 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1367 = stablehlo.concatenate %1364, %1366, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1368 = stablehlo.reshape %1367 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1369 = stablehlo.transpose %arg265, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1370 = stablehlo.dot_general %1351, %1369, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1371 = stablehlo.reshape %1370 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1372 = stablehlo.transpose %1371, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1373 = stablehlo.reshape %1372 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1374 = stablehlo.slice %1373 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1375 = stablehlo.reshape %1374 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1376 = stablehlo.slice %1373 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1377 = stablehlo.reshape %1376 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1378 = stablehlo.complex %1375, %1377 : tensor<256x100x64xcomplex<f32>>
%1379 = stablehlo.multiply %1378, %28 : tensor<256x100x64xcomplex<f32>>
%1380 = stablehlo.real %1379 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1381 = stablehlo.reshape %1380 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1382 = stablehlo.imag %1379 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1383 = stablehlo.reshape %1382 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1384 = stablehlo.concatenate %1381, %1383, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1385 = stablehlo.reshape %1384 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1386 = stablehlo.transpose %1385, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1387 = "stablehlo.scatter"(%arg266, %39, %1386) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1388 = stablehlo.transpose %1387, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1389 = stablehlo.reshape %1388 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1390 = stablehlo.dot_general %1368, %1389, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1391 = stablehlo.reshape %1390 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1392 = stablehlo.divide %1391, %cst : tensor<8x32x100x1024xf32>
%1393 = stablehlo.add %1392, %66 : tensor<8x32x100x1024xf32>
%1394 = stablehlo.reduce(%1393 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1395 = stablehlo.broadcast_in_dim %1394, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1396 = stablehlo.subtract %1393, %1395 : tensor<8x32x100x1024xf32>
%1397 = stablehlo.exponential %1396 : tensor<8x32x100x1024xf32>
%1398 = stablehlo.reduce(%1397 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1399 = stablehlo.broadcast_in_dim %1398, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1400 = stablehlo.divide %1397, %1399 : tensor<8x32x100x1024xf32>
%1401 = stablehlo.reshape %1400 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1402 = stablehlo.transpose %arg114, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1403 = stablehlo.dot_general %1351, %1402, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1404 = stablehlo.reshape %1403 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1405 = "stablehlo.scatter"(%arg264, %39, %1404) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1406 = stablehlo.transpose %1405, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1407 = stablehlo.reshape %1406 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1408 = stablehlo.dot_general %1401, %1407, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1409 = stablehlo.reshape %1408 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1410 = stablehlo.transpose %1409, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1411 = stablehlo.reshape %1410 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1412 = stablehlo.transpose %arg113, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1413 = stablehlo.dot_general %1411, %1412, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1414 = stablehlo.reshape %1413 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1415 = stablehlo.add %1339, %1414 : tensor<8x100x4096xf32>
%1416 = stablehlo.power %1415, %cst_3 : tensor<8x100x4096xf32>
%1417 = stablehlo.reduce(%1416 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1418 = stablehlo.multiply %1417, %cst_2 : tensor<8x100xf32>
%1419 = stablehlo.reshape %1418 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1420 = stablehlo.add %1419, %cst_1 : tensor<8x100x1xf32>
%1421 = stablehlo.rsqrt %1420 : tensor<8x100x1xf32>
%1422 = stablehlo.reshape %1421 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1423 = stablehlo.broadcast_in_dim %1422, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1424 = stablehlo.multiply %1415, %1423 : tensor<8x100x4096xf32>
%1425 = stablehlo.broadcast_in_dim %arg112, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1426 = stablehlo.multiply %1424, %1425 : tensor<8x100x4096xf32>
%1427 = stablehlo.reshape %1426 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1428 = stablehlo.transpose %arg268, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1429 = stablehlo.dot_general %1427, %1428, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1430 = stablehlo.reshape %1429 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1431 = stablehlo.logistic %1430 : tensor<8x100x11008xf32>
%1432 = stablehlo.multiply %1430, %1431 : tensor<8x100x11008xf32>
%1433 = stablehlo.transpose %arg111, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1434 = stablehlo.dot_general %1427, %1433, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1435 = stablehlo.reshape %1434 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1436 = stablehlo.multiply %1432, %1435 : tensor<8x100x11008xf32>
%1437 = stablehlo.reshape %1436 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1438 = stablehlo.transpose %arg110, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1439 = stablehlo.dot_general %1437, %1438, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1440 = stablehlo.reshape %1439 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1441 = stablehlo.add %1415, %1440 : tensor<8x100x4096xf32>
%1442 = stablehlo.power %1441, %cst_3 : tensor<8x100x4096xf32>
%1443 = stablehlo.reduce(%1442 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1444 = stablehlo.multiply %1443, %cst_2 : tensor<8x100xf32>
%1445 = stablehlo.reshape %1444 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1446 = stablehlo.add %1445, %cst_1 : tensor<8x100x1xf32>
%1447 = stablehlo.rsqrt %1446 : tensor<8x100x1xf32>
%1448 = stablehlo.reshape %1447 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1449 = stablehlo.broadcast_in_dim %1448, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1450 = stablehlo.multiply %1441, %1449 : tensor<8x100x4096xf32>
%1451 = stablehlo.broadcast_in_dim %arg109, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1452 = stablehlo.multiply %1450, %1451 : tensor<8x100x4096xf32>
%1453 = stablehlo.reshape %1452 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1454 = stablehlo.transpose %arg272, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1455 = stablehlo.dot_general %1453, %1454, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1456 = stablehlo.reshape %1455 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1457 = stablehlo.transpose %1456, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1458 = stablehlo.reshape %1457 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1459 = stablehlo.slice %1458 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1460 = stablehlo.reshape %1459 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1461 = stablehlo.slice %1458 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1462 = stablehlo.reshape %1461 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1463 = stablehlo.complex %1460, %1462 : tensor<256x100x64xcomplex<f32>>
%1464 = stablehlo.multiply %1463, %28 : tensor<256x100x64xcomplex<f32>>
%1465 = stablehlo.real %1464 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1466 = stablehlo.reshape %1465 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1467 = stablehlo.imag %1464 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1468 = stablehlo.reshape %1467 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1469 = stablehlo.concatenate %1466, %1468, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1470 = stablehlo.reshape %1469 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1471 = stablehlo.transpose %arg270, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1472 = stablehlo.dot_general %1453, %1471, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1473 = stablehlo.reshape %1472 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1474 = stablehlo.transpose %1473, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1475 = stablehlo.reshape %1474 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1476 = stablehlo.slice %1475 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1477 = stablehlo.reshape %1476 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1478 = stablehlo.slice %1475 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1479 = stablehlo.reshape %1478 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1480 = stablehlo.complex %1477, %1479 : tensor<256x100x64xcomplex<f32>>
%1481 = stablehlo.multiply %1480, %28 : tensor<256x100x64xcomplex<f32>>
%1482 = stablehlo.real %1481 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1483 = stablehlo.reshape %1482 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1484 = stablehlo.imag %1481 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1485 = stablehlo.reshape %1484 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1486 = stablehlo.concatenate %1483, %1485, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1487 = stablehlo.reshape %1486 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1488 = stablehlo.transpose %1487, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1489 = "stablehlo.scatter"(%arg271, %39, %1488) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1490 = stablehlo.transpose %1489, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1491 = stablehlo.reshape %1490 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1492 = stablehlo.dot_general %1470, %1491, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1493 = stablehlo.reshape %1492 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1494 = stablehlo.divide %1493, %cst : tensor<8x32x100x1024xf32>
%1495 = stablehlo.add %1494, %66 : tensor<8x32x100x1024xf32>
%1496 = stablehlo.reduce(%1495 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1497 = stablehlo.broadcast_in_dim %1496, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1498 = stablehlo.subtract %1495, %1497 : tensor<8x32x100x1024xf32>
%1499 = stablehlo.exponential %1498 : tensor<8x32x100x1024xf32>
%1500 = stablehlo.reduce(%1499 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1501 = stablehlo.broadcast_in_dim %1500, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1502 = stablehlo.divide %1499, %1501 : tensor<8x32x100x1024xf32>
%1503 = stablehlo.reshape %1502 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1504 = stablehlo.transpose %arg108, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1505 = stablehlo.dot_general %1453, %1504, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1506 = stablehlo.reshape %1505 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1507 = "stablehlo.scatter"(%arg269, %39, %1506) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1508 = stablehlo.transpose %1507, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1509 = stablehlo.reshape %1508 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1510 = stablehlo.dot_general %1503, %1509, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1511 = stablehlo.reshape %1510 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1512 = stablehlo.transpose %1511, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1513 = stablehlo.reshape %1512 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1514 = stablehlo.transpose %arg107, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1515 = stablehlo.dot_general %1513, %1514, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1516 = stablehlo.reshape %1515 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1517 = stablehlo.add %1441, %1516 : tensor<8x100x4096xf32>
%1518 = stablehlo.power %1517, %cst_3 : tensor<8x100x4096xf32>
%1519 = stablehlo.reduce(%1518 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1520 = stablehlo.multiply %1519, %cst_2 : tensor<8x100xf32>
%1521 = stablehlo.reshape %1520 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1522 = stablehlo.add %1521, %cst_1 : tensor<8x100x1xf32>
%1523 = stablehlo.rsqrt %1522 : tensor<8x100x1xf32>
%1524 = stablehlo.reshape %1523 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1525 = stablehlo.broadcast_in_dim %1524, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1526 = stablehlo.multiply %1517, %1525 : tensor<8x100x4096xf32>
%1527 = stablehlo.broadcast_in_dim %arg106, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1528 = stablehlo.multiply %1526, %1527 : tensor<8x100x4096xf32>
%1529 = stablehlo.reshape %1528 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1530 = stablehlo.transpose %arg273, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1531 = stablehlo.dot_general %1529, %1530, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1532 = stablehlo.reshape %1531 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1533 = stablehlo.logistic %1532 : tensor<8x100x11008xf32>
%1534 = stablehlo.multiply %1532, %1533 : tensor<8x100x11008xf32>
%1535 = stablehlo.transpose %arg105, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1536 = stablehlo.dot_general %1529, %1535, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1537 = stablehlo.reshape %1536 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1538 = stablehlo.multiply %1534, %1537 : tensor<8x100x11008xf32>
%1539 = stablehlo.reshape %1538 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1540 = stablehlo.transpose %arg104, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1541 = stablehlo.dot_general %1539, %1540, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1542 = stablehlo.reshape %1541 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1543 = stablehlo.add %1517, %1542 : tensor<8x100x4096xf32>
%1544 = stablehlo.power %1543, %cst_3 : tensor<8x100x4096xf32>
%1545 = stablehlo.reduce(%1544 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1546 = stablehlo.multiply %1545, %cst_2 : tensor<8x100xf32>
%1547 = stablehlo.reshape %1546 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1548 = stablehlo.add %1547, %cst_1 : tensor<8x100x1xf32>
%1549 = stablehlo.rsqrt %1548 : tensor<8x100x1xf32>
%1550 = stablehlo.reshape %1549 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1551 = stablehlo.broadcast_in_dim %1550, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1552 = stablehlo.multiply %1543, %1551 : tensor<8x100x4096xf32>
%1553 = stablehlo.broadcast_in_dim %arg103, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1554 = stablehlo.multiply %1552, %1553 : tensor<8x100x4096xf32>
%1555 = stablehlo.reshape %1554 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1556 = stablehlo.transpose %arg277, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1557 = stablehlo.dot_general %1555, %1556, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1558 = stablehlo.reshape %1557 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1559 = stablehlo.transpose %1558, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1560 = stablehlo.reshape %1559 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1561 = stablehlo.slice %1560 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1562 = stablehlo.reshape %1561 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1563 = stablehlo.slice %1560 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1564 = stablehlo.reshape %1563 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1565 = stablehlo.complex %1562, %1564 : tensor<256x100x64xcomplex<f32>>
%1566 = stablehlo.multiply %1565, %28 : tensor<256x100x64xcomplex<f32>>
%1567 = stablehlo.real %1566 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1568 = stablehlo.reshape %1567 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1569 = stablehlo.imag %1566 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1570 = stablehlo.reshape %1569 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1571 = stablehlo.concatenate %1568, %1570, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1572 = stablehlo.reshape %1571 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1573 = stablehlo.transpose %arg275, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1574 = stablehlo.dot_general %1555, %1573, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1575 = stablehlo.reshape %1574 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1576 = stablehlo.transpose %1575, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1577 = stablehlo.reshape %1576 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1578 = stablehlo.slice %1577 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1579 = stablehlo.reshape %1578 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1580 = stablehlo.slice %1577 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1581 = stablehlo.reshape %1580 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1582 = stablehlo.complex %1579, %1581 : tensor<256x100x64xcomplex<f32>>
%1583 = stablehlo.multiply %1582, %28 : tensor<256x100x64xcomplex<f32>>
%1584 = stablehlo.real %1583 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1585 = stablehlo.reshape %1584 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1586 = stablehlo.imag %1583 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1587 = stablehlo.reshape %1586 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1588 = stablehlo.concatenate %1585, %1587, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1589 = stablehlo.reshape %1588 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1590 = stablehlo.transpose %1589, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1591 = "stablehlo.scatter"(%arg276, %39, %1590) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1592 = stablehlo.transpose %1591, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1593 = stablehlo.reshape %1592 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1594 = stablehlo.dot_general %1572, %1593, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1595 = stablehlo.reshape %1594 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1596 = stablehlo.divide %1595, %cst : tensor<8x32x100x1024xf32>
%1597 = stablehlo.add %1596, %66 : tensor<8x32x100x1024xf32>
%1598 = stablehlo.reduce(%1597 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1599 = stablehlo.broadcast_in_dim %1598, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1600 = stablehlo.subtract %1597, %1599 : tensor<8x32x100x1024xf32>
%1601 = stablehlo.exponential %1600 : tensor<8x32x100x1024xf32>
%1602 = stablehlo.reduce(%1601 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1603 = stablehlo.broadcast_in_dim %1602, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1604 = stablehlo.divide %1601, %1603 : tensor<8x32x100x1024xf32>
%1605 = stablehlo.reshape %1604 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1606 = stablehlo.transpose %arg102, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1607 = stablehlo.dot_general %1555, %1606, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1608 = stablehlo.reshape %1607 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1609 = "stablehlo.scatter"(%arg274, %39, %1608) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1610 = stablehlo.transpose %1609, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1611 = stablehlo.reshape %1610 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1612 = stablehlo.dot_general %1605, %1611, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1613 = stablehlo.reshape %1612 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1614 = stablehlo.transpose %1613, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1615 = stablehlo.reshape %1614 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1616 = stablehlo.transpose %arg101, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1617 = stablehlo.dot_general %1615, %1616, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1618 = stablehlo.reshape %1617 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1619 = stablehlo.add %1543, %1618 : tensor<8x100x4096xf32>
%1620 = stablehlo.power %1619, %cst_3 : tensor<8x100x4096xf32>
%1621 = stablehlo.reduce(%1620 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1622 = stablehlo.multiply %1621, %cst_2 : tensor<8x100xf32>
%1623 = stablehlo.reshape %1622 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1624 = stablehlo.add %1623, %cst_1 : tensor<8x100x1xf32>
%1625 = stablehlo.rsqrt %1624 : tensor<8x100x1xf32>
%1626 = stablehlo.reshape %1625 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1627 = stablehlo.broadcast_in_dim %1626, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1628 = stablehlo.multiply %1619, %1627 : tensor<8x100x4096xf32>
%1629 = stablehlo.broadcast_in_dim %arg100, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1630 = stablehlo.multiply %1628, %1629 : tensor<8x100x4096xf32>
%1631 = stablehlo.reshape %1630 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1632 = stablehlo.transpose %arg278, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1633 = stablehlo.dot_general %1631, %1632, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1634 = stablehlo.reshape %1633 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1635 = stablehlo.logistic %1634 : tensor<8x100x11008xf32>
%1636 = stablehlo.multiply %1634, %1635 : tensor<8x100x11008xf32>
%1637 = stablehlo.transpose %arg99, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1638 = stablehlo.dot_general %1631, %1637, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1639 = stablehlo.reshape %1638 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1640 = stablehlo.multiply %1636, %1639 : tensor<8x100x11008xf32>
%1641 = stablehlo.reshape %1640 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1642 = stablehlo.transpose %arg98, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1643 = stablehlo.dot_general %1641, %1642, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1644 = stablehlo.reshape %1643 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1645 = stablehlo.add %1619, %1644 : tensor<8x100x4096xf32>
%1646 = stablehlo.power %1645, %cst_3 : tensor<8x100x4096xf32>
%1647 = stablehlo.reduce(%1646 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1648 = stablehlo.multiply %1647, %cst_2 : tensor<8x100xf32>
%1649 = stablehlo.reshape %1648 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1650 = stablehlo.add %1649, %cst_1 : tensor<8x100x1xf32>
%1651 = stablehlo.rsqrt %1650 : tensor<8x100x1xf32>
%1652 = stablehlo.reshape %1651 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1653 = stablehlo.broadcast_in_dim %1652, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1654 = stablehlo.multiply %1645, %1653 : tensor<8x100x4096xf32>
%1655 = stablehlo.broadcast_in_dim %arg97, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1656 = stablehlo.multiply %1654, %1655 : tensor<8x100x4096xf32>
%1657 = stablehlo.reshape %1656 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1658 = stablehlo.transpose %arg282, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1659 = stablehlo.dot_general %1657, %1658, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1660 = stablehlo.reshape %1659 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1661 = stablehlo.transpose %1660, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1662 = stablehlo.reshape %1661 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1663 = stablehlo.slice %1662 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1664 = stablehlo.reshape %1663 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1665 = stablehlo.slice %1662 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1666 = stablehlo.reshape %1665 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1667 = stablehlo.complex %1664, %1666 : tensor<256x100x64xcomplex<f32>>
%1668 = stablehlo.multiply %1667, %28 : tensor<256x100x64xcomplex<f32>>
%1669 = stablehlo.real %1668 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1670 = stablehlo.reshape %1669 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1671 = stablehlo.imag %1668 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1672 = stablehlo.reshape %1671 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1673 = stablehlo.concatenate %1670, %1672, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1674 = stablehlo.reshape %1673 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1675 = stablehlo.transpose %arg280, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1676 = stablehlo.dot_general %1657, %1675, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1677 = stablehlo.reshape %1676 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1678 = stablehlo.transpose %1677, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1679 = stablehlo.reshape %1678 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1680 = stablehlo.slice %1679 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1681 = stablehlo.reshape %1680 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1682 = stablehlo.slice %1679 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1683 = stablehlo.reshape %1682 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1684 = stablehlo.complex %1681, %1683 : tensor<256x100x64xcomplex<f32>>
%1685 = stablehlo.multiply %1684, %28 : tensor<256x100x64xcomplex<f32>>
%1686 = stablehlo.real %1685 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1687 = stablehlo.reshape %1686 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1688 = stablehlo.imag %1685 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1689 = stablehlo.reshape %1688 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1690 = stablehlo.concatenate %1687, %1689, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1691 = stablehlo.reshape %1690 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1692 = stablehlo.transpose %1691, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1693 = "stablehlo.scatter"(%arg281, %39, %1692) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1694 = stablehlo.transpose %1693, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1695 = stablehlo.reshape %1694 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1696 = stablehlo.dot_general %1674, %1695, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1697 = stablehlo.reshape %1696 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1698 = stablehlo.divide %1697, %cst : tensor<8x32x100x1024xf32>
%1699 = stablehlo.add %1698, %66 : tensor<8x32x100x1024xf32>
%1700 = stablehlo.reduce(%1699 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1701 = stablehlo.broadcast_in_dim %1700, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1702 = stablehlo.subtract %1699, %1701 : tensor<8x32x100x1024xf32>
%1703 = stablehlo.exponential %1702 : tensor<8x32x100x1024xf32>
%1704 = stablehlo.reduce(%1703 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1705 = stablehlo.broadcast_in_dim %1704, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1706 = stablehlo.divide %1703, %1705 : tensor<8x32x100x1024xf32>
%1707 = stablehlo.reshape %1706 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1708 = stablehlo.transpose %arg96, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1709 = stablehlo.dot_general %1657, %1708, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1710 = stablehlo.reshape %1709 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1711 = "stablehlo.scatter"(%arg279, %39, %1710) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1712 = stablehlo.transpose %1711, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1713 = stablehlo.reshape %1712 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1714 = stablehlo.dot_general %1707, %1713, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1715 = stablehlo.reshape %1714 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1716 = stablehlo.transpose %1715, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1717 = stablehlo.reshape %1716 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1718 = stablehlo.transpose %arg95, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1719 = stablehlo.dot_general %1717, %1718, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1720 = stablehlo.reshape %1719 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1721 = stablehlo.add %1645, %1720 : tensor<8x100x4096xf32>
%1722 = stablehlo.power %1721, %cst_3 : tensor<8x100x4096xf32>
%1723 = stablehlo.reduce(%1722 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1724 = stablehlo.multiply %1723, %cst_2 : tensor<8x100xf32>
%1725 = stablehlo.reshape %1724 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1726 = stablehlo.add %1725, %cst_1 : tensor<8x100x1xf32>
%1727 = stablehlo.rsqrt %1726 : tensor<8x100x1xf32>
%1728 = stablehlo.reshape %1727 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1729 = stablehlo.broadcast_in_dim %1728, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1730 = stablehlo.multiply %1721, %1729 : tensor<8x100x4096xf32>
%1731 = stablehlo.broadcast_in_dim %arg94, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1732 = stablehlo.multiply %1730, %1731 : tensor<8x100x4096xf32>
%1733 = stablehlo.reshape %1732 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1734 = stablehlo.transpose %arg283, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1735 = stablehlo.dot_general %1733, %1734, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1736 = stablehlo.reshape %1735 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1737 = stablehlo.logistic %1736 : tensor<8x100x11008xf32>
%1738 = stablehlo.multiply %1736, %1737 : tensor<8x100x11008xf32>
%1739 = stablehlo.transpose %arg93, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1740 = stablehlo.dot_general %1733, %1739, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1741 = stablehlo.reshape %1740 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1742 = stablehlo.multiply %1738, %1741 : tensor<8x100x11008xf32>
%1743 = stablehlo.reshape %1742 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1744 = stablehlo.transpose %arg92, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1745 = stablehlo.dot_general %1743, %1744, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1746 = stablehlo.reshape %1745 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1747 = stablehlo.add %1721, %1746 : tensor<8x100x4096xf32>
%1748 = stablehlo.power %1747, %cst_3 : tensor<8x100x4096xf32>
%1749 = stablehlo.reduce(%1748 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1750 = stablehlo.multiply %1749, %cst_2 : tensor<8x100xf32>
%1751 = stablehlo.reshape %1750 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1752 = stablehlo.add %1751, %cst_1 : tensor<8x100x1xf32>
%1753 = stablehlo.rsqrt %1752 : tensor<8x100x1xf32>
%1754 = stablehlo.reshape %1753 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1755 = stablehlo.broadcast_in_dim %1754, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1756 = stablehlo.multiply %1747, %1755 : tensor<8x100x4096xf32>
%1757 = stablehlo.broadcast_in_dim %arg91, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1758 = stablehlo.multiply %1756, %1757 : tensor<8x100x4096xf32>
%1759 = stablehlo.reshape %1758 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1760 = stablehlo.transpose %arg287, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1761 = stablehlo.dot_general %1759, %1760, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1762 = stablehlo.reshape %1761 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1763 = stablehlo.transpose %1762, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1764 = stablehlo.reshape %1763 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1765 = stablehlo.slice %1764 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1766 = stablehlo.reshape %1765 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1767 = stablehlo.slice %1764 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1768 = stablehlo.reshape %1767 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1769 = stablehlo.complex %1766, %1768 : tensor<256x100x64xcomplex<f32>>
%1770 = stablehlo.multiply %1769, %28 : tensor<256x100x64xcomplex<f32>>
%1771 = stablehlo.real %1770 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1772 = stablehlo.reshape %1771 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1773 = stablehlo.imag %1770 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1774 = stablehlo.reshape %1773 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1775 = stablehlo.concatenate %1772, %1774, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1776 = stablehlo.reshape %1775 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1777 = stablehlo.transpose %arg285, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1778 = stablehlo.dot_general %1759, %1777, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1779 = stablehlo.reshape %1778 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1780 = stablehlo.transpose %1779, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1781 = stablehlo.reshape %1780 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1782 = stablehlo.slice %1781 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1783 = stablehlo.reshape %1782 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1784 = stablehlo.slice %1781 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1785 = stablehlo.reshape %1784 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1786 = stablehlo.complex %1783, %1785 : tensor<256x100x64xcomplex<f32>>
%1787 = stablehlo.multiply %1786, %28 : tensor<256x100x64xcomplex<f32>>
%1788 = stablehlo.real %1787 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1789 = stablehlo.reshape %1788 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1790 = stablehlo.imag %1787 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1791 = stablehlo.reshape %1790 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1792 = stablehlo.concatenate %1789, %1791, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1793 = stablehlo.reshape %1792 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1794 = stablehlo.transpose %1793, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1795 = "stablehlo.scatter"(%arg286, %39, %1794) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1796 = stablehlo.transpose %1795, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1797 = stablehlo.reshape %1796 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1798 = stablehlo.dot_general %1776, %1797, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1799 = stablehlo.reshape %1798 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1800 = stablehlo.divide %1799, %cst : tensor<8x32x100x1024xf32>
%1801 = stablehlo.add %1800, %66 : tensor<8x32x100x1024xf32>
%1802 = stablehlo.reduce(%1801 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1803 = stablehlo.broadcast_in_dim %1802, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1804 = stablehlo.subtract %1801, %1803 : tensor<8x32x100x1024xf32>
%1805 = stablehlo.exponential %1804 : tensor<8x32x100x1024xf32>
%1806 = stablehlo.reduce(%1805 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1807 = stablehlo.broadcast_in_dim %1806, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1808 = stablehlo.divide %1805, %1807 : tensor<8x32x100x1024xf32>
%1809 = stablehlo.reshape %1808 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1810 = stablehlo.transpose %arg90, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1811 = stablehlo.dot_general %1759, %1810, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1812 = stablehlo.reshape %1811 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1813 = "stablehlo.scatter"(%arg284, %39, %1812) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1814 = stablehlo.transpose %1813, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1815 = stablehlo.reshape %1814 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1816 = stablehlo.dot_general %1809, %1815, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1817 = stablehlo.reshape %1816 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1818 = stablehlo.transpose %1817, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1819 = stablehlo.reshape %1818 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1820 = stablehlo.transpose %arg89, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1821 = stablehlo.dot_general %1819, %1820, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1822 = stablehlo.reshape %1821 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1823 = stablehlo.add %1747, %1822 : tensor<8x100x4096xf32>
%1824 = stablehlo.power %1823, %cst_3 : tensor<8x100x4096xf32>
%1825 = stablehlo.reduce(%1824 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1826 = stablehlo.multiply %1825, %cst_2 : tensor<8x100xf32>
%1827 = stablehlo.reshape %1826 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1828 = stablehlo.add %1827, %cst_1 : tensor<8x100x1xf32>
%1829 = stablehlo.rsqrt %1828 : tensor<8x100x1xf32>
%1830 = stablehlo.reshape %1829 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1831 = stablehlo.broadcast_in_dim %1830, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1832 = stablehlo.multiply %1823, %1831 : tensor<8x100x4096xf32>
%1833 = stablehlo.broadcast_in_dim %arg88, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1834 = stablehlo.multiply %1832, %1833 : tensor<8x100x4096xf32>
%1835 = stablehlo.reshape %1834 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1836 = stablehlo.transpose %arg288, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1837 = stablehlo.dot_general %1835, %1836, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1838 = stablehlo.reshape %1837 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1839 = stablehlo.logistic %1838 : tensor<8x100x11008xf32>
%1840 = stablehlo.multiply %1838, %1839 : tensor<8x100x11008xf32>
%1841 = stablehlo.transpose %arg87, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1842 = stablehlo.dot_general %1835, %1841, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1843 = stablehlo.reshape %1842 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1844 = stablehlo.multiply %1840, %1843 : tensor<8x100x11008xf32>
%1845 = stablehlo.reshape %1844 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1846 = stablehlo.transpose %arg86, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1847 = stablehlo.dot_general %1845, %1846, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1848 = stablehlo.reshape %1847 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1849 = stablehlo.add %1823, %1848 : tensor<8x100x4096xf32>
%1850 = stablehlo.power %1849, %cst_3 : tensor<8x100x4096xf32>
%1851 = stablehlo.reduce(%1850 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1852 = stablehlo.multiply %1851, %cst_2 : tensor<8x100xf32>
%1853 = stablehlo.reshape %1852 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1854 = stablehlo.add %1853, %cst_1 : tensor<8x100x1xf32>
%1855 = stablehlo.rsqrt %1854 : tensor<8x100x1xf32>
%1856 = stablehlo.reshape %1855 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1857 = stablehlo.broadcast_in_dim %1856, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1858 = stablehlo.multiply %1849, %1857 : tensor<8x100x4096xf32>
%1859 = stablehlo.broadcast_in_dim %arg85, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1860 = stablehlo.multiply %1858, %1859 : tensor<8x100x4096xf32>
%1861 = stablehlo.reshape %1860 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1862 = stablehlo.transpose %arg292, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1863 = stablehlo.dot_general %1861, %1862, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1864 = stablehlo.reshape %1863 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1865 = stablehlo.transpose %1864, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1866 = stablehlo.reshape %1865 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1867 = stablehlo.slice %1866 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1868 = stablehlo.reshape %1867 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1869 = stablehlo.slice %1866 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1870 = stablehlo.reshape %1869 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1871 = stablehlo.complex %1868, %1870 : tensor<256x100x64xcomplex<f32>>
%1872 = stablehlo.multiply %1871, %28 : tensor<256x100x64xcomplex<f32>>
%1873 = stablehlo.real %1872 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1874 = stablehlo.reshape %1873 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1875 = stablehlo.imag %1872 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1876 = stablehlo.reshape %1875 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1877 = stablehlo.concatenate %1874, %1876, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1878 = stablehlo.reshape %1877 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1879 = stablehlo.transpose %arg290, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1880 = stablehlo.dot_general %1861, %1879, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1881 = stablehlo.reshape %1880 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1882 = stablehlo.transpose %1881, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1883 = stablehlo.reshape %1882 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1884 = stablehlo.slice %1883 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1885 = stablehlo.reshape %1884 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1886 = stablehlo.slice %1883 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1887 = stablehlo.reshape %1886 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1888 = stablehlo.complex %1885, %1887 : tensor<256x100x64xcomplex<f32>>
%1889 = stablehlo.multiply %1888, %28 : tensor<256x100x64xcomplex<f32>>
%1890 = stablehlo.real %1889 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1891 = stablehlo.reshape %1890 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1892 = stablehlo.imag %1889 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1893 = stablehlo.reshape %1892 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1894 = stablehlo.concatenate %1891, %1893, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1895 = stablehlo.reshape %1894 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1896 = stablehlo.transpose %1895, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1897 = "stablehlo.scatter"(%arg291, %39, %1896) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1898 = stablehlo.transpose %1897, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%1899 = stablehlo.reshape %1898 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%1900 = stablehlo.dot_general %1878, %1899, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%1901 = stablehlo.reshape %1900 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%1902 = stablehlo.divide %1901, %cst : tensor<8x32x100x1024xf32>
%1903 = stablehlo.add %1902, %66 : tensor<8x32x100x1024xf32>
%1904 = stablehlo.reduce(%1903 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1905 = stablehlo.broadcast_in_dim %1904, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1906 = stablehlo.subtract %1903, %1905 : tensor<8x32x100x1024xf32>
%1907 = stablehlo.exponential %1906 : tensor<8x32x100x1024xf32>
%1908 = stablehlo.reduce(%1907 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%1909 = stablehlo.broadcast_in_dim %1908, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%1910 = stablehlo.divide %1907, %1909 : tensor<8x32x100x1024xf32>
%1911 = stablehlo.reshape %1910 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%1912 = stablehlo.transpose %arg84, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1913 = stablehlo.dot_general %1861, %1912, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1914 = stablehlo.reshape %1913 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1915 = "stablehlo.scatter"(%arg289, %39, %1914) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%1916 = stablehlo.transpose %1915, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%1917 = stablehlo.reshape %1916 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%1918 = stablehlo.dot_general %1911, %1917, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%1919 = stablehlo.reshape %1918 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%1920 = stablehlo.transpose %1919, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1921 = stablehlo.reshape %1920 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%1922 = stablehlo.transpose %arg83, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1923 = stablehlo.dot_general %1921, %1922, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1924 = stablehlo.reshape %1923 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1925 = stablehlo.add %1849, %1924 : tensor<8x100x4096xf32>
%1926 = stablehlo.power %1925, %cst_3 : tensor<8x100x4096xf32>
%1927 = stablehlo.reduce(%1926 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1928 = stablehlo.multiply %1927, %cst_2 : tensor<8x100xf32>
%1929 = stablehlo.reshape %1928 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1930 = stablehlo.add %1929, %cst_1 : tensor<8x100x1xf32>
%1931 = stablehlo.rsqrt %1930 : tensor<8x100x1xf32>
%1932 = stablehlo.reshape %1931 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1933 = stablehlo.broadcast_in_dim %1932, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1934 = stablehlo.multiply %1925, %1933 : tensor<8x100x4096xf32>
%1935 = stablehlo.broadcast_in_dim %arg82, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1936 = stablehlo.multiply %1934, %1935 : tensor<8x100x4096xf32>
%1937 = stablehlo.reshape %1936 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1938 = stablehlo.transpose %arg293, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1939 = stablehlo.dot_general %1937, %1938, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1940 = stablehlo.reshape %1939 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1941 = stablehlo.logistic %1940 : tensor<8x100x11008xf32>
%1942 = stablehlo.multiply %1940, %1941 : tensor<8x100x11008xf32>
%1943 = stablehlo.transpose %arg81, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%1944 = stablehlo.dot_general %1937, %1943, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%1945 = stablehlo.reshape %1944 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%1946 = stablehlo.multiply %1942, %1945 : tensor<8x100x11008xf32>
%1947 = stablehlo.reshape %1946 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%1948 = stablehlo.transpose %arg80, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%1949 = stablehlo.dot_general %1947, %1948, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%1950 = stablehlo.reshape %1949 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%1951 = stablehlo.add %1925, %1950 : tensor<8x100x4096xf32>
%1952 = stablehlo.power %1951, %cst_3 : tensor<8x100x4096xf32>
%1953 = stablehlo.reduce(%1952 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%1954 = stablehlo.multiply %1953, %cst_2 : tensor<8x100xf32>
%1955 = stablehlo.reshape %1954 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%1956 = stablehlo.add %1955, %cst_1 : tensor<8x100x1xf32>
%1957 = stablehlo.rsqrt %1956 : tensor<8x100x1xf32>
%1958 = stablehlo.reshape %1957 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%1959 = stablehlo.broadcast_in_dim %1958, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%1960 = stablehlo.multiply %1951, %1959 : tensor<8x100x4096xf32>
%1961 = stablehlo.broadcast_in_dim %arg79, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%1962 = stablehlo.multiply %1960, %1961 : tensor<8x100x4096xf32>
%1963 = stablehlo.reshape %1962 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%1964 = stablehlo.transpose %arg297, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1965 = stablehlo.dot_general %1963, %1964, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1966 = stablehlo.reshape %1965 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1967 = stablehlo.transpose %1966, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1968 = stablehlo.reshape %1967 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1969 = stablehlo.slice %1968 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1970 = stablehlo.reshape %1969 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1971 = stablehlo.slice %1968 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1972 = stablehlo.reshape %1971 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1973 = stablehlo.complex %1970, %1972 : tensor<256x100x64xcomplex<f32>>
%1974 = stablehlo.multiply %1973, %28 : tensor<256x100x64xcomplex<f32>>
%1975 = stablehlo.real %1974 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1976 = stablehlo.reshape %1975 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1977 = stablehlo.imag %1974 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1978 = stablehlo.reshape %1977 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1979 = stablehlo.concatenate %1976, %1978, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1980 = stablehlo.reshape %1979 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%1981 = stablehlo.transpose %arg295, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%1982 = stablehlo.dot_general %1963, %1981, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%1983 = stablehlo.reshape %1982 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%1984 = stablehlo.transpose %1983, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%1985 = stablehlo.reshape %1984 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%1986 = stablehlo.slice %1985 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1987 = stablehlo.reshape %1986 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1988 = stablehlo.slice %1985 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%1989 = stablehlo.reshape %1988 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%1990 = stablehlo.complex %1987, %1989 : tensor<256x100x64xcomplex<f32>>
%1991 = stablehlo.multiply %1990, %28 : tensor<256x100x64xcomplex<f32>>
%1992 = stablehlo.real %1991 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1993 = stablehlo.reshape %1992 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1994 = stablehlo.imag %1991 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%1995 = stablehlo.reshape %1994 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%1996 = stablehlo.concatenate %1993, %1995, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%1997 = stablehlo.reshape %1996 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%1998 = stablehlo.transpose %1997, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%1999 = "stablehlo.scatter"(%arg296, %39, %1998) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2000 = stablehlo.transpose %1999, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2001 = stablehlo.reshape %2000 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2002 = stablehlo.dot_general %1980, %2001, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2003 = stablehlo.reshape %2002 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2004 = stablehlo.divide %2003, %cst : tensor<8x32x100x1024xf32>
%2005 = stablehlo.add %2004, %66 : tensor<8x32x100x1024xf32>
%2006 = stablehlo.reduce(%2005 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2007 = stablehlo.broadcast_in_dim %2006, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2008 = stablehlo.subtract %2005, %2007 : tensor<8x32x100x1024xf32>
%2009 = stablehlo.exponential %2008 : tensor<8x32x100x1024xf32>
%2010 = stablehlo.reduce(%2009 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2011 = stablehlo.broadcast_in_dim %2010, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2012 = stablehlo.divide %2009, %2011 : tensor<8x32x100x1024xf32>
%2013 = stablehlo.reshape %2012 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2014 = stablehlo.transpose %arg78, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2015 = stablehlo.dot_general %1963, %2014, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2016 = stablehlo.reshape %2015 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2017 = "stablehlo.scatter"(%arg294, %39, %2016) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2018 = stablehlo.transpose %2017, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2019 = stablehlo.reshape %2018 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2020 = stablehlo.dot_general %2013, %2019, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2021 = stablehlo.reshape %2020 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2022 = stablehlo.transpose %2021, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2023 = stablehlo.reshape %2022 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2024 = stablehlo.transpose %arg77, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2025 = stablehlo.dot_general %2023, %2024, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2026 = stablehlo.reshape %2025 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2027 = stablehlo.add %1951, %2026 : tensor<8x100x4096xf32>
%2028 = stablehlo.power %2027, %cst_3 : tensor<8x100x4096xf32>
%2029 = stablehlo.reduce(%2028 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2030 = stablehlo.multiply %2029, %cst_2 : tensor<8x100xf32>
%2031 = stablehlo.reshape %2030 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2032 = stablehlo.add %2031, %cst_1 : tensor<8x100x1xf32>
%2033 = stablehlo.rsqrt %2032 : tensor<8x100x1xf32>
%2034 = stablehlo.reshape %2033 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2035 = stablehlo.broadcast_in_dim %2034, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2036 = stablehlo.multiply %2027, %2035 : tensor<8x100x4096xf32>
%2037 = stablehlo.broadcast_in_dim %arg76, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2038 = stablehlo.multiply %2036, %2037 : tensor<8x100x4096xf32>
%2039 = stablehlo.reshape %2038 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2040 = stablehlo.transpose %arg298, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2041 = stablehlo.dot_general %2039, %2040, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2042 = stablehlo.reshape %2041 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2043 = stablehlo.logistic %2042 : tensor<8x100x11008xf32>
%2044 = stablehlo.multiply %2042, %2043 : tensor<8x100x11008xf32>
%2045 = stablehlo.transpose %arg75, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2046 = stablehlo.dot_general %2039, %2045, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2047 = stablehlo.reshape %2046 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2048 = stablehlo.multiply %2044, %2047 : tensor<8x100x11008xf32>
%2049 = stablehlo.reshape %2048 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2050 = stablehlo.transpose %arg74, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2051 = stablehlo.dot_general %2049, %2050, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2052 = stablehlo.reshape %2051 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2053 = stablehlo.add %2027, %2052 : tensor<8x100x4096xf32>
%2054 = stablehlo.power %2053, %cst_3 : tensor<8x100x4096xf32>
%2055 = stablehlo.reduce(%2054 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2056 = stablehlo.multiply %2055, %cst_2 : tensor<8x100xf32>
%2057 = stablehlo.reshape %2056 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2058 = stablehlo.add %2057, %cst_1 : tensor<8x100x1xf32>
%2059 = stablehlo.rsqrt %2058 : tensor<8x100x1xf32>
%2060 = stablehlo.reshape %2059 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2061 = stablehlo.broadcast_in_dim %2060, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2062 = stablehlo.multiply %2053, %2061 : tensor<8x100x4096xf32>
%2063 = stablehlo.broadcast_in_dim %arg73, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2064 = stablehlo.multiply %2062, %2063 : tensor<8x100x4096xf32>
%2065 = stablehlo.reshape %2064 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2066 = stablehlo.transpose %arg302, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2067 = stablehlo.dot_general %2065, %2066, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2068 = stablehlo.reshape %2067 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2069 = stablehlo.transpose %2068, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2070 = stablehlo.reshape %2069 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2071 = stablehlo.slice %2070 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2072 = stablehlo.reshape %2071 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2073 = stablehlo.slice %2070 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2074 = stablehlo.reshape %2073 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2075 = stablehlo.complex %2072, %2074 : tensor<256x100x64xcomplex<f32>>
%2076 = stablehlo.multiply %2075, %28 : tensor<256x100x64xcomplex<f32>>
%2077 = stablehlo.real %2076 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2078 = stablehlo.reshape %2077 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2079 = stablehlo.imag %2076 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2080 = stablehlo.reshape %2079 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2081 = stablehlo.concatenate %2078, %2080, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2082 = stablehlo.reshape %2081 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2083 = stablehlo.transpose %arg300, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2084 = stablehlo.dot_general %2065, %2083, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2085 = stablehlo.reshape %2084 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2086 = stablehlo.transpose %2085, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2087 = stablehlo.reshape %2086 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2088 = stablehlo.slice %2087 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2089 = stablehlo.reshape %2088 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2090 = stablehlo.slice %2087 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2091 = stablehlo.reshape %2090 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2092 = stablehlo.complex %2089, %2091 : tensor<256x100x64xcomplex<f32>>
%2093 = stablehlo.multiply %2092, %28 : tensor<256x100x64xcomplex<f32>>
%2094 = stablehlo.real %2093 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2095 = stablehlo.reshape %2094 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2096 = stablehlo.imag %2093 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2097 = stablehlo.reshape %2096 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2098 = stablehlo.concatenate %2095, %2097, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2099 = stablehlo.reshape %2098 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2100 = stablehlo.transpose %2099, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2101 = "stablehlo.scatter"(%arg301, %39, %2100) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2102 = stablehlo.transpose %2101, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2103 = stablehlo.reshape %2102 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2104 = stablehlo.dot_general %2082, %2103, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2105 = stablehlo.reshape %2104 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2106 = stablehlo.divide %2105, %cst : tensor<8x32x100x1024xf32>
%2107 = stablehlo.add %2106, %66 : tensor<8x32x100x1024xf32>
%2108 = stablehlo.reduce(%2107 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2109 = stablehlo.broadcast_in_dim %2108, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2110 = stablehlo.subtract %2107, %2109 : tensor<8x32x100x1024xf32>
%2111 = stablehlo.exponential %2110 : tensor<8x32x100x1024xf32>
%2112 = stablehlo.reduce(%2111 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2113 = stablehlo.broadcast_in_dim %2112, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2114 = stablehlo.divide %2111, %2113 : tensor<8x32x100x1024xf32>
%2115 = stablehlo.reshape %2114 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2116 = stablehlo.transpose %arg72, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2117 = stablehlo.dot_general %2065, %2116, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2118 = stablehlo.reshape %2117 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2119 = "stablehlo.scatter"(%arg299, %39, %2118) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2120 = stablehlo.transpose %2119, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2121 = stablehlo.reshape %2120 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2122 = stablehlo.dot_general %2115, %2121, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2123 = stablehlo.reshape %2122 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2124 = stablehlo.transpose %2123, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2125 = stablehlo.reshape %2124 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2126 = stablehlo.transpose %arg71, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2127 = stablehlo.dot_general %2125, %2126, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2128 = stablehlo.reshape %2127 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2129 = stablehlo.add %2053, %2128 : tensor<8x100x4096xf32>
%2130 = stablehlo.power %2129, %cst_3 : tensor<8x100x4096xf32>
%2131 = stablehlo.reduce(%2130 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2132 = stablehlo.multiply %2131, %cst_2 : tensor<8x100xf32>
%2133 = stablehlo.reshape %2132 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2134 = stablehlo.add %2133, %cst_1 : tensor<8x100x1xf32>
%2135 = stablehlo.rsqrt %2134 : tensor<8x100x1xf32>
%2136 = stablehlo.reshape %2135 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2137 = stablehlo.broadcast_in_dim %2136, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2138 = stablehlo.multiply %2129, %2137 : tensor<8x100x4096xf32>
%2139 = stablehlo.broadcast_in_dim %arg70, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2140 = stablehlo.multiply %2138, %2139 : tensor<8x100x4096xf32>
%2141 = stablehlo.reshape %2140 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2142 = stablehlo.transpose %arg303, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2143 = stablehlo.dot_general %2141, %2142, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2144 = stablehlo.reshape %2143 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2145 = stablehlo.logistic %2144 : tensor<8x100x11008xf32>
%2146 = stablehlo.multiply %2144, %2145 : tensor<8x100x11008xf32>
%2147 = stablehlo.transpose %arg69, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2148 = stablehlo.dot_general %2141, %2147, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2149 = stablehlo.reshape %2148 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2150 = stablehlo.multiply %2146, %2149 : tensor<8x100x11008xf32>
%2151 = stablehlo.reshape %2150 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2152 = stablehlo.transpose %arg68, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2153 = stablehlo.dot_general %2151, %2152, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2154 = stablehlo.reshape %2153 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2155 = stablehlo.add %2129, %2154 : tensor<8x100x4096xf32>
%2156 = stablehlo.power %2155, %cst_3 : tensor<8x100x4096xf32>
%2157 = stablehlo.reduce(%2156 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2158 = stablehlo.multiply %2157, %cst_2 : tensor<8x100xf32>
%2159 = stablehlo.reshape %2158 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2160 = stablehlo.add %2159, %cst_1 : tensor<8x100x1xf32>
%2161 = stablehlo.rsqrt %2160 : tensor<8x100x1xf32>
%2162 = stablehlo.reshape %2161 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2163 = stablehlo.broadcast_in_dim %2162, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2164 = stablehlo.multiply %2155, %2163 : tensor<8x100x4096xf32>
%2165 = stablehlo.broadcast_in_dim %arg67, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2166 = stablehlo.multiply %2164, %2165 : tensor<8x100x4096xf32>
%2167 = stablehlo.reshape %2166 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2168 = stablehlo.transpose %arg307, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2169 = stablehlo.dot_general %2167, %2168, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2170 = stablehlo.reshape %2169 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2171 = stablehlo.transpose %2170, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2172 = stablehlo.reshape %2171 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2173 = stablehlo.slice %2172 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2174 = stablehlo.reshape %2173 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2175 = stablehlo.slice %2172 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2176 = stablehlo.reshape %2175 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2177 = stablehlo.complex %2174, %2176 : tensor<256x100x64xcomplex<f32>>
%2178 = stablehlo.multiply %2177, %28 : tensor<256x100x64xcomplex<f32>>
%2179 = stablehlo.real %2178 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2180 = stablehlo.reshape %2179 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2181 = stablehlo.imag %2178 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2182 = stablehlo.reshape %2181 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2183 = stablehlo.concatenate %2180, %2182, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2184 = stablehlo.reshape %2183 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2185 = stablehlo.transpose %arg305, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2186 = stablehlo.dot_general %2167, %2185, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2187 = stablehlo.reshape %2186 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2188 = stablehlo.transpose %2187, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2189 = stablehlo.reshape %2188 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2190 = stablehlo.slice %2189 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2191 = stablehlo.reshape %2190 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2192 = stablehlo.slice %2189 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2193 = stablehlo.reshape %2192 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2194 = stablehlo.complex %2191, %2193 : tensor<256x100x64xcomplex<f32>>
%2195 = stablehlo.multiply %2194, %28 : tensor<256x100x64xcomplex<f32>>
%2196 = stablehlo.real %2195 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2197 = stablehlo.reshape %2196 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2198 = stablehlo.imag %2195 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2199 = stablehlo.reshape %2198 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2200 = stablehlo.concatenate %2197, %2199, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2201 = stablehlo.reshape %2200 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2202 = stablehlo.transpose %2201, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2203 = "stablehlo.scatter"(%arg306, %39, %2202) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2204 = stablehlo.transpose %2203, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2205 = stablehlo.reshape %2204 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2206 = stablehlo.dot_general %2184, %2205, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2207 = stablehlo.reshape %2206 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2208 = stablehlo.divide %2207, %cst : tensor<8x32x100x1024xf32>
%2209 = stablehlo.add %2208, %66 : tensor<8x32x100x1024xf32>
%2210 = stablehlo.reduce(%2209 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2211 = stablehlo.broadcast_in_dim %2210, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2212 = stablehlo.subtract %2209, %2211 : tensor<8x32x100x1024xf32>
%2213 = stablehlo.exponential %2212 : tensor<8x32x100x1024xf32>
%2214 = stablehlo.reduce(%2213 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2215 = stablehlo.broadcast_in_dim %2214, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2216 = stablehlo.divide %2213, %2215 : tensor<8x32x100x1024xf32>
%2217 = stablehlo.reshape %2216 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2218 = stablehlo.transpose %arg66, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2219 = stablehlo.dot_general %2167, %2218, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2220 = stablehlo.reshape %2219 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2221 = "stablehlo.scatter"(%arg304, %39, %2220) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2222 = stablehlo.transpose %2221, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2223 = stablehlo.reshape %2222 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2224 = stablehlo.dot_general %2217, %2223, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2225 = stablehlo.reshape %2224 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2226 = stablehlo.transpose %2225, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2227 = stablehlo.reshape %2226 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2228 = stablehlo.transpose %arg65, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2229 = stablehlo.dot_general %2227, %2228, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2230 = stablehlo.reshape %2229 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2231 = stablehlo.add %2155, %2230 : tensor<8x100x4096xf32>
%2232 = stablehlo.power %2231, %cst_3 : tensor<8x100x4096xf32>
%2233 = stablehlo.reduce(%2232 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2234 = stablehlo.multiply %2233, %cst_2 : tensor<8x100xf32>
%2235 = stablehlo.reshape %2234 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2236 = stablehlo.add %2235, %cst_1 : tensor<8x100x1xf32>
%2237 = stablehlo.rsqrt %2236 : tensor<8x100x1xf32>
%2238 = stablehlo.reshape %2237 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2239 = stablehlo.broadcast_in_dim %2238, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2240 = stablehlo.multiply %2231, %2239 : tensor<8x100x4096xf32>
%2241 = stablehlo.broadcast_in_dim %arg64, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2242 = stablehlo.multiply %2240, %2241 : tensor<8x100x4096xf32>
%2243 = stablehlo.reshape %2242 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2244 = stablehlo.transpose %arg308, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2245 = stablehlo.dot_general %2243, %2244, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2246 = stablehlo.reshape %2245 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2247 = stablehlo.logistic %2246 : tensor<8x100x11008xf32>
%2248 = stablehlo.multiply %2246, %2247 : tensor<8x100x11008xf32>
%2249 = stablehlo.transpose %arg63, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2250 = stablehlo.dot_general %2243, %2249, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2251 = stablehlo.reshape %2250 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2252 = stablehlo.multiply %2248, %2251 : tensor<8x100x11008xf32>
%2253 = stablehlo.reshape %2252 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2254 = stablehlo.transpose %arg62, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2255 = stablehlo.dot_general %2253, %2254, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2256 = stablehlo.reshape %2255 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2257 = stablehlo.add %2231, %2256 : tensor<8x100x4096xf32>
%2258 = stablehlo.power %2257, %cst_3 : tensor<8x100x4096xf32>
%2259 = stablehlo.reduce(%2258 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2260 = stablehlo.multiply %2259, %cst_2 : tensor<8x100xf32>
%2261 = stablehlo.reshape %2260 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2262 = stablehlo.add %2261, %cst_1 : tensor<8x100x1xf32>
%2263 = stablehlo.rsqrt %2262 : tensor<8x100x1xf32>
%2264 = stablehlo.reshape %2263 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2265 = stablehlo.broadcast_in_dim %2264, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2266 = stablehlo.multiply %2257, %2265 : tensor<8x100x4096xf32>
%2267 = stablehlo.broadcast_in_dim %arg61, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2268 = stablehlo.multiply %2266, %2267 : tensor<8x100x4096xf32>
%2269 = stablehlo.reshape %2268 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2270 = stablehlo.transpose %arg312, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2271 = stablehlo.dot_general %2269, %2270, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2272 = stablehlo.reshape %2271 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2273 = stablehlo.transpose %2272, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2274 = stablehlo.reshape %2273 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2275 = stablehlo.slice %2274 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2276 = stablehlo.reshape %2275 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2277 = stablehlo.slice %2274 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2278 = stablehlo.reshape %2277 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2279 = stablehlo.complex %2276, %2278 : tensor<256x100x64xcomplex<f32>>
%2280 = stablehlo.multiply %2279, %28 : tensor<256x100x64xcomplex<f32>>
%2281 = stablehlo.real %2280 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2282 = stablehlo.reshape %2281 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2283 = stablehlo.imag %2280 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2284 = stablehlo.reshape %2283 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2285 = stablehlo.concatenate %2282, %2284, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2286 = stablehlo.reshape %2285 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2287 = stablehlo.transpose %arg310, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2288 = stablehlo.dot_general %2269, %2287, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2289 = stablehlo.reshape %2288 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2290 = stablehlo.transpose %2289, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2291 = stablehlo.reshape %2290 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2292 = stablehlo.slice %2291 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2293 = stablehlo.reshape %2292 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2294 = stablehlo.slice %2291 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2295 = stablehlo.reshape %2294 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2296 = stablehlo.complex %2293, %2295 : tensor<256x100x64xcomplex<f32>>
%2297 = stablehlo.multiply %2296, %28 : tensor<256x100x64xcomplex<f32>>
%2298 = stablehlo.real %2297 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2299 = stablehlo.reshape %2298 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2300 = stablehlo.imag %2297 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2301 = stablehlo.reshape %2300 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2302 = stablehlo.concatenate %2299, %2301, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2303 = stablehlo.reshape %2302 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2304 = stablehlo.transpose %2303, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2305 = "stablehlo.scatter"(%arg311, %39, %2304) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2306 = stablehlo.transpose %2305, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2307 = stablehlo.reshape %2306 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2308 = stablehlo.dot_general %2286, %2307, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2309 = stablehlo.reshape %2308 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2310 = stablehlo.divide %2309, %cst : tensor<8x32x100x1024xf32>
%2311 = stablehlo.add %2310, %66 : tensor<8x32x100x1024xf32>
%2312 = stablehlo.reduce(%2311 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2313 = stablehlo.broadcast_in_dim %2312, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2314 = stablehlo.subtract %2311, %2313 : tensor<8x32x100x1024xf32>
%2315 = stablehlo.exponential %2314 : tensor<8x32x100x1024xf32>
%2316 = stablehlo.reduce(%2315 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2317 = stablehlo.broadcast_in_dim %2316, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2318 = stablehlo.divide %2315, %2317 : tensor<8x32x100x1024xf32>
%2319 = stablehlo.reshape %2318 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2320 = stablehlo.transpose %arg60, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2321 = stablehlo.dot_general %2269, %2320, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2322 = stablehlo.reshape %2321 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2323 = "stablehlo.scatter"(%arg309, %39, %2322) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2324 = stablehlo.transpose %2323, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2325 = stablehlo.reshape %2324 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2326 = stablehlo.dot_general %2319, %2325, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2327 = stablehlo.reshape %2326 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2328 = stablehlo.transpose %2327, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2329 = stablehlo.reshape %2328 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2330 = stablehlo.transpose %arg59, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2331 = stablehlo.dot_general %2329, %2330, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2332 = stablehlo.reshape %2331 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2333 = stablehlo.add %2257, %2332 : tensor<8x100x4096xf32>
%2334 = stablehlo.power %2333, %cst_3 : tensor<8x100x4096xf32>
%2335 = stablehlo.reduce(%2334 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2336 = stablehlo.multiply %2335, %cst_2 : tensor<8x100xf32>
%2337 = stablehlo.reshape %2336 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2338 = stablehlo.add %2337, %cst_1 : tensor<8x100x1xf32>
%2339 = stablehlo.rsqrt %2338 : tensor<8x100x1xf32>
%2340 = stablehlo.reshape %2339 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2341 = stablehlo.broadcast_in_dim %2340, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2342 = stablehlo.multiply %2333, %2341 : tensor<8x100x4096xf32>
%2343 = stablehlo.broadcast_in_dim %arg58, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2344 = stablehlo.multiply %2342, %2343 : tensor<8x100x4096xf32>
%2345 = stablehlo.reshape %2344 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2346 = stablehlo.transpose %arg313, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2347 = stablehlo.dot_general %2345, %2346, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2348 = stablehlo.reshape %2347 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2349 = stablehlo.logistic %2348 : tensor<8x100x11008xf32>
%2350 = stablehlo.multiply %2348, %2349 : tensor<8x100x11008xf32>
%2351 = stablehlo.transpose %arg57, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2352 = stablehlo.dot_general %2345, %2351, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2353 = stablehlo.reshape %2352 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2354 = stablehlo.multiply %2350, %2353 : tensor<8x100x11008xf32>
%2355 = stablehlo.reshape %2354 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2356 = stablehlo.transpose %arg56, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2357 = stablehlo.dot_general %2355, %2356, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2358 = stablehlo.reshape %2357 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2359 = stablehlo.add %2333, %2358 : tensor<8x100x4096xf32>
%2360 = stablehlo.power %2359, %cst_3 : tensor<8x100x4096xf32>
%2361 = stablehlo.reduce(%2360 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2362 = stablehlo.multiply %2361, %cst_2 : tensor<8x100xf32>
%2363 = stablehlo.reshape %2362 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2364 = stablehlo.add %2363, %cst_1 : tensor<8x100x1xf32>
%2365 = stablehlo.rsqrt %2364 : tensor<8x100x1xf32>
%2366 = stablehlo.reshape %2365 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2367 = stablehlo.broadcast_in_dim %2366, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2368 = stablehlo.multiply %2359, %2367 : tensor<8x100x4096xf32>
%2369 = stablehlo.broadcast_in_dim %arg55, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2370 = stablehlo.multiply %2368, %2369 : tensor<8x100x4096xf32>
%2371 = stablehlo.reshape %2370 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2372 = stablehlo.transpose %arg317, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2373 = stablehlo.dot_general %2371, %2372, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2374 = stablehlo.reshape %2373 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2375 = stablehlo.transpose %2374, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2376 = stablehlo.reshape %2375 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2377 = stablehlo.slice %2376 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2378 = stablehlo.reshape %2377 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2379 = stablehlo.slice %2376 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2380 = stablehlo.reshape %2379 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2381 = stablehlo.complex %2378, %2380 : tensor<256x100x64xcomplex<f32>>
%2382 = stablehlo.multiply %2381, %28 : tensor<256x100x64xcomplex<f32>>
%2383 = stablehlo.real %2382 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2384 = stablehlo.reshape %2383 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2385 = stablehlo.imag %2382 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2386 = stablehlo.reshape %2385 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2387 = stablehlo.concatenate %2384, %2386, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2388 = stablehlo.reshape %2387 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2389 = stablehlo.transpose %arg315, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2390 = stablehlo.dot_general %2371, %2389, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2391 = stablehlo.reshape %2390 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2392 = stablehlo.transpose %2391, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2393 = stablehlo.reshape %2392 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2394 = stablehlo.slice %2393 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2395 = stablehlo.reshape %2394 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2396 = stablehlo.slice %2393 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2397 = stablehlo.reshape %2396 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2398 = stablehlo.complex %2395, %2397 : tensor<256x100x64xcomplex<f32>>
%2399 = stablehlo.multiply %2398, %28 : tensor<256x100x64xcomplex<f32>>
%2400 = stablehlo.real %2399 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2401 = stablehlo.reshape %2400 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2402 = stablehlo.imag %2399 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2403 = stablehlo.reshape %2402 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2404 = stablehlo.concatenate %2401, %2403, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2405 = stablehlo.reshape %2404 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2406 = stablehlo.transpose %2405, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2407 = "stablehlo.scatter"(%arg316, %39, %2406) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2408 = stablehlo.transpose %2407, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2409 = stablehlo.reshape %2408 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2410 = stablehlo.dot_general %2388, %2409, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2411 = stablehlo.reshape %2410 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2412 = stablehlo.divide %2411, %cst : tensor<8x32x100x1024xf32>
%2413 = stablehlo.add %2412, %66 : tensor<8x32x100x1024xf32>
%2414 = stablehlo.reduce(%2413 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2415 = stablehlo.broadcast_in_dim %2414, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2416 = stablehlo.subtract %2413, %2415 : tensor<8x32x100x1024xf32>
%2417 = stablehlo.exponential %2416 : tensor<8x32x100x1024xf32>
%2418 = stablehlo.reduce(%2417 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2419 = stablehlo.broadcast_in_dim %2418, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2420 = stablehlo.divide %2417, %2419 : tensor<8x32x100x1024xf32>
%2421 = stablehlo.reshape %2420 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2422 = stablehlo.transpose %arg54, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2423 = stablehlo.dot_general %2371, %2422, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2424 = stablehlo.reshape %2423 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2425 = "stablehlo.scatter"(%arg314, %39, %2424) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2426 = stablehlo.transpose %2425, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2427 = stablehlo.reshape %2426 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2428 = stablehlo.dot_general %2421, %2427, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2429 = stablehlo.reshape %2428 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2430 = stablehlo.transpose %2429, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2431 = stablehlo.reshape %2430 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2432 = stablehlo.transpose %arg53, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2433 = stablehlo.dot_general %2431, %2432, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2434 = stablehlo.reshape %2433 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2435 = stablehlo.add %2359, %2434 : tensor<8x100x4096xf32>
%2436 = stablehlo.power %2435, %cst_3 : tensor<8x100x4096xf32>
%2437 = stablehlo.reduce(%2436 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2438 = stablehlo.multiply %2437, %cst_2 : tensor<8x100xf32>
%2439 = stablehlo.reshape %2438 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2440 = stablehlo.add %2439, %cst_1 : tensor<8x100x1xf32>
%2441 = stablehlo.rsqrt %2440 : tensor<8x100x1xf32>
%2442 = stablehlo.reshape %2441 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2443 = stablehlo.broadcast_in_dim %2442, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2444 = stablehlo.multiply %2435, %2443 : tensor<8x100x4096xf32>
%2445 = stablehlo.broadcast_in_dim %arg52, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2446 = stablehlo.multiply %2444, %2445 : tensor<8x100x4096xf32>
%2447 = stablehlo.reshape %2446 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2448 = stablehlo.transpose %arg318, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2449 = stablehlo.dot_general %2447, %2448, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2450 = stablehlo.reshape %2449 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2451 = stablehlo.logistic %2450 : tensor<8x100x11008xf32>
%2452 = stablehlo.multiply %2450, %2451 : tensor<8x100x11008xf32>
%2453 = stablehlo.transpose %arg51, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2454 = stablehlo.dot_general %2447, %2453, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2455 = stablehlo.reshape %2454 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2456 = stablehlo.multiply %2452, %2455 : tensor<8x100x11008xf32>
%2457 = stablehlo.reshape %2456 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2458 = stablehlo.transpose %arg50, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2459 = stablehlo.dot_general %2457, %2458, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2460 = stablehlo.reshape %2459 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2461 = stablehlo.add %2435, %2460 : tensor<8x100x4096xf32>
%2462 = stablehlo.power %2461, %cst_3 : tensor<8x100x4096xf32>
%2463 = stablehlo.reduce(%2462 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2464 = stablehlo.multiply %2463, %cst_2 : tensor<8x100xf32>
%2465 = stablehlo.reshape %2464 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2466 = stablehlo.add %2465, %cst_1 : tensor<8x100x1xf32>
%2467 = stablehlo.rsqrt %2466 : tensor<8x100x1xf32>
%2468 = stablehlo.reshape %2467 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2469 = stablehlo.broadcast_in_dim %2468, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2470 = stablehlo.multiply %2461, %2469 : tensor<8x100x4096xf32>
%2471 = stablehlo.broadcast_in_dim %arg49, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2472 = stablehlo.multiply %2470, %2471 : tensor<8x100x4096xf32>
%2473 = stablehlo.reshape %2472 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2474 = stablehlo.transpose %arg322, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2475 = stablehlo.dot_general %2473, %2474, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2476 = stablehlo.reshape %2475 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2477 = stablehlo.transpose %2476, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2478 = stablehlo.reshape %2477 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2479 = stablehlo.slice %2478 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2480 = stablehlo.reshape %2479 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2481 = stablehlo.slice %2478 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2482 = stablehlo.reshape %2481 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2483 = stablehlo.complex %2480, %2482 : tensor<256x100x64xcomplex<f32>>
%2484 = stablehlo.multiply %2483, %28 : tensor<256x100x64xcomplex<f32>>
%2485 = stablehlo.real %2484 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2486 = stablehlo.reshape %2485 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2487 = stablehlo.imag %2484 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2488 = stablehlo.reshape %2487 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2489 = stablehlo.concatenate %2486, %2488, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2490 = stablehlo.reshape %2489 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2491 = stablehlo.transpose %arg320, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2492 = stablehlo.dot_general %2473, %2491, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2493 = stablehlo.reshape %2492 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2494 = stablehlo.transpose %2493, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2495 = stablehlo.reshape %2494 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2496 = stablehlo.slice %2495 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2497 = stablehlo.reshape %2496 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2498 = stablehlo.slice %2495 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2499 = stablehlo.reshape %2498 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2500 = stablehlo.complex %2497, %2499 : tensor<256x100x64xcomplex<f32>>
%2501 = stablehlo.multiply %2500, %28 : tensor<256x100x64xcomplex<f32>>
%2502 = stablehlo.real %2501 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2503 = stablehlo.reshape %2502 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2504 = stablehlo.imag %2501 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2505 = stablehlo.reshape %2504 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2506 = stablehlo.concatenate %2503, %2505, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2507 = stablehlo.reshape %2506 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2508 = stablehlo.transpose %2507, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2509 = "stablehlo.scatter"(%arg321, %39, %2508) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2510 = stablehlo.transpose %2509, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2511 = stablehlo.reshape %2510 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2512 = stablehlo.dot_general %2490, %2511, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2513 = stablehlo.reshape %2512 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2514 = stablehlo.divide %2513, %cst : tensor<8x32x100x1024xf32>
%2515 = stablehlo.add %2514, %66 : tensor<8x32x100x1024xf32>
%2516 = stablehlo.reduce(%2515 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2517 = stablehlo.broadcast_in_dim %2516, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2518 = stablehlo.subtract %2515, %2517 : tensor<8x32x100x1024xf32>
%2519 = stablehlo.exponential %2518 : tensor<8x32x100x1024xf32>
%2520 = stablehlo.reduce(%2519 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2521 = stablehlo.broadcast_in_dim %2520, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2522 = stablehlo.divide %2519, %2521 : tensor<8x32x100x1024xf32>
%2523 = stablehlo.reshape %2522 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2524 = stablehlo.transpose %arg48, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2525 = stablehlo.dot_general %2473, %2524, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2526 = stablehlo.reshape %2525 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2527 = "stablehlo.scatter"(%arg319, %39, %2526) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2528 = stablehlo.transpose %2527, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2529 = stablehlo.reshape %2528 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2530 = stablehlo.dot_general %2523, %2529, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2531 = stablehlo.reshape %2530 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2532 = stablehlo.transpose %2531, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2533 = stablehlo.reshape %2532 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2534 = stablehlo.transpose %arg47, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2535 = stablehlo.dot_general %2533, %2534, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2536 = stablehlo.reshape %2535 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2537 = stablehlo.add %2461, %2536 : tensor<8x100x4096xf32>
%2538 = stablehlo.power %2537, %cst_3 : tensor<8x100x4096xf32>
%2539 = stablehlo.reduce(%2538 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2540 = stablehlo.multiply %2539, %cst_2 : tensor<8x100xf32>
%2541 = stablehlo.reshape %2540 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2542 = stablehlo.add %2541, %cst_1 : tensor<8x100x1xf32>
%2543 = stablehlo.rsqrt %2542 : tensor<8x100x1xf32>
%2544 = stablehlo.reshape %2543 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2545 = stablehlo.broadcast_in_dim %2544, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2546 = stablehlo.multiply %2537, %2545 : tensor<8x100x4096xf32>
%2547 = stablehlo.broadcast_in_dim %arg46, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2548 = stablehlo.multiply %2546, %2547 : tensor<8x100x4096xf32>
%2549 = stablehlo.reshape %2548 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2550 = stablehlo.transpose %arg323, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2551 = stablehlo.dot_general %2549, %2550, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2552 = stablehlo.reshape %2551 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2553 = stablehlo.logistic %2552 : tensor<8x100x11008xf32>
%2554 = stablehlo.multiply %2552, %2553 : tensor<8x100x11008xf32>
%2555 = stablehlo.transpose %arg45, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2556 = stablehlo.dot_general %2549, %2555, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2557 = stablehlo.reshape %2556 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2558 = stablehlo.multiply %2554, %2557 : tensor<8x100x11008xf32>
%2559 = stablehlo.reshape %2558 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2560 = stablehlo.transpose %arg44, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2561 = stablehlo.dot_general %2559, %2560, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2562 = stablehlo.reshape %2561 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2563 = stablehlo.add %2537, %2562 : tensor<8x100x4096xf32>
%2564 = stablehlo.power %2563, %cst_3 : tensor<8x100x4096xf32>
%2565 = stablehlo.reduce(%2564 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2566 = stablehlo.multiply %2565, %cst_2 : tensor<8x100xf32>
%2567 = stablehlo.reshape %2566 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2568 = stablehlo.add %2567, %cst_1 : tensor<8x100x1xf32>
%2569 = stablehlo.rsqrt %2568 : tensor<8x100x1xf32>
%2570 = stablehlo.reshape %2569 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2571 = stablehlo.broadcast_in_dim %2570, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2572 = stablehlo.multiply %2563, %2571 : tensor<8x100x4096xf32>
%2573 = stablehlo.broadcast_in_dim %arg43, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2574 = stablehlo.multiply %2572, %2573 : tensor<8x100x4096xf32>
%2575 = stablehlo.reshape %2574 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2576 = stablehlo.transpose %arg327, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2577 = stablehlo.dot_general %2575, %2576, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2578 = stablehlo.reshape %2577 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2579 = stablehlo.transpose %2578, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2580 = stablehlo.reshape %2579 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2581 = stablehlo.slice %2580 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2582 = stablehlo.reshape %2581 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2583 = stablehlo.slice %2580 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2584 = stablehlo.reshape %2583 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2585 = stablehlo.complex %2582, %2584 : tensor<256x100x64xcomplex<f32>>
%2586 = stablehlo.multiply %2585, %28 : tensor<256x100x64xcomplex<f32>>
%2587 = stablehlo.real %2586 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2588 = stablehlo.reshape %2587 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2589 = stablehlo.imag %2586 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2590 = stablehlo.reshape %2589 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2591 = stablehlo.concatenate %2588, %2590, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2592 = stablehlo.reshape %2591 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2593 = stablehlo.transpose %arg325, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2594 = stablehlo.dot_general %2575, %2593, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2595 = stablehlo.reshape %2594 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2596 = stablehlo.transpose %2595, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2597 = stablehlo.reshape %2596 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2598 = stablehlo.slice %2597 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2599 = stablehlo.reshape %2598 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2600 = stablehlo.slice %2597 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2601 = stablehlo.reshape %2600 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2602 = stablehlo.complex %2599, %2601 : tensor<256x100x64xcomplex<f32>>
%2603 = stablehlo.multiply %2602, %28 : tensor<256x100x64xcomplex<f32>>
%2604 = stablehlo.real %2603 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2605 = stablehlo.reshape %2604 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2606 = stablehlo.imag %2603 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2607 = stablehlo.reshape %2606 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2608 = stablehlo.concatenate %2605, %2607, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2609 = stablehlo.reshape %2608 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2610 = stablehlo.transpose %2609, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2611 = "stablehlo.scatter"(%arg326, %39, %2610) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2612 = stablehlo.transpose %2611, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2613 = stablehlo.reshape %2612 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2614 = stablehlo.dot_general %2592, %2613, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2615 = stablehlo.reshape %2614 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2616 = stablehlo.divide %2615, %cst : tensor<8x32x100x1024xf32>
%2617 = stablehlo.add %2616, %66 : tensor<8x32x100x1024xf32>
%2618 = stablehlo.reduce(%2617 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2619 = stablehlo.broadcast_in_dim %2618, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2620 = stablehlo.subtract %2617, %2619 : tensor<8x32x100x1024xf32>
%2621 = stablehlo.exponential %2620 : tensor<8x32x100x1024xf32>
%2622 = stablehlo.reduce(%2621 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2623 = stablehlo.broadcast_in_dim %2622, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2624 = stablehlo.divide %2621, %2623 : tensor<8x32x100x1024xf32>
%2625 = stablehlo.reshape %2624 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2626 = stablehlo.transpose %arg42, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2627 = stablehlo.dot_general %2575, %2626, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2628 = stablehlo.reshape %2627 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2629 = "stablehlo.scatter"(%arg324, %39, %2628) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2630 = stablehlo.transpose %2629, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2631 = stablehlo.reshape %2630 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2632 = stablehlo.dot_general %2625, %2631, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2633 = stablehlo.reshape %2632 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2634 = stablehlo.transpose %2633, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2635 = stablehlo.reshape %2634 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2636 = stablehlo.transpose %arg41, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2637 = stablehlo.dot_general %2635, %2636, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2638 = stablehlo.reshape %2637 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2639 = stablehlo.add %2563, %2638 : tensor<8x100x4096xf32>
%2640 = stablehlo.power %2639, %cst_3 : tensor<8x100x4096xf32>
%2641 = stablehlo.reduce(%2640 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2642 = stablehlo.multiply %2641, %cst_2 : tensor<8x100xf32>
%2643 = stablehlo.reshape %2642 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2644 = stablehlo.add %2643, %cst_1 : tensor<8x100x1xf32>
%2645 = stablehlo.rsqrt %2644 : tensor<8x100x1xf32>
%2646 = stablehlo.reshape %2645 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2647 = stablehlo.broadcast_in_dim %2646, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2648 = stablehlo.multiply %2639, %2647 : tensor<8x100x4096xf32>
%2649 = stablehlo.broadcast_in_dim %arg40, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2650 = stablehlo.multiply %2648, %2649 : tensor<8x100x4096xf32>
%2651 = stablehlo.reshape %2650 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2652 = stablehlo.transpose %arg328, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2653 = stablehlo.dot_general %2651, %2652, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2654 = stablehlo.reshape %2653 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2655 = stablehlo.logistic %2654 : tensor<8x100x11008xf32>
%2656 = stablehlo.multiply %2654, %2655 : tensor<8x100x11008xf32>
%2657 = stablehlo.transpose %arg39, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2658 = stablehlo.dot_general %2651, %2657, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2659 = stablehlo.reshape %2658 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2660 = stablehlo.multiply %2656, %2659 : tensor<8x100x11008xf32>
%2661 = stablehlo.reshape %2660 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2662 = stablehlo.transpose %arg38, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2663 = stablehlo.dot_general %2661, %2662, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2664 = stablehlo.reshape %2663 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2665 = stablehlo.add %2639, %2664 : tensor<8x100x4096xf32>
%2666 = stablehlo.power %2665, %cst_3 : tensor<8x100x4096xf32>
%2667 = stablehlo.reduce(%2666 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2668 = stablehlo.multiply %2667, %cst_2 : tensor<8x100xf32>
%2669 = stablehlo.reshape %2668 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2670 = stablehlo.add %2669, %cst_1 : tensor<8x100x1xf32>
%2671 = stablehlo.rsqrt %2670 : tensor<8x100x1xf32>
%2672 = stablehlo.reshape %2671 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2673 = stablehlo.broadcast_in_dim %2672, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2674 = stablehlo.multiply %2665, %2673 : tensor<8x100x4096xf32>
%2675 = stablehlo.broadcast_in_dim %arg37, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2676 = stablehlo.multiply %2674, %2675 : tensor<8x100x4096xf32>
%2677 = stablehlo.reshape %2676 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2678 = stablehlo.transpose %arg332, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2679 = stablehlo.dot_general %2677, %2678, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2680 = stablehlo.reshape %2679 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2681 = stablehlo.transpose %2680, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2682 = stablehlo.reshape %2681 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2683 = stablehlo.slice %2682 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2684 = stablehlo.reshape %2683 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2685 = stablehlo.slice %2682 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2686 = stablehlo.reshape %2685 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2687 = stablehlo.complex %2684, %2686 : tensor<256x100x64xcomplex<f32>>
%2688 = stablehlo.multiply %2687, %28 : tensor<256x100x64xcomplex<f32>>
%2689 = stablehlo.real %2688 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2690 = stablehlo.reshape %2689 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2691 = stablehlo.imag %2688 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2692 = stablehlo.reshape %2691 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2693 = stablehlo.concatenate %2690, %2692, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2694 = stablehlo.reshape %2693 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2695 = stablehlo.transpose %arg330, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2696 = stablehlo.dot_general %2677, %2695, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2697 = stablehlo.reshape %2696 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2698 = stablehlo.transpose %2697, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2699 = stablehlo.reshape %2698 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2700 = stablehlo.slice %2699 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2701 = stablehlo.reshape %2700 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2702 = stablehlo.slice %2699 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2703 = stablehlo.reshape %2702 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2704 = stablehlo.complex %2701, %2703 : tensor<256x100x64xcomplex<f32>>
%2705 = stablehlo.multiply %2704, %28 : tensor<256x100x64xcomplex<f32>>
%2706 = stablehlo.real %2705 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2707 = stablehlo.reshape %2706 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2708 = stablehlo.imag %2705 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2709 = stablehlo.reshape %2708 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2710 = stablehlo.concatenate %2707, %2709, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2711 = stablehlo.reshape %2710 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2712 = stablehlo.transpose %2711, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2713 = "stablehlo.scatter"(%arg331, %39, %2712) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2714 = stablehlo.transpose %2713, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2715 = stablehlo.reshape %2714 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2716 = stablehlo.dot_general %2694, %2715, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2717 = stablehlo.reshape %2716 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2718 = stablehlo.divide %2717, %cst : tensor<8x32x100x1024xf32>
%2719 = stablehlo.add %2718, %66 : tensor<8x32x100x1024xf32>
%2720 = stablehlo.reduce(%2719 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2721 = stablehlo.broadcast_in_dim %2720, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2722 = stablehlo.subtract %2719, %2721 : tensor<8x32x100x1024xf32>
%2723 = stablehlo.exponential %2722 : tensor<8x32x100x1024xf32>
%2724 = stablehlo.reduce(%2723 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2725 = stablehlo.broadcast_in_dim %2724, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2726 = stablehlo.divide %2723, %2725 : tensor<8x32x100x1024xf32>
%2727 = stablehlo.reshape %2726 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2728 = stablehlo.transpose %arg36, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2729 = stablehlo.dot_general %2677, %2728, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2730 = stablehlo.reshape %2729 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2731 = "stablehlo.scatter"(%arg329, %39, %2730) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2732 = stablehlo.transpose %2731, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2733 = stablehlo.reshape %2732 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2734 = stablehlo.dot_general %2727, %2733, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2735 = stablehlo.reshape %2734 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2736 = stablehlo.transpose %2735, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2737 = stablehlo.reshape %2736 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2738 = stablehlo.transpose %arg35, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2739 = stablehlo.dot_general %2737, %2738, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2740 = stablehlo.reshape %2739 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2741 = stablehlo.add %2665, %2740 : tensor<8x100x4096xf32>
%2742 = stablehlo.power %2741, %cst_3 : tensor<8x100x4096xf32>
%2743 = stablehlo.reduce(%2742 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2744 = stablehlo.multiply %2743, %cst_2 : tensor<8x100xf32>
%2745 = stablehlo.reshape %2744 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2746 = stablehlo.add %2745, %cst_1 : tensor<8x100x1xf32>
%2747 = stablehlo.rsqrt %2746 : tensor<8x100x1xf32>
%2748 = stablehlo.reshape %2747 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2749 = stablehlo.broadcast_in_dim %2748, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2750 = stablehlo.multiply %2741, %2749 : tensor<8x100x4096xf32>
%2751 = stablehlo.broadcast_in_dim %arg34, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2752 = stablehlo.multiply %2750, %2751 : tensor<8x100x4096xf32>
%2753 = stablehlo.reshape %2752 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2754 = stablehlo.transpose %arg333, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2755 = stablehlo.dot_general %2753, %2754, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2756 = stablehlo.reshape %2755 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2757 = stablehlo.logistic %2756 : tensor<8x100x11008xf32>
%2758 = stablehlo.multiply %2756, %2757 : tensor<8x100x11008xf32>
%2759 = stablehlo.transpose %arg33, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2760 = stablehlo.dot_general %2753, %2759, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2761 = stablehlo.reshape %2760 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2762 = stablehlo.multiply %2758, %2761 : tensor<8x100x11008xf32>
%2763 = stablehlo.reshape %2762 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2764 = stablehlo.transpose %arg32, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2765 = stablehlo.dot_general %2763, %2764, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2766 = stablehlo.reshape %2765 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2767 = stablehlo.add %2741, %2766 : tensor<8x100x4096xf32>
%2768 = stablehlo.power %2767, %cst_3 : tensor<8x100x4096xf32>
%2769 = stablehlo.reduce(%2768 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2770 = stablehlo.multiply %2769, %cst_2 : tensor<8x100xf32>
%2771 = stablehlo.reshape %2770 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2772 = stablehlo.add %2771, %cst_1 : tensor<8x100x1xf32>
%2773 = stablehlo.rsqrt %2772 : tensor<8x100x1xf32>
%2774 = stablehlo.reshape %2773 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2775 = stablehlo.broadcast_in_dim %2774, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2776 = stablehlo.multiply %2767, %2775 : tensor<8x100x4096xf32>
%2777 = stablehlo.broadcast_in_dim %arg31, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2778 = stablehlo.multiply %2776, %2777 : tensor<8x100x4096xf32>
%2779 = stablehlo.reshape %2778 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2780 = stablehlo.transpose %arg337, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2781 = stablehlo.dot_general %2779, %2780, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2782 = stablehlo.reshape %2781 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2783 = stablehlo.transpose %2782, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2784 = stablehlo.reshape %2783 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2785 = stablehlo.slice %2784 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2786 = stablehlo.reshape %2785 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2787 = stablehlo.slice %2784 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2788 = stablehlo.reshape %2787 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2789 = stablehlo.complex %2786, %2788 : tensor<256x100x64xcomplex<f32>>
%2790 = stablehlo.multiply %2789, %28 : tensor<256x100x64xcomplex<f32>>
%2791 = stablehlo.real %2790 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2792 = stablehlo.reshape %2791 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2793 = stablehlo.imag %2790 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2794 = stablehlo.reshape %2793 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2795 = stablehlo.concatenate %2792, %2794, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2796 = stablehlo.reshape %2795 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2797 = stablehlo.transpose %arg335, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2798 = stablehlo.dot_general %2779, %2797, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2799 = stablehlo.reshape %2798 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2800 = stablehlo.transpose %2799, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2801 = stablehlo.reshape %2800 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2802 = stablehlo.slice %2801 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2803 = stablehlo.reshape %2802 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2804 = stablehlo.slice %2801 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2805 = stablehlo.reshape %2804 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2806 = stablehlo.complex %2803, %2805 : tensor<256x100x64xcomplex<f32>>
%2807 = stablehlo.multiply %2806, %28 : tensor<256x100x64xcomplex<f32>>
%2808 = stablehlo.real %2807 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2809 = stablehlo.reshape %2808 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2810 = stablehlo.imag %2807 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2811 = stablehlo.reshape %2810 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2812 = stablehlo.concatenate %2809, %2811, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2813 = stablehlo.reshape %2812 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2814 = stablehlo.transpose %2813, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2815 = "stablehlo.scatter"(%arg336, %39, %2814) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2816 = stablehlo.transpose %2815, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2817 = stablehlo.reshape %2816 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2818 = stablehlo.dot_general %2796, %2817, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2819 = stablehlo.reshape %2818 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2820 = stablehlo.divide %2819, %cst : tensor<8x32x100x1024xf32>
%2821 = stablehlo.add %2820, %66 : tensor<8x32x100x1024xf32>
%2822 = stablehlo.reduce(%2821 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2823 = stablehlo.broadcast_in_dim %2822, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2824 = stablehlo.subtract %2821, %2823 : tensor<8x32x100x1024xf32>
%2825 = stablehlo.exponential %2824 : tensor<8x32x100x1024xf32>
%2826 = stablehlo.reduce(%2825 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2827 = stablehlo.broadcast_in_dim %2826, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2828 = stablehlo.divide %2825, %2827 : tensor<8x32x100x1024xf32>
%2829 = stablehlo.reshape %2828 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2830 = stablehlo.transpose %arg30, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2831 = stablehlo.dot_general %2779, %2830, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2832 = stablehlo.reshape %2831 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2833 = "stablehlo.scatter"(%arg334, %39, %2832) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2834 = stablehlo.transpose %2833, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2835 = stablehlo.reshape %2834 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2836 = stablehlo.dot_general %2829, %2835, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2837 = stablehlo.reshape %2836 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2838 = stablehlo.transpose %2837, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2839 = stablehlo.reshape %2838 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2840 = stablehlo.transpose %arg29, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2841 = stablehlo.dot_general %2839, %2840, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2842 = stablehlo.reshape %2841 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2843 = stablehlo.add %2767, %2842 : tensor<8x100x4096xf32>
%2844 = stablehlo.power %2843, %cst_3 : tensor<8x100x4096xf32>
%2845 = stablehlo.reduce(%2844 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2846 = stablehlo.multiply %2845, %cst_2 : tensor<8x100xf32>
%2847 = stablehlo.reshape %2846 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2848 = stablehlo.add %2847, %cst_1 : tensor<8x100x1xf32>
%2849 = stablehlo.rsqrt %2848 : tensor<8x100x1xf32>
%2850 = stablehlo.reshape %2849 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2851 = stablehlo.broadcast_in_dim %2850, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2852 = stablehlo.multiply %2843, %2851 : tensor<8x100x4096xf32>
%2853 = stablehlo.broadcast_in_dim %arg28, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2854 = stablehlo.multiply %2852, %2853 : tensor<8x100x4096xf32>
%2855 = stablehlo.reshape %2854 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2856 = stablehlo.transpose %arg338, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2857 = stablehlo.dot_general %2855, %2856, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2858 = stablehlo.reshape %2857 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2859 = stablehlo.logistic %2858 : tensor<8x100x11008xf32>
%2860 = stablehlo.multiply %2858, %2859 : tensor<8x100x11008xf32>
%2861 = stablehlo.transpose %arg27, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2862 = stablehlo.dot_general %2855, %2861, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2863 = stablehlo.reshape %2862 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2864 = stablehlo.multiply %2860, %2863 : tensor<8x100x11008xf32>
%2865 = stablehlo.reshape %2864 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2866 = stablehlo.transpose %arg26, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2867 = stablehlo.dot_general %2865, %2866, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2868 = stablehlo.reshape %2867 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2869 = stablehlo.add %2843, %2868 : tensor<8x100x4096xf32>
%2870 = stablehlo.power %2869, %cst_3 : tensor<8x100x4096xf32>
%2871 = stablehlo.reduce(%2870 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2872 = stablehlo.multiply %2871, %cst_2 : tensor<8x100xf32>
%2873 = stablehlo.reshape %2872 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2874 = stablehlo.add %2873, %cst_1 : tensor<8x100x1xf32>
%2875 = stablehlo.rsqrt %2874 : tensor<8x100x1xf32>
%2876 = stablehlo.reshape %2875 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2877 = stablehlo.broadcast_in_dim %2876, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2878 = stablehlo.multiply %2869, %2877 : tensor<8x100x4096xf32>
%2879 = stablehlo.broadcast_in_dim %arg25, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2880 = stablehlo.multiply %2878, %2879 : tensor<8x100x4096xf32>
%2881 = stablehlo.reshape %2880 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2882 = stablehlo.transpose %arg342, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2883 = stablehlo.dot_general %2881, %2882, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2884 = stablehlo.reshape %2883 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2885 = stablehlo.transpose %2884, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2886 = stablehlo.reshape %2885 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2887 = stablehlo.slice %2886 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2888 = stablehlo.reshape %2887 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2889 = stablehlo.slice %2886 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2890 = stablehlo.reshape %2889 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2891 = stablehlo.complex %2888, %2890 : tensor<256x100x64xcomplex<f32>>
%2892 = stablehlo.multiply %2891, %28 : tensor<256x100x64xcomplex<f32>>
%2893 = stablehlo.real %2892 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2894 = stablehlo.reshape %2893 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2895 = stablehlo.imag %2892 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2896 = stablehlo.reshape %2895 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2897 = stablehlo.concatenate %2894, %2896, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2898 = stablehlo.reshape %2897 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%2899 = stablehlo.transpose %arg340, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2900 = stablehlo.dot_general %2881, %2899, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2901 = stablehlo.reshape %2900 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2902 = stablehlo.transpose %2901, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2903 = stablehlo.reshape %2902 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2904 = stablehlo.slice %2903 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2905 = stablehlo.reshape %2904 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2906 = stablehlo.slice %2903 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2907 = stablehlo.reshape %2906 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2908 = stablehlo.complex %2905, %2907 : tensor<256x100x64xcomplex<f32>>
%2909 = stablehlo.multiply %2908, %28 : tensor<256x100x64xcomplex<f32>>
%2910 = stablehlo.real %2909 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2911 = stablehlo.reshape %2910 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2912 = stablehlo.imag %2909 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2913 = stablehlo.reshape %2912 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2914 = stablehlo.concatenate %2911, %2913, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%2915 = stablehlo.reshape %2914 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%2916 = stablehlo.transpose %2915, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2917 = "stablehlo.scatter"(%arg341, %39, %2916) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2918 = stablehlo.transpose %2917, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%2919 = stablehlo.reshape %2918 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%2920 = stablehlo.dot_general %2898, %2919, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%2921 = stablehlo.reshape %2920 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%2922 = stablehlo.divide %2921, %cst : tensor<8x32x100x1024xf32>
%2923 = stablehlo.add %2922, %66 : tensor<8x32x100x1024xf32>
%2924 = stablehlo.reduce(%2923 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2925 = stablehlo.broadcast_in_dim %2924, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2926 = stablehlo.subtract %2923, %2925 : tensor<8x32x100x1024xf32>
%2927 = stablehlo.exponential %2926 : tensor<8x32x100x1024xf32>
%2928 = stablehlo.reduce(%2927 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%2929 = stablehlo.broadcast_in_dim %2928, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%2930 = stablehlo.divide %2927, %2929 : tensor<8x32x100x1024xf32>
%2931 = stablehlo.reshape %2930 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%2932 = stablehlo.transpose %arg24, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2933 = stablehlo.dot_general %2881, %2932, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2934 = stablehlo.reshape %2933 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2935 = "stablehlo.scatter"(%arg339, %39, %2934) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%2936 = stablehlo.transpose %2935, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%2937 = stablehlo.reshape %2936 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%2938 = stablehlo.dot_general %2931, %2937, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%2939 = stablehlo.reshape %2938 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%2940 = stablehlo.transpose %2939, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%2941 = stablehlo.reshape %2940 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%2942 = stablehlo.transpose %arg23, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2943 = stablehlo.dot_general %2941, %2942, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2944 = stablehlo.reshape %2943 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2945 = stablehlo.add %2869, %2944 : tensor<8x100x4096xf32>
%2946 = stablehlo.power %2945, %cst_3 : tensor<8x100x4096xf32>
%2947 = stablehlo.reduce(%2946 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2948 = stablehlo.multiply %2947, %cst_2 : tensor<8x100xf32>
%2949 = stablehlo.reshape %2948 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2950 = stablehlo.add %2949, %cst_1 : tensor<8x100x1xf32>
%2951 = stablehlo.rsqrt %2950 : tensor<8x100x1xf32>
%2952 = stablehlo.reshape %2951 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2953 = stablehlo.broadcast_in_dim %2952, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2954 = stablehlo.multiply %2945, %2953 : tensor<8x100x4096xf32>
%2955 = stablehlo.broadcast_in_dim %arg22, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2956 = stablehlo.multiply %2954, %2955 : tensor<8x100x4096xf32>
%2957 = stablehlo.reshape %2956 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2958 = stablehlo.transpose %arg343, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2959 = stablehlo.dot_general %2957, %2958, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2960 = stablehlo.reshape %2959 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2961 = stablehlo.logistic %2960 : tensor<8x100x11008xf32>
%2962 = stablehlo.multiply %2960, %2961 : tensor<8x100x11008xf32>
%2963 = stablehlo.transpose %arg21, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%2964 = stablehlo.dot_general %2957, %2963, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%2965 = stablehlo.reshape %2964 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%2966 = stablehlo.multiply %2962, %2965 : tensor<8x100x11008xf32>
%2967 = stablehlo.reshape %2966 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%2968 = stablehlo.transpose %arg20, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%2969 = stablehlo.dot_general %2967, %2968, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%2970 = stablehlo.reshape %2969 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%2971 = stablehlo.add %2945, %2970 : tensor<8x100x4096xf32>
%2972 = stablehlo.power %2971, %cst_3 : tensor<8x100x4096xf32>
%2973 = stablehlo.reduce(%2972 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%2974 = stablehlo.multiply %2973, %cst_2 : tensor<8x100xf32>
%2975 = stablehlo.reshape %2974 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%2976 = stablehlo.add %2975, %cst_1 : tensor<8x100x1xf32>
%2977 = stablehlo.rsqrt %2976 : tensor<8x100x1xf32>
%2978 = stablehlo.reshape %2977 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%2979 = stablehlo.broadcast_in_dim %2978, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%2980 = stablehlo.multiply %2971, %2979 : tensor<8x100x4096xf32>
%2981 = stablehlo.broadcast_in_dim %arg19, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%2982 = stablehlo.multiply %2980, %2981 : tensor<8x100x4096xf32>
%2983 = stablehlo.reshape %2982 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%2984 = stablehlo.transpose %arg347, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%2985 = stablehlo.dot_general %2983, %2984, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%2986 = stablehlo.reshape %2985 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%2987 = stablehlo.transpose %2986, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%2988 = stablehlo.reshape %2987 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%2989 = stablehlo.slice %2988 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2990 = stablehlo.reshape %2989 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2991 = stablehlo.slice %2988 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%2992 = stablehlo.reshape %2991 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%2993 = stablehlo.complex %2990, %2992 : tensor<256x100x64xcomplex<f32>>
%2994 = stablehlo.multiply %2993, %28 : tensor<256x100x64xcomplex<f32>>
%2995 = stablehlo.real %2994 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2996 = stablehlo.reshape %2995 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2997 = stablehlo.imag %2994 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%2998 = stablehlo.reshape %2997 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%2999 = stablehlo.concatenate %2996, %2998, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3000 = stablehlo.reshape %2999 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%3001 = stablehlo.transpose %arg345, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3002 = stablehlo.dot_general %2983, %3001, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3003 = stablehlo.reshape %3002 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3004 = stablehlo.transpose %3003, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%3005 = stablehlo.reshape %3004 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%3006 = stablehlo.slice %3005 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3007 = stablehlo.reshape %3006 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3008 = stablehlo.slice %3005 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3009 = stablehlo.reshape %3008 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3010 = stablehlo.complex %3007, %3009 : tensor<256x100x64xcomplex<f32>>
%3011 = stablehlo.multiply %3010, %28 : tensor<256x100x64xcomplex<f32>>
%3012 = stablehlo.real %3011 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3013 = stablehlo.reshape %3012 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3014 = stablehlo.imag %3011 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3015 = stablehlo.reshape %3014 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3016 = stablehlo.concatenate %3013, %3015, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3017 = stablehlo.reshape %3016 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%3018 = stablehlo.transpose %3017, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3019 = "stablehlo.scatter"(%arg346, %39, %3018) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3020 = stablehlo.transpose %3019, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%3021 = stablehlo.reshape %3020 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%3022 = stablehlo.dot_general %3000, %3021, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%3023 = stablehlo.reshape %3022 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%3024 = stablehlo.divide %3023, %cst : tensor<8x32x100x1024xf32>
%3025 = stablehlo.add %3024, %66 : tensor<8x32x100x1024xf32>
%3026 = stablehlo.reduce(%3025 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3027 = stablehlo.broadcast_in_dim %3026, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3028 = stablehlo.subtract %3025, %3027 : tensor<8x32x100x1024xf32>
%3029 = stablehlo.exponential %3028 : tensor<8x32x100x1024xf32>
%3030 = stablehlo.reduce(%3029 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3031 = stablehlo.broadcast_in_dim %3030, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3032 = stablehlo.divide %3029, %3031 : tensor<8x32x100x1024xf32>
%3033 = stablehlo.reshape %3032 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%3034 = stablehlo.transpose %arg18, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3035 = stablehlo.dot_general %2983, %3034, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3036 = stablehlo.reshape %3035 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3037 = "stablehlo.scatter"(%arg344, %39, %3036) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3038 = stablehlo.transpose %3037, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%3039 = stablehlo.reshape %3038 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%3040 = stablehlo.dot_general %3033, %3039, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%3041 = stablehlo.reshape %3040 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%3042 = stablehlo.transpose %3041, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3043 = stablehlo.reshape %3042 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%3044 = stablehlo.transpose %arg17, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3045 = stablehlo.dot_general %3043, %3044, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3046 = stablehlo.reshape %3045 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3047 = stablehlo.add %2971, %3046 : tensor<8x100x4096xf32>
%3048 = stablehlo.power %3047, %cst_3 : tensor<8x100x4096xf32>
%3049 = stablehlo.reduce(%3048 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3050 = stablehlo.multiply %3049, %cst_2 : tensor<8x100xf32>
%3051 = stablehlo.reshape %3050 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3052 = stablehlo.add %3051, %cst_1 : tensor<8x100x1xf32>
%3053 = stablehlo.rsqrt %3052 : tensor<8x100x1xf32>
%3054 = stablehlo.reshape %3053 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3055 = stablehlo.broadcast_in_dim %3054, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3056 = stablehlo.multiply %3047, %3055 : tensor<8x100x4096xf32>
%3057 = stablehlo.broadcast_in_dim %arg16, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3058 = stablehlo.multiply %3056, %3057 : tensor<8x100x4096xf32>
%3059 = stablehlo.reshape %3058 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3060 = stablehlo.transpose %arg348, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3061 = stablehlo.dot_general %3059, %3060, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3062 = stablehlo.reshape %3061 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3063 = stablehlo.logistic %3062 : tensor<8x100x11008xf32>
%3064 = stablehlo.multiply %3062, %3063 : tensor<8x100x11008xf32>
%3065 = stablehlo.transpose %arg15, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3066 = stablehlo.dot_general %3059, %3065, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3067 = stablehlo.reshape %3066 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3068 = stablehlo.multiply %3064, %3067 : tensor<8x100x11008xf32>
%3069 = stablehlo.reshape %3068 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%3070 = stablehlo.transpose %arg14, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%3071 = stablehlo.dot_general %3069, %3070, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%3072 = stablehlo.reshape %3071 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3073 = stablehlo.add %3047, %3072 : tensor<8x100x4096xf32>
%3074 = stablehlo.power %3073, %cst_3 : tensor<8x100x4096xf32>
%3075 = stablehlo.reduce(%3074 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3076 = stablehlo.multiply %3075, %cst_2 : tensor<8x100xf32>
%3077 = stablehlo.reshape %3076 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3078 = stablehlo.add %3077, %cst_1 : tensor<8x100x1xf32>
%3079 = stablehlo.rsqrt %3078 : tensor<8x100x1xf32>
%3080 = stablehlo.reshape %3079 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3081 = stablehlo.broadcast_in_dim %3080, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3082 = stablehlo.multiply %3073, %3081 : tensor<8x100x4096xf32>
%3083 = stablehlo.broadcast_in_dim %arg13, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3084 = stablehlo.multiply %3082, %3083 : tensor<8x100x4096xf32>
%3085 = stablehlo.reshape %3084 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3086 = stablehlo.transpose %arg352, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3087 = stablehlo.dot_general %3085, %3086, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3088 = stablehlo.reshape %3087 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3089 = stablehlo.transpose %3088, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%3090 = stablehlo.reshape %3089 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%3091 = stablehlo.slice %3090 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3092 = stablehlo.reshape %3091 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3093 = stablehlo.slice %3090 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3094 = stablehlo.reshape %3093 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3095 = stablehlo.complex %3092, %3094 : tensor<256x100x64xcomplex<f32>>
%3096 = stablehlo.multiply %3095, %28 : tensor<256x100x64xcomplex<f32>>
%3097 = stablehlo.real %3096 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3098 = stablehlo.reshape %3097 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3099 = stablehlo.imag %3096 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3100 = stablehlo.reshape %3099 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3101 = stablehlo.concatenate %3098, %3100, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3102 = stablehlo.reshape %3101 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%3103 = stablehlo.transpose %arg350, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3104 = stablehlo.dot_general %3085, %3103, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3105 = stablehlo.reshape %3104 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3106 = stablehlo.transpose %3105, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%3107 = stablehlo.reshape %3106 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%3108 = stablehlo.slice %3107 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3109 = stablehlo.reshape %3108 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3110 = stablehlo.slice %3107 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3111 = stablehlo.reshape %3110 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3112 = stablehlo.complex %3109, %3111 : tensor<256x100x64xcomplex<f32>>
%3113 = stablehlo.multiply %3112, %28 : tensor<256x100x64xcomplex<f32>>
%3114 = stablehlo.real %3113 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3115 = stablehlo.reshape %3114 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3116 = stablehlo.imag %3113 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3117 = stablehlo.reshape %3116 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3118 = stablehlo.concatenate %3115, %3117, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3119 = stablehlo.reshape %3118 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%3120 = stablehlo.transpose %3119, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3121 = "stablehlo.scatter"(%arg351, %39, %3120) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3122 = stablehlo.transpose %3121, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%3123 = stablehlo.reshape %3122 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%3124 = stablehlo.dot_general %3102, %3123, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%3125 = stablehlo.reshape %3124 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%3126 = stablehlo.divide %3125, %cst : tensor<8x32x100x1024xf32>
%3127 = stablehlo.add %3126, %66 : tensor<8x32x100x1024xf32>
%3128 = stablehlo.reduce(%3127 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3129 = stablehlo.broadcast_in_dim %3128, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3130 = stablehlo.subtract %3127, %3129 : tensor<8x32x100x1024xf32>
%3131 = stablehlo.exponential %3130 : tensor<8x32x100x1024xf32>
%3132 = stablehlo.reduce(%3131 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3133 = stablehlo.broadcast_in_dim %3132, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3134 = stablehlo.divide %3131, %3133 : tensor<8x32x100x1024xf32>
%3135 = stablehlo.reshape %3134 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%3136 = stablehlo.transpose %arg12, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3137 = stablehlo.dot_general %3085, %3136, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3138 = stablehlo.reshape %3137 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3139 = "stablehlo.scatter"(%arg349, %39, %3138) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3140 = stablehlo.transpose %3139, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%3141 = stablehlo.reshape %3140 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%3142 = stablehlo.dot_general %3135, %3141, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%3143 = stablehlo.reshape %3142 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%3144 = stablehlo.transpose %3143, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3145 = stablehlo.reshape %3144 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%3146 = stablehlo.transpose %arg11, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3147 = stablehlo.dot_general %3145, %3146, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3148 = stablehlo.reshape %3147 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3149 = stablehlo.add %3073, %3148 : tensor<8x100x4096xf32>
%3150 = stablehlo.power %3149, %cst_3 : tensor<8x100x4096xf32>
%3151 = stablehlo.reduce(%3150 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3152 = stablehlo.multiply %3151, %cst_2 : tensor<8x100xf32>
%3153 = stablehlo.reshape %3152 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3154 = stablehlo.add %3153, %cst_1 : tensor<8x100x1xf32>
%3155 = stablehlo.rsqrt %3154 : tensor<8x100x1xf32>
%3156 = stablehlo.reshape %3155 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3157 = stablehlo.broadcast_in_dim %3156, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3158 = stablehlo.multiply %3149, %3157 : tensor<8x100x4096xf32>
%3159 = stablehlo.broadcast_in_dim %arg10, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3160 = stablehlo.multiply %3158, %3159 : tensor<8x100x4096xf32>
%3161 = stablehlo.reshape %3160 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3162 = stablehlo.transpose %arg353, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3163 = stablehlo.dot_general %3161, %3162, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3164 = stablehlo.reshape %3163 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3165 = stablehlo.logistic %3164 : tensor<8x100x11008xf32>
%3166 = stablehlo.multiply %3164, %3165 : tensor<8x100x11008xf32>
%3167 = stablehlo.transpose %arg9, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3168 = stablehlo.dot_general %3161, %3167, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3169 = stablehlo.reshape %3168 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3170 = stablehlo.multiply %3166, %3169 : tensor<8x100x11008xf32>
%3171 = stablehlo.reshape %3170 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%3172 = stablehlo.transpose %arg8, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%3173 = stablehlo.dot_general %3171, %3172, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%3174 = stablehlo.reshape %3173 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3175 = stablehlo.add %3149, %3174 : tensor<8x100x4096xf32>
%3176 = stablehlo.power %3175, %cst_3 : tensor<8x100x4096xf32>
%3177 = stablehlo.reduce(%3176 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3178 = stablehlo.multiply %3177, %cst_2 : tensor<8x100xf32>
%3179 = stablehlo.reshape %3178 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3180 = stablehlo.add %3179, %cst_1 : tensor<8x100x1xf32>
%3181 = stablehlo.rsqrt %3180 : tensor<8x100x1xf32>
%3182 = stablehlo.reshape %3181 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3183 = stablehlo.broadcast_in_dim %3182, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3184 = stablehlo.multiply %3175, %3183 : tensor<8x100x4096xf32>
%3185 = stablehlo.broadcast_in_dim %arg7, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3186 = stablehlo.multiply %3184, %3185 : tensor<8x100x4096xf32>
%3187 = stablehlo.reshape %3186 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3188 = stablehlo.transpose %arg357, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3189 = stablehlo.dot_general %3187, %3188, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3190 = stablehlo.reshape %3189 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3191 = stablehlo.transpose %3190, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%3192 = stablehlo.reshape %3191 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%3193 = stablehlo.slice %3192 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3194 = stablehlo.reshape %3193 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3195 = stablehlo.slice %3192 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3196 = stablehlo.reshape %3195 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3197 = stablehlo.complex %3194, %3196 : tensor<256x100x64xcomplex<f32>>
%3198 = stablehlo.multiply %3197, %28 : tensor<256x100x64xcomplex<f32>>
%3199 = stablehlo.real %3198 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3200 = stablehlo.reshape %3199 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3201 = stablehlo.imag %3198 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3202 = stablehlo.reshape %3201 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3203 = stablehlo.concatenate %3200, %3202, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3204 = stablehlo.reshape %3203 : (tensor<256x100x64x2xf32>) -> tensor<256x100x128xf32>
%3205 = stablehlo.transpose %arg355, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3206 = stablehlo.dot_general %3187, %3205, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3207 = stablehlo.reshape %3206 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3208 = stablehlo.transpose %3207, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,100,128]{3,1,2,0}"} : (tensor<8x100x32x128xf32>) -> tensor<8x32x100x128xf32>
%3209 = stablehlo.reshape %3208 : (tensor<8x32x100x128xf32>) -> tensor<256x100x64x2xf32>
%3210 = stablehlo.slice %3209 [0:256, 0:100, 0:64, 0:1] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3211 = stablehlo.reshape %3210 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3212 = stablehlo.slice %3209 [0:256, 0:100, 0:64, 1:2] : (tensor<256x100x64x2xf32>) -> tensor<256x100x64x1xf32>
%3213 = stablehlo.reshape %3212 : (tensor<256x100x64x1xf32>) -> tensor<256x100x64xf32>
%3214 = stablehlo.complex %3211, %3213 : tensor<256x100x64xcomplex<f32>>
%3215 = stablehlo.multiply %3214, %28 : tensor<256x100x64xcomplex<f32>>
%3216 = stablehlo.real %3215 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3217 = stablehlo.reshape %3216 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3218 = stablehlo.imag %3215 : (tensor<256x100x64xcomplex<f32>>) -> tensor<256x100x64xf32>
%3219 = stablehlo.reshape %3218 : (tensor<256x100x64xf32>) -> tensor<256x100x64x1xf32>
%3220 = stablehlo.concatenate %3217, %3219, dim = 3 : (tensor<256x100x64x1xf32>, tensor<256x100x64x1xf32>) -> tensor<256x100x64x2xf32>
%3221 = stablehlo.reshape %3220 : (tensor<256x100x64x2xf32>) -> tensor<8x32x100x128xf32>
%3222 = stablehlo.transpose %3221, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3223 = "stablehlo.scatter"(%arg356, %39, %3222) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3224 = stablehlo.transpose %3223, dims = [0, 2, 3, 1] : (tensor<8x1024x32x128xf32>) -> tensor<8x32x128x1024xf32>
%3225 = stablehlo.reshape %3224 : (tensor<8x32x128x1024xf32>) -> tensor<256x128x1024xf32>
%3226 = stablehlo.dot_general %3204, %3225, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x128xf32>, tensor<256x128x1024xf32>) -> tensor<256x100x1024xf32>
%3227 = stablehlo.reshape %3226 : (tensor<256x100x1024xf32>) -> tensor<8x32x100x1024xf32>
%3228 = stablehlo.divide %3227, %cst : tensor<8x32x100x1024xf32>
%3229 = stablehlo.add %3228, %66 : tensor<8x32x100x1024xf32>
%3230 = stablehlo.reduce(%3229 init: %cst_4) applies stablehlo.maximum across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3231 = stablehlo.broadcast_in_dim %3230, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3232 = stablehlo.subtract %3229, %3231 : tensor<8x32x100x1024xf32>
%3233 = stablehlo.exponential %3232 : tensor<8x32x100x1024xf32>
%3234 = stablehlo.reduce(%3233 init: %cst_5) applies stablehlo.add across dimensions = [3] : (tensor<8x32x100x1024xf32>, tensor<f32>) -> tensor<8x32x100xf32>
%3235 = stablehlo.broadcast_in_dim %3234, dims = [0, 1, 2] : (tensor<8x32x100xf32>) -> tensor<8x32x100x1024xf32>
%3236 = stablehlo.divide %3233, %3235 : tensor<8x32x100x1024xf32>
%3237 = stablehlo.reshape %3236 : (tensor<8x32x100x1024xf32>) -> tensor<256x100x1024xf32>
%3238 = stablehlo.transpose %arg6, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3239 = stablehlo.dot_general %3187, %3238, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3240 = stablehlo.reshape %3239 : (tensor<800x4096xf32>) -> tensor<8x100x32x128xf32>
%3241 = "stablehlo.scatter"(%arg354, %39, %3240) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0, 2, 3], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1], index_vector_dim = 1>, unique_indices = false}> ({
^bb0(%arg359: tensor<f32>, %arg360: tensor<f32>):
stablehlo.return %arg360 : tensor<f32>
}) : (tensor<8x1024x32x128xf32>, tensor<100x1xi64>, tensor<8x100x32x128xf32>) -> tensor<8x1024x32x128xf32>
%3242 = stablehlo.transpose %3241, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,32,1024,128]{3,1,2,0}"} : (tensor<8x1024x32x128xf32>) -> tensor<8x32x1024x128xf32>
%3243 = stablehlo.reshape %3242 : (tensor<8x32x1024x128xf32>) -> tensor<256x1024x128xf32>
%3244 = stablehlo.dot_general %3237, %3243, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<256x100x1024xf32>, tensor<256x1024x128xf32>) -> tensor<256x100x128xf32>
%3245 = stablehlo.reshape %3244 : (tensor<256x100x128xf32>) -> tensor<8x32x100x128xf32>
%3246 = stablehlo.transpose %3245, dims = [0, 2, 1, 3] {result_layout = dense<[3, 1, 2, 0]> : tensor<4xindex>, xla_shape = "f32[8,100,32,128]{3,1,2,0}"} : (tensor<8x32x100x128xf32>) -> tensor<8x100x32x128xf32>
%3247 = stablehlo.reshape %3246 : (tensor<8x100x32x128xf32>) -> tensor<800x4096xf32>
%3248 = stablehlo.transpose %arg5, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,4096]{0,1}"} : (tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%3249 = stablehlo.dot_general %3247, %3248, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x4096xf32>) -> tensor<800x4096xf32>
%3250 = stablehlo.reshape %3249 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3251 = stablehlo.add %3175, %3250 : tensor<8x100x4096xf32>
%3252 = stablehlo.power %3251, %cst_3 : tensor<8x100x4096xf32>
%3253 = stablehlo.reduce(%3252 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3254 = stablehlo.multiply %3253, %cst_2 : tensor<8x100xf32>
%3255 = stablehlo.reshape %3254 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3256 = stablehlo.add %3255, %cst_1 : tensor<8x100x1xf32>
%3257 = stablehlo.rsqrt %3256 : tensor<8x100x1xf32>
%3258 = stablehlo.reshape %3257 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3259 = stablehlo.broadcast_in_dim %3258, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3260 = stablehlo.multiply %3251, %3259 : tensor<8x100x4096xf32>
%3261 = stablehlo.broadcast_in_dim %arg4, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3262 = stablehlo.multiply %3260, %3261 : tensor<8x100x4096xf32>
%3263 = stablehlo.reshape %3262 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3264 = stablehlo.transpose %arg358, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3265 = stablehlo.dot_general %3263, %3264, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3266 = stablehlo.reshape %3265 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3267 = stablehlo.logistic %3266 : tensor<8x100x11008xf32>
%3268 = stablehlo.multiply %3266, %3267 : tensor<8x100x11008xf32>
%3269 = stablehlo.transpose %arg3, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,11008]{0,1}"} : (tensor<11008x4096xf32>) -> tensor<4096x11008xf32>
%3270 = stablehlo.dot_general %3263, %3269, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x11008xf32>) -> tensor<800x11008xf32>
%3271 = stablehlo.reshape %3270 : (tensor<800x11008xf32>) -> tensor<8x100x11008xf32>
%3272 = stablehlo.multiply %3268, %3271 : tensor<8x100x11008xf32>
%3273 = stablehlo.reshape %3272 : (tensor<8x100x11008xf32>) -> tensor<800x11008xf32>
%3274 = stablehlo.transpose %arg2, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[11008,4096]{0,1}"} : (tensor<4096x11008xf32>) -> tensor<11008x4096xf32>
%3275 = stablehlo.dot_general %3273, %3274, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x11008xf32>, tensor<11008x4096xf32>) -> tensor<800x4096xf32>
%3276 = stablehlo.reshape %3275 : (tensor<800x4096xf32>) -> tensor<8x100x4096xf32>
%3277 = stablehlo.add %3251, %3276 : tensor<8x100x4096xf32>
%3278 = stablehlo.power %3277, %cst_3 : tensor<8x100x4096xf32>
%3279 = stablehlo.reduce(%3278 init: %cst_5) applies stablehlo.add across dimensions = [2] : (tensor<8x100x4096xf32>, tensor<f32>) -> tensor<8x100xf32>
%3280 = stablehlo.multiply %3279, %cst_2 : tensor<8x100xf32>
%3281 = stablehlo.reshape %3280 : (tensor<8x100xf32>) -> tensor<8x100x1xf32>
%3282 = stablehlo.add %3281, %cst_1 : tensor<8x100x1xf32>
%3283 = stablehlo.rsqrt %3282 : tensor<8x100x1xf32>
%3284 = stablehlo.reshape %3283 : (tensor<8x100x1xf32>) -> tensor<8x100xf32>
%3285 = stablehlo.broadcast_in_dim %3284, dims = [0, 1] : (tensor<8x100xf32>) -> tensor<8x100x4096xf32>
%3286 = stablehlo.multiply %3277, %3285 : tensor<8x100x4096xf32>
%3287 = stablehlo.broadcast_in_dim %arg1, dims = [2] : (tensor<4096xf32>) -> tensor<8x100x4096xf32>
%3288 = stablehlo.multiply %3286, %3287 : tensor<8x100x4096xf32>
%3289 = stablehlo.reshape %3288 : (tensor<8x100x4096xf32>) -> tensor<800x4096xf32>
%3290 = stablehlo.transpose %arg0, dims = [1, 0] {result_layout = dense<[0, 1]> : tensor<2xindex>, xla_shape = "f32[4096,32000]{0,1}"} : (tensor<32000x4096xf32>) -> tensor<4096x32000xf32>
%3291 = stablehlo.dot_general %3289, %3290, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<800x4096xf32>, tensor<4096x32000xf32>) -> tensor<800x32000xf32>
%3292 = stablehlo.reshape %3291 : (tensor<800x32000xf32>) -> tensor<8x100x32000xf32>
return %3292 : tensor<8x100x32000xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment